1use std::collections::HashMap;
10
11use pyo3::Bound;
12use pyo3::IntoPyObject;
13use pyo3::IntoPyObjectExt;
14use pyo3::PyAny;
15use pyo3::PyResult;
16use pyo3::Python;
17use pyo3::prelude::*;
18use pyo3::types::PyDict;
19use pyo3::types::PyNone;
20use pyo3::types::PyTuple;
21use serde::Deserialize;
22use serde::Serialize;
23
24pub trait TryIntoPyObjectUnsafe<'py, P> {
27 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, P>>;
32}
33
34impl<'a, 'py, T> TryIntoPyObjectUnsafe<'py, PyTuple> for &'a Vec<T>
36where
37 &'a T: TryIntoPyObjectUnsafe<'py, PyAny>,
38 T: 'a,
39{
40 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
41 PyTuple::new(
42 py,
43 self.iter()
44 .map(|v| unsafe { v.try_to_object_unsafe(py) })
47 .collect::<Result<Vec<_>, _>>()?,
48 )
49 }
50}
51
52impl<'a, 'py, K, V> TryIntoPyObjectUnsafe<'py, PyDict> for &'a HashMap<K, V>
54where
55 &'a K: IntoPyObject<'py> + std::cmp::Eq + std::hash::Hash,
56 &'a V: TryIntoPyObjectUnsafe<'py, PyAny>,
57 K: 'a,
58 V: 'a,
59{
60 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
61 let dict = PyDict::new(py);
62 for (key, val) in self {
63 dict.set_item(key, unsafe { val.try_to_object_unsafe(py) }?)?;
66 }
67 Ok(dict)
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, derive_more::Error)]
73pub struct SerializablePyErr {
74 pub message: String,
75}
76
77impl SerializablePyErr {
78 pub fn from(py: Python, err: &PyErr) -> Self {
79 let inspect = py.import("inspect").unwrap();
83 let types = py.import("types").unwrap();
84 let traceback_type = types.getattr("TracebackType").unwrap();
85 let traceback = py.import("traceback").unwrap();
86
87 let mut f = inspect
88 .call_method0("currentframe")
89 .unwrap_or(PyNone::get(py).to_owned().into_any());
90 let mut tb: Bound<'_, PyAny> = err.traceback(py).into_bound_py_any(py).unwrap();
91 while !f.is_none() {
92 let lasti = f.getattr("f_lasti").unwrap();
93 let lineno = f.getattr("f_lineno").unwrap();
94 let back = f.getattr("f_back").unwrap();
95 tb = traceback_type.call1((tb, f, lasti, lineno)).unwrap();
96 f = back;
97 }
98
99 let traceback_exception = traceback.getattr("TracebackException").unwrap();
100
101 let tb = traceback_exception
102 .call1((err.get_type(py), err.value(py), tb))
103 .unwrap();
104
105 let message: String = tb
106 .getattr("format")
107 .unwrap()
108 .call0()
109 .unwrap()
110 .try_iter()
111 .unwrap()
112 .map(|x| -> String { x.unwrap().extract().unwrap() })
113 .collect::<Vec<String>>()
114 .join("");
115
116 Self { message }
117 }
118
119 pub fn from_fn<'py>(py: Python<'py>) -> impl Fn(PyErr) -> Self + 'py {
120 move |err| Self::from(py, &err)
121 }
122}
123
124impl std::fmt::Display for SerializablePyErr {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(f, "{}", self.message)
127 }
128}
129
130impl<T> From<T> for SerializablePyErr
131where
132 T: Into<PyErr>,
133{
134 fn from(value: T) -> Self {
135 Python::attach(|py| SerializablePyErr::from(py, &value.into()))
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use pyo3::Python;
142 use pyo3::ffi::c_str;
143 use pyo3::indoc::indoc;
144 use pyo3::prelude::*;
145 use timed_test::async_timed_test;
146
147 use crate::SerializablePyErr;
148
149 #[async_timed_test(timeout_secs = 60)]
150 async fn test_serializable_py_err() {
151 Python::initialize();
152 let _unused = Python::attach(|py| {
153 let module = PyModule::from_code(
154 py,
155 c_str!(indoc! {r#"
156 def func1():
157 raise Exception("test")
158
159 def func2():
160 func1()
161
162 def func3():
163 func2()
164 "#}),
165 c_str!("test_helpers.py"),
166 c_str!("test_helpers"),
167 )?;
168
169 let err = SerializablePyErr::from(py, &module.call_method0("func3").unwrap_err());
170 assert_eq!(
171 err.message.as_str(),
172 indoc! {r#"
173 Traceback (most recent call last):
174 File "test_helpers.py", line 8, in func3
175 File "test_helpers.py", line 5, in func2
176 File "test_helpers.py", line 2, in func1
177 Exception: test
178 "#}
179 );
180
181 PyResult::Ok(())
182 });
183 }
184}