Skip to main content

hyperactor_telemetry/
task.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#[inline]
10pub fn current_task_id() -> u64 {
11    tokio::task::try_id().map_or(0, |x| x.to_string().parse::<u64>().unwrap_or(0))
12}
13
14#[macro_export]
15macro_rules! spawn {
16    ($f:expr) => {{
17        $crate::spawn!(
18            concat!(file!(), ":", line!()),
19            tokio::runtime::Handle::current(),
20            $f
21        )
22    }};
23    ($name:expr, $f:expr) => {{ $crate::spawn!($name, tokio::runtime::Handle::current(), $f) }};
24    ($name:expr, $rt:expr, $f:expr) => {{
25        let current = tracing::span::Span::current().id();
26        let parent_task = $crate::task::current_task_id();
27        let ft = $f;
28        $rt.spawn(async move {
29            let span = tracing::debug_span!($name, parent_tokio_task_id = parent_task);
30            span.follows_from(current);
31            span.in_scope(|| tracing::debug!("spawned_tokio_task"));
32            $crate::tracing::Instrument::instrument(ft, span).await
33        })
34    }};
35}
36
37#[cfg(test)]
38mod tests {
39    // #[traced_test] holds a tracing::Entered guard across await points; suppress the false positive.
40    #![expect(
41        clippy::await_holding_invalid_type,
42        reason = "tracing_test::traced_test macro expansion holds tracing::span::Entered across awaits; can't be fixed in our code"
43    )]
44
45    use std::sync::Arc;
46    use std::sync::atomic::AtomicBool;
47    use std::sync::atomic::AtomicUsize;
48    use std::sync::atomic::Ordering;
49    use std::time::Duration;
50
51    use tokio::time::timeout;
52    use tracing_test::traced_test;
53
54    use super::*;
55
56    /// Like the `logs_assert` injected by `#[traced_test]`, but without scope
57    /// filtering. Use when asserting on events emitted outside the test's span
58    /// (e.g. from spawned tasks or panic hooks).
59    fn logs_assert_unscoped(f: impl Fn(&[&str]) -> Result<(), String>) {
60        let buf = tracing_test::internal::global_buf().lock().unwrap();
61        let logs_str = std::str::from_utf8(&buf).expect("Logs contain invalid UTF8");
62        let lines: Vec<&str> = logs_str.lines().collect();
63        match f(&lines) {
64            Ok(()) => {}
65            Err(msg) => panic!("{}", msg),
66        }
67    }
68
69    #[traced_test]
70    #[tokio::test]
71    async fn test_current_task_id_returns_valid_id() {
72        let handle = spawn!("test", async move {
73            let task_id = current_task_id();
74            // Task ID should be non-zero when called from within a tokio task
75            assert!(task_id > 0, "Task ID should be greater than 0");
76        });
77        handle.await.unwrap();
78    }
79
80    #[traced_test]
81    #[tokio::test]
82    async fn test_current_task_id_different_tasks() {
83        let task1_id = Arc::new(std::sync::Mutex::new(0u64));
84        let task2_id = Arc::new(std::sync::Mutex::new(0u64));
85
86        let task1_id_clone = task1_id.clone();
87        let task2_id_clone = task2_id.clone();
88
89        let handle1 = crate::spawn!(async move {
90            *task1_id_clone.lock().unwrap() = current_task_id();
91        });
92
93        let handle2 = crate::spawn!(async move {
94            *task2_id_clone.lock().unwrap() = current_task_id();
95        });
96
97        handle1.await.unwrap();
98        handle2.await.unwrap();
99
100        let id1 = *task1_id.lock().unwrap();
101        let id2 = *task2_id.lock().unwrap();
102
103        assert!(id1 > 0, "Task 1 ID should be greater than 0");
104        assert!(id2 > 0, "Task 2 ID should be greater than 0");
105        assert_ne!(id1, id2, "Different tasks should have different IDs");
106    }
107
108    #[traced_test]
109    #[tokio::test]
110    async fn test_spawn_macro_basic_functionality() {
111        let completed = Arc::new(AtomicBool::new(false));
112        let completed_clone = completed.clone();
113
114        let handle = spawn!("test_task", async move {
115            completed_clone.store(true, Ordering::SeqCst);
116            42
117        });
118
119        let result = handle.await.unwrap();
120        assert_eq!(result, 42);
121        assert!(completed.load(Ordering::SeqCst));
122    }
123
124    #[traced_test]
125    #[tokio::test]
126    async fn test_spawn_macro_with_runtime_handle() {
127        let rt = tokio::runtime::Handle::current();
128        let completed = Arc::new(AtomicBool::new(false));
129        let completed_clone = completed.clone();
130
131        let handle = spawn!("test_task_with_rt", rt, async move {
132            completed_clone.store(true, Ordering::SeqCst);
133            "success"
134        });
135
136        let result = handle.await.unwrap();
137        assert_eq!(result, "success");
138        assert!(completed.load(Ordering::SeqCst));
139    }
140
141    #[traced_test]
142    #[tokio::test]
143    async fn test_spawn_macro_with_async_operation() {
144        let handle = spawn!("async_operation", async {
145            tokio::time::sleep(Duration::from_millis(10)).await;
146            "async_result"
147        });
148
149        let result = timeout(Duration::from_secs(1), handle)
150            .await
151            .expect("Task should complete within timeout")
152            .expect("Task should not panic");
153
154        assert_eq!(result, "async_result");
155    }
156
157    #[traced_test]
158    #[tokio::test]
159    async fn test_spawn_macro_error_handling() {
160        let handle = spawn!("error_task", async {
161            panic!("intentional panic");
162        });
163
164        let result = handle.await;
165        assert!(result.is_err(), "Task should panic and return an error");
166    }
167
168    #[traced_test]
169    #[tokio::test]
170    async fn test_spawn_macro_multiple_tasks() {
171        let num_tasks = 5;
172        let completed_count = Arc::new(AtomicUsize::new(0));
173
174        let mut handles = Vec::new();
175        for i in 0..num_tasks {
176            let count_clone = completed_count.clone();
177            let handle = spawn!("parallel_task", async move {
178                count_clone.fetch_add(1, Ordering::SeqCst);
179                i
180            });
181            handles.push(handle);
182        }
183
184        // Wait for all tasks to complete
185        let mut results = Vec::new();
186        for handle in handles {
187            let result = handle.await.expect("Task should complete successfully");
188            results.push(result);
189        }
190
191        assert_eq!(results.len(), num_tasks);
192        assert_eq!(completed_count.load(Ordering::SeqCst), num_tasks);
193
194        // Results should contain values 0 through num_tasks-1
195        let mut sorted_results = results;
196        sorted_results.sort();
197        let expected: Vec<usize> = (0..num_tasks).collect();
198        assert_eq!(sorted_results, expected);
199    }
200
201    macro_rules! logs_match {
202        ($expr:expr) => {
203            logs_match!($expr, format!("{} not in logs", stringify!($expr)));
204        };
205        ($expr:expr, $msg:expr) => {
206            logs_assert_unscoped(|lines| {
207                if lines.iter().any($expr) {
208                    Ok(())
209                } else {
210                    Err($msg.into())
211                }
212            })
213        };
214    }
215
216    #[traced_test]
217    #[tokio::test]
218    async fn test_spawn_macro_creates_proper_span() {
219        let completed = Arc::new(AtomicBool::new(false));
220        let completed_clone = completed.clone();
221
222        let parent_span = tracing::debug_span!("parent_span");
223        let _guard = parent_span.enter();
224
225        let handle = spawn!("child_task", async move {
226            tracing::debug!(task_data = "test_value", "task_execution");
227            completed_clone.store(true, Ordering::SeqCst);
228            "completed"
229        });
230
231        let result = handle.await.unwrap();
232        assert_eq!(result, "completed");
233        assert!(completed.load(Ordering::SeqCst));
234
235        // Check that spawn event was logged
236        logs_match!(
237            |line| line.contains("spawned_tokio_task"),
238            "task logging never occured"
239        );
240
241        // Check that task execution event was logged
242        logs_match!(|line| line.contains("task_execution"));
243    }
244
245    #[traced_test]
246    #[tokio::test]
247    async fn test_spawn_macro_preserves_parent_context() {
248        let parent_span = tracing::debug_span!("parent", operation = "test_context");
249        let _guard = parent_span.enter();
250
251        let handle = spawn!("context_child", async {
252            tracing::debug!(child_data = "value", "child_operation");
253            "context_preserved"
254        });
255
256        let result = handle.await.unwrap();
257        assert_eq!(result, "context_preserved");
258
259        // Verify the spawn event was logged
260        logs_match!(|line| line.contains("spawned_tokio_task"));
261
262        // Verify child operation was logged
263        logs_match!(|line| line.contains("child_operation"));
264    }
265
266    #[traced_test]
267    #[tokio::test]
268    async fn test_spawn_macro_with_instrumentation() {
269        let handle = spawn!("instrumented_task", async {
270            tracing::info!("inside_spawned_task");
271            42
272        });
273
274        let result = handle.await.unwrap();
275        assert_eq!(result, 42);
276
277        // Check for spawned task event
278        logs_match!(|line| line.contains("spawned_tokio_task"));
279
280        // Check for task execution event
281        logs_match!(|line| line.contains("inside_spawned_task"));
282    }
283
284    #[traced_test]
285    #[tokio::test]
286    async fn test_spawn_macro_span_hierarchy() {
287        let outer_span = tracing::info_span!("outer_operation", test_id = 123);
288        let _outer_guard = outer_span.enter();
289
290        let handle = spawn!("nested_task", async {
291            let inner_span = tracing::debug_span!("inner_operation", step = "processing");
292            let _inner_guard = inner_span.enter();
293
294            tracing::warn!(data_size = 1024, "processing_data");
295            "nested_complete"
296        });
297
298        let result = handle.await.unwrap();
299        assert_eq!(result, "nested_complete");
300
301        // Verify spawn debug event was captured
302        logs_match!(|line| line.contains("spawned_tokio_task"));
303
304        // Verify the warning event was captured
305        logs_match!(|line| line.contains("processing_data"));
306
307        // Verify structured data is present
308        logs_match!(|line| line.contains("data_size"));
309    }
310
311    #[traced_test]
312    #[tokio::test]
313    async fn test_spawn_macro_error_tracing() {
314        let handle = spawn!("error_prone_task", async {
315            tracing::error!(reason = "intentional", "task_about_to_fail");
316            panic!("deliberate failure");
317        });
318
319        let result = handle.await;
320        assert!(result.is_err(), "Task should fail");
321
322        // Verify spawn event was logged
323        logs_match!(|line| line.contains("spawned_tokio_task"));
324
325        // Verify error event was logged before panic
326        logs_match!(|line| line.contains("task_about_to_fail"));
327    }
328
329    #[traced_test]
330    #[tokio::test]
331    async fn test_spawn_macro_concurrent_tracing() {
332        let barrier = Arc::new(tokio::sync::Barrier::new(3));
333
334        let handles = (0..3)
335            .map(|i| {
336                let barrier = barrier.clone();
337                spawn!("concurrent_task", async move {
338                    barrier.wait().await;
339                    tracing::info!(task_num = i, "concurrent_execution");
340                    i * 10
341                })
342            })
343            .collect::<Vec<_>>();
344
345        // Await all handles manually
346        let mut results = Vec::new();
347        for handle in handles {
348            let result = handle.await.expect("Task should complete");
349            results.push(result);
350        }
351
352        results.sort();
353        assert_eq!(results, vec![0, 10, 20]);
354
355        // Verify spawn events were logged (at least 3)
356        logs_assert_unscoped(|lines| {
357            let spawn_count = lines
358                .iter()
359                .filter(|line| line.contains("spawned_tokio_task"))
360                .count();
361            match spawn_count {
362                3.. => Ok(()),
363                _ => Err("wrong count".into()),
364            }
365        });
366
367        // Verify execution events were logged
368        logs_assert_unscoped(|lines| {
369            let exec_count = lines
370                .iter()
371                .filter(|line| line.contains("concurrent_execution"))
372                .count();
373            if exec_count >= 3 {
374                Ok(())
375            } else {
376                Err(format!(
377                    "Expected at least 3 concurrent execution events, found {}",
378                    exec_count
379                ))
380            }
381        });
382    }
383
384    #[traced_test]
385    #[tokio::test]
386    async fn test_spawn_macro_with_fields() {
387        let handle = spawn!("field_task", async {
388            tracing::info!(user_id = 42, session = "abc123", "user_action");
389            "field_test_complete"
390        });
391
392        let result = handle.await.unwrap();
393        assert_eq!(result, "field_test_complete");
394
395        // Verify spawn event was logged
396        logs_match!(|line| line.contains("spawned_tokio_task"));
397
398        // Verify user action event with fields was logged
399        logs_match!(|line| line.contains("user_action")
400            && line.contains("user_id")
401            && line.contains("session"));
402    }
403
404    #[traced_test]
405    #[tokio::test]
406    async fn test_spawn_macro_nested_spans() {
407        let outer_span = tracing::info_span!("request_handler", request_id = "req-123");
408        let _outer_guard = outer_span.enter();
409
410        let handle = spawn!("database_query", async {
411            let db_span = tracing::debug_span!("db_operation", table = "users");
412            let _db_guard = db_span.enter();
413
414            tracing::debug!(query = "SELECT * FROM users", "executing_query");
415
416            let cache_span = tracing::debug_span!("cache_check", cache_key = "user:42");
417            let _cache_guard = cache_span.enter();
418
419            tracing::debug!("cache_miss");
420
421            "query_complete"
422        });
423
424        let result = handle.await.unwrap();
425        assert_eq!(result, "query_complete");
426        // Verify spawn event
427        logs_match!(|line| line.contains("spawned_tokio_task"));
428
429        // Verify query execution event
430        logs_match!(|line| line.contains("executing_query"));
431
432        // Verify cache miss event
433        logs_match!(|line| line.contains("cache_miss"));
434
435        // Verify structured fields are present
436        logs_match!(|line| line.contains("table"));
437    }
438
439    #[traced_test]
440    #[tokio::test]
441    async fn test_spawn_macro_performance_tracing() {
442        let handle = spawn!("performance_task", async {
443            let start = std::time::Instant::now();
444
445            // Simulate some work
446            tokio::time::sleep(Duration::from_millis(50)).await;
447
448            let duration = start.elapsed();
449            tracing::info!(duration_ms = duration.as_millis(), "task_completed");
450
451            duration.as_millis()
452        });
453
454        let duration_ms = handle.await.unwrap();
455        assert!(duration_ms >= 50, "Task should take at least 50ms");
456
457        // Verify spawn event was logged
458        logs_match!(|line| line.contains("spawned_tokio_task"));
459
460        // Verify task completion event with duration
461        logs_match!(|line| line.contains("task_completed") && line.contains("duration_ms"));
462    }
463
464    #[traced_test]
465    #[tokio::test]
466    async fn test_spawn_macro_error_with_context() {
467        let outer_span = tracing::error_span!("error_context", operation = "critical_task");
468        let _guard = outer_span.enter();
469
470        let handle = spawn!("failing_task", async {
471            tracing::warn!(retry_count = 1, "attempting_risky_operation");
472            tracing::error!(
473                error_code = "E001",
474                message = "Operation failed",
475                "critical_error"
476            );
477            panic!("Critical failure occurred");
478        });
479
480        let result = handle.await;
481        assert!(result.is_err(), "Task should fail");
482
483        // Verify spawn event
484        logs_match!(|line| line.contains("spawned_tokio_task"));
485
486        logs_match!(|line| line.contains("attempting_risky_operation"));
487
488        // Verify critical error event
489        logs_assert_unscoped(|lines| {
490            if lines
491                .iter()
492                .any(|line| line.contains("critical_error") && line.contains("error_code"))
493            {
494                Ok(())
495            } else {
496                Err("Critical error event with error_code not found in logs".into())
497            }
498        });
499    }
500}