1use std::cell::RefCell;
16use std::collections::VecDeque;
17
18use monarch_types::py_global;
19use pyo3::IntoPyObjectExt;
20use pyo3::prelude::*;
21use pyo3::types::PyList;
22use pyo3::types::PyTuple;
23use serde_multipart::Part;
24
25use crate::actor::PythonMessage;
26use crate::actor::PythonMessageKind;
27use crate::buffers::Buffer;
28use crate::pytokio::PyShared;
29
30py_global!(unflatten, "monarch._src.actor.pickle", "unflatten");
33
34py_global!(flatten, "monarch._src.actor.pickle", "flatten");
42
43py_global!(cloudpickle, "cloudpickle", "cloudpickle");
45
46py_global!(_unpickle, "pickle", "loads");
47
48py_global!(
54 pickle_monkeypatch,
55 "monarch._src.actor.pickle",
56 "_function_getstate"
57);
58
59py_global!(maybe_torch_fn, "monarch._src.actor.pickle", "maybe_torch");
62
63py_global!(torch_dump_fn, "monarch._src.actor.pickle", "torch_dump");
66
67py_global!(torch_loads_fn, "monarch._src.actor.pickle", "torch_loads");
70
71py_global!(
73 shared_class,
74 "monarch._rust_bindings.monarch_hyperactor.pytokio",
75 "Shared"
76);
77
78py_global!(
80 pop_pending_pickle_fn,
81 "monarch._rust_bindings.monarch_hyperactor.pickle",
82 "pop_pending_pickle"
83);
84
85thread_local! {
89 static ACTIVE_PICKLING_STATE: RefCell<Option<ActivePicklingState>> = const { RefCell::new(None) };
90}
91
92struct ActivePicklingGuard {
97 previous: Option<ActivePicklingState>,
98}
99
100impl ActivePicklingGuard {
101 fn enter(state: ActivePicklingState) -> Self {
103 let previous = ACTIVE_PICKLING_STATE.with(|cell| cell.borrow_mut().replace(state));
104 Self { previous }
105 }
106}
107
108impl Drop for ActivePicklingGuard {
109 fn drop(&mut self) {
110 ACTIVE_PICKLING_STATE.with(|cell| {
111 *cell.borrow_mut() = self.previous.take();
112 });
113 }
114}
115
116struct ActivePicklingState {
121 tensor_engine_references: VecDeque<Py<PyAny>>,
123 pending_pickles: VecDeque<Py<PyShared>>,
125 allow_pending_pickles: bool,
127 allow_tensor_engine_references: bool,
129}
130
131impl ActivePicklingState {
132 fn new(allow_pending_pickles: bool, allow_tensor_engine_references: bool) -> Self {
134 Self {
135 tensor_engine_references: VecDeque::new(),
136 pending_pickles: VecDeque::new(),
137 allow_pending_pickles,
138 allow_tensor_engine_references,
139 }
140 }
141
142 fn into_pickling_state(self, buffer: Part) -> PicklingStateInner {
144 PicklingStateInner {
145 buffer,
146 tensor_engine_references: self.tensor_engine_references,
147 pending_pickles: self.pending_pickles,
148 }
149 }
150}
151
152pub struct PicklingStateInner {
157 buffer: Part,
159 tensor_engine_references: VecDeque<Py<PyAny>>,
161 pending_pickles: VecDeque<Py<PyShared>>,
163}
164
165impl PicklingStateInner {
166 pub fn pending_pickles(&self) -> &VecDeque<Py<PyShared>> {
168 &self.pending_pickles
169 }
170
171 pub fn take_buffer(self) -> Part {
173 self.buffer
174 }
175}
176
177#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pickle")]
182pub struct PicklingState {
183 inner: Option<PicklingStateInner>,
184}
185
186impl PicklingState {
187 pub fn take_inner(&mut self) -> PyResult<PicklingStateInner> {
188 self.inner.take().ok_or_else(|| {
189 pyo3::exceptions::PyRuntimeError::new_err("PicklingState has already been consumed")
190 })
191 }
192
193 fn inner_ref(&self) -> PyResult<&PicklingStateInner> {
194 self.inner.as_ref().ok_or_else(|| {
195 pyo3::exceptions::PyRuntimeError::new_err("PicklingState has already been consumed")
196 })
197 }
198}
199
200#[pymethods]
201impl PicklingState {
202 #[new]
207 #[pyo3(signature = (buffer, tensor_engine_references=None))]
208 fn py_new(
209 buffer: PyRef<'_, crate::buffers::FrozenBuffer>,
210 tensor_engine_references: Option<&Bound<'_, PyList>>,
211 ) -> PyResult<Self> {
212 let refs: VecDeque<Py<PyAny>> = tensor_engine_references
213 .map(|list| list.iter().map(|item| item.unbind()).collect())
214 .unwrap_or_default();
215
216 Ok(Self {
217 inner: Some(PicklingStateInner {
218 buffer: Part::from(buffer.inner.clone()),
219 tensor_engine_references: refs,
220 pending_pickles: VecDeque::new(),
221 }),
222 })
223 }
224
225 fn tensor_engine_references(&self, py: Python<'_>) -> PyResult<Py<PyList>> {
229 let inner = self.inner_ref()?;
230 let refs: Vec<Py<PyAny>> = inner
231 .tensor_engine_references
232 .iter()
233 .map(|r| r.clone_ref(py))
234 .collect();
235 Ok(PyList::new(py, refs)?.unbind())
236 }
237
238 fn buffer(&self) -> PyResult<crate::buffers::FrozenBuffer> {
243 let inner = self.inner_ref()?;
244 Ok(crate::buffers::FrozenBuffer {
245 inner: inner.buffer.clone().into_bytes(),
246 })
247 }
248
249 fn unpickle(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
254 let inner = self.take_inner()?;
255
256 for pending in &inner.pending_pickles {
258 if pending.borrow(py).poll()?.is_none() {
259 return Err(pyo3::exceptions::PyRuntimeError::new_err(
260 "Cannot unpickle: there are unresolved pending pickles",
261 ));
262 }
263 }
264
265 let mut active = ActivePicklingState::new(false, false);
268 active.pending_pickles = inner.pending_pickles;
269 active.tensor_engine_references = inner.tensor_engine_references;
270
271 let _guard = ActivePicklingGuard::enter(active);
272
273 let frozen = crate::buffers::FrozenBuffer {
274 inner: inner.buffer.into_bytes(),
275 };
276
277 let result = if maybe_torch_fn(py).call0()?.is_truthy()? {
280 torch_loads_fn(py).call1((frozen,))
281 } else {
282 cloudpickle(py).getattr("loads")?.call1((frozen,))
283 };
284
285 result.map(|obj| obj.unbind())
286 }
287}
288
289impl PicklingState {
290 pub async fn resolve(mut self) -> PyResult<PicklingState> {
298 if self.inner_ref()?.pending_pickles.is_empty() {
300 return Ok(self);
301 }
302
303 let pending: Vec<Py<PyShared>> = Python::attach(|py| {
305 self.inner_ref().map(|inner| {
306 inner
307 .pending_pickles
308 .iter()
309 .map(|p| p.clone_ref(py))
310 .collect()
311 })
312 })?;
313
314 for pending_pickle in pending {
315 let mut task = Python::attach(|py| pending_pickle.borrow(py).task())?;
316 task.take_task()?.await?;
317 }
318
319 Python::attach(|py| {
321 let obj = self.unpickle(py)?;
322 pickle(py, obj, false, true)
323 })
324 }
325}
326
327#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pickle")]
333pub struct PendingMessage {
334 pub(crate) kind: PythonMessageKind,
335 state: PicklingState,
336}
337
338impl PendingMessage {
339 pub fn new(kind: PythonMessageKind, state: PicklingState) -> Self {
341 Self { kind, state }
342 }
343
344 pub fn take(&mut self) -> PyResult<PendingMessage> {
349 let inner = self.state.take_inner()?;
350 Ok(PendingMessage {
351 kind: std::mem::take(&mut self.kind),
352 state: PicklingState { inner: Some(inner) },
353 })
354 }
355
356 pub async fn resolve(self) -> PyResult<PythonMessage> {
363 let mut resolved_state = self.state.resolve().await?;
365
366 let inner = resolved_state.take_inner()?;
368 Ok(PythonMessage::new_from_buf(self.kind, inner.take_buffer()))
369 }
370}
371
372#[pymethods]
373impl PendingMessage {
374 #[new]
376 pub fn py_new(
377 kind: PythonMessageKind,
378 mut state: PyRefMut<'_, PicklingState>,
379 ) -> PyResult<Self> {
380 let inner = state.take_inner()?;
382 Ok(Self {
383 kind,
384 state: PicklingState { inner: Some(inner) },
385 })
386 }
387
388 #[getter]
390 fn kind(&self) -> PythonMessageKind {
391 self.kind.clone()
392 }
393}
394
395#[pyfunction]
404fn push_tensor_engine_reference_if_active(obj: Py<PyAny>) -> PyResult<bool> {
405 ACTIVE_PICKLING_STATE.with(|cell| {
406 let mut state = cell.borrow_mut();
407 match state.as_mut() {
408 Some(s) => {
409 if !s.allow_tensor_engine_references {
410 return Err(pyo3::exceptions::PyRuntimeError::new_err(
411 "Tensor engine references are not allowed in the current pickling context",
412 ));
413 }
414 s.tensor_engine_references.push_back(obj);
415 Ok(true)
416 }
417 None => Ok(false),
418 }
419 })
420}
421
422#[pyfunction]
427fn pop_tensor_engine_reference(py: Python<'_>) -> PyResult<Py<PyAny>> {
428 ACTIVE_PICKLING_STATE
429 .with(|cell| {
430 let mut state = cell.borrow_mut();
431 match state.as_mut() {
432 Some(s) => s.tensor_engine_references.pop_front().ok_or_else(|| {
433 pyo3::exceptions::PyRuntimeError::new_err(
434 "No tensor engine references remaining",
435 )
436 }),
437 None => Err(pyo3::exceptions::PyRuntimeError::new_err(
438 "No active pickling state",
439 )),
440 }
441 })
442 .map(|obj| obj.clone_ref(py))
443}
444
445#[pyfunction]
450fn pop_pending_pickle(py: Python<'_>) -> PyResult<Py<PyShared>> {
451 ACTIVE_PICKLING_STATE.with(|cell| {
452 let mut state = cell.borrow_mut();
453 match state.as_mut() {
454 Some(s) => {
455 let shared = s.pending_pickles.pop_front().ok_or_else(|| {
456 pyo3::exceptions::PyRuntimeError::new_err("No pending pickles remaining")
457 })?;
458 Ok(shared.clone_ref(py))
459 }
460 None => Err(pyo3::exceptions::PyRuntimeError::new_err(
461 "No active pickling state",
462 )),
463 }
464 })
465}
466
467pub fn push_pending_pickle(py_shared: Py<PyShared>) -> PyResult<()> {
475 ACTIVE_PICKLING_STATE.with(|cell| {
476 let mut state = cell.borrow_mut();
477 match state.as_mut() {
478 Some(s) => {
479 if !s.allow_pending_pickles {
480 return Err(pyo3::exceptions::PyRuntimeError::new_err(
481 "Pending pickles are not allowed in the current pickling context",
482 ));
483 }
484 s.pending_pickles.push_back(py_shared);
485 Ok(())
486 }
487 None => Err(pyo3::exceptions::PyRuntimeError::new_err(
488 "No active pickling state",
489 )),
490 }
491 })
492}
493
494pub fn reduce_shared<'py>(
501 py: Python<'py>,
502 py_shared: &Bound<'py, PyShared>,
503) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyTuple>)> {
504 if let Some(value) = py_shared.borrow().poll()? {
506 let from_value = shared_class(py).getattr("from_value")?;
507 let args = PyTuple::new(py, [value])?;
508 return Ok((from_value, args));
509 }
510
511 let py_shared_py: Py<PyShared> = py_shared.clone().unbind();
513 if push_pending_pickle(py_shared_py).is_ok() {
514 let pop_fn = pop_pending_pickle_fn(py);
515 let args = PyTuple::empty(py);
516 return Ok((pop_fn, args));
517 }
518
519 let value = PyShared::block_on(py_shared.borrow(), py)?;
521 let from_value = shared_class(py).getattr("from_value")?;
522 let args = PyTuple::new(py, [value])?;
523 Ok((from_value, args))
524}
525
526fn pickle_into_buffer(py: Python<'_>, obj: &Py<PyAny>, buffer: &Py<Buffer>) -> PyResult<()> {
531 pickle_monkeypatch(py);
533
534 if maybe_torch_fn(py).call0()?.is_truthy()? {
537 torch_dump_fn(py).call1((obj, buffer.bind(py)))?;
538 } else {
539 let pickler = cloudpickle(py)
540 .getattr("Pickler")?
541 .call1((buffer.bind(py),))?;
542 pickler.call_method1("dump", (obj,))?;
543 }
544
545 Ok(())
546}
547
548pub fn pickle_to_part(py: Python<'_>, obj: &Py<PyAny>) -> PyResult<Part> {
554 let active = ActivePicklingState::new(false, false);
555 let buffer = Py::new(py, Buffer::default())?;
556 let _guard = ActivePicklingGuard::enter(active);
557
558 pickle_into_buffer(py, obj, &buffer)?;
559
560 Ok(buffer.borrow_mut(py).take_part())
561}
562
563#[pyfunction]
577#[pyo3(signature = (obj, allow_pending_pickles=true, allow_tensor_engine_references=true))]
578pub fn pickle(
579 py: Python<'_>,
580 obj: Py<PyAny>,
581 allow_pending_pickles: bool,
582 allow_tensor_engine_references: bool,
583) -> PyResult<PicklingState> {
584 let active = ActivePicklingState::new(allow_pending_pickles, allow_tensor_engine_references);
585 let buffer = Py::new(py, Buffer::default())?;
586 let _guard = ActivePicklingGuard::enter(active);
587
588 pickle_into_buffer(py, &obj, &buffer)?;
589
590 let active = ACTIVE_PICKLING_STATE
593 .with(|cell| cell.borrow_mut().take())
594 .expect("active pickling state should still be set");
595
596 let part = buffer.borrow_mut(py).take_part();
598 let inner = active.into_pickling_state(part);
599 Ok(PicklingState { inner: Some(inner) })
600}
601
602pub(crate) fn unpickle<'py>(
603 py: Python<'py>,
604 buffer: crate::buffers::FrozenBuffer,
605) -> PyResult<Bound<'py, PyAny>> {
606 _unpickle(py).call1((buffer.into_py_any(py)?,))
607}
608
609pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
611 module.add_class::<PicklingState>()?;
612 module.add_class::<PendingMessage>()?;
613 module.add_function(wrap_pyfunction!(pickle, module)?)?;
614 module.add_function(wrap_pyfunction!(
615 push_tensor_engine_reference_if_active,
616 module
617 )?)?;
618 module.add_function(wrap_pyfunction!(pop_tensor_engine_reference, module)?)?;
619 module.add_function(wrap_pyfunction!(pop_pending_pickle, module)?)?;
620 Ok(())
621}