monarch_hyperactor/
pickle.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Deferred pickling support for Monarch.
10//!
11//! This module provides utilities for deferring the pickling of objects
12//! that contain async values (futures/tasks) that must be resolved before
13//! the final pickle can be produced.
14
15use std::cell::RefCell;
16use std::collections::VecDeque;
17
18use monarch_types::py_global;
19use pyo3::IntoPyObjectExt;
20use pyo3::prelude::*;
21use pyo3::types::PyList;
22use pyo3::types::PyTuple;
23use serde_multipart::Part;
24
25use crate::actor::PythonMessage;
26use crate::actor::PythonMessageKind;
27use crate::buffers::Buffer;
28use crate::pytokio::PyShared;
29
30// Python helper used to reconstruct an object graph from a pickled
31// buffer plus a list of "unflatten values" (including placeholders).
32py_global!(unflatten, "monarch._src.actor.pickle", "unflatten");
33
34// Python helper used to pickle an object graph, optionally using a
35// filter to replace certain values with placeholders (e.g.
36// `PendingPickle`).
37//
38// We use `flatten`/`unflatten` to support "deferred pickling":
39// initially pickle with placeholders, then later resolve futures and
40// re-pickle with concrete values.
41py_global!(flatten, "monarch._src.actor.pickle", "flatten");
42
43// cloudpickle module for serialization
44py_global!(cloudpickle, "cloudpickle", "cloudpickle");
45
46py_global!(_unpickle, "pickle", "loads");
47
48// Importing monarch._src.actor.pickle applies a monkeypatch to cloudpickle
49// that injects RemoteImportLoader into pickled function globals, enabling
50// source loading for pickle-by-value code on remote hosts (needed for
51// debugger and tracebacks). We access this before pickling to ensure
52// the monkeypatch is applied.
53py_global!(
54    pickle_monkeypatch,
55    "monarch._src.actor.pickle",
56    "_function_getstate"
57);
58
59// Check if torch has been loaded into the current Python process.
60// Returns the torch module if loaded, otherwise None.
61py_global!(maybe_torch_fn, "monarch._src.actor.pickle", "maybe_torch");
62
63// Torch-aware dump function: uses a Pickler subclass with dispatch_table
64// entries for torch storage types (UntypedStorage, TypedStorage, etc.).
65py_global!(torch_dump_fn, "monarch._src.actor.pickle", "torch_dump");
66
67// Torch-aware loads function: wraps cloudpickle.loads with
68// torch.utils._python_dispatch._disable_current_modes().
69py_global!(torch_loads_fn, "monarch._src.actor.pickle", "torch_loads");
70
71// Shared class for pickling PyShared values
72py_global!(
73    shared_class,
74    "monarch._rust_bindings.monarch_hyperactor.pytokio",
75    "Shared"
76);
77
78// pop_pending_pickle function for unpickling deferred PyShared values
79py_global!(
80    pop_pending_pickle_fn,
81    "monarch._rust_bindings.monarch_hyperactor.pickle",
82    "pop_pending_pickle"
83);
84
85// Thread-local storage for the active pickling state.
86// Set by pickle/unpickle operations so free functions used in __reduce__
87// implementations can access it.
88thread_local! {
89    static ACTIVE_PICKLING_STATE: RefCell<Option<ActivePicklingState>> = const { RefCell::new(None) };
90}
91
92/// RAII guard that sets the thread-local `ACTIVE_PICKLING_STATE` on creation
93/// and restores the previous state (if any) on drop. This supports nesting:
94/// if a guard already exists, the new guard saves the old state and restores
95/// it when dropped, even on panic.
96struct ActivePicklingGuard {
97    previous: Option<ActivePicklingState>,
98}
99
100impl ActivePicklingGuard {
101    /// Set `state` as the active pickling state, saving any existing state.
102    fn enter(state: ActivePicklingState) -> Self {
103        let previous = ACTIVE_PICKLING_STATE.with(|cell| cell.borrow_mut().replace(state));
104        Self { previous }
105    }
106}
107
108impl Drop for ActivePicklingGuard {
109    fn drop(&mut self) {
110        ACTIVE_PICKLING_STATE.with(|cell| {
111            *cell.borrow_mut() = self.previous.take();
112        });
113    }
114}
115
116/// State maintained during active pickling/unpickling operations.
117///
118/// This is the thread-local state used while cloudpickle is running.
119/// It collects tensor engine references and pending pickles during serialization.
120struct ActivePicklingState {
121    /// References to tensor engine objects that need special handling.
122    tensor_engine_references: VecDeque<Py<PyAny>>,
123    /// Pending pickles (PyShared values) that must be resolved.
124    pending_pickles: VecDeque<Py<PyShared>>,
125    /// Whether pending pickles are allowed in this pickling context.
126    allow_pending_pickles: bool,
127    /// Whether tensor engine references are allowed in this pickling context.
128    allow_tensor_engine_references: bool,
129}
130
131impl ActivePicklingState {
132    /// Create a new ActivePicklingState.
133    fn new(allow_pending_pickles: bool, allow_tensor_engine_references: bool) -> Self {
134        Self {
135            tensor_engine_references: VecDeque::new(),
136            pending_pickles: VecDeque::new(),
137            allow_pending_pickles,
138            allow_tensor_engine_references,
139        }
140    }
141
142    /// Convert this active state into a frozen PicklingState.
143    fn into_pickling_state(self, buffer: crate::buffers::FrozenBuffer) -> PicklingStateInner {
144        PicklingStateInner {
145            buffer,
146            tensor_engine_references: self.tensor_engine_references,
147            pending_pickles: self.pending_pickles,
148        }
149    }
150}
151
152/// Inner data for a completed pickling operation.
153///
154/// This contains the frozen pickled bytes and any collected references.
155/// Does not require GIL for access to the FrozenBuffer.
156pub struct PicklingStateInner {
157    /// The pickled bytes as a FrozenBuffer (zero-copy).
158    buffer: crate::buffers::FrozenBuffer,
159    /// References to tensor engine objects that need special handling.
160    tensor_engine_references: VecDeque<Py<PyAny>>,
161    /// Pending pickles (PyShared values) that must be resolved.
162    pending_pickles: VecDeque<Py<PyShared>>,
163}
164
165impl PicklingStateInner {
166    /// Get a reference to the pending pickles.
167    pub fn pending_pickles(&self) -> &VecDeque<Py<PyShared>> {
168        &self.pending_pickles
169    }
170
171    /// Take the FrozenBuffer (pickled bytes) from this inner state.
172    pub fn take_buffer(self) -> crate::buffers::FrozenBuffer {
173        self.buffer
174    }
175}
176
177/// Python-visible wrapper for the result of a pickling operation.
178///
179/// Contains the pickled bytes and any tensor engine references or pending
180/// pickles that were collected during serialization.
181#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pickle")]
182pub struct PicklingState {
183    inner: Option<PicklingStateInner>,
184}
185
186impl PicklingState {
187    pub fn take_inner(&mut self) -> PyResult<PicklingStateInner> {
188        self.inner.take().ok_or_else(|| {
189            pyo3::exceptions::PyRuntimeError::new_err("PicklingState has already been consumed")
190        })
191    }
192
193    fn inner_ref(&self) -> PyResult<&PicklingStateInner> {
194        self.inner.as_ref().ok_or_else(|| {
195            pyo3::exceptions::PyRuntimeError::new_err("PicklingState has already been consumed")
196        })
197    }
198}
199
200#[pymethods]
201impl PicklingState {
202    /// Create a new PicklingState from a buffer and optional tensor engine references.
203    ///
204    /// This is used for unpickling received messages that may contain tensor engine
205    /// references that need to be restored during deserialization.
206    #[new]
207    #[pyo3(signature = (buffer, tensor_engine_references=None))]
208    fn py_new(
209        buffer: PyRef<'_, crate::buffers::FrozenBuffer>,
210        tensor_engine_references: Option<&Bound<'_, PyList>>,
211    ) -> PyResult<Self> {
212        let refs: VecDeque<Py<PyAny>> = tensor_engine_references
213            .map(|list| list.iter().map(|item| item.unbind()).collect())
214            .unwrap_or_default();
215
216        Ok(Self {
217            inner: Some(PicklingStateInner {
218                buffer: buffer.clone(),
219                tensor_engine_references: refs,
220                pending_pickles: VecDeque::new(),
221            }),
222        })
223    }
224
225    /// Get a copy of all tensor engine references from this pickling state.
226    ///
227    /// Returns a Python list containing copies of the tensor engine references.
228    fn tensor_engine_references(&self, py: Python<'_>) -> PyResult<Py<PyList>> {
229        let inner = self.inner_ref()?;
230        let refs: Vec<Py<PyAny>> = inner
231            .tensor_engine_references
232            .iter()
233            .map(|r| r.clone_ref(py))
234            .collect();
235        Ok(PyList::new(py, refs)?.unbind())
236    }
237
238    /// Get the buffer from this pickling state.
239    ///
240    /// Returns a FrozenBuffer containing the pickled bytes.
241    /// This does not consume the PicklingState.
242    fn buffer(&self) -> PyResult<crate::buffers::FrozenBuffer> {
243        let inner = self.inner_ref()?;
244        Ok(inner.buffer.clone())
245    }
246
247    /// Unpickle the buffer contents.
248    ///
249    /// This consumes the PicklingState. It will fail if there are any pending
250    /// pickles that haven't been resolved.
251    fn unpickle(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
252        let inner = self.take_inner()?;
253
254        // Verify all pending pickles are resolved before unpickling
255        for pending in &inner.pending_pickles {
256            if pending.borrow(py).poll()?.is_none() {
257                return Err(pyo3::exceptions::PyRuntimeError::new_err(
258                    "Cannot unpickle: there are unresolved pending pickles",
259                ));
260            }
261        }
262
263        // Set up an active state for unpickling (to handle pop calls).
264        // The guard restores any previous state on drop (including on panic).
265        let mut active = ActivePicklingState::new(false, false);
266        active.pending_pickles = inner.pending_pickles;
267        active.tensor_engine_references = inner.tensor_engine_references;
268
269        let _guard = ActivePicklingGuard::enter(active);
270
271        // Unpickle the object. If torch is loaded, use torch_loads which
272        // disables dispatch modes during unpickling.
273        let result = if maybe_torch_fn(py).call0()?.is_truthy()? {
274            torch_loads_fn(py).call1((inner.buffer,))
275        } else {
276            cloudpickle(py).getattr("loads")?.call1((inner.buffer,))
277        };
278
279        result.map(|obj| obj.unbind())
280    }
281}
282
283impl PicklingState {
284    /// Resolve all pending pickles and return a new PicklingState without pending pickles.
285    ///
286    /// This consumes the PicklingState. It:
287    /// 1. If there are no pending pickles, returns self immediately
288    /// 2. Otherwise, awaits all pending pickles until they're finished
289    /// 3. Calls unpickle to reconstruct the object
290    /// 4. Calls pickle again to get a new PicklingState without pending pickles
291    pub async fn resolve(mut self) -> PyResult<PicklingState> {
292        // Short-circuit if there are no pending pickles
293        if self.inner_ref()?.pending_pickles.is_empty() {
294            return Ok(self);
295        }
296
297        // Await all pending pickles to ensure they're resolved
298        let pending: Vec<Py<PyShared>> = Python::attach(|py| {
299            self.inner_ref().map(|inner| {
300                inner
301                    .pending_pickles
302                    .iter()
303                    .map(|p| p.clone_ref(py))
304                    .collect()
305            })
306        })?;
307
308        for pending_pickle in pending {
309            let mut task = Python::attach(|py| pending_pickle.borrow(py).task())?;
310            task.take_task()?.await?;
311        }
312
313        // Unpickle (pending pickles are now resolved) and re-pickle without allowing new ones
314        Python::attach(|py| {
315            let obj = self.unpickle(py)?;
316            pickle(py, obj, false, true)
317        })
318    }
319}
320
321/// A message that is pending resolution of async values before it can be sent.
322///
323/// Contains a `PythonMessageKind` and a `PicklingState`. The `PicklingState` may contain
324/// pending pickles (unresolved async values) that must be resolved before the message
325/// can be converted into a `PythonMessage`.
326#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.pickle")]
327pub struct PendingMessage {
328    pub(crate) kind: PythonMessageKind,
329    state: PicklingState,
330}
331
332impl PendingMessage {
333    /// Create a new PendingMessage from a kind and pickling state.
334    pub fn new(kind: PythonMessageKind, state: PicklingState) -> Self {
335        Self { kind, state }
336    }
337
338    /// Take ownership of the inner state from a mutable reference.
339    ///
340    /// This is used by pyo3 pymethods that receive `&mut PendingMessage`
341    /// but need to pass ownership to the trait method.
342    pub fn take(&mut self) -> PyResult<PendingMessage> {
343        let inner = self.state.take_inner()?;
344        Ok(PendingMessage {
345            kind: std::mem::take(&mut self.kind),
346            state: PicklingState { inner: Some(inner) },
347        })
348    }
349
350    /// Resolve all pending pickles and convert this into a PythonMessage.
351    ///
352    /// This is an async method that:
353    /// 1. Awaits all pending pickles in the PicklingState
354    /// 2. Re-pickles the resolved object
355    /// 3. Returns a PythonMessage with the resolved bytes (no GIL needed for final step)
356    pub async fn resolve(self) -> PyResult<PythonMessage> {
357        // Resolve the pickling state (awaits all pending pickles and re-pickles)
358        let mut resolved_state = self.state.resolve().await?;
359
360        // Take the FrozenBuffer directly - no GIL needed since FrozenBuffer doesn't contain Py<>
361        let inner = resolved_state.take_inner()?;
362        Ok(PythonMessage::new_from_buf(self.kind, inner.take_buffer()))
363    }
364}
365
366#[pymethods]
367impl PendingMessage {
368    /// Create a new PendingMessage from a kind and pickling state.
369    #[new]
370    pub fn py_new(
371        kind: PythonMessageKind,
372        mut state: PyRefMut<'_, PicklingState>,
373    ) -> PyResult<Self> {
374        // Take the inner state from the PicklingState
375        let inner = state.take_inner()?;
376        Ok(Self {
377            kind,
378            state: PicklingState { inner: Some(inner) },
379        })
380    }
381
382    /// Get the message kind.
383    #[getter]
384    fn kind(&self) -> PythonMessageKind {
385        self.kind.clone()
386    }
387}
388
389/// Push a tensor engine reference to the active pickling state if one is active.
390///
391/// This is called from Python during pickling when a tensor engine object
392/// is encountered that needs special handling.
393///
394/// Returns False if there is no active pickling state.
395/// Returns True if the reference was successfully pushed.
396/// Raises an error if tensor engine references are not allowed in the current pickling context.
397#[pyfunction]
398fn push_tensor_engine_reference_if_active(obj: Py<PyAny>) -> PyResult<bool> {
399    ACTIVE_PICKLING_STATE.with(|cell| {
400        let mut state = cell.borrow_mut();
401        match state.as_mut() {
402            Some(s) => {
403                if !s.allow_tensor_engine_references {
404                    return Err(pyo3::exceptions::PyRuntimeError::new_err(
405                        "Tensor engine references are not allowed in the current pickling context",
406                    ));
407                }
408                s.tensor_engine_references.push_back(obj);
409                Ok(true)
410            }
411            None => Ok(false),
412        }
413    })
414}
415
416/// Pop a tensor engine reference from the active pickling state.
417///
418/// This is called from Python during unpickling to retrieve tensor engine
419/// objects in the order they were pushed.
420#[pyfunction]
421fn pop_tensor_engine_reference(py: Python<'_>) -> PyResult<Py<PyAny>> {
422    ACTIVE_PICKLING_STATE
423        .with(|cell| {
424            let mut state = cell.borrow_mut();
425            match state.as_mut() {
426                Some(s) => s.tensor_engine_references.pop_front().ok_or_else(|| {
427                    pyo3::exceptions::PyRuntimeError::new_err(
428                        "No tensor engine references remaining",
429                    )
430                }),
431                None => Err(pyo3::exceptions::PyRuntimeError::new_err(
432                    "No active pickling state",
433                )),
434            }
435        })
436        .map(|obj| obj.clone_ref(py))
437}
438
439/// Pop a pending pickle from the active pickling state.
440///
441/// This is called from Python during unpickling to retrieve the PyShared
442/// object that was deferred during pickling.
443#[pyfunction]
444fn pop_pending_pickle(py: Python<'_>) -> PyResult<Py<PyShared>> {
445    ACTIVE_PICKLING_STATE.with(|cell| {
446        let mut state = cell.borrow_mut();
447        match state.as_mut() {
448            Some(s) => {
449                let shared = s.pending_pickles.pop_front().ok_or_else(|| {
450                    pyo3::exceptions::PyRuntimeError::new_err("No pending pickles remaining")
451                })?;
452                Ok(shared.clone_ref(py))
453            }
454            None => Err(pyo3::exceptions::PyRuntimeError::new_err(
455                "No active pickling state",
456            )),
457        }
458    })
459}
460
461/// Push a pending pickle to the active pickling state (Rust-only).
462///
463/// This is used by __reduce__ implementations to register a PyShared
464/// that must be resolved before the pickle is complete.
465///
466/// Returns an error if there is no active pickling state or if pending
467/// pickles are not allowed in the current pickling context.
468pub fn push_pending_pickle(py_shared: Py<PyShared>) -> PyResult<()> {
469    ACTIVE_PICKLING_STATE.with(|cell| {
470        let mut state = cell.borrow_mut();
471        match state.as_mut() {
472            Some(s) => {
473                if !s.allow_pending_pickles {
474                    return Err(pyo3::exceptions::PyRuntimeError::new_err(
475                        "Pending pickles are not allowed in the current pickling context",
476                    ));
477                }
478                s.pending_pickles.push_back(py_shared);
479                Ok(())
480            }
481            None => Err(pyo3::exceptions::PyRuntimeError::new_err(
482                "No active pickling state",
483            )),
484        }
485    })
486}
487
488/// Reduce a PyShared for pickling.
489///
490/// This function implements the pickle protocol for PyShared:
491/// 1. If the shared is already finished, return (Shared.from_value, (value,))
492/// 2. If pending pickles are allowed, push it as a pending pickle and return (pop_pending_pickle, ())
493/// 3. Otherwise, block on the shared and return (Shared.from_value, (value,))
494pub fn reduce_shared<'py>(
495    py: Python<'py>,
496    py_shared: &Bound<'py, PyShared>,
497) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyTuple>)> {
498    // First, check if the shared is already finished
499    if let Some(value) = py_shared.borrow().poll()? {
500        let from_value = shared_class(py).getattr("from_value")?;
501        let args = PyTuple::new(py, [value])?;
502        return Ok((from_value, args));
503    }
504
505    // Try to push as a pending pickle (will fail if not allowed or no active state)
506    let py_shared_py: Py<PyShared> = py_shared.clone().unbind();
507    if push_pending_pickle(py_shared_py).is_ok() {
508        let pop_fn = pop_pending_pickle_fn(py);
509        let args = PyTuple::empty(py);
510        return Ok((pop_fn, args));
511    }
512
513    // Fall back to blocking on the shared
514    let value = PyShared::block_on(py_shared.borrow(), py)?;
515    let from_value = shared_class(py).getattr("from_value")?;
516    let args = PyTuple::new(py, [value])?;
517    Ok((from_value, args))
518}
519
520/// Pickle a Python object into a [`Buffer`].
521///
522/// This is the shared pickling core. The caller is responsible for setting up
523/// the [`ActivePicklingGuard`] before calling this function.
524fn pickle_into_buffer(py: Python<'_>, obj: &Py<PyAny>, buffer: &Py<Buffer>) -> PyResult<()> {
525    // Ensure the cloudpickle monkeypatch for RemoteImportLoader is applied.
526    pickle_monkeypatch(py);
527
528    // If torch is loaded, use the torch-aware pickler that handles
529    // torch storage types via dispatch_table.
530    if maybe_torch_fn(py).call0()?.is_truthy()? {
531        torch_dump_fn(py).call1((obj, buffer.bind(py)))?;
532    } else {
533        let pickler = cloudpickle(py)
534            .getattr("Pickler")?
535            .call1((buffer.bind(py),))?;
536        pickler.call_method1("dump", (obj,))?;
537    }
538
539    Ok(())
540}
541
542/// Pickle a Python object and return the serialized data as a [`Part`].
543///
544/// This is a simplified variant of [`pickle`] that disallows pending pickles
545/// and tensor engine references, and returns the raw serialized bytes instead
546/// of a [`PicklingState`].
547pub fn pickle_to_part(py: Python<'_>, obj: &Py<PyAny>) -> PyResult<Part> {
548    let active = ActivePicklingState::new(false, false);
549    let buffer = Py::new(py, Buffer::default())?;
550    let _guard = ActivePicklingGuard::enter(active);
551
552    pickle_into_buffer(py, obj, &buffer)?;
553
554    Ok(buffer.borrow_mut(py).take_part())
555}
556
557/// Pickle an object with support for pending pickles and tensor engine references.
558///
559/// This function creates a PicklingState and calls cloudpickle.dumps with
560/// an active thread-local PicklingState, allowing __reduce__ implementations
561/// to push tensor engine references and pending pickles.
562///
563/// # Arguments
564/// * `obj` - The Python object to pickle
565/// * `allow_pending_pickles` - If true, allow PyShared values to be registered as pending
566/// * `allow_tensor_engine_references` - If true, allow tensor engine references to be registered
567///
568/// # Returns
569/// A PicklingState containing the pickled buffer and any registered references/pending pickles
570#[pyfunction]
571#[pyo3(signature = (obj, allow_pending_pickles=true, allow_tensor_engine_references=true))]
572pub fn pickle(
573    py: Python<'_>,
574    obj: Py<PyAny>,
575    allow_pending_pickles: bool,
576    allow_tensor_engine_references: bool,
577) -> PyResult<PicklingState> {
578    let active = ActivePicklingState::new(allow_pending_pickles, allow_tensor_engine_references);
579    let buffer = Py::new(py, Buffer::default())?;
580    let _guard = ActivePicklingGuard::enter(active);
581
582    pickle_into_buffer(py, &obj, &buffer)?;
583
584    // Take the state (which may have been modified during pickling).
585    // The guard will restore the previous state on drop.
586    let active = ACTIVE_PICKLING_STATE
587        .with(|cell| cell.borrow_mut().take())
588        .expect("active pickling state should still be set");
589
590    // Convert to frozen PicklingState
591    let frozen_buffer = buffer.borrow_mut(py).freeze();
592    let inner = active.into_pickling_state(frozen_buffer);
593    Ok(PicklingState { inner: Some(inner) })
594}
595
596pub(crate) fn unpickle<'py>(
597    py: Python<'py>,
598    buffer: crate::buffers::FrozenBuffer,
599) -> PyResult<Bound<'py, PyAny>> {
600    _unpickle(py).call1((buffer.into_py_any(py)?,))
601}
602
603/// Register the pickle Python bindings into the given module.
604pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
605    module.add_class::<PicklingState>()?;
606    module.add_class::<PendingMessage>()?;
607    module.add_function(wrap_pyfunction!(pickle, module)?)?;
608    module.add_function(wrap_pyfunction!(
609        push_tensor_engine_reference_if_active,
610        module
611    )?)?;
612    module.add_function(wrap_pyfunction!(pop_tensor_engine_reference, module)?)?;
613    module.add_function(wrap_pyfunction!(pop_pending_pickle, module)?)?;
614    Ok(())
615}