hyperactor/
signal_handler.rs1#![cfg(unix)]
10
11use std::collections::HashMap;
12use std::fmt;
13use std::future::Future;
14use std::io;
15use std::mem::MaybeUninit;
16use std::pin::Pin;
17use std::ptr;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::sync::OnceLock;
21
22use nix::libc;
23use nix::sys::signal;
24use tokio_stream::StreamExt;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SignalDisposition {
33 Ignored,
35 Default,
37 Custom,
40}
41
42impl fmt::Display for SignalDisposition {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 SignalDisposition::Ignored => write!(f, "ignored"),
46 SignalDisposition::Default => write!(f, "default"),
47 SignalDisposition::Custom => write!(f, "custom handler"),
48 }
49 }
50}
51
52pub fn query_signal_disposition(signum: libc::c_int) -> io::Result<SignalDisposition> {
67 unsafe {
79 let mut old = MaybeUninit::<libc::sigaction>::uninit();
81 if libc::sigaction(signum, ptr::null(), old.as_mut_ptr()) != 0 {
82 return Err(io::Error::last_os_error());
83 }
84 let old = old.assume_init();
85
86 if (old.sa_flags & libc::SA_SIGINFO) != 0 {
89 return Ok(SignalDisposition::Custom);
90 }
91
92 let handler = old.sa_sigaction;
96 let ignore = libc::SIG_IGN;
97 let default = libc::SIG_DFL;
98
99 match handler {
100 h if h == ignore => Ok(SignalDisposition::Ignored),
101 h if h == default => Ok(SignalDisposition::Default),
102 _ => Ok(SignalDisposition::Custom),
103 }
104 }
105}
106
107pub fn sigpipe_disposition() -> io::Result<SignalDisposition> {
116 query_signal_disposition(libc::SIGPIPE)
117}
118
119type AsyncCleanupCallback = Pin<Box<dyn Future<Output = ()> + Send>>;
120
121pub(crate) struct GlobalSignalManager {
123 cleanup_callbacks: Arc<Mutex<HashMap<u64, AsyncCleanupCallback>>>,
124 next_id: Arc<Mutex<u64>>,
125 _listener: tokio::task::JoinHandle<()>,
126}
127
128impl GlobalSignalManager {
129 fn new() -> Self {
130 let listener = tokio::spawn(async move {
131 if let Ok(mut signals) =
132 signal_hook_tokio::Signals::new([signal::SIGINT as i32, signal::SIGTERM as i32])
133 && let Some(signal) = signals.next().await
134 {
135 crate::stdio_redirect::handle_broken_pipes();
139
140 tracing::info!("received signal: {}", signal);
141
142 get_signal_manager().execute_all_cleanups().await;
143
144 match signal::Signal::try_from(signal) {
145 Ok(sig) => {
146 if let Err(err) =
147 unsafe { signal::signal(sig, signal::SigHandler::SigDfl) }
149 {
150 tracing::error!(
151 "failed to restore default signal handler for {}: {}",
152 sig,
153 err
154 );
155 }
156
157 if let Err(err) = signal::raise(sig) {
159 tracing::error!("failed to re-raise signal {}: {}", sig, err);
160 }
161 }
162 Err(err) => {
163 tracing::error!("failed to convert signal {}: {}", signal, err);
164 }
165 }
166 }
167 });
168 Self {
169 cleanup_callbacks: Arc::new(Mutex::new(HashMap::new())),
170 next_id: Arc::new(Mutex::new(0)),
171 _listener: listener,
172 }
173 }
174
175 fn register_cleanup(&self, callback: AsyncCleanupCallback) -> u64 {
177 let mut next_id = self.next_id.lock().unwrap_or_else(|e| e.into_inner());
178 let id = *next_id;
179 *next_id += 1;
180 drop(next_id);
181
182 let mut callbacks = self
183 .cleanup_callbacks
184 .lock()
185 .unwrap_or_else(|e| e.into_inner());
186 callbacks.insert(id, callback);
187 tracing::info!(
188 "process {} registered signal cleanup callback with ID: {}",
189 std::process::id(),
190 id
191 );
192 id
193 }
194
195 fn unregister_cleanup(&self, id: u64) {
197 let mut callbacks = self
198 .cleanup_callbacks
199 .lock()
200 .unwrap_or_else(|e| e.into_inner());
201 if callbacks.remove(&id).is_some() {
202 tracing::info!("unregistered signal cleanup callback with ID: {}", id);
203 } else {
204 tracing::warn!(
205 "attempted to unregister non-existent cleanup callback with ID: {}",
206 id
207 );
208 }
209 }
210
211 async fn execute_all_cleanups(&self) {
213 let callbacks = {
214 let mut callbacks = self
215 .cleanup_callbacks
216 .lock()
217 .unwrap_or_else(|e| e.into_inner());
218 std::mem::take(&mut *callbacks)
219 };
220
221 let futures = callbacks.into_iter().map(|(id, future)| async move {
222 tracing::debug!("executing cleanup callback with ID: {}", id);
223 future.await;
224 });
225
226 futures::future::join_all(futures).await;
227 }
228}
229
230static SIGNAL_MANAGER: OnceLock<GlobalSignalManager> = OnceLock::new();
232
233pub(crate) fn get_signal_manager() -> &'static GlobalSignalManager {
235 SIGNAL_MANAGER.get_or_init(GlobalSignalManager::new)
236}
237
238pub struct SignalCleanupGuard {
240 id: u64,
241}
242
243impl SignalCleanupGuard {
244 fn new(id: u64) -> Self {
245 Self { id }
246 }
247
248 pub fn id(&self) -> u64 {
250 self.id
251 }
252}
253
254impl Drop for SignalCleanupGuard {
255 fn drop(&mut self) {
256 get_signal_manager().unregister_cleanup(self.id);
257 }
258}
259
260pub fn register_signal_cleanup(callback: AsyncCleanupCallback) -> u64 {
263 get_signal_manager().register_cleanup(callback)
264}
265
266pub fn register_signal_cleanup_scoped(callback: AsyncCleanupCallback) -> SignalCleanupGuard {
269 let id = get_signal_manager().register_cleanup(callback);
270 SignalCleanupGuard::new(id)
271}
272
273pub fn unregister_signal_cleanup(id: u64) {
275 get_signal_manager().unregister_cleanup(id);
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn sigpipe_is_ignored_by_default() {
284 let disp = sigpipe_disposition().expect("query failed");
285 assert_eq!(
286 disp,
287 SignalDisposition::Ignored,
288 "expected SIGPIPE to be ignored by default"
289 );
290 }
291}