monarch_tensor_worker/test_util.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::io::IsTerminal;
10
11use anyhow::Result;
12use pyo3::Python;
13use pyo3::ffi::c_str;
14use tracing_subscriber::fmt::format::FmtSpan;
15
16pub fn test_setup() -> Result<()> {
17 let _ = tracing_subscriber::fmt()
18 .with_thread_ids(true)
19 .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
20 .with_max_level(tracing::Level::DEBUG)
21 .with_ansi(std::io::stderr().is_terminal())
22 .with_writer(std::io::stderr)
23 .try_init();
24
25 // Redirect NCCL_DEBUG log output to a file so it doesn't clash on stdout.
26 // TestX requires stdout to have JSON output on individual lines, and
27 // the NCCL output is not JSON. Because it runs in a different thread, it'll
28 // race on writing to stdout.
29 // Do this regardless of whether NCCL_DEBUG is set or not, because it can
30 // be set after this point in the test. If it doesn't get set, NCCL_DEBUG_FILE
31 // will be ignored.
32 // %h becomes hostname, %p becomes pid.
33 let nccl_debug_file = std::env::temp_dir().join("nccl_debug.%h.%p");
34 tracing::debug!("Set NCCL_DEBUG_FILE to {:?}", nccl_debug_file);
35 // Safety: Can be unsound if there are multiple threads
36 // reading and writing the environment.
37 unsafe {
38 std::env::set_var("NCCL_DEBUG_FILE", nccl_debug_file);
39 }
40 // NOTE(agallagher): Calling `prepare_freethreaded_python` appears to
41 // clear `PYTHONPATH` in the env, which we need for test subprocesses
42 // to work. So, manually preserve it.
43 let py_path = std::env::var("PYTHONPATH");
44 pyo3::prepare_freethreaded_python();
45 if let Ok(py_path) = py_path {
46 // SAFETY: Re-setting env var cleard by `prepare_freethreaded_python`.
47 unsafe { std::env::set_var("PYTHONPATH", py_path) }
48 }
49
50 // We need to load torch to initialize some internal structures used by
51 // the FFI funcs we use to convert ivalues to/from py objects.
52 Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?;
53
54 Ok(())
55}