monarch_types/
python.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10
11use pyo3::Bound;
12use pyo3::IntoPyObject;
13use pyo3::IntoPyObjectExt;
14use pyo3::PyAny;
15use pyo3::PyResult;
16use pyo3::Python;
17use pyo3::prelude::*;
18use pyo3::types::PyDict;
19use pyo3::types::PyNone;
20use pyo3::types::PyTuple;
21use serde::Deserialize;
22use serde::Serialize;
23
24/// A variant of `pyo3::IntoPyObject` used to wrap unsafe impls and propagates the
25/// unsafety to the caller.
26pub trait TryIntoPyObjectUnsafe<'py, P> {
27    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, P>>;
28}
29
30/// Helper impl for casting into args for python functions calls.
31impl<'a, 'py, T> TryIntoPyObjectUnsafe<'py, PyTuple> for &'a Vec<T>
32where
33    &'a T: TryIntoPyObjectUnsafe<'py, PyAny>,
34    T: 'a,
35{
36    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
37        PyTuple::new(
38            py,
39            self.iter()
40                // SAFETY: Safety requirements are propagated via the `unsafe`
41                // tag on this method.
42                .map(|v| unsafe { v.try_to_object_unsafe(py) })
43                .collect::<Result<Vec<_>, _>>()?,
44        )
45    }
46}
47
48/// Helper impl for casting into kwargs for python functions calls.
49impl<'a, 'py, K, V> TryIntoPyObjectUnsafe<'py, PyDict> for &'a HashMap<K, V>
50where
51    &'a K: IntoPyObject<'py> + std::cmp::Eq + std::hash::Hash,
52    &'a V: TryIntoPyObjectUnsafe<'py, PyAny>,
53    K: 'a,
54    V: 'a,
55{
56    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
57        let dict = PyDict::new(py);
58        for (key, val) in self {
59            // SAFETY: Safety requirements are propagated via the `unsafe`
60            // tag on this method.
61            dict.set_item(key, unsafe { val.try_to_object_unsafe(py) }?)?;
62        }
63        Ok(dict)
64    }
65}
66
67/// A wrapper around `PyErr` that contains a serialized traceback.
68#[derive(Debug, Clone, Serialize, Deserialize, derive_more::Error)]
69pub struct SerializablePyErr {
70    pub message: String,
71}
72
73impl SerializablePyErr {
74    pub fn from(py: Python, err: &PyErr) -> Self {
75        // first construct the full traceback including any python frames that were used
76        // to invoke where we currently are. This is pre-pended to the traceback of the
77        // currently unwinded frames (err.traceback())
78        let inspect = py.import("inspect").unwrap();
79        let types = py.import("types").unwrap();
80        let traceback_type = types.getattr("TracebackType").unwrap();
81        let traceback = py.import("traceback").unwrap();
82
83        let mut f = inspect
84            .call_method0("currentframe")
85            .unwrap_or(PyNone::get(py).to_owned().into_any());
86        let mut tb: Bound<'_, PyAny> = err.traceback(py).into_bound_py_any(py).unwrap();
87        while !f.is_none() {
88            let lasti = f.getattr("f_lasti").unwrap();
89            let lineno = f.getattr("f_lineno").unwrap();
90            let back = f.getattr("f_back").unwrap();
91            tb = traceback_type.call1((tb, f, lasti, lineno)).unwrap();
92            f = back;
93        }
94
95        let traceback_exception = traceback.getattr("TracebackException").unwrap();
96
97        let tb = traceback_exception
98            .call1((err.get_type(py), err.value(py), tb))
99            .unwrap();
100
101        let message: String = tb
102            .getattr("format")
103            .unwrap()
104            .call0()
105            .unwrap()
106            .try_iter()
107            .unwrap()
108            .map(|x| -> String { x.unwrap().extract().unwrap() })
109            .collect::<Vec<String>>()
110            .join("");
111
112        Self { message }
113    }
114
115    pub fn from_fn<'py>(py: Python<'py>) -> impl Fn(PyErr) -> Self + 'py {
116        move |err| Self::from(py, &err)
117    }
118}
119
120impl std::fmt::Display for SerializablePyErr {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        write!(f, "{}", self.message)
123    }
124}
125
126impl<T> From<T> for SerializablePyErr
127where
128    T: Into<PyErr>,
129{
130    fn from(value: T) -> Self {
131        Python::with_gil(|py| SerializablePyErr::from(py, &value.into()))
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use pyo3::Python;
138    use pyo3::ffi::c_str;
139    use pyo3::indoc::indoc;
140    use pyo3::prelude::*;
141    use timed_test::async_timed_test;
142
143    use crate::SerializablePyErr;
144
145    #[async_timed_test(timeout_secs = 60)]
146    async fn test_serializable_py_err() {
147        pyo3::prepare_freethreaded_python();
148        let _unused = Python::with_gil(|py| {
149            let module = PyModule::from_code(
150                py,
151                c_str!(indoc! {r#"
152                        def func1():
153                            raise Exception("test")
154
155                        def func2():
156                            func1()
157
158                        def func3():
159                            func2()
160                    "#}),
161                c_str!("test_helpers.py"),
162                c_str!("test_helpers"),
163            )?;
164
165            let err = SerializablePyErr::from(py, &module.call_method0("func3").unwrap_err());
166            assert_eq!(
167                err.message.as_str(),
168                indoc! {r#"
169                    Traceback (most recent call last):
170                      File "test_helpers.py", line 8, in func3
171                      File "test_helpers.py", line 5, in func2
172                      File "test_helpers.py", line 2, in func1
173                    Exception: test
174                "#}
175            );
176
177            PyResult::Ok(())
178        });
179    }
180}