hyperactor_telemetry/
task.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#[inline]
10pub fn current_task_id() -> u64 {
11    tokio::task::try_id().map_or(0, |x| x.to_string().parse::<u64>().unwrap_or(0))
12}
13
14#[macro_export]
15macro_rules! spawn {
16    ($f:expr) => {{
17        $crate::spawn!(
18            concat!(file!(), ":", line!()),
19            tokio::runtime::Handle::current(),
20            $f
21        )
22    }};
23    ($name:expr, $f:expr) => {{ $crate::spawn!($name, tokio::runtime::Handle::current(), $f) }};
24    ($name:expr, $rt:expr, $f:expr) => {{
25        let current = tracing::span::Span::current().id();
26        let parent_task = $crate::task::current_task_id();
27        let ft = $f;
28        $rt.spawn(async move {
29            let span = tracing::debug_span!($name, parent_tokio_task_id = parent_task);
30            span.follows_from(current);
31            span.in_scope(|| tracing::debug!("spawned_tokio_task"));
32            $crate::tracing::Instrument::instrument(ft, span).await
33        })
34    }};
35}
36
37#[cfg(test)]
38mod tests {
39    use std::sync::Arc;
40    use std::sync::atomic::AtomicBool;
41    use std::sync::atomic::AtomicUsize;
42    use std::sync::atomic::Ordering;
43    use std::time::Duration;
44
45    use tokio::time::timeout;
46    use tracing_test::traced_test;
47
48    use super::*;
49
50    #[traced_test]
51    #[tokio::test]
52    async fn test_current_task_id_returns_valid_id() {
53        let handle = spawn!("test", async move {
54            let task_id = current_task_id();
55            // Task ID should be non-zero when called from within a tokio task
56            assert!(task_id > 0, "Task ID should be greater than 0");
57        });
58        handle.await.unwrap();
59    }
60
61    #[traced_test]
62    #[tokio::test]
63    async fn test_current_task_id_different_tasks() {
64        let task1_id = Arc::new(std::sync::Mutex::new(0u64));
65        let task2_id = Arc::new(std::sync::Mutex::new(0u64));
66
67        let task1_id_clone = task1_id.clone();
68        let task2_id_clone = task2_id.clone();
69
70        let handle1 = crate::spawn!(async move {
71            *task1_id_clone.lock().unwrap() = current_task_id();
72        });
73
74        let handle2 = crate::spawn!(async move {
75            *task2_id_clone.lock().unwrap() = current_task_id();
76        });
77
78        handle1.await.unwrap();
79        handle2.await.unwrap();
80
81        let id1 = *task1_id.lock().unwrap();
82        let id2 = *task2_id.lock().unwrap();
83
84        assert!(id1 > 0, "Task 1 ID should be greater than 0");
85        assert!(id2 > 0, "Task 2 ID should be greater than 0");
86        assert_ne!(id1, id2, "Different tasks should have different IDs");
87    }
88
89    #[traced_test]
90    #[tokio::test]
91    async fn test_spawn_macro_basic_functionality() {
92        let completed = Arc::new(AtomicBool::new(false));
93        let completed_clone = completed.clone();
94
95        let handle = spawn!("test_task", async move {
96            completed_clone.store(true, Ordering::SeqCst);
97            42
98        });
99
100        let result = handle.await.unwrap();
101        assert_eq!(result, 42);
102        assert!(completed.load(Ordering::SeqCst));
103    }
104
105    #[traced_test]
106    #[tokio::test]
107    async fn test_spawn_macro_with_runtime_handle() {
108        let rt = tokio::runtime::Handle::current();
109        let completed = Arc::new(AtomicBool::new(false));
110        let completed_clone = completed.clone();
111
112        let handle = spawn!("test_task_with_rt", rt, async move {
113            completed_clone.store(true, Ordering::SeqCst);
114            "success"
115        });
116
117        let result = handle.await.unwrap();
118        assert_eq!(result, "success");
119        assert!(completed.load(Ordering::SeqCst));
120    }
121
122    #[traced_test]
123    #[tokio::test]
124    async fn test_spawn_macro_with_async_operation() {
125        let handle = spawn!("async_operation", async {
126            tokio::time::sleep(Duration::from_millis(10)).await;
127            "async_result"
128        });
129
130        let result = timeout(Duration::from_secs(1), handle)
131            .await
132            .expect("Task should complete within timeout")
133            .expect("Task should not panic");
134
135        assert_eq!(result, "async_result");
136    }
137
138    #[traced_test]
139    #[tokio::test]
140    async fn test_spawn_macro_error_handling() {
141        let handle = spawn!("error_task", async {
142            panic!("intentional panic");
143        });
144
145        let result = handle.await;
146        assert!(result.is_err(), "Task should panic and return an error");
147    }
148
149    #[traced_test]
150    #[tokio::test]
151    async fn test_spawn_macro_multiple_tasks() {
152        let num_tasks = 5;
153        let completed_count = Arc::new(AtomicUsize::new(0));
154
155        let mut handles = Vec::new();
156        for i in 0..num_tasks {
157            let count_clone = completed_count.clone();
158            let handle = spawn!("parallel_task", async move {
159                count_clone.fetch_add(1, Ordering::SeqCst);
160                i
161            });
162            handles.push(handle);
163        }
164
165        // Wait for all tasks to complete
166        let mut results = Vec::new();
167        for handle in handles {
168            let result = handle.await.expect("Task should complete successfully");
169            results.push(result);
170        }
171
172        assert_eq!(results.len(), num_tasks);
173        assert_eq!(completed_count.load(Ordering::SeqCst), num_tasks);
174
175        // Results should contain values 0 through num_tasks-1
176        let mut sorted_results = results;
177        sorted_results.sort();
178        let expected: Vec<usize> = (0..num_tasks).collect();
179        assert_eq!(sorted_results, expected);
180    }
181
182    macro_rules! logs_match {
183        ($expr:expr) => {
184            logs_match!($expr, format!("{} not in logs", stringify!($expr)));
185        };
186        ($expr:expr, $msg:expr) => {
187            logs_assert(|lines| {
188                if lines.iter().any($expr) {
189                    Ok(())
190                } else {
191                    Err($msg.into())
192                }
193            })
194        };
195    }
196
197    #[traced_test]
198    #[tokio::test]
199    async fn test_spawn_macro_creates_proper_span() {
200        let completed = Arc::new(AtomicBool::new(false));
201        let completed_clone = completed.clone();
202
203        let parent_span = tracing::debug_span!("parent_span");
204        let _guard = parent_span.enter();
205
206        let handle = spawn!("child_task", async move {
207            tracing::debug!(task_data = "test_value", "task_execution");
208            completed_clone.store(true, Ordering::SeqCst);
209            "completed"
210        });
211
212        let result = handle.await.unwrap();
213        assert_eq!(result, "completed");
214        assert!(completed.load(Ordering::SeqCst));
215
216        // Check that spawn event was logged
217        logs_match!(
218            |line| line.contains("spawned_tokio_task"),
219            "task logging never occured"
220        );
221
222        // Check that task execution event was logged
223        logs_match!(|line| line.contains("task_execution"));
224    }
225
226    #[traced_test]
227    #[tokio::test]
228    async fn test_spawn_macro_preserves_parent_context() {
229        let parent_span = tracing::debug_span!("parent", operation = "test_context");
230        let _guard = parent_span.enter();
231
232        let handle = spawn!("context_child", async {
233            tracing::debug!(child_data = "value", "child_operation");
234            "context_preserved"
235        });
236
237        let result = handle.await.unwrap();
238        assert_eq!(result, "context_preserved");
239
240        // Verify the spawn event was logged
241        logs_match!(|line| line.contains("spawned_tokio_task"));
242
243        // Verify child operation was logged
244        logs_match!(|line| line.contains("child_operation"));
245    }
246
247    #[traced_test]
248    #[tokio::test]
249    async fn test_spawn_macro_with_instrumentation() {
250        let handle = spawn!("instrumented_task", async {
251            tracing::info!("inside_spawned_task");
252            42
253        });
254
255        let result = handle.await.unwrap();
256        assert_eq!(result, 42);
257
258        // Check for spawned task event
259        logs_match!(|line| line.contains("spawned_tokio_task"));
260
261        // Check for task execution event
262        logs_match!(|line| line.contains("inside_spawned_task"));
263    }
264
265    #[traced_test]
266    #[tokio::test]
267    async fn test_spawn_macro_span_hierarchy() {
268        let outer_span = tracing::info_span!("outer_operation", test_id = 123);
269        let _outer_guard = outer_span.enter();
270
271        let handle = spawn!("nested_task", async {
272            let inner_span = tracing::debug_span!("inner_operation", step = "processing");
273            let _inner_guard = inner_span.enter();
274
275            tracing::warn!(data_size = 1024, "processing_data");
276            "nested_complete"
277        });
278
279        let result = handle.await.unwrap();
280        assert_eq!(result, "nested_complete");
281
282        // Verify spawn debug event was captured
283        logs_match!(|line| line.contains("spawned_tokio_task"));
284
285        // Verify the warning event was captured
286        logs_match!(|line| line.contains("processing_data"));
287
288        // Verify structured data is present
289        logs_match!(|line| line.contains("data_size"));
290    }
291
292    #[traced_test]
293    #[tokio::test]
294    async fn test_spawn_macro_error_tracing() {
295        let handle = spawn!("error_prone_task", async {
296            tracing::error!(reason = "intentional", "task_about_to_fail");
297            panic!("deliberate failure");
298        });
299
300        let result = handle.await;
301        assert!(result.is_err(), "Task should fail");
302
303        // Verify spawn event was logged
304        logs_match!(|line| line.contains("spawned_tokio_task"));
305
306        // Verify error event was logged before panic
307        logs_match!(|line| line.contains("task_about_to_fail"));
308    }
309
310    #[traced_test]
311    #[tokio::test]
312    async fn test_spawn_macro_concurrent_tracing() {
313        let barrier = Arc::new(tokio::sync::Barrier::new(3));
314
315        let handles = (0..3)
316            .map(|i| {
317                let barrier = barrier.clone();
318                spawn!("concurrent_task", async move {
319                    barrier.wait().await;
320                    tracing::info!(task_num = i, "concurrent_execution");
321                    i * 10
322                })
323            })
324            .collect::<Vec<_>>();
325
326        // Await all handles manually
327        let mut results = Vec::new();
328        for handle in handles {
329            let result = handle.await.expect("Task should complete");
330            results.push(result);
331        }
332
333        results.sort();
334        assert_eq!(results, vec![0, 10, 20]);
335
336        // Verify spawn events were logged (at least 3)
337        logs_assert(|lines| {
338            let spawn_count = lines
339                .iter()
340                .filter(|line| line.contains("spawned_tokio_task"))
341                .count();
342            match spawn_count {
343                3.. => Ok(()),
344                _ => Err("wrong count".into()),
345            }
346        });
347
348        // Verify execution events were logged
349        logs_assert(|lines| {
350            let exec_count = lines
351                .iter()
352                .filter(|line| line.contains("concurrent_execution"))
353                .count();
354            if exec_count >= 3 {
355                Ok(())
356            } else {
357                Err(format!(
358                    "Expected at least 3 concurrent execution events, found {}",
359                    exec_count
360                ))
361            }
362        });
363    }
364
365    #[traced_test]
366    #[tokio::test]
367    async fn test_spawn_macro_with_fields() {
368        let handle = spawn!("field_task", async {
369            tracing::info!(user_id = 42, session = "abc123", "user_action");
370            "field_test_complete"
371        });
372
373        let result = handle.await.unwrap();
374        assert_eq!(result, "field_test_complete");
375
376        // Verify spawn event was logged
377        logs_match!(|line| line.contains("spawned_tokio_task"));
378
379        // Verify user action event with fields was logged
380        logs_match!(|line| line.contains("user_action")
381            && line.contains("user_id")
382            && line.contains("session"));
383    }
384
385    #[traced_test]
386    #[tokio::test]
387    async fn test_spawn_macro_nested_spans() {
388        let outer_span = tracing::info_span!("request_handler", request_id = "req-123");
389        let _outer_guard = outer_span.enter();
390
391        let handle = spawn!("database_query", async {
392            let db_span = tracing::debug_span!("db_operation", table = "users");
393            let _db_guard = db_span.enter();
394
395            tracing::debug!(query = "SELECT * FROM users", "executing_query");
396
397            let cache_span = tracing::debug_span!("cache_check", cache_key = "user:42");
398            let _cache_guard = cache_span.enter();
399
400            tracing::debug!("cache_miss");
401
402            "query_complete"
403        });
404
405        let result = handle.await.unwrap();
406        assert_eq!(result, "query_complete");
407        // Verify spawn event
408        logs_match!(|line| line.contains("spawned_tokio_task"));
409
410        // Verify query execution event
411        logs_match!(|line| line.contains("executing_query"));
412
413        // Verify cache miss event
414        logs_match!(|line| line.contains("cache_miss"));
415
416        // Verify structured fields are present
417        logs_match!(|line| line.contains("table"));
418    }
419
420    #[traced_test]
421    #[tokio::test]
422    async fn test_spawn_macro_performance_tracing() {
423        let handle = spawn!("performance_task", async {
424            let start = std::time::Instant::now();
425
426            // Simulate some work
427            tokio::time::sleep(Duration::from_millis(50)).await;
428
429            let duration = start.elapsed();
430            tracing::info!(duration_ms = duration.as_millis(), "task_completed");
431
432            duration.as_millis()
433        });
434
435        let duration_ms = handle.await.unwrap();
436        assert!(duration_ms >= 50, "Task should take at least 50ms");
437
438        // Verify spawn event was logged
439        logs_match!(|line| line.contains("spawned_tokio_task"));
440
441        // Verify task completion event with duration
442        logs_match!(|line| line.contains("task_completed") && line.contains("duration_ms"));
443    }
444
445    #[traced_test]
446    #[tokio::test]
447    async fn test_spawn_macro_error_with_context() {
448        let outer_span = tracing::error_span!("error_context", operation = "critical_task");
449        let _guard = outer_span.enter();
450
451        let handle = spawn!("failing_task", async {
452            tracing::warn!(retry_count = 1, "attempting_risky_operation");
453            tracing::error!(
454                error_code = "E001",
455                message = "Operation failed",
456                "critical_error"
457            );
458            panic!("Critical failure occurred");
459        });
460
461        let result = handle.await;
462        assert!(result.is_err(), "Task should fail");
463
464        // Verify spawn event
465        logs_match!(|line| line.contains("spawned_tokio_task"));
466
467        logs_match!(|line| line.contains("attempting_risky_operation"));
468
469        // Verify critical error event
470        logs_assert(|lines| {
471            if lines
472                .iter()
473                .any(|line| line.contains("critical_error") && line.contains("error_code"))
474            {
475                Ok(())
476            } else {
477                Err("Critical error event with error_code not found in logs".into())
478            }
479        });
480    }
481}