1use std::collections::HashMap;
10
11use pyo3::Bound;
12use pyo3::IntoPyObject;
13use pyo3::IntoPyObjectExt;
14use pyo3::PyAny;
15use pyo3::PyResult;
16use pyo3::Python;
17use pyo3::prelude::*;
18use pyo3::types::PyDict;
19use pyo3::types::PyNone;
20use pyo3::types::PyTuple;
21use serde::Deserialize;
22use serde::Serialize;
23
24pub trait TryIntoPyObjectUnsafe<'py, P> {
27 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, P>>;
28}
29
30impl<'a, 'py, T> TryIntoPyObjectUnsafe<'py, PyTuple> for &'a Vec<T>
32where
33 &'a T: TryIntoPyObjectUnsafe<'py, PyAny>,
34 T: 'a,
35{
36 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
37 PyTuple::new(
38 py,
39 self.iter()
40 .map(|v| unsafe { v.try_to_object_unsafe(py) })
43 .collect::<Result<Vec<_>, _>>()?,
44 )
45 }
46}
47
48impl<'a, 'py, K, V> TryIntoPyObjectUnsafe<'py, PyDict> for &'a HashMap<K, V>
50where
51 &'a K: IntoPyObject<'py> + std::cmp::Eq + std::hash::Hash,
52 &'a V: TryIntoPyObjectUnsafe<'py, PyAny>,
53 K: 'a,
54 V: 'a,
55{
56 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
57 let dict = PyDict::new(py);
58 for (key, val) in self {
59 dict.set_item(key, unsafe { val.try_to_object_unsafe(py) }?)?;
62 }
63 Ok(dict)
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize, derive_more::Error)]
69pub struct SerializablePyErr {
70 pub message: String,
71}
72
73impl SerializablePyErr {
74 pub fn from(py: Python, err: &PyErr) -> Self {
75 let inspect = py.import("inspect").unwrap();
79 let types = py.import("types").unwrap();
80 let traceback_type = types.getattr("TracebackType").unwrap();
81 let traceback = py.import("traceback").unwrap();
82
83 let mut f = inspect
84 .call_method0("currentframe")
85 .unwrap_or(PyNone::get(py).to_owned().into_any());
86 let mut tb: Bound<'_, PyAny> = err.traceback(py).into_bound_py_any(py).unwrap();
87 while !f.is_none() {
88 let lasti = f.getattr("f_lasti").unwrap();
89 let lineno = f.getattr("f_lineno").unwrap();
90 let back = f.getattr("f_back").unwrap();
91 tb = traceback_type.call1((tb, f, lasti, lineno)).unwrap();
92 f = back;
93 }
94
95 let traceback_exception = traceback.getattr("TracebackException").unwrap();
96
97 let tb = traceback_exception
98 .call1((err.get_type(py), err.value(py), tb))
99 .unwrap();
100
101 let message: String = tb
102 .getattr("format")
103 .unwrap()
104 .call0()
105 .unwrap()
106 .try_iter()
107 .unwrap()
108 .map(|x| -> String { x.unwrap().extract().unwrap() })
109 .collect::<Vec<String>>()
110 .join("");
111
112 Self { message }
113 }
114
115 pub fn from_fn<'py>(py: Python<'py>) -> impl Fn(PyErr) -> Self + 'py {
116 move |err| Self::from(py, &err)
117 }
118}
119
120impl std::fmt::Display for SerializablePyErr {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 write!(f, "{}", self.message)
123 }
124}
125
126impl<T> From<T> for SerializablePyErr
127where
128 T: Into<PyErr>,
129{
130 fn from(value: T) -> Self {
131 Python::with_gil(|py| SerializablePyErr::from(py, &value.into()))
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use pyo3::Python;
138 use pyo3::ffi::c_str;
139 use pyo3::indoc::indoc;
140 use pyo3::prelude::*;
141 use timed_test::async_timed_test;
142
143 use crate::SerializablePyErr;
144
145 #[async_timed_test(timeout_secs = 60)]
146 async fn test_serializable_py_err() {
147 pyo3::prepare_freethreaded_python();
148 let _unused = Python::with_gil(|py| {
149 let module = PyModule::from_code(
150 py,
151 c_str!(indoc! {r#"
152 def func1():
153 raise Exception("test")
154
155 def func2():
156 func1()
157
158 def func3():
159 func2()
160 "#}),
161 c_str!("test_helpers.py"),
162 c_str!("test_helpers"),
163 )?;
164
165 let err = SerializablePyErr::from(py, &module.call_method0("func3").unwrap_err());
166 assert_eq!(
167 err.message.as_str(),
168 indoc! {r#"
169 Traceback (most recent call last):
170 File "test_helpers.py", line 8, in func3
171 File "test_helpers.py", line 5, in func2
172 File "test_helpers.py", line 2, in func1
173 Exception: test
174 "#}
175 );
176
177 PyResult::Ok(())
178 });
179 }
180}