1use std::cell::RefCell;
16use std::collections::VecDeque;
17
18use monarch_types::py_global;
19use pyo3::IntoPyObjectExt;
20use pyo3::prelude::*;
21use pyo3::types::PyList;
22use pyo3::types::PyTuple;
23use serde_multipart::Part;
24
25use crate::actor::PythonMessage;
26use crate::actor::PythonMessageKind;
27use crate::buffers::Buffer;
28use crate::pytokio::PyShared;
29
30py_global!(unflatten, "monarch._src.actor.pickle", "unflatten");
33
34py_global!(flatten, "monarch._src.actor.pickle", "flatten");
42
43py_global!(cloudpickle, "cloudpickle", "cloudpickle");
45
46py_global!(_unpickle, "pickle", "loads");
47
48py_global!(
54 pickle_monkeypatch,
55 "monarch._src.actor.pickle",
56 "_function_getstate"
57);
58
59py_global!(maybe_torch_fn, "monarch._src.actor.pickle", "maybe_torch");
62
63py_global!(torch_dump_fn, "monarch._src.actor.pickle", "torch_dump");
66
67py_global!(torch_loads_fn, "monarch._src.actor.pickle", "torch_loads");
70
71py_global!(
73 shared_class,
74 "monarch._rust_bindings.monarch_hyperactor.pytokio",
75 "Shared"
76);
77
78py_global!(
80 pop_pending_pickle_fn,
81 "monarch._rust_bindings.monarch_hyperactor.pickle",
82 "pop_pending_pickle"
83);
84
85thread_local! {
89 static ACTIVE_PICKLING_STATE: RefCell<Option<ActivePicklingState>> = const { RefCell::new(None) };
90}
91
92struct ActivePicklingGuard {
97 previous: Option<ActivePicklingState>,
98}
99
100impl ActivePicklingGuard {
101 fn enter(state: ActivePicklingState) -> Self {
103 let previous = ACTIVE_PICKLING_STATE.with(|cell| cell.borrow_mut().replace(state));
104 Self { previous }
105 }
106}
107
108impl Drop for ActivePicklingGuard {
109 fn drop(&mut self) {
110 ACTIVE_PICKLING_STATE.with(|cell| {
111 *cell.borrow_mut() = self.previous.take();
112 });
113 }
114}
115
116struct ActivePicklingState {
121 tensor_engine_references: VecDeque<Py<PyAny>>,
123 pending_pickles: VecDeque<Py<PyShared>>,
125 allow_pending_pickles: bool,
127 allow_tensor_engine_references: bool,
129}
130
131impl ActivePicklingState {
132 fn new(allow_pending_pickles: bool, allow_tensor_engine_references: bool) -> Self {
134 Self {
135 tensor_engine_references: VecDeque::new(),
136 pending_pickles: VecDeque::new(),
137 allow_pending_pickles,
138 allow_tensor_engine_references,
139 }
140 }
141
142 fn into_pickling_state(self, buffer: crate::buffers::FrozenBuffer) -> PicklingStateInner {
144 PicklingStateInner {
145 buffer,
146 tensor_engine_references: self.tensor_engine_references,
147 pending_pickles: self.pending_pickles,
148 }
149 }
150}
151
152pub struct PicklingStateInner {
157 buffer: crate::buffers::FrozenBuffer,
159 tensor_engine_references: VecDeque<Py<PyAny>>,
161 pending_pickles: VecDeque<Py<PyShared>>,
163}
164
165impl PicklingStateInner {
166 pub fn pending_pickles(&self) -> &VecDeque<Py<PyShared>> {
168 &self.pending_pickles
169 }
170
171 pub fn take_buffer(self) -> crate::buffers::FrozenBuffer {
173 self.buffer
174 }
175}
176
177#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pickle")]
182pub struct PicklingState {
183 inner: Option<PicklingStateInner>,
184}
185
186impl PicklingState {
187 pub fn take_inner(&mut self) -> PyResult<PicklingStateInner> {
188 self.inner.take().ok_or_else(|| {
189 pyo3::exceptions::PyRuntimeError::new_err("PicklingState has already been consumed")
190 })
191 }
192
193 fn inner_ref(&self) -> PyResult<&PicklingStateInner> {
194 self.inner.as_ref().ok_or_else(|| {
195 pyo3::exceptions::PyRuntimeError::new_err("PicklingState has already been consumed")
196 })
197 }
198}
199
200#[pymethods]
201impl PicklingState {
202 #[new]
207 #[pyo3(signature = (buffer, tensor_engine_references=None))]
208 fn py_new(
209 buffer: PyRef<'_, crate::buffers::FrozenBuffer>,
210 tensor_engine_references: Option<&Bound<'_, PyList>>,
211 ) -> PyResult<Self> {
212 let refs: VecDeque<Py<PyAny>> = tensor_engine_references
213 .map(|list| list.iter().map(|item| item.unbind()).collect())
214 .unwrap_or_default();
215
216 Ok(Self {
217 inner: Some(PicklingStateInner {
218 buffer: buffer.clone(),
219 tensor_engine_references: refs,
220 pending_pickles: VecDeque::new(),
221 }),
222 })
223 }
224
225 fn tensor_engine_references(&self, py: Python<'_>) -> PyResult<Py<PyList>> {
229 let inner = self.inner_ref()?;
230 let refs: Vec<Py<PyAny>> = inner
231 .tensor_engine_references
232 .iter()
233 .map(|r| r.clone_ref(py))
234 .collect();
235 Ok(PyList::new(py, refs)?.unbind())
236 }
237
238 fn buffer(&self) -> PyResult<crate::buffers::FrozenBuffer> {
243 let inner = self.inner_ref()?;
244 Ok(inner.buffer.clone())
245 }
246
247 fn unpickle(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
252 let inner = self.take_inner()?;
253
254 for pending in &inner.pending_pickles {
256 if pending.borrow(py).poll()?.is_none() {
257 return Err(pyo3::exceptions::PyRuntimeError::new_err(
258 "Cannot unpickle: there are unresolved pending pickles",
259 ));
260 }
261 }
262
263 let mut active = ActivePicklingState::new(false, false);
266 active.pending_pickles = inner.pending_pickles;
267 active.tensor_engine_references = inner.tensor_engine_references;
268
269 let _guard = ActivePicklingGuard::enter(active);
270
271 let result = if maybe_torch_fn(py).call0()?.is_truthy()? {
274 torch_loads_fn(py).call1((inner.buffer,))
275 } else {
276 cloudpickle(py).getattr("loads")?.call1((inner.buffer,))
277 };
278
279 result.map(|obj| obj.unbind())
280 }
281}
282
283impl PicklingState {
284 pub async fn resolve(mut self) -> PyResult<PicklingState> {
292 if self.inner_ref()?.pending_pickles.is_empty() {
294 return Ok(self);
295 }
296
297 let pending: Vec<Py<PyShared>> = Python::attach(|py| {
299 self.inner_ref().map(|inner| {
300 inner
301 .pending_pickles
302 .iter()
303 .map(|p| p.clone_ref(py))
304 .collect()
305 })
306 })?;
307
308 for pending_pickle in pending {
309 let mut task = Python::attach(|py| pending_pickle.borrow(py).task())?;
310 task.take_task()?.await?;
311 }
312
313 Python::attach(|py| {
315 let obj = self.unpickle(py)?;
316 pickle(py, obj, false, true)
317 })
318 }
319}
320
321#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pickle")]
327pub struct PendingMessage {
328 pub(crate) kind: PythonMessageKind,
329 state: PicklingState,
330}
331
332impl PendingMessage {
333 pub fn new(kind: PythonMessageKind, state: PicklingState) -> Self {
335 Self { kind, state }
336 }
337
338 pub fn take(&mut self) -> PyResult<PendingMessage> {
343 let inner = self.state.take_inner()?;
344 Ok(PendingMessage {
345 kind: std::mem::take(&mut self.kind),
346 state: PicklingState { inner: Some(inner) },
347 })
348 }
349
350 pub async fn resolve(self) -> PyResult<PythonMessage> {
357 let mut resolved_state = self.state.resolve().await?;
359
360 let inner = resolved_state.take_inner()?;
362 Ok(PythonMessage::new_from_buf(self.kind, inner.take_buffer()))
363 }
364}
365
366#[pymethods]
367impl PendingMessage {
368 #[new]
370 pub fn py_new(
371 kind: PythonMessageKind,
372 mut state: PyRefMut<'_, PicklingState>,
373 ) -> PyResult<Self> {
374 let inner = state.take_inner()?;
376 Ok(Self {
377 kind,
378 state: PicklingState { inner: Some(inner) },
379 })
380 }
381
382 #[getter]
384 fn kind(&self) -> PythonMessageKind {
385 self.kind.clone()
386 }
387}
388
389#[pyfunction]
398fn push_tensor_engine_reference_if_active(obj: Py<PyAny>) -> PyResult<bool> {
399 ACTIVE_PICKLING_STATE.with(|cell| {
400 let mut state = cell.borrow_mut();
401 match state.as_mut() {
402 Some(s) => {
403 if !s.allow_tensor_engine_references {
404 return Err(pyo3::exceptions::PyRuntimeError::new_err(
405 "Tensor engine references are not allowed in the current pickling context",
406 ));
407 }
408 s.tensor_engine_references.push_back(obj);
409 Ok(true)
410 }
411 None => Ok(false),
412 }
413 })
414}
415
416#[pyfunction]
421fn pop_tensor_engine_reference(py: Python<'_>) -> PyResult<Py<PyAny>> {
422 ACTIVE_PICKLING_STATE
423 .with(|cell| {
424 let mut state = cell.borrow_mut();
425 match state.as_mut() {
426 Some(s) => s.tensor_engine_references.pop_front().ok_or_else(|| {
427 pyo3::exceptions::PyRuntimeError::new_err(
428 "No tensor engine references remaining",
429 )
430 }),
431 None => Err(pyo3::exceptions::PyRuntimeError::new_err(
432 "No active pickling state",
433 )),
434 }
435 })
436 .map(|obj| obj.clone_ref(py))
437}
438
439#[pyfunction]
444fn pop_pending_pickle(py: Python<'_>) -> PyResult<Py<PyShared>> {
445 ACTIVE_PICKLING_STATE.with(|cell| {
446 let mut state = cell.borrow_mut();
447 match state.as_mut() {
448 Some(s) => {
449 let shared = s.pending_pickles.pop_front().ok_or_else(|| {
450 pyo3::exceptions::PyRuntimeError::new_err("No pending pickles remaining")
451 })?;
452 Ok(shared.clone_ref(py))
453 }
454 None => Err(pyo3::exceptions::PyRuntimeError::new_err(
455 "No active pickling state",
456 )),
457 }
458 })
459}
460
461pub fn push_pending_pickle(py_shared: Py<PyShared>) -> PyResult<()> {
469 ACTIVE_PICKLING_STATE.with(|cell| {
470 let mut state = cell.borrow_mut();
471 match state.as_mut() {
472 Some(s) => {
473 if !s.allow_pending_pickles {
474 return Err(pyo3::exceptions::PyRuntimeError::new_err(
475 "Pending pickles are not allowed in the current pickling context",
476 ));
477 }
478 s.pending_pickles.push_back(py_shared);
479 Ok(())
480 }
481 None => Err(pyo3::exceptions::PyRuntimeError::new_err(
482 "No active pickling state",
483 )),
484 }
485 })
486}
487
488pub fn reduce_shared<'py>(
495 py: Python<'py>,
496 py_shared: &Bound<'py, PyShared>,
497) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyTuple>)> {
498 if let Some(value) = py_shared.borrow().poll()? {
500 let from_value = shared_class(py).getattr("from_value")?;
501 let args = PyTuple::new(py, [value])?;
502 return Ok((from_value, args));
503 }
504
505 let py_shared_py: Py<PyShared> = py_shared.clone().unbind();
507 if push_pending_pickle(py_shared_py).is_ok() {
508 let pop_fn = pop_pending_pickle_fn(py);
509 let args = PyTuple::empty(py);
510 return Ok((pop_fn, args));
511 }
512
513 let value = PyShared::block_on(py_shared.borrow(), py)?;
515 let from_value = shared_class(py).getattr("from_value")?;
516 let args = PyTuple::new(py, [value])?;
517 Ok((from_value, args))
518}
519
520fn pickle_into_buffer(py: Python<'_>, obj: &Py<PyAny>, buffer: &Py<Buffer>) -> PyResult<()> {
525 pickle_monkeypatch(py);
527
528 if maybe_torch_fn(py).call0()?.is_truthy()? {
531 torch_dump_fn(py).call1((obj, buffer.bind(py)))?;
532 } else {
533 let pickler = cloudpickle(py)
534 .getattr("Pickler")?
535 .call1((buffer.bind(py),))?;
536 pickler.call_method1("dump", (obj,))?;
537 }
538
539 Ok(())
540}
541
542pub fn pickle_to_part(py: Python<'_>, obj: &Py<PyAny>) -> PyResult<Part> {
548 let active = ActivePicklingState::new(false, false);
549 let buffer = Py::new(py, Buffer::default())?;
550 let _guard = ActivePicklingGuard::enter(active);
551
552 pickle_into_buffer(py, obj, &buffer)?;
553
554 Ok(buffer.borrow_mut(py).take_part())
555}
556
557#[pyfunction]
571#[pyo3(signature = (obj, allow_pending_pickles=true, allow_tensor_engine_references=true))]
572pub fn pickle(
573 py: Python<'_>,
574 obj: Py<PyAny>,
575 allow_pending_pickles: bool,
576 allow_tensor_engine_references: bool,
577) -> PyResult<PicklingState> {
578 let active = ActivePicklingState::new(allow_pending_pickles, allow_tensor_engine_references);
579 let buffer = Py::new(py, Buffer::default())?;
580 let _guard = ActivePicklingGuard::enter(active);
581
582 pickle_into_buffer(py, &obj, &buffer)?;
583
584 let active = ACTIVE_PICKLING_STATE
587 .with(|cell| cell.borrow_mut().take())
588 .expect("active pickling state should still be set");
589
590 let frozen_buffer = buffer.borrow_mut(py).freeze();
592 let inner = active.into_pickling_state(frozen_buffer);
593 Ok(PicklingState { inner: Some(inner) })
594}
595
596pub(crate) fn unpickle<'py>(
597 py: Python<'py>,
598 buffer: crate::buffers::FrozenBuffer,
599) -> PyResult<Bound<'py, PyAny>> {
600 _unpickle(py).call1((buffer.into_py_any(py)?,))
601}
602
603pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
605 module.add_class::<PicklingState>()?;
606 module.add_class::<PendingMessage>()?;
607 module.add_function(wrap_pyfunction!(pickle, module)?)?;
608 module.add_function(wrap_pyfunction!(
609 push_tensor_engine_reference_if_active,
610 module
611 )?)?;
612 module.add_function(wrap_pyfunction!(pop_tensor_engine_reference, module)?)?;
613 module.add_function(wrap_pyfunction!(pop_pending_pickle, module)?)?;
614 Ok(())
615}