hyperactor/
signal_handler.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::fmt;
11use std::future::Future;
12use std::io;
13use std::mem::MaybeUninit;
14use std::pin::Pin;
15use std::ptr;
16use std::sync::Arc;
17use std::sync::Mutex;
18use std::sync::OnceLock;
19
20use nix::libc;
21use nix::sys::signal;
22use tokio_stream::StreamExt;
23
24/// This type describes how a signal is currently handled by the
25/// process.
26///
27/// This is derived from the kernel's `sigaction` for a given signal,
28/// normalized into three categories:
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum SignalDisposition {
31    /// The signal is explicitly ignored (`SIG_IGN`).
32    Ignored,
33    /// The default action for the signal will occur (`SIG_DFL`).
34    Default,
35    /// A custom signal handler has been installed (either via
36    /// `sa_handler` or `sa_sigaction`).
37    Custom,
38}
39
40impl fmt::Display for SignalDisposition {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        match self {
43            SignalDisposition::Ignored => write!(f, "ignored"),
44            SignalDisposition::Default => write!(f, "default"),
45            SignalDisposition::Custom => write!(f, "custom handler"),
46        }
47    }
48}
49
50/// Query the current disposition of a signal (`signum`).
51///
52/// This inspects the kernel's `sigaction` state for the given signal
53/// without changing it (by passing `act = NULL`).
54///
55/// Returns:
56/// - [`SignalDisposition::Ignored`] if the handler is `SIG_IGN`
57/// - [`SignalDisposition::Default`] if the handler is `SIG_DFL`
58/// - [`SignalDisposition::Custom`] if a user-installed handler is
59///   present
60///
61/// # Errors
62/// Returns an `io::Error` if the underlying `sigaction` call fails,
63/// for example if `signum` is invalid.
64pub fn query_signal_disposition(signum: libc::c_int) -> io::Result<SignalDisposition> {
65    // SAFETY:
66    // - We call `libc::sigaction` with `act = NULL` to query state
67    //   only.
68    // - `old` is a properly allocated `MaybeUninit<sigaction>`, large
69    //    enough to hold the kernel response.
70    // - `sigaction` will write to `old` before we read it.
71    // - Interpreting the union field (`sa_sigaction`) as a function
72    //   pointer is safe here because we only compare it against the
73    //   constants `SIG_IGN` and `SIG_DFL`.
74    // - No undefined behavior results because we never call the
75    //   pointer, we only compare its value.
76    unsafe {
77        // Query-only: act = NULL, oldact = &old.
78        let mut old = MaybeUninit::<libc::sigaction>::uninit();
79        if libc::sigaction(signum, ptr::null(), old.as_mut_ptr()) != 0 {
80            return Err(io::Error::last_os_error());
81        }
82        let old = old.assume_init();
83
84        // If SA_SIGINFO is set, the union stores a 3-arg handler =>
85        // custom handler.
86        if (old.sa_flags & libc::SA_SIGINFO) != 0 {
87            return Ok(SignalDisposition::Custom);
88        }
89
90        // Otherwise the union stores the 1-arg handler. `libc`
91        // exposes it as `sa_sigaction` in Rust. Compare the
92        // function-pointer value against `SIG_IGN`/`SIG_DFL`.
93        let handler = old.sa_sigaction;
94        let ignore = libc::SIG_IGN;
95        let default = libc::SIG_DFL;
96
97        match handler {
98            h if h == ignore => Ok(SignalDisposition::Ignored),
99            h if h == default => Ok(SignalDisposition::Default),
100            _ => Ok(SignalDisposition::Custom),
101        }
102    }
103}
104
105/// Returns the current [`SignalDisposition`] of `SIGPIPE`.
106///
107/// This is a convenience wrapper around [`query_signal_disposition`]
108/// that checks specifically for the `SIGPIPE` signal. By default,
109/// Rust's runtime startup code installs `SIG_IGN` for `SIGPIPE` (see
110/// <https://github.com/rust-lang/rust/issues/62569>), but this
111/// function lets you confirm whether it is currently ignored, set to
112/// the default action, or handled by a custom handler.
113pub fn sigpipe_disposition() -> io::Result<SignalDisposition> {
114    query_signal_disposition(libc::SIGPIPE)
115}
116
117type AsyncCleanupCallback = Pin<Box<dyn Future<Output = ()> + Send>>;
118
119/// Global signal manager that coordinates cleanup across all signal handlers
120pub(crate) struct GlobalSignalManager {
121    cleanup_callbacks: Arc<Mutex<HashMap<u64, AsyncCleanupCallback>>>,
122    next_id: Arc<Mutex<u64>>,
123    _listener: tokio::task::JoinHandle<()>,
124}
125
126impl GlobalSignalManager {
127    fn new() -> Self {
128        let listener = tokio::spawn(async move {
129            if let Ok(mut signals) =
130                signal_hook_tokio::Signals::new([signal::SIGINT as i32, signal::SIGTERM as i32])
131            {
132                if let Some(signal) = signals.next().await {
133                    tracing::info!("received signal: {}", signal);
134
135                    get_signal_manager().execute_all_cleanups().await;
136
137                    match signal::Signal::try_from(signal) {
138                        Ok(sig) => {
139                            if let Err(err) =
140                                // SAFETY: We're setting the handle to SigDfl (default system behaviour)
141                                unsafe { signal::signal(sig, signal::SigHandler::SigDfl) }
142                            {
143                                tracing::error!(
144                                    "failed to restore default signal handler for {}: {}",
145                                    sig,
146                                    err
147                                );
148                            }
149
150                            // Re-raise the signal to trigger default behavior (process termination)
151                            if let Err(err) = signal::raise(sig) {
152                                tracing::error!("failed to re-raise signal {}: {}", sig, err);
153                            }
154                        }
155                        Err(err) => {
156                            tracing::error!("failed to convert signal {}: {}", signal, err);
157                        }
158                    }
159                }
160            }
161        });
162        Self {
163            cleanup_callbacks: Arc::new(Mutex::new(HashMap::new())),
164            next_id: Arc::new(Mutex::new(0)),
165            _listener: listener,
166        }
167    }
168
169    /// Register a cleanup callback and return a unique ID for later unregistration
170    fn register_cleanup(&self, callback: AsyncCleanupCallback) -> u64 {
171        let mut next_id = self.next_id.lock().unwrap_or_else(|e| e.into_inner());
172        let id = *next_id;
173        *next_id += 1;
174        drop(next_id);
175
176        let mut callbacks = self
177            .cleanup_callbacks
178            .lock()
179            .unwrap_or_else(|e| e.into_inner());
180        callbacks.insert(id, callback);
181        tracing::info!("registered signal cleanup callback with ID: {}", id);
182        id
183    }
184
185    /// Unregister a cleanup callback by ID
186    fn unregister_cleanup(&self, id: u64) {
187        let mut callbacks = self
188            .cleanup_callbacks
189            .lock()
190            .unwrap_or_else(|e| e.into_inner());
191        if callbacks.remove(&id).is_some() {
192            tracing::info!("unregistered signal cleanup callback with ID: {}", id);
193        } else {
194            tracing::warn!(
195                "attempted to unregister non-existent cleanup callback with ID: {}",
196                id
197            );
198        }
199    }
200
201    /// Execute all registered cleanup callbacks asynchronously
202    async fn execute_all_cleanups(&self) {
203        let callbacks = {
204            let mut callbacks = self
205                .cleanup_callbacks
206                .lock()
207                .unwrap_or_else(|e| e.into_inner());
208            std::mem::take(&mut *callbacks)
209        };
210
211        let futures = callbacks.into_iter().map(|(id, future)| async move {
212            tracing::debug!("executing cleanup callback with ID: {}", id);
213            future.await;
214        });
215
216        futures::future::join_all(futures).await;
217    }
218}
219
220/// Global instance of the signal manager
221static SIGNAL_MANAGER: OnceLock<GlobalSignalManager> = OnceLock::new();
222
223/// Get the global signal manager instance
224pub(crate) fn get_signal_manager() -> &'static GlobalSignalManager {
225    SIGNAL_MANAGER.get_or_init(GlobalSignalManager::new)
226}
227
228/// RAII guard that automatically unregisters a signal cleanup callback when dropped
229pub struct SignalCleanupGuard {
230    id: u64,
231}
232
233impl SignalCleanupGuard {
234    fn new(id: u64) -> Self {
235        Self { id }
236    }
237
238    /// Get the ID of the registered cleanup callback
239    pub fn id(&self) -> u64 {
240        self.id
241    }
242}
243
244impl Drop for SignalCleanupGuard {
245    fn drop(&mut self) {
246        get_signal_manager().unregister_cleanup(self.id);
247    }
248}
249
250/// Register a cleanup callback to be executed on SIGINT/SIGTERM
251/// Returns a unique ID that can be used to unregister the callback
252pub fn register_signal_cleanup(callback: AsyncCleanupCallback) -> u64 {
253    get_signal_manager().register_cleanup(callback)
254}
255
256/// Register a scoped cleanup callback to be executed on SIGINT/SIGTERM
257/// Returns a guard that automatically unregisters the callback when dropped
258pub fn register_signal_cleanup_scoped(callback: AsyncCleanupCallback) -> SignalCleanupGuard {
259    let id = get_signal_manager().register_cleanup(callback);
260    SignalCleanupGuard::new(id)
261}
262
263/// Unregister a previously registered cleanup callback
264pub fn unregister_signal_cleanup(id: u64) {
265    get_signal_manager().unregister_cleanup(id);
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn sigpipe_is_ignored_by_default() {
274        let disp = sigpipe_disposition().expect("query failed");
275        assert_eq!(
276            disp,
277            SignalDisposition::Ignored,
278            "expected SIGPIPE to be ignored by default"
279        );
280    }
281}