hyperactor_telemetry/
task.rs1#[inline]
10pub fn current_task_id() -> u64 {
11 tokio::task::try_id().map_or(0, |x| x.to_string().parse::<u64>().unwrap_or(0))
12}
13
14#[macro_export]
15macro_rules! spawn {
16 ($f:expr) => {{
17 $crate::spawn!(
18 concat!(file!(), ":", line!()),
19 tokio::runtime::Handle::current(),
20 $f
21 )
22 }};
23 ($name:expr, $f:expr) => {{ $crate::spawn!($name, tokio::runtime::Handle::current(), $f) }};
24 ($name:expr, $rt:expr, $f:expr) => {{
25 let current = tracing::span::Span::current().id();
26 let parent_task = $crate::task::current_task_id();
27 let ft = $f;
28 $rt.spawn(async move {
29 let span = tracing::debug_span!($name, parent_tokio_task_id = parent_task);
30 span.follows_from(current);
31 span.in_scope(|| tracing::debug!("spawned_tokio_task"));
32 $crate::tracing::Instrument::instrument(ft, span).await
33 })
34 }};
35}
36
37#[cfg(test)]
38mod tests {
39 use std::sync::Arc;
40 use std::sync::atomic::AtomicBool;
41 use std::sync::atomic::AtomicUsize;
42 use std::sync::atomic::Ordering;
43 use std::time::Duration;
44
45 use tokio::time::timeout;
46 use tracing_test::traced_test;
47
48 use super::*;
49
50 #[traced_test]
51 #[tokio::test]
52 async fn test_current_task_id_returns_valid_id() {
53 let handle = spawn!("test", async move {
54 let task_id = current_task_id();
55 assert!(task_id > 0, "Task ID should be greater than 0");
57 });
58 handle.await.unwrap();
59 }
60
61 #[traced_test]
62 #[tokio::test]
63 async fn test_current_task_id_different_tasks() {
64 let task1_id = Arc::new(std::sync::Mutex::new(0u64));
65 let task2_id = Arc::new(std::sync::Mutex::new(0u64));
66
67 let task1_id_clone = task1_id.clone();
68 let task2_id_clone = task2_id.clone();
69
70 let handle1 = crate::spawn!(async move {
71 *task1_id_clone.lock().unwrap() = current_task_id();
72 });
73
74 let handle2 = crate::spawn!(async move {
75 *task2_id_clone.lock().unwrap() = current_task_id();
76 });
77
78 handle1.await.unwrap();
79 handle2.await.unwrap();
80
81 let id1 = *task1_id.lock().unwrap();
82 let id2 = *task2_id.lock().unwrap();
83
84 assert!(id1 > 0, "Task 1 ID should be greater than 0");
85 assert!(id2 > 0, "Task 2 ID should be greater than 0");
86 assert_ne!(id1, id2, "Different tasks should have different IDs");
87 }
88
89 #[traced_test]
90 #[tokio::test]
91 async fn test_spawn_macro_basic_functionality() {
92 let completed = Arc::new(AtomicBool::new(false));
93 let completed_clone = completed.clone();
94
95 let handle = spawn!("test_task", async move {
96 completed_clone.store(true, Ordering::SeqCst);
97 42
98 });
99
100 let result = handle.await.unwrap();
101 assert_eq!(result, 42);
102 assert!(completed.load(Ordering::SeqCst));
103 }
104
105 #[traced_test]
106 #[tokio::test]
107 async fn test_spawn_macro_with_runtime_handle() {
108 let rt = tokio::runtime::Handle::current();
109 let completed = Arc::new(AtomicBool::new(false));
110 let completed_clone = completed.clone();
111
112 let handle = spawn!("test_task_with_rt", rt, async move {
113 completed_clone.store(true, Ordering::SeqCst);
114 "success"
115 });
116
117 let result = handle.await.unwrap();
118 assert_eq!(result, "success");
119 assert!(completed.load(Ordering::SeqCst));
120 }
121
122 #[traced_test]
123 #[tokio::test]
124 async fn test_spawn_macro_with_async_operation() {
125 let handle = spawn!("async_operation", async {
126 tokio::time::sleep(Duration::from_millis(10)).await;
127 "async_result"
128 });
129
130 let result = timeout(Duration::from_secs(1), handle)
131 .await
132 .expect("Task should complete within timeout")
133 .expect("Task should not panic");
134
135 assert_eq!(result, "async_result");
136 }
137
138 #[traced_test]
139 #[tokio::test]
140 async fn test_spawn_macro_error_handling() {
141 let handle = spawn!("error_task", async {
142 panic!("intentional panic");
143 });
144
145 let result = handle.await;
146 assert!(result.is_err(), "Task should panic and return an error");
147 }
148
149 #[traced_test]
150 #[tokio::test]
151 async fn test_spawn_macro_multiple_tasks() {
152 let num_tasks = 5;
153 let completed_count = Arc::new(AtomicUsize::new(0));
154
155 let mut handles = Vec::new();
156 for i in 0..num_tasks {
157 let count_clone = completed_count.clone();
158 let handle = spawn!("parallel_task", async move {
159 count_clone.fetch_add(1, Ordering::SeqCst);
160 i
161 });
162 handles.push(handle);
163 }
164
165 let mut results = Vec::new();
167 for handle in handles {
168 let result = handle.await.expect("Task should complete successfully");
169 results.push(result);
170 }
171
172 assert_eq!(results.len(), num_tasks);
173 assert_eq!(completed_count.load(Ordering::SeqCst), num_tasks);
174
175 let mut sorted_results = results;
177 sorted_results.sort();
178 let expected: Vec<usize> = (0..num_tasks).collect();
179 assert_eq!(sorted_results, expected);
180 }
181
182 macro_rules! logs_match {
183 ($expr:expr) => {
184 logs_match!($expr, format!("{} not in logs", stringify!($expr)));
185 };
186 ($expr:expr, $msg:expr) => {
187 logs_assert(|lines| {
188 if lines.iter().any($expr) {
189 Ok(())
190 } else {
191 Err($msg.into())
192 }
193 })
194 };
195 }
196
197 #[traced_test]
198 #[tokio::test]
199 async fn test_spawn_macro_creates_proper_span() {
200 let completed = Arc::new(AtomicBool::new(false));
201 let completed_clone = completed.clone();
202
203 let parent_span = tracing::debug_span!("parent_span");
204 let _guard = parent_span.enter();
205
206 let handle = spawn!("child_task", async move {
207 tracing::debug!(task_data = "test_value", "task_execution");
208 completed_clone.store(true, Ordering::SeqCst);
209 "completed"
210 });
211
212 let result = handle.await.unwrap();
213 assert_eq!(result, "completed");
214 assert!(completed.load(Ordering::SeqCst));
215
216 logs_match!(
218 |line| line.contains("spawned_tokio_task"),
219 "task logging never occured"
220 );
221
222 logs_match!(|line| line.contains("task_execution"));
224 }
225
226 #[traced_test]
227 #[tokio::test]
228 async fn test_spawn_macro_preserves_parent_context() {
229 let parent_span = tracing::debug_span!("parent", operation = "test_context");
230 let _guard = parent_span.enter();
231
232 let handle = spawn!("context_child", async {
233 tracing::debug!(child_data = "value", "child_operation");
234 "context_preserved"
235 });
236
237 let result = handle.await.unwrap();
238 assert_eq!(result, "context_preserved");
239
240 logs_match!(|line| line.contains("spawned_tokio_task"));
242
243 logs_match!(|line| line.contains("child_operation"));
245 }
246
247 #[traced_test]
248 #[tokio::test]
249 async fn test_spawn_macro_with_instrumentation() {
250 let handle = spawn!("instrumented_task", async {
251 tracing::info!("inside_spawned_task");
252 42
253 });
254
255 let result = handle.await.unwrap();
256 assert_eq!(result, 42);
257
258 logs_match!(|line| line.contains("spawned_tokio_task"));
260
261 logs_match!(|line| line.contains("inside_spawned_task"));
263 }
264
265 #[traced_test]
266 #[tokio::test]
267 async fn test_spawn_macro_span_hierarchy() {
268 let outer_span = tracing::info_span!("outer_operation", test_id = 123);
269 let _outer_guard = outer_span.enter();
270
271 let handle = spawn!("nested_task", async {
272 let inner_span = tracing::debug_span!("inner_operation", step = "processing");
273 let _inner_guard = inner_span.enter();
274
275 tracing::warn!(data_size = 1024, "processing_data");
276 "nested_complete"
277 });
278
279 let result = handle.await.unwrap();
280 assert_eq!(result, "nested_complete");
281
282 logs_match!(|line| line.contains("spawned_tokio_task"));
284
285 logs_match!(|line| line.contains("processing_data"));
287
288 logs_match!(|line| line.contains("data_size"));
290 }
291
292 #[traced_test]
293 #[tokio::test]
294 async fn test_spawn_macro_error_tracing() {
295 let handle = spawn!("error_prone_task", async {
296 tracing::error!(reason = "intentional", "task_about_to_fail");
297 panic!("deliberate failure");
298 });
299
300 let result = handle.await;
301 assert!(result.is_err(), "Task should fail");
302
303 logs_match!(|line| line.contains("spawned_tokio_task"));
305
306 logs_match!(|line| line.contains("task_about_to_fail"));
308 }
309
310 #[traced_test]
311 #[tokio::test]
312 async fn test_spawn_macro_concurrent_tracing() {
313 let barrier = Arc::new(tokio::sync::Barrier::new(3));
314
315 let handles = (0..3)
316 .map(|i| {
317 let barrier = barrier.clone();
318 spawn!("concurrent_task", async move {
319 barrier.wait().await;
320 tracing::info!(task_num = i, "concurrent_execution");
321 i * 10
322 })
323 })
324 .collect::<Vec<_>>();
325
326 let mut results = Vec::new();
328 for handle in handles {
329 let result = handle.await.expect("Task should complete");
330 results.push(result);
331 }
332
333 results.sort();
334 assert_eq!(results, vec![0, 10, 20]);
335
336 logs_assert(|lines| {
338 let spawn_count = lines
339 .iter()
340 .filter(|line| line.contains("spawned_tokio_task"))
341 .count();
342 match spawn_count {
343 3.. => Ok(()),
344 _ => Err("wrong count".into()),
345 }
346 });
347
348 logs_assert(|lines| {
350 let exec_count = lines
351 .iter()
352 .filter(|line| line.contains("concurrent_execution"))
353 .count();
354 if exec_count >= 3 {
355 Ok(())
356 } else {
357 Err(format!(
358 "Expected at least 3 concurrent execution events, found {}",
359 exec_count
360 ))
361 }
362 });
363 }
364
365 #[traced_test]
366 #[tokio::test]
367 async fn test_spawn_macro_with_fields() {
368 let handle = spawn!("field_task", async {
369 tracing::info!(user_id = 42, session = "abc123", "user_action");
370 "field_test_complete"
371 });
372
373 let result = handle.await.unwrap();
374 assert_eq!(result, "field_test_complete");
375
376 logs_match!(|line| line.contains("spawned_tokio_task"));
378
379 logs_match!(|line| line.contains("user_action")
381 && line.contains("user_id")
382 && line.contains("session"));
383 }
384
385 #[traced_test]
386 #[tokio::test]
387 async fn test_spawn_macro_nested_spans() {
388 let outer_span = tracing::info_span!("request_handler", request_id = "req-123");
389 let _outer_guard = outer_span.enter();
390
391 let handle = spawn!("database_query", async {
392 let db_span = tracing::debug_span!("db_operation", table = "users");
393 let _db_guard = db_span.enter();
394
395 tracing::debug!(query = "SELECT * FROM users", "executing_query");
396
397 let cache_span = tracing::debug_span!("cache_check", cache_key = "user:42");
398 let _cache_guard = cache_span.enter();
399
400 tracing::debug!("cache_miss");
401
402 "query_complete"
403 });
404
405 let result = handle.await.unwrap();
406 assert_eq!(result, "query_complete");
407 logs_match!(|line| line.contains("spawned_tokio_task"));
409
410 logs_match!(|line| line.contains("executing_query"));
412
413 logs_match!(|line| line.contains("cache_miss"));
415
416 logs_match!(|line| line.contains("table"));
418 }
419
420 #[traced_test]
421 #[tokio::test]
422 async fn test_spawn_macro_performance_tracing() {
423 let handle = spawn!("performance_task", async {
424 let start = std::time::Instant::now();
425
426 tokio::time::sleep(Duration::from_millis(50)).await;
428
429 let duration = start.elapsed();
430 tracing::info!(duration_ms = duration.as_millis(), "task_completed");
431
432 duration.as_millis()
433 });
434
435 let duration_ms = handle.await.unwrap();
436 assert!(duration_ms >= 50, "Task should take at least 50ms");
437
438 logs_match!(|line| line.contains("spawned_tokio_task"));
440
441 logs_match!(|line| line.contains("task_completed") && line.contains("duration_ms"));
443 }
444
445 #[traced_test]
446 #[tokio::test]
447 async fn test_spawn_macro_error_with_context() {
448 let outer_span = tracing::error_span!("error_context", operation = "critical_task");
449 let _guard = outer_span.enter();
450
451 let handle = spawn!("failing_task", async {
452 tracing::warn!(retry_count = 1, "attempting_risky_operation");
453 tracing::error!(
454 error_code = "E001",
455 message = "Operation failed",
456 "critical_error"
457 );
458 panic!("Critical failure occurred");
459 });
460
461 let result = handle.await;
462 assert!(result.is_err(), "Task should fail");
463
464 logs_match!(|line| line.contains("spawned_tokio_task"));
466
467 logs_match!(|line| line.contains("attempting_risky_operation"));
468
469 logs_assert(|lines| {
471 if lines
472 .iter()
473 .any(|line| line.contains("critical_error") && line.contains("error_code"))
474 {
475 Ok(())
476 } else {
477 Err("Critical error event with error_code not found in logs".into())
478 }
479 });
480 }
481}