hyper/utils/
system_address.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::str::FromStr;
10
11use anyhow;
12use hyperactor::channel::ChannelAddr;
13
14/// Extended type to represent a system address which can be a ChannelAdd or a MAST job name.
15#[derive(Clone, Debug)]
16pub struct SystemAddr(ChannelAddr);
17
18impl From<SystemAddr> for ChannelAddr {
19    fn from(system_addr: SystemAddr) -> ChannelAddr {
20        system_addr.0
21    }
22}
23
24impl FromStr for SystemAddr {
25    type Err = anyhow::Error;
26
27    #[cfg(fbcode_build)]
28    fn from_str(s: &str) -> Result<Self, Self::Err> {
29        let handle = tokio::runtime::Handle::try_current()?;
30        tokio::task::block_in_place(|| handle.block_on(parse_system_address_or_mast_job(s)))
31            .map(Self)
32    }
33    #[cfg(not(fbcode_build))]
34    fn from_str(s: &str) -> Result<Self, Self::Err> {
35        ChannelAddr::from_str(s).map(SystemAddr)
36    }
37}
38
39/// Parse a system address or MAST job name into a ChannelAddr. If the address is a MAST job name,
40/// job definition will be fetched to extract the SMC tier, then SMC is queried to get the system
41/// address.
42#[cfg(fbcode_build)]
43async fn parse_system_address_or_mast_job(address: &str) -> Result<ChannelAddr, anyhow::Error> {
44    use hyperactor_meta_lib::system_resolution::SMCClient;
45    use hyperactor_meta_lib::system_resolution::canonicalize_hostname;
46
47    match ChannelAddr::from_str(address) {
48        Ok(addr) => Ok(addr),
49        Err(channel_err) => {
50            let smc_tier = match get_smc_tier(address).await {
51                Ok(Some(smc_tier)) => smc_tier,
52                // job is not found, return channel parse error.
53                Ok(None) => anyhow::bail!(
54                    "could not resolve system address from {}: {}",
55                    address,
56                    channel_err
57                ),
58                Err(e) => {
59                    anyhow::bail!(e);
60                }
61            };
62            let (host, port) = SMCClient::new(fbinit::expect_init(), smc_tier)?
63                .get_system_address()
64                .await?;
65            let channel_address = ChannelAddr::MetaTls(canonicalize_hostname(&host), port);
66            Ok(channel_address)
67        }
68    }
69}
70
71/// Get the SMC tier for a given MAST job name. Returns None if the job is not found.
72#[cfg(fbcode_build)]
73async fn get_smc_tier(job_name: &str) -> Result<Option<String>, anyhow::Error> {
74    use hpcscheduler;
75    use hpcscheduler_srclients;
76    use hpcscheduler_srclients::thrift;
77
78    /// This should match the key used in the MAST job definition when job was created.
79    /// For example: https://github.com/fairinternal/xlformers/blob/5db99239e7fa2cc08ca16232edc670b13003e172/core/monarch/mast.py#L446
80    static SMC_TIER_APPLICATION_METADATA_KEY: &str = "monarch_system_smc_tier";
81
82    let client = hpcscheduler_srclients::make_HpcSchedulerService_srclient!(
83        fbinit::expect_init(),
84        tiername = "mast.api.read"
85    )?;
86    let request = hpcscheduler::GetHpcJobStatusRequest {
87        hpcJobName: job_name.to_string(),
88        ..Default::default()
89    };
90    let response = match client.getHpcJobStatus(&request).await {
91        Ok(response) => response,
92        Err(thrift::errors::GetHpcJobStatusError::e(e)) => {
93            if e.errorCode == hpcscheduler::HpcSchedulerErrorCode::JOB_NOT_FOUND {
94                return Ok(None);
95            } else {
96                anyhow::bail!(e);
97            }
98        }
99        Err(e) => anyhow::bail!(e),
100    };
101    if response.state != hpcscheduler::HpcJobState::RUNNING {
102        anyhow::bail!("job {} is not running", job_name);
103    }
104    let request = hpcscheduler::GetHpcJobDefinitionRequest {
105        hpcJobName: job_name.to_string(),
106        ..Default::default()
107    };
108    let response = client.getHpcJobDefinition(&request).await?;
109    let metadata = match response.jobDefinition.applicationMetadata {
110        Some(metadata) => metadata,
111        None => anyhow::bail!("no application metadata found in job definition"),
112    };
113    match metadata.get(SMC_TIER_APPLICATION_METADATA_KEY) {
114        Some(smc_tier) => Ok(Some(smc_tier.to_string())),
115        None => anyhow::bail!("did not find smc tier in application metadata"),
116    }
117}