1use std::any::type_name;
10use std::collections::HashMap;
11use std::fmt;
12use std::ops::Deref;
13use std::sync::Arc;
14
15use hyperactor::Actor;
16use hyperactor::ActorId;
17use hyperactor::ActorRef;
18use hyperactor::Named;
19use hyperactor::ProcId;
20use hyperactor::RemoteMessage;
21use hyperactor::actor::Referable;
22use hyperactor::actor::remote::Remote;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::context;
26use hyperactor::mailbox::DialMailboxRouter;
27use hyperactor::mailbox::MailboxServer;
28use ndslice::Extent;
29use ndslice::ViewExt as _;
30use ndslice::view;
31use ndslice::view::CollectMeshExt;
32use ndslice::view::MapIntoExt;
33use ndslice::view::Ranked;
34use ndslice::view::Region;
35use serde::Deserialize;
36use serde::Serialize;
37
38use crate::CommActor;
39use crate::alloc::Alloc;
40use crate::alloc::AllocExt;
41use crate::alloc::AllocatedProc;
42use crate::assign::Ranks;
43use crate::comm::CommActorMode;
44use crate::proc_mesh::mesh_agent;
45use crate::proc_mesh::mesh_agent::ActorState;
46use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
47use crate::proc_mesh::mesh_agent::ProcMeshAgent;
48use crate::proc_mesh::mesh_agent::ReconfigurableMailboxSender;
49use crate::resource;
50use crate::resource::RankedValues;
51use crate::v1;
52use crate::v1::ActorMesh;
53use crate::v1::ActorMeshRef;
54use crate::v1::Error;
55use crate::v1::HostMeshRef;
56use crate::v1::Name;
57use crate::v1::ValueMesh;
58
59#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
61pub struct ProcRef {
62 proc_id: ProcId,
63 create_rank: usize,
65 agent: ActorRef<ProcMeshAgent>,
67}
68
69impl ProcRef {
70 pub(crate) fn new(proc_id: ProcId, create_rank: usize, agent: ActorRef<ProcMeshAgent>) -> Self {
71 Self {
72 proc_id,
73 create_rank,
74 agent,
75 }
76 }
77
78 pub(crate) async fn status(&self, cx: &impl context::Actor) -> v1::Result<bool> {
81 let (port, mut rx) = cx.mailbox().open_port();
82 self.agent
83 .status(cx, port.bind())
84 .await
85 .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e))?;
86 loop {
87 let (rank, status) = rx
88 .recv()
89 .await
90 .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
91 if rank == self.create_rank {
92 break Ok(status);
93 }
94 }
95 }
96
97 #[allow(dead_code)]
99 async fn actor_state(
100 &self,
101 cx: &impl context::Actor,
102 name: Name,
103 ) -> v1::Result<resource::State<ActorState>> {
104 let (port, mut rx) = cx.mailbox().open_port::<resource::State<ActorState>>();
105 self.agent
106 .send(
107 cx,
108 resource::GetState::<ActorState> {
109 name: name.clone(),
110 reply: port.bind(),
111 },
112 )
113 .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
114 let state = rx
115 .recv()
116 .await
117 .map_err(|e| Error::CallError(self.agent.actor_id().clone(), e.into()))?;
118 if let Some(ref inner) = state.state {
119 let rank = inner.create_rank;
120 if rank == self.create_rank {
121 Ok(state)
122 } else {
123 Err(Error::CallError(
124 self.agent.actor_id().clone(),
125 anyhow::anyhow!(
126 "Rank on mesh agent not matching for Actor {}: returned {}, expected {}",
127 name,
128 rank,
129 self.create_rank
130 ),
131 ))
132 }
133 } else {
134 Err(Error::CallError(
135 self.agent.actor_id().clone(),
136 anyhow::anyhow!("Actor {} does not exist", name),
137 ))
138 }
139 }
140
141 pub(crate) fn actor_id(&self, name: &Name) -> ActorId {
142 self.proc_id.actor_id(name.to_string(), 0)
143 }
144
145 pub(crate) fn attest<A: Referable>(&self, name: &Name) -> ActorRef<A> {
148 ActorRef::attest(self.actor_id(name))
149 }
150}
151
152#[allow(dead_code)]
154#[derive(Debug)]
155pub struct ProcMesh {
156 name: Name,
157 allocation: ProcMeshAllocation,
158 comm_actor_name: Option<Name>,
159 current_ref: ProcMeshRef,
160}
161
162impl ProcMesh {
163 async fn create(
164 cx: &impl context::Actor,
165 name: Name,
166 allocation: ProcMeshAllocation,
167 spawn_comm_actor: bool,
168 ) -> v1::Result<Self> {
169 let comm_actor_name = if spawn_comm_actor {
170 Some(Name::new("comm"))
171 } else {
172 None
173 };
174
175 let region = allocation.extent().clone().into();
176 let ranks = allocation.ranks();
177 let root_comm_actor = comm_actor_name.as_ref().map(|name| {
178 ActorRef::attest(
179 ranks
180 .first()
181 .expect("root mesh cannot be empty")
182 .actor_id(name),
183 )
184 });
185 let current_ref = ProcMeshRef::new(
186 name.clone(),
187 region,
188 ranks,
189 None, None, )
192 .unwrap();
193
194 let mut proc_mesh = Self {
195 name,
196 allocation,
197 comm_actor_name: comm_actor_name.clone(),
198 current_ref,
199 };
200
201 if let Some(comm_actor_name) = comm_actor_name {
202 let comm_actor_mesh = proc_mesh
205 .spawn_with_name::<CommActor>(cx, comm_actor_name, &Default::default())
206 .await?;
207 let address_book: HashMap<_, _> = comm_actor_mesh
208 .iter()
209 .map(|(point, actor_ref)| (point.rank(), actor_ref))
210 .collect();
211 for (rank, comm_actor) in &address_book {
214 comm_actor
215 .send(cx, CommActorMode::Mesh(*rank, address_book.clone()))
216 .map_err(|e| Error::SendingError(comm_actor.actor_id().clone(), Box::new(e)))?
217 }
218
219 proc_mesh.current_ref.root_comm_actor = root_comm_actor;
221 }
222
223 Ok(proc_mesh)
224 }
225
226 pub(crate) async fn create_owned_unchecked(
227 cx: &impl context::Actor,
228 name: Name,
229 extent: Extent,
230 hosts: HostMeshRef,
231 ranks: Vec<ProcRef>,
232 ) -> v1::Result<Self> {
233 Self::create(
234 cx,
235 name,
236 ProcMeshAllocation::Owned {
237 hosts,
238 extent,
239 ranks: Arc::new(ranks),
240 },
241 true,
242 )
243 .await
244 }
245
246 pub async fn allocate(
250 cx: &impl context::Actor,
251 mut alloc: Box<dyn Alloc + Send + Sync + 'static>,
252 name: &str,
253 ) -> v1::Result<Self> {
254 let running = alloc.initialize().await?;
255
256 let proc = cx.instance().proc();
262
263 let (proc_channel_addr, rx) = channel::serve(ChannelAddr::any(alloc.transport()))?;
265 proc.clone().serve(rx);
266
267 let bind_allocated_procs = |router: &DialMailboxRouter| {
268 for AllocatedProc { proc_id, addr, .. } in running.iter() {
270 if proc_id.is_direct() {
271 continue;
272 }
273 router.bind(proc_id.clone().into(), addr.clone());
274 }
275 };
276
277 if let Some(router) = proc.forwarder().downcast_ref() {
282 bind_allocated_procs(router);
283 } else if let Some(router) = proc
284 .forwarder()
285 .downcast_ref::<ReconfigurableMailboxSender>()
286 {
287 bind_allocated_procs(
288 router
289 .as_inner()
290 .map_err(|_| Error::UnroutableMesh())?
291 .as_configured()
292 .ok_or(Error::UnroutableMesh())?
293 .downcast_ref()
294 .ok_or(Error::UnroutableMesh())?,
295 );
296 } else {
297 return Err(Error::UnroutableMesh());
298 }
299
300 let address_book: HashMap<_, _> = running
303 .iter()
304 .map(
305 |AllocatedProc {
306 addr, mesh_agent, ..
307 }| { (mesh_agent.actor_id().proc_id().clone(), addr.clone()) },
308 )
309 .collect();
310
311 let (config_handle, mut config_receiver) = cx.mailbox().open_port();
312 for (rank, AllocatedProc { mesh_agent, .. }) in running.iter().enumerate() {
313 mesh_agent
314 .configure(
315 cx,
316 rank,
317 proc_channel_addr.clone(),
318 None, address_book.clone(),
320 config_handle.bind(),
321 true,
322 )
323 .await
324 .map_err(Error::ConfigurationError)?;
325 }
326 let mut completed = Ranks::new(running.len());
327 while !completed.is_full() {
328 let rank = config_receiver
329 .recv()
330 .await
331 .map_err(|err| Error::ConfigurationError(err.into()))?;
332 if completed.insert(rank, rank).is_some() {
333 tracing::warn!("multiple completions received for rank {}", rank);
334 }
335 }
336
337 let ranks: Vec<_> = running
338 .into_iter()
339 .enumerate()
340 .map(|(create_rank, allocated)| ProcRef {
341 proc_id: allocated.proc_id,
342 create_rank,
343 agent: allocated.mesh_agent,
344 })
345 .collect();
346
347 Self::create(
348 cx,
349 Name::new(name),
350 ProcMeshAllocation::Allocated {
351 alloc,
352 ranks: Arc::new(ranks),
353 },
354 true, )
356 .await
357 }
358}
359
360impl Deref for ProcMesh {
361 type Target = ProcMeshRef;
362
363 fn deref(&self) -> &Self::Target {
364 &self.current_ref
365 }
366}
367
368enum ProcMeshAllocation {
370 Allocated {
372 alloc: Box<dyn Alloc + Send + Sync + 'static>,
375
376 ranks: Arc<Vec<ProcRef>>,
378 },
379
380 Owned {
382 hosts: HostMeshRef,
384 extent: Extent,
387 ranks: Arc<Vec<ProcRef>>,
389 },
390}
391
392impl ProcMeshAllocation {
393 fn extent(&self) -> &Extent {
394 match self {
395 ProcMeshAllocation::Allocated { alloc, .. } => alloc.extent(),
396 ProcMeshAllocation::Owned { extent, .. } => extent,
397 }
398 }
399
400 fn ranks(&self) -> Arc<Vec<ProcRef>> {
401 Arc::clone(match self {
402 ProcMeshAllocation::Allocated { ranks, .. } => ranks,
403 ProcMeshAllocation::Owned { ranks, .. } => ranks,
404 })
405 }
406}
407
408impl fmt::Debug for ProcMeshAllocation {
409 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
410 match self {
411 ProcMeshAllocation::Allocated { ranks, .. } => f
412 .debug_struct("ProcMeshAllocation::Allocated")
413 .field("alloc", &"<dyn Alloc>")
414 .field("ranks", ranks)
415 .finish(),
416 ProcMeshAllocation::Owned {
417 hosts,
418 ranks,
419 extent: _,
420 } => f
421 .debug_struct("ProcMeshAllocation::Owned")
422 .field("hosts", hosts)
423 .field("ranks", ranks)
424 .finish(),
425 }
426 }
427}
428
429#[derive(Debug, Clone, PartialEq, Eq, Hash, Named, Serialize, Deserialize)]
436pub struct ProcMeshRef {
437 name: Name,
438 region: Region,
439 ranks: Arc<Vec<ProcRef>>,
440 pub(crate) root_region: Option<Region>,
444 pub(crate) root_comm_actor: Option<ActorRef<CommActor>>,
449}
450
451impl ProcMeshRef {
452 fn new(
454 name: Name,
455 region: Region,
456 ranks: Arc<Vec<ProcRef>>,
457 root_region: Option<Region>,
458 root_comm_actor: Option<ActorRef<CommActor>>,
459 ) -> v1::Result<Self> {
460 if region.num_ranks() != ranks.len() {
461 return Err(v1::Error::InvalidRankCardinality {
462 expected: region.num_ranks(),
463 actual: ranks.len(),
464 });
465 }
466 Ok(Self {
467 name,
468 region,
469 ranks,
470 root_region,
471 root_comm_actor,
472 })
473 }
474
475 pub(crate) fn root_comm_actor(&self) -> Option<&ActorRef<CommActor>> {
476 self.root_comm_actor.as_ref()
477 }
478
479 pub async fn status(&self, cx: &impl context::Actor) -> v1::Result<ValueMesh<bool>> {
481 let vm: ValueMesh<_> = self.map_into(|proc_ref| {
482 let proc_ref = proc_ref.clone();
483 async move { proc_ref.status(cx).await }
484 });
485 vm.join().await.transpose()
486 }
487
488 fn agent_mesh(&self) -> ActorMeshRef<ProcMeshAgent> {
489 let agent_name = self.ranks.first().unwrap().agent.actor_id().name();
490 ActorMeshRef::new(Name::new_reserved(agent_name), self.clone())
492 }
493
494 pub async fn actor_states(
496 &self,
497 cx: &impl context::Actor,
498 name: Name,
499 ) -> v1::Result<ValueMesh<resource::State<ActorState>>> {
500 let agent_mesh = self.agent_mesh();
501 let (port, mut rx) = cx.mailbox().open_port::<resource::State<ActorState>>();
502 agent_mesh.cast(
505 cx,
506 resource::GetState::<ActorState> {
507 name: name.clone(),
508 reply: port.bind(),
509 },
510 )?;
511 let expected = self.ranks.len();
512 let mut states = Vec::with_capacity(expected);
513 for _ in 0..expected {
514 let state = rx.recv().await?;
515 match state.state {
516 Some(ref inner) => {
517 states.push((inner.create_rank, state));
518 }
519 None => {
520 return Err(Error::NotExist(state.name));
521 }
522 }
523 }
524 states.sort_by_key(|(rank, _)| *rank);
526 let vm = states
527 .into_iter()
528 .map(|(_, state)| state)
529 .collect_mesh::<ValueMesh<_>>(self.region.clone())?;
530 Ok(vm)
531 }
532
533 pub async fn spawn<A: Actor + Referable>(
543 &self,
544 cx: &impl context::Actor,
545 name: &str,
546 params: &A::Params,
547 ) -> v1::Result<ActorMesh<A>>
548 where
549 A::Params: RemoteMessage,
550 {
551 self.spawn_with_name(cx, Name::new(name), params).await
552 }
553
554 pub(crate) async fn spawn_with_name<A: Actor + Referable>(
568 &self,
569 cx: &impl context::Actor,
570 name: Name,
571 params: &A::Params,
572 ) -> v1::Result<ActorMesh<A>>
573 where
574 A::Params: RemoteMessage,
575 {
576 let remote = Remote::collect();
577 let actor_type = remote
580 .name_of::<A>()
581 .ok_or(Error::ActorTypeNotRegistered(type_name::<A>().to_string()))?
582 .to_string();
583
584 let serialized_params = bincode::serialize(params)?;
585
586 self.agent_mesh().cast(
587 cx,
588 resource::CreateOrUpdate::<mesh_agent::ActorSpec> {
589 name: name.clone(),
590 rank: Default::default(),
591 spec: mesh_agent::ActorSpec {
592 actor_type: actor_type.clone(),
593 params_data: serialized_params.clone(),
594 },
595 },
596 )?;
597
598 let (port, mut rx) = cx.mailbox().open_accum_port(RankedValues::default());
599
600 self.agent_mesh().cast(
601 cx,
602 resource::GetRankStatus {
603 name: name.clone(),
604 reply: port.bind(),
605 },
606 )?;
607
608 let statuses = loop {
611 let statuses = rx.recv().await?;
612 if statuses.rank(self.ranks.len()) == self.ranks.len() {
613 break statuses;
614 }
615 };
616
617 let failed: Vec<_> = statuses
618 .iter()
619 .filter_map(|(ranks, status)| {
620 if status.is_terminating() {
621 Some(ranks.clone())
622 } else {
623 None
624 }
625 })
626 .flatten()
627 .collect();
628 if !failed.is_empty() {
629 return Err(Error::GspawnError(
630 name,
631 format!("failed ranks: {:?}", failed,),
632 ));
633 }
634
635 Ok(ActorMesh::new(self.clone(), name))
636 }
637}
638
639impl view::Ranked for ProcMeshRef {
640 type Item = ProcRef;
641
642 fn region(&self) -> &Region {
643 &self.region
644 }
645
646 fn get(&self, rank: usize) -> Option<&Self::Item> {
647 self.ranks.get(rank)
648 }
649}
650
651impl view::RankedSliceable for ProcMeshRef {
652 fn sliced(&self, region: Region) -> Self {
653 debug_assert!(region.is_subset(view::Ranked::region(self)));
654 let ranks = self
655 .region()
656 .remap(®ion)
657 .unwrap()
658 .map(|index| self.get(index).unwrap().clone())
659 .collect();
660 Self::new(
661 self.name.clone(),
662 region,
663 Arc::new(ranks),
664 Some(self.root_region.as_ref().unwrap_or(&self.region).clone()),
665 self.root_comm_actor.clone(),
666 )
667 .unwrap()
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use std::assert_matches::assert_matches;
674
675 use ndslice::ViewExt;
676 use ndslice::extent;
677 use timed_test::async_timed_test;
678
679 use crate::v1;
680 use crate::v1::testactor;
681 use crate::v1::testing;
682
683 #[tokio::test]
684 async fn test_proc_mesh_allocate() {
685 let (mesh, actor, router) = testing::local_proc_mesh(extent!(replica = 4)).await;
686 assert_eq!(mesh.extent(), extent!(replica = 4));
687 assert_eq!(mesh.ranks.len(), 4);
688 assert!(!router.prefixes().is_empty());
689
690 for proc_ref in mesh.values() {
692 assert!(proc_ref.status(&actor).await.unwrap());
693 }
694
695 assert!(
697 mesh.status(&actor)
698 .await
699 .unwrap()
700 .values()
701 .all(|status| status)
702 );
703 }
704
705 #[async_timed_test(timeout_secs = 30)]
706 async fn test_spawn_actor() {
707 hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
708
709 let instance = testing::instance().await;
710
711 for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await {
712 testactor::assert_mesh_shape(proc_mesh.spawn(instance, "test", &()).await.unwrap())
713 .await;
714 }
715 }
716
717 #[tokio::test]
718 async fn test_failing_spawn_actor() {
719 hyperactor_telemetry::initialize_logging(hyperactor::clock::ClockKind::default());
720
721 let instance = testing::instance().await;
722
723 for proc_mesh in testing::proc_meshes(&instance, extent!(replicas = 4, hosts = 2)).await {
724 let err = proc_mesh
725 .spawn::<testactor::FailingCreateTestActor>(instance, "testfail", &())
726 .await
727 .unwrap_err();
728 assert_matches!(err, v1::Error::GspawnError(_, _))
729 }
730 }
731}