hyperactor_mesh/
namespace.rs1use std::collections::HashMap;
30use std::sync::RwLock;
31
32use async_trait::async_trait;
33use hyperactor::actor::Referable;
34use serde::Serialize;
35use serde::de::DeserializeOwned;
36
37use crate::ActorMeshRef;
38use crate::HostMeshRef;
39use crate::ProcMeshRef;
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub enum MeshKind {
44 Host,
45 Proc,
46 Actor,
47}
48
49impl MeshKind {
50 pub fn as_str(&self) -> &'static str {
52 match self {
53 MeshKind::Host => "host",
54 MeshKind::Proc => "proc",
55 MeshKind::Actor => "actor",
56 }
57 }
58}
59
60impl std::fmt::Display for MeshKind {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.as_str())
63 }
64}
65
66#[derive(Debug, thiserror::Error)]
68pub enum NamespaceError {
69 #[error("serialization failed: {0}")]
70 SerializationError(String),
71 #[error("deserialization failed: {0}")]
72 DeserializationError(String),
73 #[error("operation failed: {0}")]
74 OperationError(String),
75 #[error("not found: {0}")]
76 NotFound(String),
77}
78
79pub trait Registrable: Serialize + DeserializeOwned + Send + Sync {
85 fn kind() -> MeshKind;
87}
88
89impl Registrable for HostMeshRef {
90 fn kind() -> MeshKind {
91 MeshKind::Host
92 }
93}
94
95impl Registrable for ProcMeshRef {
96 fn kind() -> MeshKind {
97 MeshKind::Proc
98 }
99}
100
101impl<A: Referable> Registrable for ActorMeshRef<A> {
102 fn kind() -> MeshKind {
103 MeshKind::Actor
104 }
105}
106
107#[async_trait]
115pub trait Namespace {
116 fn name(&self) -> &str;
118
119 async fn register<T: Registrable>(&self, name: &str, mesh: &T) -> Result<(), NamespaceError>;
124
125 async fn get<T: Registrable>(&self, name: &str) -> Result<T, NamespaceError>;
129
130 async fn unregister<T: Registrable>(&self, name: &str) -> Result<(), NamespaceError>;
132
133 async fn contains<T: Registrable>(&self, name: &str) -> Result<bool, NamespaceError>;
135}
136
137#[derive(Debug)]
139pub struct InMemoryNamespace {
140 namespace_name: String,
141 data: RwLock<HashMap<String, Vec<u8>>>,
142}
143
144impl InMemoryNamespace {
145 pub fn new(name: impl Into<String>) -> Self {
147 Self {
148 namespace_name: name.into(),
149 data: RwLock::new(HashMap::new()),
150 }
151 }
152
153 fn full_key(&self, kind: MeshKind, name: &str) -> String {
155 format!("{}.{}.{}", self.namespace_name, kind.as_str(), name)
156 }
157}
158
159#[async_trait]
160impl Namespace for InMemoryNamespace {
161 fn name(&self) -> &str {
162 &self.namespace_name
163 }
164
165 async fn register<T: Registrable>(&self, name: &str, mesh: &T) -> Result<(), NamespaceError> {
166 let data = serde_json::to_vec(mesh)
167 .map_err(|e| NamespaceError::SerializationError(e.to_string()))?;
168 let key = self.full_key(T::kind(), name);
169 self.data
170 .write()
171 .map_err(|e| NamespaceError::OperationError(e.to_string()))?
172 .insert(key.clone(), data);
173 tracing::debug!(
174 key = %key,
175 "registered mesh to in-memory namespace"
176 );
177 Ok(())
178 }
179
180 async fn get<T: Registrable>(&self, name: &str) -> Result<T, NamespaceError> {
181 let key = self.full_key(T::kind(), name);
182 let data = self
183 .data
184 .read()
185 .map_err(|e| NamespaceError::OperationError(e.to_string()))?
186 .get(&key)
187 .cloned()
188 .ok_or(NamespaceError::NotFound(key))?;
189 serde_json::from_slice(&data)
190 .map_err(|e| NamespaceError::DeserializationError(e.to_string()))
191 }
192
193 async fn unregister<T: Registrable>(&self, name: &str) -> Result<(), NamespaceError> {
194 let key = self.full_key(T::kind(), name);
195 self.data
196 .write()
197 .map_err(|e| NamespaceError::OperationError(e.to_string()))?
198 .remove(&key);
199 tracing::debug!(
200 key = %key,
201 "unregistered mesh from in-memory namespace"
202 );
203 Ok(())
204 }
205
206 async fn contains<T: Registrable>(&self, name: &str) -> Result<bool, NamespaceError> {
207 let key = self.full_key(T::kind(), name);
208 Ok(self
209 .data
210 .read()
211 .map_err(|e| NamespaceError::OperationError(e.to_string()))?
212 .contains_key(&key))
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use std::str::FromStr;
219
220 use super::*;
221
222 fn make_host_mesh_ref(name: &str) -> HostMeshRef {
223 let s = format!("{}:tcp:127.0.0.1:1234,tcp:127.0.0.1:1235@replica=2/1", name);
225 HostMeshRef::from_str(&s).unwrap()
226 }
227
228 #[tokio::test]
229 async fn test_register_and_get() {
230 let ns = InMemoryNamespace::new("test.namespace");
231
232 let mesh = make_host_mesh_ref("test_mesh");
233
234 ns.register("my_hosts", &mesh).await.unwrap();
236
237 let retrieved: HostMeshRef = ns.get("my_hosts").await.unwrap();
239 assert_eq!(retrieved, mesh);
240 }
241
242 #[tokio::test]
243 async fn test_contains() {
244 let ns = InMemoryNamespace::new("test.namespace");
245
246 let mesh = make_host_mesh_ref("workers");
247
248 assert!(!ns.contains::<HostMeshRef>("my_hosts").await.unwrap());
250
251 ns.register("my_hosts", &mesh).await.unwrap();
253
254 assert!(ns.contains::<HostMeshRef>("my_hosts").await.unwrap());
256
257 assert!(!ns.contains::<HostMeshRef>("other").await.unwrap());
259 }
260
261 #[tokio::test]
262 async fn test_unregister() {
263 let ns = InMemoryNamespace::new("test.namespace");
264
265 let mesh = make_host_mesh_ref("workers");
266
267 ns.register("my_hosts", &mesh).await.unwrap();
268 assert!(ns.contains::<HostMeshRef>("my_hosts").await.unwrap());
269
270 ns.unregister::<HostMeshRef>("my_hosts").await.unwrap();
272 assert!(!ns.contains::<HostMeshRef>("my_hosts").await.unwrap());
273
274 let result: Result<HostMeshRef, _> = ns.get("my_hosts").await;
276 assert!(result.is_err());
277 }
278
279 #[tokio::test]
280 async fn test_get_not_found() {
281 let ns = InMemoryNamespace::new("test.namespace");
282
283 let result: Result<HostMeshRef, _> = ns.get("nonexistent").await;
284 assert!(matches!(result, Err(NamespaceError::NotFound(_))));
285 }
286
287 #[tokio::test]
288 async fn test_multiple_meshes() {
289 let ns = InMemoryNamespace::new("test");
290
291 let mesh1 = make_host_mesh_ref("mesh1");
292 let mesh2 = make_host_mesh_ref("mesh2");
293
294 ns.register("hosts_a", &mesh1).await.unwrap();
296 ns.register("hosts_b", &mesh2).await.unwrap();
297
298 let retrieved1: HostMeshRef = ns.get("hosts_a").await.unwrap();
300 let retrieved2: HostMeshRef = ns.get("hosts_b").await.unwrap();
301
302 assert_eq!(retrieved1, mesh1);
303 assert_eq!(retrieved2, mesh2);
304 }
305
306 #[tokio::test]
307 async fn test_overwrite_registration() {
308 let ns = InMemoryNamespace::new("test");
309
310 let mesh1 = make_host_mesh_ref("mesh1");
311 let mesh2 = make_host_mesh_ref("mesh2");
312
313 ns.register("hosts", &mesh1).await.unwrap();
315 let retrieved: HostMeshRef = ns.get("hosts").await.unwrap();
316 assert_eq!(retrieved, mesh1);
317
318 ns.register("hosts", &mesh2).await.unwrap();
320 let retrieved: HostMeshRef = ns.get("hosts").await.unwrap();
321 assert_eq!(retrieved, mesh2);
322 }
323
324 #[test]
325 fn test_mesh_kind_as_str() {
326 assert_eq!(MeshKind::Host.as_str(), "host");
327 assert_eq!(MeshKind::Proc.as_str(), "proc");
328 assert_eq!(MeshKind::Actor.as_str(), "actor");
329 }
330
331 #[test]
332 fn test_name() {
333 let ns = InMemoryNamespace::new("my.namespace");
334 assert_eq!(ns.name(), "my.namespace");
335 }
336
337 #[test]
338 fn test_registrable_impl_for_host_mesh_ref() {
339 assert_eq!(HostMeshRef::kind(), MeshKind::Host);
340 }
341
342 #[test]
343 fn test_registrable_impl_for_proc_mesh_ref() {
344 assert_eq!(ProcMeshRef::kind(), MeshKind::Proc);
345 }
346}