monarch_hyperactor/
pickle.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Deferred pickling support for Monarch.
10//!
11//! This module provides utilities for deferring the pickling of objects
12//! that contain async values (futures/tasks) that must be resolved before
13//! the final pickle can be produced.
14
15use std::cell::RefCell;
16use std::collections::VecDeque;
17
18use monarch_types::py_global;
19use pyo3::IntoPyObjectExt;
20use pyo3::prelude::*;
21use pyo3::types::PyList;
22use pyo3::types::PyTuple;
23use serde_multipart::Part;
24
25use crate::actor::PythonMessage;
26use crate::actor::PythonMessageKind;
27use crate::buffers::Buffer;
28use crate::pytokio::PyShared;
29
30// Python helper used to reconstruct an object graph from a pickled
31// buffer plus a list of "unflatten values" (including placeholders).
32py_global!(unflatten, "monarch._src.actor.pickle", "unflatten");
33
34// Python helper used to pickle an object graph, optionally using a
35// filter to replace certain values with placeholders (e.g.
36// `PendingPickle`).
37//
38// We use `flatten`/`unflatten` to support "deferred pickling":
39// initially pickle with placeholders, then later resolve futures and
40// re-pickle with concrete values.
41py_global!(flatten, "monarch._src.actor.pickle", "flatten");
42
43// cloudpickle module for serialization
44py_global!(cloudpickle, "cloudpickle", "cloudpickle");
45
46py_global!(_unpickle, "pickle", "loads");
47
48// Importing monarch._src.actor.pickle applies a monkeypatch to cloudpickle
49// that injects RemoteImportLoader into pickled function globals, enabling
50// source loading for pickle-by-value code on remote hosts (needed for
51// debugger and tracebacks). We access this before pickling to ensure
52// the monkeypatch is applied.
53py_global!(
54    pickle_monkeypatch,
55    "monarch._src.actor.pickle",
56    "_function_getstate"
57);
58
59// Check if torch has been loaded into the current Python process.
60// Returns the torch module if loaded, otherwise None.
61py_global!(maybe_torch_fn, "monarch._src.actor.pickle", "maybe_torch");
62
63// Torch-aware dump function: uses a Pickler subclass with dispatch_table
64// entries for torch storage types (UntypedStorage, TypedStorage, etc.).
65py_global!(torch_dump_fn, "monarch._src.actor.pickle", "torch_dump");
66
67// Torch-aware loads function: wraps cloudpickle.loads with
68// torch.utils._python_dispatch._disable_current_modes().
69py_global!(torch_loads_fn, "monarch._src.actor.pickle", "torch_loads");
70
71// Shared class for pickling PyShared values
72py_global!(
73    shared_class,
74    "monarch._rust_bindings.monarch_hyperactor.pytokio",
75    "Shared"
76);
77
78// pop_pending_pickle function for unpickling deferred PyShared values
79py_global!(
80    pop_pending_pickle_fn,
81    "monarch._rust_bindings.monarch_hyperactor.pickle",
82    "pop_pending_pickle"
83);
84
85// Thread-local storage for the active pickling state.
86// Set by pickle/unpickle operations so free functions used in __reduce__
87// implementations can access it.
88thread_local! {
89    static ACTIVE_PICKLING_STATE: RefCell<Option<ActivePicklingState>> = const { RefCell::new(None) };
90}
91
92/// RAII guard that sets the thread-local `ACTIVE_PICKLING_STATE` on creation
93/// and restores the previous state (if any) on drop. This supports nesting:
94/// if a guard already exists, the new guard saves the old state and restores
95/// it when dropped, even on panic.
96struct ActivePicklingGuard {
97    previous: Option<ActivePicklingState>,
98}
99
100impl ActivePicklingGuard {
101    /// Set `state` as the active pickling state, saving any existing state.
102    fn enter(state: ActivePicklingState) -> Self {
103        let previous = ACTIVE_PICKLING_STATE.with(|cell| cell.borrow_mut().replace(state));
104        Self { previous }
105    }
106}
107
108impl Drop for ActivePicklingGuard {
109    fn drop(&mut self) {
110        ACTIVE_PICKLING_STATE.with(|cell| {
111            *cell.borrow_mut() = self.previous.take();
112        });
113    }
114}
115
116/// State maintained during active pickling/unpickling operations.
117///
118/// This is the thread-local state used while cloudpickle is running.
119/// It collects tensor engine references and pending pickles during serialization.
120struct ActivePicklingState {
121    /// References to tensor engine objects that need special handling.
122    tensor_engine_references: VecDeque<Py<PyAny>>,
123    /// Pending pickles (PyShared values) that must be resolved.
124    pending_pickles: VecDeque<Py<PyShared>>,
125    /// Whether pending pickles are allowed in this pickling context.
126    allow_pending_pickles: bool,
127    /// Whether tensor engine references are allowed in this pickling context.
128    allow_tensor_engine_references: bool,
129}
130
131impl ActivePicklingState {
132    /// Create a new ActivePicklingState.
133    fn new(allow_pending_pickles: bool, allow_tensor_engine_references: bool) -> Self {
134        Self {
135            tensor_engine_references: VecDeque::new(),
136            pending_pickles: VecDeque::new(),
137            allow_pending_pickles,
138            allow_tensor_engine_references,
139        }
140    }
141
142    /// Convert this active state into a frozen PicklingState.
143    fn into_pickling_state(self, buffer: Part) -> PicklingStateInner {
144        PicklingStateInner {
145            buffer,
146            tensor_engine_references: self.tensor_engine_references,
147            pending_pickles: self.pending_pickles,
148        }
149    }
150}
151
152/// Inner data for a completed pickling operation.
153///
154/// This contains the pickled bytes as a fragmented [`Part`] (zero-copy)
155/// and any collected references.
156pub struct PicklingStateInner {
157    /// The pickled bytes as a fragmented Part (zero-copy).
158    buffer: Part,
159    /// References to tensor engine objects that need special handling.
160    tensor_engine_references: VecDeque<Py<PyAny>>,
161    /// Pending pickles (PyShared values) that must be resolved.
162    pending_pickles: VecDeque<Py<PyShared>>,
163}
164
165impl PicklingStateInner {
166    /// Get a reference to the pending pickles.
167    pub fn pending_pickles(&self) -> &VecDeque<Py<PyShared>> {
168        &self.pending_pickles
169    }
170
171    /// Take the Part (pickled bytes) from this inner state.
172    pub fn take_buffer(self) -> Part {
173        self.buffer
174    }
175}
176
177/// Python-visible wrapper for the result of a pickling operation.
178///
179/// Contains the pickled bytes and any tensor engine references or pending
180/// pickles that were collected during serialization.
181#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pickle")]
182pub struct PicklingState {
183    inner: Option<PicklingStateInner>,
184}
185
186impl PicklingState {
187    pub fn take_inner(&mut self) -> PyResult<PicklingStateInner> {
188        self.inner.take().ok_or_else(|| {
189            pyo3::exceptions::PyRuntimeError::new_err("PicklingState has already been consumed")
190        })
191    }
192
193    fn inner_ref(&self) -> PyResult<&PicklingStateInner> {
194        self.inner.as_ref().ok_or_else(|| {
195            pyo3::exceptions::PyRuntimeError::new_err("PicklingState has already been consumed")
196        })
197    }
198}
199
200#[pymethods]
201impl PicklingState {
202    /// Create a new PicklingState from a buffer and optional tensor engine references.
203    ///
204    /// This is used for unpickling received messages that may contain tensor engine
205    /// references that need to be restored during deserialization.
206    #[new]
207    #[pyo3(signature = (buffer, tensor_engine_references=None))]
208    fn py_new(
209        buffer: PyRef<'_, crate::buffers::FrozenBuffer>,
210        tensor_engine_references: Option<&Bound<'_, PyList>>,
211    ) -> PyResult<Self> {
212        let refs: VecDeque<Py<PyAny>> = tensor_engine_references
213            .map(|list| list.iter().map(|item| item.unbind()).collect())
214            .unwrap_or_default();
215
216        Ok(Self {
217            inner: Some(PicklingStateInner {
218                buffer: Part::from(buffer.inner.clone()),
219                tensor_engine_references: refs,
220                pending_pickles: VecDeque::new(),
221            }),
222        })
223    }
224
225    /// Get a copy of all tensor engine references from this pickling state.
226    ///
227    /// Returns a Python list containing copies of the tensor engine references.
228    fn tensor_engine_references(&self, py: Python<'_>) -> PyResult<Py<PyList>> {
229        let inner = self.inner_ref()?;
230        let refs: Vec<Py<PyAny>> = inner
231            .tensor_engine_references
232            .iter()
233            .map(|r| r.clone_ref(py))
234            .collect();
235        Ok(PyList::new(py, refs)?.unbind())
236    }
237
238    /// Get the buffer from this pickling state.
239    ///
240    /// Returns a FrozenBuffer containing the pickled bytes.
241    /// This does not consume the PicklingState.
242    fn buffer(&self) -> PyResult<crate::buffers::FrozenBuffer> {
243        let inner = self.inner_ref()?;
244        Ok(crate::buffers::FrozenBuffer {
245            inner: inner.buffer.clone().into_bytes(),
246        })
247    }
248
249    /// Unpickle the buffer contents.
250    ///
251    /// This consumes the PicklingState. It will fail if there are any pending
252    /// pickles that haven't been resolved.
253    fn unpickle(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
254        let inner = self.take_inner()?;
255
256        // Verify all pending pickles are resolved before unpickling
257        for pending in &inner.pending_pickles {
258            if pending.borrow(py).poll()?.is_none() {
259                return Err(pyo3::exceptions::PyRuntimeError::new_err(
260                    "Cannot unpickle: there are unresolved pending pickles",
261                ));
262            }
263        }
264
265        // Set up an active state for unpickling (to handle pop calls).
266        // The guard restores any previous state on drop (including on panic).
267        let mut active = ActivePicklingState::new(false, false);
268        active.pending_pickles = inner.pending_pickles;
269        active.tensor_engine_references = inner.tensor_engine_references;
270
271        let _guard = ActivePicklingGuard::enter(active);
272
273        let frozen = crate::buffers::FrozenBuffer {
274            inner: inner.buffer.into_bytes(),
275        };
276
277        // Unpickle the object. If torch is loaded, use torch_loads which
278        // disables dispatch modes during unpickling.
279        let result = if maybe_torch_fn(py).call0()?.is_truthy()? {
280            torch_loads_fn(py).call1((frozen,))
281        } else {
282            cloudpickle(py).getattr("loads")?.call1((frozen,))
283        };
284
285        result.map(|obj| obj.unbind())
286    }
287}
288
289impl PicklingState {
290    /// Resolve all pending pickles and return a new PicklingState without pending pickles.
291    ///
292    /// This consumes the PicklingState. It:
293    /// 1. If there are no pending pickles, returns self immediately
294    /// 2. Otherwise, awaits all pending pickles until they're finished
295    /// 3. Calls unpickle to reconstruct the object
296    /// 4. Calls pickle again to get a new PicklingState without pending pickles
297    pub async fn resolve(mut self) -> PyResult<PicklingState> {
298        // Short-circuit if there are no pending pickles
299        if self.inner_ref()?.pending_pickles.is_empty() {
300            return Ok(self);
301        }
302
303        // Await all pending pickles to ensure they're resolved
304        let pending: Vec<Py<PyShared>> = Python::attach(|py| {
305            self.inner_ref().map(|inner| {
306                inner
307                    .pending_pickles
308                    .iter()
309                    .map(|p| p.clone_ref(py))
310                    .collect()
311            })
312        })?;
313
314        for pending_pickle in pending {
315            let mut task = Python::attach(|py| pending_pickle.borrow(py).task())?;
316            task.take_task()?.await?;
317        }
318
319        // Unpickle (pending pickles are now resolved) and re-pickle without allowing new ones
320        Python::attach(|py| {
321            let obj = self.unpickle(py)?;
322            pickle(py, obj, false, true)
323        })
324    }
325}
326
327/// A message that is pending resolution of async values before it can be sent.
328///
329/// Contains a `PythonMessageKind` and a `PicklingState`. The `PicklingState` may contain
330/// pending pickles (unresolved async values) that must be resolved before the message
331/// can be converted into a `PythonMessage`.
332#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pickle")]
333pub struct PendingMessage {
334    pub(crate) kind: PythonMessageKind,
335    state: PicklingState,
336}
337
338impl PendingMessage {
339    /// Create a new PendingMessage from a kind and pickling state.
340    pub fn new(kind: PythonMessageKind, state: PicklingState) -> Self {
341        Self { kind, state }
342    }
343
344    /// Take ownership of the inner state from a mutable reference.
345    ///
346    /// This is used by pyo3 pymethods that receive `&mut PendingMessage`
347    /// but need to pass ownership to the trait method.
348    pub fn take(&mut self) -> PyResult<PendingMessage> {
349        let inner = self.state.take_inner()?;
350        Ok(PendingMessage {
351            kind: std::mem::take(&mut self.kind),
352            state: PicklingState { inner: Some(inner) },
353        })
354    }
355
356    /// Resolve all pending pickles and convert this into a PythonMessage.
357    ///
358    /// This is an async method that:
359    /// 1. Awaits all pending pickles in the PicklingState
360    /// 2. Re-pickles the resolved object
361    /// 3. Returns a PythonMessage with the resolved bytes (no GIL needed for final step)
362    pub async fn resolve(self) -> PyResult<PythonMessage> {
363        // Resolve the pickling state (awaits all pending pickles and re-pickles)
364        let mut resolved_state = self.state.resolve().await?;
365
366        // Take the Part directly - no GIL needed since Part doesn't contain Py<>
367        let inner = resolved_state.take_inner()?;
368        Ok(PythonMessage::new_from_buf(self.kind, inner.take_buffer()))
369    }
370}
371
372#[pymethods]
373impl PendingMessage {
374    /// Create a new PendingMessage from a kind and pickling state.
375    #[new]
376    pub fn py_new(
377        kind: PythonMessageKind,
378        mut state: PyRefMut<'_, PicklingState>,
379    ) -> PyResult<Self> {
380        // Take the inner state from the PicklingState
381        let inner = state.take_inner()?;
382        Ok(Self {
383            kind,
384            state: PicklingState { inner: Some(inner) },
385        })
386    }
387
388    /// Get the message kind.
389    #[getter]
390    fn kind(&self) -> PythonMessageKind {
391        self.kind.clone()
392    }
393}
394
395/// Push a tensor engine reference to the active pickling state if one is active.
396///
397/// This is called from Python during pickling when a tensor engine object
398/// is encountered that needs special handling.
399///
400/// Returns False if there is no active pickling state.
401/// Returns True if the reference was successfully pushed.
402/// Raises an error if tensor engine references are not allowed in the current pickling context.
403#[pyfunction]
404fn push_tensor_engine_reference_if_active(obj: Py<PyAny>) -> PyResult<bool> {
405    ACTIVE_PICKLING_STATE.with(|cell| {
406        let mut state = cell.borrow_mut();
407        match state.as_mut() {
408            Some(s) => {
409                if !s.allow_tensor_engine_references {
410                    return Err(pyo3::exceptions::PyRuntimeError::new_err(
411                        "Tensor engine references are not allowed in the current pickling context",
412                    ));
413                }
414                s.tensor_engine_references.push_back(obj);
415                Ok(true)
416            }
417            None => Ok(false),
418        }
419    })
420}
421
422/// Pop a tensor engine reference from the active pickling state.
423///
424/// This is called from Python during unpickling to retrieve tensor engine
425/// objects in the order they were pushed.
426#[pyfunction]
427fn pop_tensor_engine_reference(py: Python<'_>) -> PyResult<Py<PyAny>> {
428    ACTIVE_PICKLING_STATE
429        .with(|cell| {
430            let mut state = cell.borrow_mut();
431            match state.as_mut() {
432                Some(s) => s.tensor_engine_references.pop_front().ok_or_else(|| {
433                    pyo3::exceptions::PyRuntimeError::new_err(
434                        "No tensor engine references remaining",
435                    )
436                }),
437                None => Err(pyo3::exceptions::PyRuntimeError::new_err(
438                    "No active pickling state",
439                )),
440            }
441        })
442        .map(|obj| obj.clone_ref(py))
443}
444
445/// Pop a pending pickle from the active pickling state.
446///
447/// This is called from Python during unpickling to retrieve the PyShared
448/// object that was deferred during pickling.
449#[pyfunction]
450fn pop_pending_pickle(py: Python<'_>) -> PyResult<Py<PyShared>> {
451    ACTIVE_PICKLING_STATE.with(|cell| {
452        let mut state = cell.borrow_mut();
453        match state.as_mut() {
454            Some(s) => {
455                let shared = s.pending_pickles.pop_front().ok_or_else(|| {
456                    pyo3::exceptions::PyRuntimeError::new_err("No pending pickles remaining")
457                })?;
458                Ok(shared.clone_ref(py))
459            }
460            None => Err(pyo3::exceptions::PyRuntimeError::new_err(
461                "No active pickling state",
462            )),
463        }
464    })
465}
466
467/// Push a pending pickle to the active pickling state (Rust-only).
468///
469/// This is used by __reduce__ implementations to register a PyShared
470/// that must be resolved before the pickle is complete.
471///
472/// Returns an error if there is no active pickling state or if pending
473/// pickles are not allowed in the current pickling context.
474pub fn push_pending_pickle(py_shared: Py<PyShared>) -> PyResult<()> {
475    ACTIVE_PICKLING_STATE.with(|cell| {
476        let mut state = cell.borrow_mut();
477        match state.as_mut() {
478            Some(s) => {
479                if !s.allow_pending_pickles {
480                    return Err(pyo3::exceptions::PyRuntimeError::new_err(
481                        "Pending pickles are not allowed in the current pickling context",
482                    ));
483                }
484                s.pending_pickles.push_back(py_shared);
485                Ok(())
486            }
487            None => Err(pyo3::exceptions::PyRuntimeError::new_err(
488                "No active pickling state",
489            )),
490        }
491    })
492}
493
494/// Reduce a PyShared for pickling.
495///
496/// This function implements the pickle protocol for PyShared:
497/// 1. If the shared is already finished, return (Shared.from_value, (value,))
498/// 2. If pending pickles are allowed, push it as a pending pickle and return (pop_pending_pickle, ())
499/// 3. Otherwise, block on the shared and return (Shared.from_value, (value,))
500pub fn reduce_shared<'py>(
501    py: Python<'py>,
502    py_shared: &Bound<'py, PyShared>,
503) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyTuple>)> {
504    // First, check if the shared is already finished
505    if let Some(value) = py_shared.borrow().poll()? {
506        let from_value = shared_class(py).getattr("from_value")?;
507        let args = PyTuple::new(py, [value])?;
508        return Ok((from_value, args));
509    }
510
511    // Try to push as a pending pickle (will fail if not allowed or no active state)
512    let py_shared_py: Py<PyShared> = py_shared.clone().unbind();
513    if push_pending_pickle(py_shared_py).is_ok() {
514        let pop_fn = pop_pending_pickle_fn(py);
515        let args = PyTuple::empty(py);
516        return Ok((pop_fn, args));
517    }
518
519    // Fall back to blocking on the shared
520    let value = PyShared::block_on(py_shared.borrow(), py)?;
521    let from_value = shared_class(py).getattr("from_value")?;
522    let args = PyTuple::new(py, [value])?;
523    Ok((from_value, args))
524}
525
526/// Pickle a Python object into a [`Buffer`].
527///
528/// This is the shared pickling core. The caller is responsible for setting up
529/// the [`ActivePicklingGuard`] before calling this function.
530fn pickle_into_buffer(py: Python<'_>, obj: &Py<PyAny>, buffer: &Py<Buffer>) -> PyResult<()> {
531    // Ensure the cloudpickle monkeypatch for RemoteImportLoader is applied.
532    pickle_monkeypatch(py);
533
534    // If torch is loaded, use the torch-aware pickler that handles
535    // torch storage types via dispatch_table.
536    if maybe_torch_fn(py).call0()?.is_truthy()? {
537        torch_dump_fn(py).call1((obj, buffer.bind(py)))?;
538    } else {
539        let pickler = cloudpickle(py)
540            .getattr("Pickler")?
541            .call1((buffer.bind(py),))?;
542        pickler.call_method1("dump", (obj,))?;
543    }
544
545    Ok(())
546}
547
548/// Pickle a Python object and return the serialized data as a [`Part`].
549///
550/// This is a simplified variant of [`pickle`] that disallows pending pickles
551/// and tensor engine references, and returns the raw serialized bytes instead
552/// of a [`PicklingState`].
553pub fn pickle_to_part(py: Python<'_>, obj: &Py<PyAny>) -> PyResult<Part> {
554    let active = ActivePicklingState::new(false, false);
555    let buffer = Py::new(py, Buffer::default())?;
556    let _guard = ActivePicklingGuard::enter(active);
557
558    pickle_into_buffer(py, obj, &buffer)?;
559
560    Ok(buffer.borrow_mut(py).take_part())
561}
562
563/// Pickle an object with support for pending pickles and tensor engine references.
564///
565/// This function creates a PicklingState and calls cloudpickle.dumps with
566/// an active thread-local PicklingState, allowing __reduce__ implementations
567/// to push tensor engine references and pending pickles.
568///
569/// # Arguments
570/// * `obj` - The Python object to pickle
571/// * `allow_pending_pickles` - If true, allow PyShared values to be registered as pending
572/// * `allow_tensor_engine_references` - If true, allow tensor engine references to be registered
573///
574/// # Returns
575/// A PicklingState containing the pickled buffer and any registered references/pending pickles
576#[pyfunction]
577#[pyo3(signature = (obj, allow_pending_pickles=true, allow_tensor_engine_references=true))]
578pub fn pickle(
579    py: Python<'_>,
580    obj: Py<PyAny>,
581    allow_pending_pickles: bool,
582    allow_tensor_engine_references: bool,
583) -> PyResult<PicklingState> {
584    let active = ActivePicklingState::new(allow_pending_pickles, allow_tensor_engine_references);
585    let buffer = Py::new(py, Buffer::default())?;
586    let _guard = ActivePicklingGuard::enter(active);
587
588    pickle_into_buffer(py, &obj, &buffer)?;
589
590    // Take the state (which may have been modified during pickling).
591    // The guard will restore the previous state on drop.
592    let active = ACTIVE_PICKLING_STATE
593        .with(|cell| cell.borrow_mut().take())
594        .expect("active pickling state should still be set");
595
596    // Take the Part (zero-copy fragmented buffer) directly.
597    let part = buffer.borrow_mut(py).take_part();
598    let inner = active.into_pickling_state(part);
599    Ok(PicklingState { inner: Some(inner) })
600}
601
602pub(crate) fn unpickle<'py>(
603    py: Python<'py>,
604    buffer: crate::buffers::FrozenBuffer,
605) -> PyResult<Bound<'py, PyAny>> {
606    _unpickle(py).call1((buffer.into_py_any(py)?,))
607}
608
609/// Register the pickle Python bindings into the given module.
610pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
611    module.add_class::<PicklingState>()?;
612    module.add_class::<PendingMessage>()?;
613    module.add_function(wrap_pyfunction!(pickle, module)?)?;
614    module.add_function(wrap_pyfunction!(
615        push_tensor_engine_reference_if_active,
616        module
617    )?)?;
618    module.add_function(wrap_pyfunction!(pop_tensor_engine_reference, module)?)?;
619    module.add_function(wrap_pyfunction!(pop_pending_pickle, module)?)?;
620    Ok(())
621}