monarch_tensor_worker/
test_util.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::io::IsTerminal;
10
11use anyhow::Result;
12use pyo3::Python;
13use pyo3::ffi::c_str;
14use tracing_subscriber::fmt::format::FmtSpan;
15
16pub fn test_setup() -> Result<()> {
17    let _ = tracing_subscriber::fmt()
18        .with_thread_ids(true)
19        .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
20        .with_max_level(tracing::Level::DEBUG)
21        .with_ansi(std::io::stderr().is_terminal())
22        .with_writer(std::io::stderr)
23        .try_init();
24
25    // Redirect NCCL_DEBUG log output to a file so it doesn't clash on stdout.
26    // TestX requires stdout to have JSON output on individual lines, and
27    // the NCCL output is not JSON. Because it runs in a different thread, it'll
28    // race on writing to stdout.
29    // Do this regardless of whether NCCL_DEBUG is set or not, because it can
30    // be set after this point in the test. If it doesn't get set, NCCL_DEBUG_FILE
31    // will be ignored.
32    // %h becomes hostname, %p becomes pid.
33    let nccl_debug_file = std::env::temp_dir().join("nccl_debug.%h.%p");
34    tracing::debug!("Set NCCL_DEBUG_FILE to {:?}", nccl_debug_file);
35    // Safety: Can be unsound if there are multiple threads
36    // reading and writing the environment.
37    unsafe {
38        std::env::set_var("NCCL_DEBUG_FILE", nccl_debug_file);
39    }
40    // NOTE(agallagher): Calling `prepare_freethreaded_python` appears to
41    // clear `PYTHONPATH` in the env, which we need for test subprocesses
42    // to work.  So, manually preserve it.
43    let py_path = std::env::var("PYTHONPATH");
44    pyo3::prepare_freethreaded_python();
45    if let Ok(py_path) = py_path {
46        // SAFETY: Re-setting env var cleard by `prepare_freethreaded_python`.
47        unsafe { std::env::set_var("PYTHONPATH", py_path) }
48    }
49
50    // We need to load torch to initialize some internal structures used by
51    // the FFI funcs we use to convert ivalues to/from py objects.
52    Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?;
53
54    Ok(())
55}