hyperactor/
signal_handler.rs1use std::collections::HashMap;
10use std::fmt;
11use std::future::Future;
12use std::io;
13use std::mem::MaybeUninit;
14use std::pin::Pin;
15use std::ptr;
16use std::sync::Arc;
17use std::sync::Mutex;
18use std::sync::OnceLock;
19
20use nix::libc;
21use nix::sys::signal;
22use tokio_stream::StreamExt;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum SignalDisposition {
31 Ignored,
33 Default,
35 Custom,
38}
39
40impl fmt::Display for SignalDisposition {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 SignalDisposition::Ignored => write!(f, "ignored"),
44 SignalDisposition::Default => write!(f, "default"),
45 SignalDisposition::Custom => write!(f, "custom handler"),
46 }
47 }
48}
49
50pub fn query_signal_disposition(signum: libc::c_int) -> io::Result<SignalDisposition> {
65 unsafe {
77 let mut old = MaybeUninit::<libc::sigaction>::uninit();
79 if libc::sigaction(signum, ptr::null(), old.as_mut_ptr()) != 0 {
80 return Err(io::Error::last_os_error());
81 }
82 let old = old.assume_init();
83
84 if (old.sa_flags & libc::SA_SIGINFO) != 0 {
87 return Ok(SignalDisposition::Custom);
88 }
89
90 let handler = old.sa_sigaction;
94 let ignore = libc::SIG_IGN;
95 let default = libc::SIG_DFL;
96
97 match handler {
98 h if h == ignore => Ok(SignalDisposition::Ignored),
99 h if h == default => Ok(SignalDisposition::Default),
100 _ => Ok(SignalDisposition::Custom),
101 }
102 }
103}
104
105pub fn sigpipe_disposition() -> io::Result<SignalDisposition> {
114 query_signal_disposition(libc::SIGPIPE)
115}
116
117type AsyncCleanupCallback = Pin<Box<dyn Future<Output = ()> + Send>>;
118
119pub(crate) struct GlobalSignalManager {
121 cleanup_callbacks: Arc<Mutex<HashMap<u64, AsyncCleanupCallback>>>,
122 next_id: Arc<Mutex<u64>>,
123 _listener: tokio::task::JoinHandle<()>,
124}
125
126impl GlobalSignalManager {
127 fn new() -> Self {
128 let listener = tokio::spawn(async move {
129 if let Ok(mut signals) =
130 signal_hook_tokio::Signals::new([signal::SIGINT as i32, signal::SIGTERM as i32])
131 {
132 if let Some(signal) = signals.next().await {
133 tracing::info!("received signal: {}", signal);
134
135 get_signal_manager().execute_all_cleanups().await;
136
137 match signal::Signal::try_from(signal) {
138 Ok(sig) => {
139 if let Err(err) =
140 unsafe { signal::signal(sig, signal::SigHandler::SigDfl) }
142 {
143 tracing::error!(
144 "failed to restore default signal handler for {}: {}",
145 sig,
146 err
147 );
148 }
149
150 if let Err(err) = signal::raise(sig) {
152 tracing::error!("failed to re-raise signal {}: {}", sig, err);
153 }
154 }
155 Err(err) => {
156 tracing::error!("failed to convert signal {}: {}", signal, err);
157 }
158 }
159 }
160 }
161 });
162 Self {
163 cleanup_callbacks: Arc::new(Mutex::new(HashMap::new())),
164 next_id: Arc::new(Mutex::new(0)),
165 _listener: listener,
166 }
167 }
168
169 fn register_cleanup(&self, callback: AsyncCleanupCallback) -> u64 {
171 let mut next_id = self.next_id.lock().unwrap_or_else(|e| e.into_inner());
172 let id = *next_id;
173 *next_id += 1;
174 drop(next_id);
175
176 let mut callbacks = self
177 .cleanup_callbacks
178 .lock()
179 .unwrap_or_else(|e| e.into_inner());
180 callbacks.insert(id, callback);
181 tracing::info!("registered signal cleanup callback with ID: {}", id);
182 id
183 }
184
185 fn unregister_cleanup(&self, id: u64) {
187 let mut callbacks = self
188 .cleanup_callbacks
189 .lock()
190 .unwrap_or_else(|e| e.into_inner());
191 if callbacks.remove(&id).is_some() {
192 tracing::info!("unregistered signal cleanup callback with ID: {}", id);
193 } else {
194 tracing::warn!(
195 "attempted to unregister non-existent cleanup callback with ID: {}",
196 id
197 );
198 }
199 }
200
201 async fn execute_all_cleanups(&self) {
203 let callbacks = {
204 let mut callbacks = self
205 .cleanup_callbacks
206 .lock()
207 .unwrap_or_else(|e| e.into_inner());
208 std::mem::take(&mut *callbacks)
209 };
210
211 let futures = callbacks.into_iter().map(|(id, future)| async move {
212 tracing::debug!("executing cleanup callback with ID: {}", id);
213 future.await;
214 });
215
216 futures::future::join_all(futures).await;
217 }
218}
219
220static SIGNAL_MANAGER: OnceLock<GlobalSignalManager> = OnceLock::new();
222
223pub(crate) fn get_signal_manager() -> &'static GlobalSignalManager {
225 SIGNAL_MANAGER.get_or_init(GlobalSignalManager::new)
226}
227
228pub struct SignalCleanupGuard {
230 id: u64,
231}
232
233impl SignalCleanupGuard {
234 fn new(id: u64) -> Self {
235 Self { id }
236 }
237
238 pub fn id(&self) -> u64 {
240 self.id
241 }
242}
243
244impl Drop for SignalCleanupGuard {
245 fn drop(&mut self) {
246 get_signal_manager().unregister_cleanup(self.id);
247 }
248}
249
250pub fn register_signal_cleanup(callback: AsyncCleanupCallback) -> u64 {
253 get_signal_manager().register_cleanup(callback)
254}
255
256pub fn register_signal_cleanup_scoped(callback: AsyncCleanupCallback) -> SignalCleanupGuard {
259 let id = get_signal_manager().register_cleanup(callback);
260 SignalCleanupGuard::new(id)
261}
262
263pub fn unregister_signal_cleanup(id: u64) {
265 get_signal_manager().unregister_cleanup(id);
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn sigpipe_is_ignored_by_default() {
274 let disp = sigpipe_disposition().expect("query failed");
275 assert_eq!(
276 disp,
277 SignalDisposition::Ignored,
278 "expected SIGPIPE to be ignored by default"
279 );
280 }
281}