hyper/utils/
system_address.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::str::FromStr;
10
11use anyhow;
12use hyperactor::channel::ChannelAddr;
13use hyperactor::channel::MetaTlsAddr;
14
15/// Extended type to represent a system address which can be a ChannelAdd or a MAST job name.
16#[derive(Clone, Debug)]
17pub struct SystemAddr(ChannelAddr);
18
19impl From<SystemAddr> for ChannelAddr {
20    fn from(system_addr: SystemAddr) -> ChannelAddr {
21        system_addr.0
22    }
23}
24
25impl FromStr for SystemAddr {
26    type Err = anyhow::Error;
27
28    #[cfg(fbcode_build)]
29    fn from_str(s: &str) -> Result<Self, Self::Err> {
30        let handle = tokio::runtime::Handle::try_current()?;
31        tokio::task::block_in_place(|| handle.block_on(parse_system_address_or_mast_job(s)))
32            .map(Self)
33    }
34    #[cfg(not(fbcode_build))]
35    fn from_str(s: &str) -> Result<Self, Self::Err> {
36        ChannelAddr::from_str(s).map(SystemAddr)
37    }
38}
39
40/// Parse a system address or MAST job name into a ChannelAddr. If the address is a MAST job name,
41/// job definition will be fetched to extract the SMC tier, then SMC is queried to get the system
42/// address.
43#[cfg(fbcode_build)]
44async fn parse_system_address_or_mast_job(address: &str) -> Result<ChannelAddr, anyhow::Error> {
45    use hyperactor_meta_lib::system_resolution::SMCClient;
46    use hyperactor_meta_lib::system_resolution::canonicalize_hostname;
47
48    match ChannelAddr::from_str(address) {
49        Ok(addr) => Ok(addr),
50        Err(channel_err) => {
51            let smc_tier = match get_smc_tier(address).await {
52                Ok(Some(smc_tier)) => smc_tier,
53                // job is not found, return channel parse error.
54                Ok(None) => anyhow::bail!(
55                    "could not resolve system address from {}: {}",
56                    address,
57                    channel_err
58                ),
59                Err(e) => {
60                    anyhow::bail!(e);
61                }
62            };
63            let (host, port) = SMCClient::new(fbinit::expect_init(), smc_tier)?
64                .get_system_address()
65                .await?;
66            let channel_address = ChannelAddr::MetaTls(MetaTlsAddr::Host {
67                hostname: canonicalize_hostname(&host),
68                port,
69            });
70            Ok(channel_address)
71        }
72    }
73}
74
75/// Get the SMC tier for a given MAST job name. Returns None if the job is not found.
76#[cfg(fbcode_build)]
77async fn get_smc_tier(job_name: &str) -> Result<Option<String>, anyhow::Error> {
78    use hpcscheduler;
79    use hpcscheduler_srclients;
80    use hpcscheduler_srclients::thrift;
81
82    /// This should match the key used in the MAST job definition when job was created.
83    /// For example: https://github.com/fairinternal/xlformers/blob/5db99239e7fa2cc08ca16232edc670b13003e172/core/monarch/mast.py#L446
84    static SMC_TIER_APPLICATION_METADATA_KEY: &str = "monarch_system_smc_tier";
85
86    let client = hpcscheduler_srclients::make_HpcSchedulerService_srclient!(
87        fbinit::expect_init(),
88        tiername = "mast.api.read"
89    )?;
90    let request = hpcscheduler::GetHpcJobStatusRequest {
91        hpcJobName: job_name.to_string(),
92        ..Default::default()
93    };
94    let response = match client.getHpcJobStatus(&request).await {
95        Ok(response) => response,
96        Err(thrift::errors::GetHpcJobStatusError::e(e)) => {
97            if e.errorCode == hpcscheduler::HpcSchedulerErrorCode::JOB_NOT_FOUND {
98                return Ok(None);
99            } else {
100                anyhow::bail!(e);
101            }
102        }
103        Err(e) => anyhow::bail!(e),
104    };
105    if response.state != hpcscheduler::HpcJobState::RUNNING {
106        anyhow::bail!("job {} is not running", job_name);
107    }
108    let request = hpcscheduler::GetHpcJobDefinitionRequest {
109        hpcJobName: job_name.to_string(),
110        ..Default::default()
111    };
112    let response = client.getHpcJobDefinition(&request).await?;
113    let metadata = match response.jobDefinition.applicationMetadata {
114        Some(metadata) => metadata,
115        None => anyhow::bail!("no application metadata found in job definition"),
116    };
117    match metadata.get(SMC_TIER_APPLICATION_METADATA_KEY) {
118        Some(smc_tier) => Ok(Some(smc_tier.to_string())),
119        None => anyhow::bail!("did not find smc tier in application metadata"),
120    }
121}