hyperactor_telemetry/
task.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#[inline]
10pub fn current_task_id() -> u64 {
11    tokio::task::try_id().map_or(0, |x| x.to_string().parse::<u64>().unwrap_or(0))
12}
13
14#[macro_export]
15macro_rules! spawn {
16    ($f:expr) => {{
17        $crate::spawn!(
18            concat!(file!(), ":", line!()),
19            tokio::runtime::Handle::current(),
20            $f
21        )
22    }};
23    ($name:expr, $f:expr) => {{ $crate::spawn!($name, tokio::runtime::Handle::current(), $f) }};
24    ($name:expr, $rt:expr, $f:expr) => {{
25        let current = tracing::span::Span::current().id();
26        let parent_task = $crate::task::current_task_id();
27        let ft = $f;
28        $rt.spawn(async move {
29            let span = tracing::debug_span!($name, parent_tokio_task_id = parent_task);
30            span.follows_from(current);
31            span.in_scope(|| tracing::debug!("spawned_tokio_task"));
32            $crate::tracing::Instrument::instrument(ft, span).await
33        })
34    }};
35}
36
37#[cfg(test)]
38mod tests {
39    use std::sync::Arc;
40    use std::sync::atomic::AtomicBool;
41    use std::sync::atomic::AtomicUsize;
42    use std::sync::atomic::Ordering;
43    use std::time::Duration;
44
45    use tokio::time::timeout;
46    use tracing_test::traced_test;
47
48    use super::*;
49
50    /// Like the `logs_assert` injected by `#[traced_test]`, but without scope
51    /// filtering. Use when asserting on events emitted outside the test's span
52    /// (e.g. from spawned tasks or panic hooks).
53    fn logs_assert_unscoped(f: impl Fn(&[&str]) -> Result<(), String>) {
54        let buf = tracing_test::internal::global_buf().lock().unwrap();
55        let logs_str = std::str::from_utf8(&buf).expect("Logs contain invalid UTF8");
56        let lines: Vec<&str> = logs_str.lines().collect();
57        match f(&lines) {
58            Ok(()) => {}
59            Err(msg) => panic!("{}", msg),
60        }
61    }
62
63    #[traced_test]
64    #[tokio::test]
65    async fn test_current_task_id_returns_valid_id() {
66        let handle = spawn!("test", async move {
67            let task_id = current_task_id();
68            // Task ID should be non-zero when called from within a tokio task
69            assert!(task_id > 0, "Task ID should be greater than 0");
70        });
71        handle.await.unwrap();
72    }
73
74    #[traced_test]
75    #[tokio::test]
76    async fn test_current_task_id_different_tasks() {
77        let task1_id = Arc::new(std::sync::Mutex::new(0u64));
78        let task2_id = Arc::new(std::sync::Mutex::new(0u64));
79
80        let task1_id_clone = task1_id.clone();
81        let task2_id_clone = task2_id.clone();
82
83        let handle1 = crate::spawn!(async move {
84            *task1_id_clone.lock().unwrap() = current_task_id();
85        });
86
87        let handle2 = crate::spawn!(async move {
88            *task2_id_clone.lock().unwrap() = current_task_id();
89        });
90
91        handle1.await.unwrap();
92        handle2.await.unwrap();
93
94        let id1 = *task1_id.lock().unwrap();
95        let id2 = *task2_id.lock().unwrap();
96
97        assert!(id1 > 0, "Task 1 ID should be greater than 0");
98        assert!(id2 > 0, "Task 2 ID should be greater than 0");
99        assert_ne!(id1, id2, "Different tasks should have different IDs");
100    }
101
102    #[traced_test]
103    #[tokio::test]
104    async fn test_spawn_macro_basic_functionality() {
105        let completed = Arc::new(AtomicBool::new(false));
106        let completed_clone = completed.clone();
107
108        let handle = spawn!("test_task", async move {
109            completed_clone.store(true, Ordering::SeqCst);
110            42
111        });
112
113        let result = handle.await.unwrap();
114        assert_eq!(result, 42);
115        assert!(completed.load(Ordering::SeqCst));
116    }
117
118    #[traced_test]
119    #[tokio::test]
120    async fn test_spawn_macro_with_runtime_handle() {
121        let rt = tokio::runtime::Handle::current();
122        let completed = Arc::new(AtomicBool::new(false));
123        let completed_clone = completed.clone();
124
125        let handle = spawn!("test_task_with_rt", rt, async move {
126            completed_clone.store(true, Ordering::SeqCst);
127            "success"
128        });
129
130        let result = handle.await.unwrap();
131        assert_eq!(result, "success");
132        assert!(completed.load(Ordering::SeqCst));
133    }
134
135    #[traced_test]
136    #[tokio::test]
137    async fn test_spawn_macro_with_async_operation() {
138        let handle = spawn!("async_operation", async {
139            tokio::time::sleep(Duration::from_millis(10)).await;
140            "async_result"
141        });
142
143        let result = timeout(Duration::from_secs(1), handle)
144            .await
145            .expect("Task should complete within timeout")
146            .expect("Task should not panic");
147
148        assert_eq!(result, "async_result");
149    }
150
151    #[traced_test]
152    #[tokio::test]
153    async fn test_spawn_macro_error_handling() {
154        let handle = spawn!("error_task", async {
155            panic!("intentional panic");
156        });
157
158        let result = handle.await;
159        assert!(result.is_err(), "Task should panic and return an error");
160    }
161
162    #[traced_test]
163    #[tokio::test]
164    async fn test_spawn_macro_multiple_tasks() {
165        let num_tasks = 5;
166        let completed_count = Arc::new(AtomicUsize::new(0));
167
168        let mut handles = Vec::new();
169        for i in 0..num_tasks {
170            let count_clone = completed_count.clone();
171            let handle = spawn!("parallel_task", async move {
172                count_clone.fetch_add(1, Ordering::SeqCst);
173                i
174            });
175            handles.push(handle);
176        }
177
178        // Wait for all tasks to complete
179        let mut results = Vec::new();
180        for handle in handles {
181            let result = handle.await.expect("Task should complete successfully");
182            results.push(result);
183        }
184
185        assert_eq!(results.len(), num_tasks);
186        assert_eq!(completed_count.load(Ordering::SeqCst), num_tasks);
187
188        // Results should contain values 0 through num_tasks-1
189        let mut sorted_results = results;
190        sorted_results.sort();
191        let expected: Vec<usize> = (0..num_tasks).collect();
192        assert_eq!(sorted_results, expected);
193    }
194
195    macro_rules! logs_match {
196        ($expr:expr) => {
197            logs_match!($expr, format!("{} not in logs", stringify!($expr)));
198        };
199        ($expr:expr, $msg:expr) => {
200            logs_assert_unscoped(|lines| {
201                if lines.iter().any($expr) {
202                    Ok(())
203                } else {
204                    Err($msg.into())
205                }
206            })
207        };
208    }
209
210    #[traced_test]
211    #[tokio::test]
212    async fn test_spawn_macro_creates_proper_span() {
213        let completed = Arc::new(AtomicBool::new(false));
214        let completed_clone = completed.clone();
215
216        let parent_span = tracing::debug_span!("parent_span");
217        let _guard = parent_span.enter();
218
219        let handle = spawn!("child_task", async move {
220            tracing::debug!(task_data = "test_value", "task_execution");
221            completed_clone.store(true, Ordering::SeqCst);
222            "completed"
223        });
224
225        let result = handle.await.unwrap();
226        assert_eq!(result, "completed");
227        assert!(completed.load(Ordering::SeqCst));
228
229        // Check that spawn event was logged
230        logs_match!(
231            |line| line.contains("spawned_tokio_task"),
232            "task logging never occured"
233        );
234
235        // Check that task execution event was logged
236        logs_match!(|line| line.contains("task_execution"));
237    }
238
239    #[traced_test]
240    #[tokio::test]
241    async fn test_spawn_macro_preserves_parent_context() {
242        let parent_span = tracing::debug_span!("parent", operation = "test_context");
243        let _guard = parent_span.enter();
244
245        let handle = spawn!("context_child", async {
246            tracing::debug!(child_data = "value", "child_operation");
247            "context_preserved"
248        });
249
250        let result = handle.await.unwrap();
251        assert_eq!(result, "context_preserved");
252
253        // Verify the spawn event was logged
254        logs_match!(|line| line.contains("spawned_tokio_task"));
255
256        // Verify child operation was logged
257        logs_match!(|line| line.contains("child_operation"));
258    }
259
260    #[traced_test]
261    #[tokio::test]
262    async fn test_spawn_macro_with_instrumentation() {
263        let handle = spawn!("instrumented_task", async {
264            tracing::info!("inside_spawned_task");
265            42
266        });
267
268        let result = handle.await.unwrap();
269        assert_eq!(result, 42);
270
271        // Check for spawned task event
272        logs_match!(|line| line.contains("spawned_tokio_task"));
273
274        // Check for task execution event
275        logs_match!(|line| line.contains("inside_spawned_task"));
276    }
277
278    #[traced_test]
279    #[tokio::test]
280    async fn test_spawn_macro_span_hierarchy() {
281        let outer_span = tracing::info_span!("outer_operation", test_id = 123);
282        let _outer_guard = outer_span.enter();
283
284        let handle = spawn!("nested_task", async {
285            let inner_span = tracing::debug_span!("inner_operation", step = "processing");
286            let _inner_guard = inner_span.enter();
287
288            tracing::warn!(data_size = 1024, "processing_data");
289            "nested_complete"
290        });
291
292        let result = handle.await.unwrap();
293        assert_eq!(result, "nested_complete");
294
295        // Verify spawn debug event was captured
296        logs_match!(|line| line.contains("spawned_tokio_task"));
297
298        // Verify the warning event was captured
299        logs_match!(|line| line.contains("processing_data"));
300
301        // Verify structured data is present
302        logs_match!(|line| line.contains("data_size"));
303    }
304
305    #[traced_test]
306    #[tokio::test]
307    async fn test_spawn_macro_error_tracing() {
308        let handle = spawn!("error_prone_task", async {
309            tracing::error!(reason = "intentional", "task_about_to_fail");
310            panic!("deliberate failure");
311        });
312
313        let result = handle.await;
314        assert!(result.is_err(), "Task should fail");
315
316        // Verify spawn event was logged
317        logs_match!(|line| line.contains("spawned_tokio_task"));
318
319        // Verify error event was logged before panic
320        logs_match!(|line| line.contains("task_about_to_fail"));
321    }
322
323    #[traced_test]
324    #[tokio::test]
325    async fn test_spawn_macro_concurrent_tracing() {
326        let barrier = Arc::new(tokio::sync::Barrier::new(3));
327
328        let handles = (0..3)
329            .map(|i| {
330                let barrier = barrier.clone();
331                spawn!("concurrent_task", async move {
332                    barrier.wait().await;
333                    tracing::info!(task_num = i, "concurrent_execution");
334                    i * 10
335                })
336            })
337            .collect::<Vec<_>>();
338
339        // Await all handles manually
340        let mut results = Vec::new();
341        for handle in handles {
342            let result = handle.await.expect("Task should complete");
343            results.push(result);
344        }
345
346        results.sort();
347        assert_eq!(results, vec![0, 10, 20]);
348
349        // Verify spawn events were logged (at least 3)
350        logs_assert_unscoped(|lines| {
351            let spawn_count = lines
352                .iter()
353                .filter(|line| line.contains("spawned_tokio_task"))
354                .count();
355            match spawn_count {
356                3.. => Ok(()),
357                _ => Err("wrong count".into()),
358            }
359        });
360
361        // Verify execution events were logged
362        logs_assert_unscoped(|lines| {
363            let exec_count = lines
364                .iter()
365                .filter(|line| line.contains("concurrent_execution"))
366                .count();
367            if exec_count >= 3 {
368                Ok(())
369            } else {
370                Err(format!(
371                    "Expected at least 3 concurrent execution events, found {}",
372                    exec_count
373                ))
374            }
375        });
376    }
377
378    #[traced_test]
379    #[tokio::test]
380    async fn test_spawn_macro_with_fields() {
381        let handle = spawn!("field_task", async {
382            tracing::info!(user_id = 42, session = "abc123", "user_action");
383            "field_test_complete"
384        });
385
386        let result = handle.await.unwrap();
387        assert_eq!(result, "field_test_complete");
388
389        // Verify spawn event was logged
390        logs_match!(|line| line.contains("spawned_tokio_task"));
391
392        // Verify user action event with fields was logged
393        logs_match!(|line| line.contains("user_action")
394            && line.contains("user_id")
395            && line.contains("session"));
396    }
397
398    #[traced_test]
399    #[tokio::test]
400    async fn test_spawn_macro_nested_spans() {
401        let outer_span = tracing::info_span!("request_handler", request_id = "req-123");
402        let _outer_guard = outer_span.enter();
403
404        let handle = spawn!("database_query", async {
405            let db_span = tracing::debug_span!("db_operation", table = "users");
406            let _db_guard = db_span.enter();
407
408            tracing::debug!(query = "SELECT * FROM users", "executing_query");
409
410            let cache_span = tracing::debug_span!("cache_check", cache_key = "user:42");
411            let _cache_guard = cache_span.enter();
412
413            tracing::debug!("cache_miss");
414
415            "query_complete"
416        });
417
418        let result = handle.await.unwrap();
419        assert_eq!(result, "query_complete");
420        // Verify spawn event
421        logs_match!(|line| line.contains("spawned_tokio_task"));
422
423        // Verify query execution event
424        logs_match!(|line| line.contains("executing_query"));
425
426        // Verify cache miss event
427        logs_match!(|line| line.contains("cache_miss"));
428
429        // Verify structured fields are present
430        logs_match!(|line| line.contains("table"));
431    }
432
433    #[traced_test]
434    #[tokio::test]
435    async fn test_spawn_macro_performance_tracing() {
436        let handle = spawn!("performance_task", async {
437            let start = std::time::Instant::now();
438
439            // Simulate some work
440            tokio::time::sleep(Duration::from_millis(50)).await;
441
442            let duration = start.elapsed();
443            tracing::info!(duration_ms = duration.as_millis(), "task_completed");
444
445            duration.as_millis()
446        });
447
448        let duration_ms = handle.await.unwrap();
449        assert!(duration_ms >= 50, "Task should take at least 50ms");
450
451        // Verify spawn event was logged
452        logs_match!(|line| line.contains("spawned_tokio_task"));
453
454        // Verify task completion event with duration
455        logs_match!(|line| line.contains("task_completed") && line.contains("duration_ms"));
456    }
457
458    #[traced_test]
459    #[tokio::test]
460    async fn test_spawn_macro_error_with_context() {
461        let outer_span = tracing::error_span!("error_context", operation = "critical_task");
462        let _guard = outer_span.enter();
463
464        let handle = spawn!("failing_task", async {
465            tracing::warn!(retry_count = 1, "attempting_risky_operation");
466            tracing::error!(
467                error_code = "E001",
468                message = "Operation failed",
469                "critical_error"
470            );
471            panic!("Critical failure occurred");
472        });
473
474        let result = handle.await;
475        assert!(result.is_err(), "Task should fail");
476
477        // Verify spawn event
478        logs_match!(|line| line.contains("spawned_tokio_task"));
479
480        logs_match!(|line| line.contains("attempting_risky_operation"));
481
482        // Verify critical error event
483        logs_assert_unscoped(|lines| {
484            if lines
485                .iter()
486                .any(|line| line.contains("critical_error") && line.contains("error_code"))
487            {
488                Ok(())
489            } else {
490                Err("Critical error event with error_code not found in logs".into())
491            }
492        });
493    }
494}