hyperactor/
signal_handler.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![cfg(unix)]
10
11use std::collections::HashMap;
12use std::fmt;
13use std::future::Future;
14use std::io;
15use std::mem::MaybeUninit;
16use std::pin::Pin;
17use std::ptr;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::sync::OnceLock;
21
22use nix::libc;
23use nix::sys::signal;
24use tokio_stream::StreamExt;
25
26/// This type describes how a signal is currently handled by the
27/// process.
28///
29/// This is derived from the kernel's `sigaction` for a given signal,
30/// normalized into three categories:
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SignalDisposition {
33    /// The signal is explicitly ignored (`SIG_IGN`).
34    Ignored,
35    /// The default action for the signal will occur (`SIG_DFL`).
36    Default,
37    /// A custom signal handler has been installed (either via
38    /// `sa_handler` or `sa_sigaction`).
39    Custom,
40}
41
42impl fmt::Display for SignalDisposition {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            SignalDisposition::Ignored => write!(f, "ignored"),
46            SignalDisposition::Default => write!(f, "default"),
47            SignalDisposition::Custom => write!(f, "custom handler"),
48        }
49    }
50}
51
52/// Query the current disposition of a signal (`signum`).
53///
54/// This inspects the kernel's `sigaction` state for the given signal
55/// without changing it (by passing `act = NULL`).
56///
57/// Returns:
58/// - [`SignalDisposition::Ignored`] if the handler is `SIG_IGN`
59/// - [`SignalDisposition::Default`] if the handler is `SIG_DFL`
60/// - [`SignalDisposition::Custom`] if a user-installed handler is
61///   present
62///
63/// # Errors
64/// Returns an `io::Error` if the underlying `sigaction` call fails,
65/// for example if `signum` is invalid.
66pub fn query_signal_disposition(signum: libc::c_int) -> io::Result<SignalDisposition> {
67    // SAFETY:
68    // - We call `libc::sigaction` with `act = NULL` to query state
69    //   only.
70    // - `old` is a properly allocated `MaybeUninit<sigaction>`, large
71    //    enough to hold the kernel response.
72    // - `sigaction` will write to `old` before we read it.
73    // - Interpreting the union field (`sa_sigaction`) as a function
74    //   pointer is safe here because we only compare it against the
75    //   constants `SIG_IGN` and `SIG_DFL`.
76    // - No undefined behavior results because we never call the
77    //   pointer, we only compare its value.
78    unsafe {
79        // Query-only: act = NULL, oldact = &old.
80        let mut old = MaybeUninit::<libc::sigaction>::uninit();
81        if libc::sigaction(signum, ptr::null(), old.as_mut_ptr()) != 0 {
82            return Err(io::Error::last_os_error());
83        }
84        let old = old.assume_init();
85
86        // If SA_SIGINFO is set, the union stores a 3-arg handler =>
87        // custom handler.
88        if (old.sa_flags & libc::SA_SIGINFO) != 0 {
89            return Ok(SignalDisposition::Custom);
90        }
91
92        // Otherwise the union stores the 1-arg handler. `libc`
93        // exposes it as `sa_sigaction` in Rust. Compare the
94        // function-pointer value against `SIG_IGN`/`SIG_DFL`.
95        let handler = old.sa_sigaction;
96        let ignore = libc::SIG_IGN;
97        let default = libc::SIG_DFL;
98
99        match handler {
100            h if h == ignore => Ok(SignalDisposition::Ignored),
101            h if h == default => Ok(SignalDisposition::Default),
102            _ => Ok(SignalDisposition::Custom),
103        }
104    }
105}
106
107/// Returns the current [`SignalDisposition`] of `SIGPIPE`.
108///
109/// This is a convenience wrapper around [`query_signal_disposition`]
110/// that checks specifically for the `SIGPIPE` signal. By default,
111/// Rust's runtime startup code installs `SIG_IGN` for `SIGPIPE` (see
112/// <https://github.com/rust-lang/rust/issues/62569>), but this
113/// function lets you confirm whether it is currently ignored, set to
114/// the default action, or handled by a custom handler.
115pub fn sigpipe_disposition() -> io::Result<SignalDisposition> {
116    query_signal_disposition(libc::SIGPIPE)
117}
118
119type AsyncCleanupCallback = Pin<Box<dyn Future<Output = ()> + Send>>;
120
121/// Global signal manager that coordinates cleanup across all signal handlers
122pub(crate) struct GlobalSignalManager {
123    cleanup_callbacks: Arc<Mutex<HashMap<u64, AsyncCleanupCallback>>>,
124    next_id: Arc<Mutex<u64>>,
125    _listener: tokio::task::JoinHandle<()>,
126}
127
128impl GlobalSignalManager {
129    fn new() -> Self {
130        let listener = tokio::spawn(async move {
131            if let Ok(mut signals) =
132                signal_hook_tokio::Signals::new([signal::SIGINT as i32, signal::SIGTERM as i32])
133                && let Some(signal) = signals.next().await
134            {
135                // If parent died, stdout/stderr are broken pipes
136                // that cause uninterruptible sleep on write.
137                // Detect and redirect to file to prevent hanging.
138                crate::stdio_redirect::handle_broken_pipes();
139
140                tracing::info!("received signal: {}", signal);
141
142                get_signal_manager().execute_all_cleanups().await;
143
144                match signal::Signal::try_from(signal) {
145                    Ok(sig) => {
146                        if let Err(err) =
147                            // SAFETY: We're setting the handle to SigDfl (default system behaviour)
148                            unsafe { signal::signal(sig, signal::SigHandler::SigDfl) }
149                        {
150                            tracing::error!(
151                                "failed to restore default signal handler for {}: {}",
152                                sig,
153                                err
154                            );
155                        }
156
157                        // Re-raise the signal to trigger default behavior (process termination)
158                        if let Err(err) = signal::raise(sig) {
159                            tracing::error!("failed to re-raise signal {}: {}", sig, err);
160                        }
161                    }
162                    Err(err) => {
163                        tracing::error!("failed to convert signal {}: {}", signal, err);
164                    }
165                }
166            }
167        });
168        Self {
169            cleanup_callbacks: Arc::new(Mutex::new(HashMap::new())),
170            next_id: Arc::new(Mutex::new(0)),
171            _listener: listener,
172        }
173    }
174
175    /// Register a cleanup callback and return a unique ID for later unregistration
176    fn register_cleanup(&self, callback: AsyncCleanupCallback) -> u64 {
177        let mut next_id = self.next_id.lock().unwrap_or_else(|e| e.into_inner());
178        let id = *next_id;
179        *next_id += 1;
180        drop(next_id);
181
182        let mut callbacks = self
183            .cleanup_callbacks
184            .lock()
185            .unwrap_or_else(|e| e.into_inner());
186        callbacks.insert(id, callback);
187        tracing::info!(
188            "process {} registered signal cleanup callback with ID: {}",
189            std::process::id(),
190            id
191        );
192        id
193    }
194
195    /// Unregister a cleanup callback by ID
196    fn unregister_cleanup(&self, id: u64) {
197        let mut callbacks = self
198            .cleanup_callbacks
199            .lock()
200            .unwrap_or_else(|e| e.into_inner());
201        if callbacks.remove(&id).is_some() {
202            tracing::info!("unregistered signal cleanup callback with ID: {}", id);
203        } else {
204            tracing::warn!(
205                "attempted to unregister non-existent cleanup callback with ID: {}",
206                id
207            );
208        }
209    }
210
211    /// Execute all registered cleanup callbacks asynchronously
212    async fn execute_all_cleanups(&self) {
213        let callbacks = {
214            let mut callbacks = self
215                .cleanup_callbacks
216                .lock()
217                .unwrap_or_else(|e| e.into_inner());
218            std::mem::take(&mut *callbacks)
219        };
220
221        let futures = callbacks.into_iter().map(|(id, future)| async move {
222            tracing::debug!("executing cleanup callback with ID: {}", id);
223            future.await;
224        });
225
226        futures::future::join_all(futures).await;
227    }
228}
229
230/// Global instance of the signal manager
231static SIGNAL_MANAGER: OnceLock<GlobalSignalManager> = OnceLock::new();
232
233/// Get the global signal manager instance
234pub(crate) fn get_signal_manager() -> &'static GlobalSignalManager {
235    SIGNAL_MANAGER.get_or_init(GlobalSignalManager::new)
236}
237
238/// RAII guard that automatically unregisters a signal cleanup callback when dropped
239pub struct SignalCleanupGuard {
240    id: u64,
241}
242
243impl SignalCleanupGuard {
244    fn new(id: u64) -> Self {
245        Self { id }
246    }
247
248    /// Get the ID of the registered cleanup callback
249    pub fn id(&self) -> u64 {
250        self.id
251    }
252}
253
254impl Drop for SignalCleanupGuard {
255    fn drop(&mut self) {
256        get_signal_manager().unregister_cleanup(self.id);
257    }
258}
259
260/// Register a cleanup callback to be executed on SIGINT/SIGTERM
261/// Returns a unique ID that can be used to unregister the callback
262pub fn register_signal_cleanup(callback: AsyncCleanupCallback) -> u64 {
263    get_signal_manager().register_cleanup(callback)
264}
265
266/// Register a scoped cleanup callback to be executed on SIGINT/SIGTERM
267/// Returns a guard that automatically unregisters the callback when dropped
268pub fn register_signal_cleanup_scoped(callback: AsyncCleanupCallback) -> SignalCleanupGuard {
269    let id = get_signal_manager().register_cleanup(callback);
270    SignalCleanupGuard::new(id)
271}
272
273/// Unregister a previously registered cleanup callback
274pub fn unregister_signal_cleanup(id: u64) {
275    get_signal_manager().unregister_cleanup(id);
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn sigpipe_is_ignored_by_default() {
284        let disp = sigpipe_disposition().expect("query failed");
285        assert_eq!(
286            disp,
287            SignalDisposition::Ignored,
288            "expected SIGPIPE to be ignored by default"
289        );
290    }
291}