Skip to main content

monarch_types/
python.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10
11use pyo3::Bound;
12use pyo3::IntoPyObject;
13use pyo3::IntoPyObjectExt;
14use pyo3::PyAny;
15use pyo3::PyResult;
16use pyo3::Python;
17use pyo3::prelude::*;
18use pyo3::types::PyDict;
19use pyo3::types::PyNone;
20use pyo3::types::PyTuple;
21use serde::Deserialize;
22use serde::Serialize;
23
24/// A variant of `pyo3::IntoPyObject` used to wrap unsafe impls and propagates the
25/// unsafety to the caller.
26pub trait TryIntoPyObjectUnsafe<'py, P> {
27    /// # Safety
28    ///
29    /// The caller must ensure self is valid for the duration of the call
30    /// and that the type-erased pointer invariants of the trait are upheld.
31    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, P>>;
32}
33
34/// Helper impl for casting into args for python functions calls.
35impl<'a, 'py, T> TryIntoPyObjectUnsafe<'py, PyTuple> for &'a Vec<T>
36where
37    &'a T: TryIntoPyObjectUnsafe<'py, PyAny>,
38    T: 'a,
39{
40    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
41        PyTuple::new(
42            py,
43            self.iter()
44                // SAFETY: Safety requirements are propagated via the `unsafe`
45                // tag on this method.
46                .map(|v| unsafe { v.try_to_object_unsafe(py) })
47                .collect::<Result<Vec<_>, _>>()?,
48        )
49    }
50}
51
52/// Helper impl for casting into kwargs for python functions calls.
53impl<'a, 'py, K, V> TryIntoPyObjectUnsafe<'py, PyDict> for &'a HashMap<K, V>
54where
55    &'a K: IntoPyObject<'py> + std::cmp::Eq + std::hash::Hash,
56    &'a V: TryIntoPyObjectUnsafe<'py, PyAny>,
57    K: 'a,
58    V: 'a,
59{
60    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
61        let dict = PyDict::new(py);
62        for (key, val) in self {
63            // SAFETY: Safety requirements are propagated via the `unsafe`
64            // tag on this method.
65            dict.set_item(key, unsafe { val.try_to_object_unsafe(py) }?)?;
66        }
67        Ok(dict)
68    }
69}
70
71/// A wrapper around `PyErr` that contains a serialized traceback.
72#[derive(Debug, Clone, Serialize, Deserialize, derive_more::Error)]
73pub struct SerializablePyErr {
74    pub message: String,
75}
76
77impl SerializablePyErr {
78    pub fn from(py: Python, err: &PyErr) -> Self {
79        // first construct the full traceback including any python frames that were used
80        // to invoke where we currently are. This is pre-pended to the traceback of the
81        // currently unwinded frames (err.traceback())
82        let inspect = py.import("inspect").unwrap();
83        let types = py.import("types").unwrap();
84        let traceback_type = types.getattr("TracebackType").unwrap();
85        let traceback = py.import("traceback").unwrap();
86
87        let mut f = inspect
88            .call_method0("currentframe")
89            .unwrap_or(PyNone::get(py).to_owned().into_any());
90        let mut tb: Bound<'_, PyAny> = err.traceback(py).into_bound_py_any(py).unwrap();
91        while !f.is_none() {
92            let lasti = f.getattr("f_lasti").unwrap();
93            let lineno = f.getattr("f_lineno").unwrap();
94            let back = f.getattr("f_back").unwrap();
95            tb = traceback_type.call1((tb, f, lasti, lineno)).unwrap();
96            f = back;
97        }
98
99        let traceback_exception = traceback.getattr("TracebackException").unwrap();
100
101        let tb = traceback_exception
102            .call1((err.get_type(py), err.value(py), tb))
103            .unwrap();
104
105        let message: String = tb
106            .getattr("format")
107            .unwrap()
108            .call0()
109            .unwrap()
110            .try_iter()
111            .unwrap()
112            .map(|x| -> String { x.unwrap().extract().unwrap() })
113            .collect::<Vec<String>>()
114            .join("");
115
116        Self { message }
117    }
118
119    pub fn from_fn<'py>(py: Python<'py>) -> impl Fn(PyErr) -> Self + 'py {
120        move |err| Self::from(py, &err)
121    }
122}
123
124impl std::fmt::Display for SerializablePyErr {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        write!(f, "{}", self.message)
127    }
128}
129
130impl<T> From<T> for SerializablePyErr
131where
132    T: Into<PyErr>,
133{
134    fn from(value: T) -> Self {
135        Python::attach(|py| SerializablePyErr::from(py, &value.into()))
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use pyo3::Python;
142    use pyo3::ffi::c_str;
143    use pyo3::indoc::indoc;
144    use pyo3::prelude::*;
145    use timed_test::async_timed_test;
146
147    use crate::SerializablePyErr;
148
149    #[async_timed_test(timeout_secs = 60)]
150    async fn test_serializable_py_err() {
151        Python::initialize();
152        let _unused = Python::attach(|py| {
153            let module = PyModule::from_code(
154                py,
155                c_str!(indoc! {r#"
156                        def func1():
157                            raise Exception("test")
158
159                        def func2():
160                            func1()
161
162                        def func3():
163                            func2()
164                    "#}),
165                c_str!("test_helpers.py"),
166                c_str!("test_helpers"),
167            )?;
168
169            let err = SerializablePyErr::from(py, &module.call_method0("func3").unwrap_err());
170            assert_eq!(
171                err.message.as_str(),
172                indoc! {r#"
173                    Traceback (most recent call last):
174                      File "test_helpers.py", line 8, in func3
175                      File "test_helpers.py", line 5, in func2
176                      File "test_helpers.py", line 2, in func1
177                    Exception: test
178                "#}
179            );
180
181            PyResult::Ok(())
182        });
183    }
184}