1#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12#![allow(unsafe_op_in_unsafe_fn)]
15
16pub mod bootstrap;
31mod borrow;
32mod comm;
33pub mod device_mesh;
34pub mod pipe;
35pub mod py_pipe;
36pub mod stream;
37pub mod test_util;
38
39use std::collections::HashMap;
40use std::collections::HashSet;
41use std::collections::hash_map::Entry;
42use std::sync::Arc;
43
44use anyhow::Context;
45use anyhow::Result;
46use anyhow::anyhow;
47use anyhow::bail;
48use anyhow::ensure;
49use async_trait::async_trait;
50use borrow::Borrow;
51use comm::CommMessageClient;
52use comm::CommParams;
53use comm::NcclCommActor;
54use derive_more::TryInto;
55use device_mesh::DeviceMesh;
56use futures::future::try_join_all;
57use hyperactor::Actor;
58use hyperactor::ActorRef;
59use hyperactor::Bind;
60use hyperactor::Handler;
61use hyperactor::Named;
62use hyperactor::Unbind;
63use hyperactor::actor::ActorHandle;
64use hyperactor::context;
65use hyperactor::reference::ActorId;
66use hyperactor_mesh::comm::multicast::CastInfo;
67use itertools::Itertools;
68use monarch_hyperactor::shape::PyPoint;
69use monarch_messages::controller::ControllerActor;
70use monarch_messages::controller::ControllerMessageClient;
71use monarch_messages::controller::Seq;
72use monarch_messages::wire_value::WireValue;
73use monarch_messages::worker::ActorCallParams;
74use monarch_messages::worker::ActorMethodParams;
75use monarch_messages::worker::CallFunctionError;
76use monarch_messages::worker::CallFunctionParams;
77use monarch_messages::worker::Factory;
78use monarch_messages::worker::Reduction;
79use monarch_messages::worker::Ref;
80use monarch_messages::worker::ResolvableFunction;
81use monarch_messages::worker::StreamCreationMode;
82use monarch_messages::worker::StreamRef;
83use monarch_messages::worker::WorkerMessage;
84use monarch_messages::worker::WorkerMessageHandler;
85use monarch_messages::worker::WorkerParams;
86use monarch_types::PyTree;
87use ndslice::Slice;
88use pipe::PipeActor;
89use pipe::PipeParams;
90use pyo3::Python;
91use pyo3::types::PyAnyMethods;
92use serde::Deserialize;
93use serde::Serialize;
94use sorted_vec::SortedVec;
95use stream::StreamActor;
96use stream::StreamMessageClient;
97use stream::StreamParams;
98use torch_sys::CudaDevice;
99use torch_sys::DeviceIndex;
100use torch_sys::Layout;
101use torch_sys::RValue;
102use torch_sys::ScalarType;
103use torch_sys::TensorCell;
104use torch_sys::factory_zeros;
105use torch_sys_cuda::nccl::NcclConfig;
106use torch_sys_cuda::nccl::ReduceOp;
107use torch_sys_cuda::nccl::UniqueId;
108
109#[derive(Debug)]
110struct RemoteProcessGroupState {
111 device_mesh_ref: Ref,
112 dims: SortedVec<String>,
113 comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
114}
115
116impl RemoteProcessGroupState {
117 fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
118 Self {
119 device_mesh_ref,
120 dims,
121 comms: HashMap::new(),
122 }
123 }
124}
125
126#[derive(Debug)]
127enum Recording {
128 PartialRecording {
131 last_index: usize,
133 commands: Vec<WorkerMessage>,
135 },
136
137 CompleteRecording {
139 streams: HashSet<StreamRef>,
141 },
142}
143
144#[derive(Debug)]
152#[hyperactor::export(
153 spawn = true,
154 handlers = [
155 WorkerMessage {cast = true},
156 AssignRankMessage {cast = true},
157 ],
158)]
159pub struct WorkerActor {
160 device: Option<CudaDevice>,
161 streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
162 device_meshes: HashMap<
165 Ref,
166 (
167 DeviceMesh,
168 HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
170 ),
171 >,
172 world_size: usize,
173 rank: usize,
174 borrows: HashMap<u64, Borrow>,
175 comm: Option<ActorHandle<NcclCommActor>>,
176 controller_actor: ActorRef<ControllerActor>,
177 pipes: HashMap<Ref, ActorHandle<PipeActor>>,
179 remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
183 send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
185 recordings: HashMap<Ref, Recording>,
186 defining_recording: Option<Ref>,
187 respond_with_python_message: bool,
188}
189
190impl WorkerActor {
191 fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
192 self.streams
193 .get(&stream)
194 .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
195 }
196
197 async fn maybe_add_stream_to_recording(
198 &mut self,
199 cx: &impl context::Actor,
200 stream: StreamRef,
201 ) -> Result<()> {
202 if let Some(defining_recording) = self.defining_recording {
205 let recording = self.recordings.get_mut(&defining_recording).unwrap();
206 let fut = match recording {
207 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
208 Recording::CompleteRecording { streams } => {
209 streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
210 Ok(self
211 .try_get_stream(stream)?
212 .define_recording(cx, defining_recording))
213 })
214 }
215 }
216 .transpose()?;
217 match fut {
218 Some(fut) => fut.await,
219 None => Ok(()),
220 }
221 } else {
222 Ok(())
223 }
224 }
225}
226
227#[async_trait]
228impl Actor for WorkerActor {
229 type Params = WorkerParams;
230
231 async fn new(
232 WorkerParams {
233 world_size,
234 rank,
235 device_index,
236 controller_actor,
237 }: Self::Params,
238 ) -> Result<Self> {
239 Ok(Self {
240 device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
241 streams: HashMap::new(),
242 device_meshes: HashMap::new(),
243 world_size,
244 rank,
245 borrows: HashMap::new(),
246 comm: None,
247 controller_actor,
248 pipes: HashMap::new(),
249 remote_process_groups: HashMap::new(),
250 send_recv_comms: HashMap::new(),
251 recordings: HashMap::new(),
252 defining_recording: None,
253 respond_with_python_message: false,
254 })
255 }
256
257 }
259
260#[async_trait]
261impl Handler<AssignRankMessage> for WorkerActor {
262 async fn handle(
263 &mut self,
264 cx: &hyperactor::Context<Self>,
265 _: AssignRankMessage,
266 ) -> anyhow::Result<()> {
267 let point = cx.cast_point();
268 self.rank = point.rank();
269 self.respond_with_python_message = true;
270 Python::with_gil(|py| {
271 let mesh_controller = py.import("monarch.mesh_controller").unwrap();
272 let p: PyPoint = point.into();
273 mesh_controller
274 .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
275 .unwrap();
276 });
277 Ok(())
278 }
279}
280
281#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
284pub enum AssignRankMessage {
285 AssignRank(),
286}
287
288#[async_trait]
289impl Handler<WorkerMessage> for WorkerActor {
290 async fn handle(
291 &mut self,
292 cx: &hyperactor::Context<Self>,
293 message: WorkerMessage,
294 ) -> anyhow::Result<()> {
295 <Self as WorkerMessageHandler>::handle(self, cx, message).await
296 }
297}
298
299#[async_trait]
300impl WorkerMessageHandler for WorkerActor {
301 async fn backend_network_init(
302 &mut self,
303 cx: &hyperactor::Context<Self>,
304 unique_id: UniqueId,
305 ) -> Result<()> {
306 let device = self
307 .device
308 .expect("tried to init backend network on a non-CUDA worker");
309 let comm = NcclCommActor::spawn(
310 cx,
311 CommParams::New {
312 device,
313 unique_id,
314 world_size: self.world_size.try_into().unwrap(),
315 rank: self.rank.try_into().unwrap(),
316 },
317 )
318 .await?;
319
320 let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
321 let cell = TensorCell::new(tensor);
322
323 comm.all_reduce(
324 cx,
325 cell,
326 ReduceOp::Sum,
327 torch_sys_cuda::cuda::Stream::get_current_stream(),
328 )
329 .await?;
330
331 let sorted_streams = self
340 .streams
341 .iter()
342 .sorted_by_key(|(k, _)| *k)
343 .map(|(_, v)| v.as_ref());
344
345 let mut splits = Vec::new();
346 for _ in 0..sorted_streams.len() {
347 splits.push(comm.split_all(cx, None).await?);
350 }
351 let _: Vec<()> = try_join_all(
352 sorted_streams
353 .into_iter()
354 .zip(splits.into_iter())
355 .map(|(stream, split)| stream.init_comm(cx, split)),
356 )
357 .await?;
358
359 self.comm = Some(comm);
360
361 Ok(())
362 }
363
364 async fn backend_network_point_to_point_init(
365 &mut self,
366 cx: &hyperactor::Context<Self>,
367 from_stream: StreamRef,
368 to_stream: StreamRef,
369 ) -> Result<()> {
370 if !self.streams.contains_key(&from_stream) {
371 bail!("invalid from_stream id: {:#?}", from_stream);
372 }
373 if !self.streams.contains_key(&to_stream) {
374 bail!("invalid to_stream id: {:#?}", to_stream);
375 }
376 let global_comm = self
377 .comm
378 .as_ref()
379 .context("tried to call Reduce before BackendNetworkInit")?;
380 let comm = global_comm.split_all(cx, None).await?;
381 self.send_recv_comms
382 .insert((from_stream, to_stream), Arc::new(comm));
383 Ok(())
384 }
385
386 async fn call_function(
387 &mut self,
388 cx: &hyperactor::Context<Self>,
389 params: CallFunctionParams,
390 ) -> Result<()> {
391 let stream = self.try_get_stream(params.stream)?.clone();
392 self.maybe_add_stream_to_recording(cx, params.stream)
393 .await?;
394
395 let device_meshes = if params.function.as_torch_op().is_some() {
396 HashMap::new()
397 } else {
398 self.device_meshes
399 .iter()
400 .map(|(k, v)| (k.clone(), v.0.clone()))
401 .collect()
402 };
403
404 let mut remote_process_groups = HashMap::new();
405 for remote_process_group_ref in ¶ms.remote_process_groups {
406 if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
407 let dims_vec = state.dims.iter().cloned().collect();
408 let (device_mesh, _) = self
409 .device_meshes
410 .get(&state.device_mesh_ref)
411 .ok_or_else(|| {
412 anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
413 })?
414 .clone();
415 let comm = state.comms
416 .get(¶ms.stream)
417 .ok_or_else(|| {
418 anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
419 })?
420 .clone();
421 remote_process_groups.insert(
422 remote_process_group_ref.clone(),
423 (device_mesh, dims_vec, comm),
424 );
425 }
426 }
427
428 stream
429 .call_function(cx, params, device_meshes, remote_process_groups)
430 .await?;
431
432 Ok(())
433 }
434
435 async fn command_group(
436 &mut self,
437 cx: &hyperactor::Context<Self>,
438 params: Vec<WorkerMessage>,
439 ) -> Result<()> {
440 for msg in params {
441 WorkerMessageHandler::handle(self, cx, msg).await?;
442 }
443 Ok(())
444 }
445
446 async fn create_stream(
447 &mut self,
448 cx: &hyperactor::Context<Self>,
449 result: StreamRef,
450 creation_mode: StreamCreationMode,
451 ) -> Result<()> {
452 let handle: ActorHandle<StreamActor> = StreamActor::spawn(
453 cx,
454 StreamParams {
455 world_size: self.world_size,
456 rank: self.rank,
457 creation_mode,
458 id: result,
459 device: self.device,
460 controller_actor: self.controller_actor.clone(),
461 respond_with_python_message: self.respond_with_python_message,
462 },
463 )
464 .await?;
465 self.streams.insert(result, Arc::new(handle));
466 Ok(())
467 }
468
469 async fn create_device_mesh(
470 &mut self,
471 _cx: &hyperactor::Context<Self>,
472 result: Ref,
473 names: Vec<String>,
474 ranks: Slice,
475 ) -> Result<()> {
476 self.device_meshes.insert(
477 result,
478 (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
479 );
480 Ok(())
481 }
482
483 async fn create_remote_process_group(
484 &mut self,
485 _cx: &hyperactor::Context<Self>,
486 result: Ref,
487 device_mesh: Ref,
488 dims: Vec<String>,
489 ) -> Result<()> {
490 self.device_meshes
491 .get(&device_mesh)
492 .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
493 match self.remote_process_groups.entry(result) {
494 Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
495 device_mesh,
496 SortedVec::from_unsorted(dims),
497 )),
498 Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
499 };
500 Ok(())
501 }
502
503 async fn borrow_create(
504 &mut self,
505 cx: &hyperactor::Context<Self>,
506 result: Ref,
507 borrow_id: u64,
508 tensor_ref: Ref,
509 from_stream: StreamRef,
510 to_stream: StreamRef,
511 ) -> Result<()> {
512 self.maybe_add_stream_to_recording(cx, from_stream).await?;
513 self.maybe_add_stream_to_recording(cx, to_stream).await?;
514 let from_stream = self.try_get_stream(from_stream)?.clone();
515 let to_stream = self.try_get_stream(to_stream)?.clone();
516
517 let borrow =
518 Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
519 self.borrows.insert(borrow_id, borrow);
520 Ok(())
521 }
522
523 async fn borrow_first_use(
524 &mut self,
525 cx: &hyperactor::Context<Self>,
526 borrow: u64,
527 ) -> Result<()> {
528 let borrow = self
529 .borrows
530 .get_mut(&borrow)
531 .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
532
533 borrow.first_use(cx).await?;
534 Ok(())
535 }
536
537 async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
538 let borrow = self
539 .borrows
540 .get_mut(&borrow)
541 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
542
543 borrow.last_use(cx).await?;
544 Ok(())
545 }
546
547 async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
548 let borrow = self
549 .borrows
550 .get_mut(&borrow_id)
551 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
552
553 borrow.drop(cx).await?;
554 self.borrows.remove(&borrow_id);
555 Ok(())
556 }
557
558 async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
559 let _: Vec<()> = try_join_all(
566 self.streams
567 .values()
568 .map(|s| s.delete_refs(cx, refs.clone())),
569 )
570 .await?;
571 Ok(())
572 }
573
574 async fn request_status(
575 &mut self,
576 cx: &hyperactor::Context<Self>,
577 seq: Seq,
578 controller: bool,
579 ) -> Result<()> {
580 let _: Vec<()> = try_join_all(
585 self.streams
586 .values()
587 .map(|stream| stream.request_status(cx)),
588 )
589 .await?;
590
591 ControllerMessageClient::status(
592 &self.controller_actor,
593 cx,
594 seq.next(),
595 cx.self_id().clone(),
596 controller,
597 )
598 .await?;
599 Ok(())
600 }
601
602 async fn reduce(
603 &mut self,
604 cx: &hyperactor::Context<Self>,
605 result: Ref,
606 local_tensor: Ref,
607 factory: Factory,
608 source_mesh: Ref,
609 stream_ref: StreamRef,
610 dims: Vec<String>,
611 reduction: Reduction,
612 scatter: bool,
613 in_place: bool,
614 out: Option<Ref>,
615 ) -> Result<()> {
616 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
617
618 let dims = SortedVec::from_unsorted(dims);
620 let stream = self.try_get_stream(stream_ref)?.clone();
621
622 let (_, comm_map) = self
623 .device_meshes
624 .get_mut(&source_mesh)
625 .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
626
627 let (size, comm) = comm_map
628 .get(&(stream_ref, dims.clone()))
629 .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
630 .clone();
631
632 stream
633 .reduce(
634 cx,
635 comm,
636 size.try_into()?,
637 result,
638 local_tensor,
639 factory,
640 reduction,
641 scatter,
642 in_place,
643 out,
644 )
645 .await?;
646
647 Ok(())
648 }
649
650 async fn create_pipe(
651 &mut self,
652 cx: &hyperactor::Context<Self>,
653 result: Ref,
654 _key: String,
657 function: ResolvableFunction,
658 max_messages: i64,
659 device_mesh: Ref,
660 args: Vec<WireValue>,
661 kwargs: HashMap<String, WireValue>,
662 ) -> Result<()> {
663 println!("CREATE PIPE1 {}", result);
664 let args: Vec<PyTree<RValue>> = args
665 .into_iter()
666 .map(|object| RValue::PyObject(object.into_py_object().unwrap()).into())
667 .collect();
668 let kwargs: HashMap<_, PyTree<RValue>> = kwargs
669 .into_iter()
670 .map(|(k, object)| (k, RValue::PyObject(object.into_py_object().unwrap()).into()))
671 .collect();
672 let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| {
673 CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh))
674 })?;
675 println!("CREATE PIPE2 {}", result);
676 let pipe = PipeActor::spawn(
678 cx,
679 PipeParams {
680 function,
681 max_messages,
682 ranks: device_mesh.0.ranks(),
683 sizes: device_mesh.0.sizes(),
684 args,
685 kwargs,
686 },
687 )
688 .await?;
689 println!("AFTER CREATE PIPE {}", result);
690
691 self.pipes.insert(result, pipe);
692 Ok(())
693 }
694
695 async fn send_tensor(
696 &mut self,
697 cx: &hyperactor::Context<Self>,
698 result: Ref,
699 from_ranks: Slice,
700 to_ranks: Slice,
701 tensor: Ref,
702 factory: Factory,
703 from_stream: StreamRef,
704 to_stream: StreamRef,
705 ) -> Result<()> {
706 let comm = self
707 .send_recv_comms
708 .get(&(from_stream, to_stream))
709 .ok_or_else(|| {
710 anyhow::anyhow!(
711 "could not find stream to stream comm for: {:#?}",
712 (from_stream, to_stream)
713 )
714 })?
715 .clone();
716
717 let to_rank = from_ranks
718 .index(self.rank)
719 .map(|index| to_ranks.get(index).ok())
720 .ok()
721 .flatten();
722 let from_rank = to_ranks
723 .index(self.rank)
724 .map(|index| from_ranks.get(index).ok())
725 .ok()
726 .flatten();
727
728 let (stream, stream_ref) = if to_rank.is_none() {
729 (self.try_get_stream(to_stream)?.clone(), to_stream)
730 } else if from_rank.is_none() || from_stream == to_stream {
731 (self.try_get_stream(from_stream)?.clone(), from_stream)
732 } else {
733 unimplemented!(
734 "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
735 It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
736 Then the send stream would do the nccl op, and then sync with sending stream again."
737 );
738 };
739
740 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
741
742 stream
743 .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
744 .await?;
745
746 Ok(())
747 }
748
749 async fn exit(
750 &mut self,
751 cx: &hyperactor::Context<Self>,
752 error: Option<(Option<ActorId>, String)>,
753 ) -> Result<()> {
754 for (_, stream) in self.streams.drain() {
755 stream.drain_and_stop()?;
756 Arc::into_inner(stream)
757 .expect("there should be no owners of this stream handle except the worker stream table")
758 .await;
759 }
760
761 let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
762 .ok()
763 .and_then(|val| val.parse::<i32>().ok())
764 .unwrap_or(1);
765 let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
766 .ok()
767 .and_then(|val| val.parse::<i32>().ok())
768 .unwrap_or(1);
769
770 let exit_code = match error {
772 Some((Some(actor_id), reason)) => {
773 tracing::error!(
774 "stopping the worker, actor {} failed with error: {}",
775 actor_id,
776 reason
777 );
778 if *cx.self_id() == actor_id {
779 self_error_exit_code
780 } else {
781 peer_error_exit_code
782 }
783 }
784 Some((None, reason)) => {
785 tracing::error!("stopping the worker, reason: {}", reason);
786 1
787 }
788 None => 0,
789 };
790
791 if exit_code != 0 {
792 tracing::info!("stopping the worker process, exit code: {}", exit_code);
793 std::process::exit(exit_code);
794 }
795 cx.stop()?;
796 Ok(())
797 }
798
799 async fn send_value(
800 &mut self,
801 cx: &hyperactor::Context<Self>,
802 seq: Seq,
803 destination: Option<Ref>,
804 mutates: Vec<Ref>,
805 function: Option<ResolvableFunction>,
806 args: Vec<WireValue>,
807 kwargs: HashMap<String, WireValue>,
808 stream: StreamRef,
809 ) -> Result<()> {
810 let stream = self.try_get_stream(stream)?;
812
813 let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) {
814 HashMap::new()
815 } else {
816 self.device_meshes
817 .iter()
818 .map(|(k, v)| (k.clone(), v.0.clone()))
819 .collect()
820 };
821
822 let pipe = if let Some(destination) = destination {
823 let pipe = self
824 .pipes
825 .get(&destination)
826 .ok_or_else(|| anyhow::anyhow!("invalid pipe id: {:#?}", destination))?
827 .port();
828 Some(pipe)
829 } else {
830 None
831 };
832 stream
835 .send_value(
836 cx,
837 seq,
838 cx.self_id().clone(),
839 mutates,
840 function,
841 args,
842 kwargs,
843 device_meshes,
844 pipe,
845 )
846 .await
847 }
848
849 async fn send_result_of_actor_call(
850 &mut self,
851 cx: &hyperactor::Context<Self>,
852 params: ActorCallParams,
853 ) -> Result<()> {
854 let stream = self.try_get_stream(params.stream)?;
855 stream
856 .send_result_of_actor_call(cx, cx.self_id().clone(), params)
857 .await?;
858 Ok(())
859 }
860 async fn call_actor_method(
861 &mut self,
862 cx: &hyperactor::Context<Self>,
863 params: ActorMethodParams,
864 ) -> Result<()> {
865 let stream = self.try_get_stream(params.call.stream)?;
866 stream.call_actor_method(cx, params).await?;
867 Ok(())
868 }
869 async fn split_comm(
870 &mut self,
871 cx: &hyperactor::Context<Self>,
872 dims: Vec<String>,
873 device_mesh: Ref,
874 stream_ref: StreamRef,
875 config: Option<NcclConfig>,
876 ) -> Result<()> {
877 let global_comm = self
878 .comm
879 .as_ref()
880 .context("tried to call SplitComm before BackendNetworkInit")?;
881 match self.device_meshes.get_mut(&device_mesh) {
882 Some((device_mesh, comm_map)) => {
883 let stream = self
886 .streams
887 .get(&stream_ref)
888 .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
889
890 let dims = SortedVec::from_unsorted(dims);
891
892 anyhow::ensure!(
893 !comm_map.contains_key(&(stream_ref, dims.clone())),
894 "comm already exists for stream {stream:#?}, dims {dims:#?}"
895 );
896 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
897 let size = ranks_for_group.len();
898 let split_comm = global_comm
899 .split_from(
900 cx,
901 ranks_for_group
902 .into_iter()
903 .map(|v| v.clone().try_into())
904 .collect::<Result<Vec<_>, _>>()?,
905 config,
906 )
907 .await?
908 .context("split comm should include self rank")?;
909 comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
910 }
911 None => {
912 global_comm.split_from(cx, vec![], config).await?;
915 }
916 }
917 Ok(())
918 }
919
920 async fn split_comm_for_process_group(
921 &mut self,
922 cx: &hyperactor::Context<Self>,
923 remote_process_group_ref: Ref,
924 stream_ref: StreamRef,
925 config: Option<NcclConfig>,
926 ) -> Result<()> {
927 ensure!(
928 self.streams.contains_key(&stream_ref),
929 "invalid stream id: {:#?}",
930 stream_ref
931 );
932 let global_comm = self
933 .comm
934 .as_ref()
935 .context("tried to call SplitComm before BackendNetworkInit")?;
936 let state = self
937 .remote_process_groups
938 .get_mut(&remote_process_group_ref)
939 .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
940 match self.device_meshes.get_mut(&state.device_mesh_ref) {
941 Some((device_mesh, _)) => {
942 let entry = match state.comms.entry(stream_ref) {
945 Entry::Vacant(entry) => entry,
946 Entry::Occupied(_) => bail!(
947 "comm already exists for remote process group {:#?} on stream {:#?}",
948 remote_process_group_ref,
949 stream_ref,
950 ),
951 };
952 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
953 let split_comm = global_comm
954 .split_from(
955 cx,
956 ranks_for_group
957 .into_iter()
958 .map(|v| v.clone().try_into())
959 .collect::<Result<Vec<_>, _>>()?,
960 config,
961 )
962 .await?
963 .context("split comm should include self rank")?;
964 entry.insert(Arc::new(split_comm));
965 }
966 None => {
967 global_comm.split_from(cx, vec![], config).await?;
970 }
971 }
972 Ok(())
973 }
974
975 async fn pipe_recv(
976 &mut self,
977 cx: &hyperactor::Context<Self>,
978 seq: Seq,
979 results: Vec<Option<Ref>>,
980 pipe: Ref,
981 stream: StreamRef,
982 ) -> Result<()> {
983 self.maybe_add_stream_to_recording(cx, stream).await?;
984
985 let pipe = self
987 .pipes
988 .get(&pipe)
989 .ok_or_else(|| anyhow::anyhow!("ref not found: {}", pipe))?;
990 let pipe = pipe.port();
991 let stream = self.try_get_stream(stream)?;
993 stream.set_value(cx, seq, results, pipe).await
995 }
996
997 async fn set_ref_unit_tests_only(
998 &mut self,
999 cx: &hyperactor::Context<Self>,
1000 reference: Ref,
1001 value: WireValue,
1002 stream: StreamRef,
1003 ) -> Result<()> {
1004 let stream = self.try_get_stream(stream)?;
1005
1006 stream.set_ref_unit_tests_only(cx, reference, value).await
1007 }
1008
1009 async fn get_ref_unit_tests_only(
1010 &mut self,
1011 cx: &hyperactor::Context<Self>,
1012 ref_id: Ref,
1013 stream: StreamRef,
1014 ) -> Result<Option<Result<WireValue, String>>> {
1015 let stream = self.try_get_stream(stream)?;
1016 Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
1017 }
1018
1019 async fn define_recording(
1020 &mut self,
1021 cx: &hyperactor::Context<Self>,
1022 result: Ref,
1023 _nresults: usize,
1024 _nformals: usize,
1025 commands: Vec<WorkerMessage>,
1026 ntotal_messages: usize,
1027 index: usize,
1028 ) -> Result<()> {
1029 if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
1030 bail!("already defining a different recording");
1031 }
1032 self.defining_recording = Some(result);
1033
1034 match self.recordings.entry(result) {
1035 Entry::Vacant(entry) => {
1036 ensure!(
1037 index == 0,
1038 "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
1039 index
1040 );
1041 entry.insert(Recording::PartialRecording {
1042 last_index: 0,
1043 commands,
1044 });
1045 }
1046 Entry::Occupied(mut entry) => match entry.get_mut() {
1047 Recording::CompleteRecording { .. } => {
1048 bail!("got DefineRecording message for already complete recording")
1049 }
1050 Recording::PartialRecording {
1051 last_index,
1052 commands: existing_commands,
1053 } => {
1054 ensure!(
1055 index == *last_index + 1,
1056 "Got DefineRecording message with index = {:?}, but \
1057 last seen index for recording is {:?}",
1058 index,
1059 last_index
1060 );
1061 *last_index = index;
1062 existing_commands.extend(commands.into_iter());
1063 }
1064 },
1065 };
1066
1067 if index < ntotal_messages - 1 {
1068 return Ok(());
1069 }
1070 let commands = match self.recordings.remove(&result).unwrap() {
1071 Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
1072 Recording::PartialRecording { commands, .. } => {
1073 self.recordings.insert(
1074 result,
1075 Recording::CompleteRecording {
1076 streams: HashSet::new(),
1077 },
1078 );
1079 commands
1080 }
1081 };
1082
1083 for command in commands {
1084 WorkerMessageHandler::handle(self, cx, command).await?;
1085 }
1086
1087 match self.recordings.get(&result).unwrap() {
1088 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1089 Recording::CompleteRecording { streams, .. } => {
1090 for stream in streams {
1091 self.try_get_stream(*stream)?
1092 .finalize_recording(cx, result)
1093 .await?;
1094 }
1095 }
1096 }
1097
1098 self.defining_recording = None;
1099 Ok(())
1100 }
1101
1102 async fn recording_formal(
1103 &mut self,
1104 cx: &hyperactor::Context<Self>,
1105 result: Ref,
1106 argument_index: usize,
1107 stream: StreamRef,
1108 ) -> Result<()> {
1109 ensure!(self.defining_recording.is_some());
1110 self.maybe_add_stream_to_recording(cx, stream).await?;
1111 self.try_get_stream(stream)?
1112 .recording_formal(cx, result, argument_index)
1113 .await
1114 }
1115
1116 async fn recording_result(
1117 &mut self,
1118 cx: &hyperactor::Context<Self>,
1119 result: Ref,
1120 output_index: usize,
1121 stream: StreamRef,
1122 ) -> Result<()> {
1123 ensure!(self.defining_recording.is_some());
1124 self.maybe_add_stream_to_recording(cx, stream).await?;
1125 self.try_get_stream(stream)?
1126 .recording_result(cx, result, output_index)
1127 .await
1128 }
1129
1130 async fn call_recording(
1131 &mut self,
1132 cx: &hyperactor::Context<Self>,
1133 seq: Seq,
1134 recording: Ref,
1135 results: Vec<Ref>,
1136 actuals: Vec<Ref>,
1137 ) -> Result<()> {
1138 ensure!(self.defining_recording.is_none());
1139 let recording_ref = recording;
1140 let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1141 "could not find recording: {:#?}",
1142 recording
1143 ))?;
1144 match recording {
1145 Recording::PartialRecording { .. } => {
1146 bail!("cannot call recording because it is incomplete")
1147 }
1148 Recording::CompleteRecording { streams } => try_join_all(
1149 streams
1150 .iter()
1151 .map(|stream| self.try_get_stream(*stream))
1152 .collect::<Result<Vec<_>>>()?
1153 .into_iter()
1154 .map(|stream| {
1155 stream.call_recording(
1156 cx,
1157 seq,
1158 recording_ref,
1159 results.clone(),
1160 actuals.clone(),
1161 )
1162 }),
1163 )
1164 .await
1165 .map(|_| ()),
1166 }
1167 }
1168}
1169
1170#[cfg(test)]
1171mod tests {
1172 use std::assert_matches::assert_matches;
1173 use std::process::Stdio;
1174
1175 use anyhow::Result;
1176 use hyperactor::Instance;
1177 use hyperactor::Named;
1178 use hyperactor::WorldId;
1179 use hyperactor::actor::ActorStatus;
1180 use hyperactor::channel::ChannelAddr;
1181 use hyperactor::id;
1182 use hyperactor::mailbox::open_port;
1183 use hyperactor::proc::Proc;
1184 use hyperactor_multiprocess::System;
1185 use hyperactor_multiprocess::proc_actor::Environment;
1186 use hyperactor_multiprocess::proc_actor::ProcActor;
1187 use hyperactor_multiprocess::proc_actor::ProcMessageClient;
1188 use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
1189 use hyperactor_multiprocess::system_actor::Shape;
1190 use hyperactor_multiprocess::system_actor::SystemMessageClient;
1191 use hyperactor_multiprocess::system_actor::SystemSnapshotFilter;
1192 use hyperactor_multiprocess::system_actor::WorldStatus;
1193 use monarch_messages::controller::ControllerMessage;
1194 use monarch_messages::controller::WorkerError;
1195 use monarch_messages::worker::WorkerMessageClient;
1196 use monarch_types::PickledPyObject;
1197 use monarch_types::PyTree;
1198 use pyo3::IntoPyObjectExt;
1199 use pyo3::Python;
1200 use pyo3::prelude::*;
1201 use pyo3::types::PyList;
1202 use pyo3::types::PyString;
1203 use rand::Rng;
1204 use rand::distributions::Alphanumeric;
1205 use timed_test::async_timed_test;
1206 use tokio::io::BufReader;
1207 use tokio::process::Command;
1208 use tokio_retry::Retry;
1209 use tokio_retry::strategy::FixedInterval;
1210 use torch_sys::Device;
1211 use torch_sys::DeviceIndex;
1212 use torch_sys::MemoryFormat;
1213
1214 use super::*;
1215 use crate::test_util::test_setup;
1216
1217 #[async_timed_test(timeout_secs = 60)]
1218 async fn basic_worker() -> Result<()> {
1219 test_setup()?;
1220
1221 let proc = Proc::local();
1222 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1223
1224 let worker_handle = proc
1225 .spawn::<WorkerActor>(
1226 "worker",
1227 WorkerParams {
1228 world_size: 1,
1229 rank: 0,
1230 device_index: None,
1231 controller_actor: controller_ref,
1232 },
1233 )
1234 .await
1235 .unwrap();
1236 worker_handle
1237 .command_group(
1238 &client,
1239 vec![
1240 WorkerMessage::CreateStream {
1241 id: 1.into(),
1242 stream_creation: StreamCreationMode::UseDefaultStream,
1243 },
1244 WorkerMessage::CallFunction(CallFunctionParams {
1245 seq: 0.into(),
1246 results: vec![Some(0.into())],
1247 mutates: vec![],
1248 function: "torch.ops.aten.ones.default".into(),
1249 args: vec![WireValue::IntList(vec![2, 3])],
1250 kwargs: HashMap::new(),
1251 stream: 1.into(),
1252 remote_process_groups: vec![],
1253 }),
1254 WorkerMessage::CallFunction(CallFunctionParams {
1255 seq: 2.into(),
1256 results: vec![Some(Ref { id: 2 })],
1257 mutates: vec![0.into()],
1258 function: "torch.ops.aten.sub_.Scalar".into(),
1259 args: vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1260 kwargs: HashMap::new(),
1261 stream: 1.into(),
1262 remote_process_groups: vec![],
1263 }),
1264 WorkerMessage::CallFunction(CallFunctionParams {
1265 seq: 3.into(),
1266 results: vec![Some(Ref { id: 3 })],
1267 mutates: vec![],
1268 function: "torch.ops.aten.zeros.default".into(),
1269 args: vec![WireValue::IntList(vec![2, 3])],
1270 kwargs: HashMap::new(),
1271 stream: 1.into(),
1272 remote_process_groups: vec![],
1273 }),
1274 WorkerMessage::CallFunction(CallFunctionParams {
1275 seq: 4.into(),
1276 results: vec![Some(Ref { id: 4 })],
1277 mutates: vec![],
1278 function: "torch.ops.aten.allclose.default".into(),
1279 args: vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1280 kwargs: HashMap::new(),
1281 stream: 1.into(),
1282 remote_process_groups: vec![],
1283 }),
1284 ],
1285 )
1286 .await
1287 .unwrap();
1288
1289 let result: bool = worker_handle
1290 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1291 .await
1292 .unwrap()
1293 .unwrap()
1294 .unwrap()
1295 .try_into()
1296 .unwrap();
1297 worker_handle.drain_and_stop().unwrap();
1298 worker_handle.await;
1299 let error_responses = controller_rx.drain();
1300 assert!(
1301 error_responses.is_empty(),
1302 "Expected no error responses, got: {:#?}",
1303 error_responses
1304 );
1305 assert!(result);
1306
1307 Ok(())
1308 }
1309
1310 #[async_timed_test(timeout_secs = 60)]
1311 async fn error_sends_response() -> Result<()> {
1312 test_setup()?;
1313
1314 let proc = Proc::local();
1315 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1316
1317 let worker_handle = proc
1318 .spawn::<WorkerActor>(
1319 "worker",
1320 WorkerParams {
1321 world_size: 1,
1322 rank: 0,
1323 device_index: None,
1324 controller_actor: controller_ref,
1325 },
1326 )
1327 .await
1328 .unwrap();
1329 worker_handle
1330 .command_group(
1331 &client,
1332 vec![
1333 WorkerMessage::CreateStream {
1334 id: 1.into(),
1335 stream_creation: StreamCreationMode::UseDefaultStream,
1336 },
1337 WorkerMessage::CallFunction(CallFunctionParams {
1338 seq: 0.into(),
1339 results: vec![Some(0.into())],
1340 mutates: vec![],
1341 function: "torch.ops.aten.rand.default".into(),
1342 args: vec![],
1343 kwargs: HashMap::new(),
1344 stream: 1.into(),
1345 remote_process_groups: vec![],
1346 }),
1347 WorkerMessage::Exit { error: None },
1348 ],
1349 )
1350 .await
1351 .unwrap();
1352
1353 worker_handle.drain_and_stop().unwrap();
1354 worker_handle.await;
1355 let response_message = controller_rx.recv().await.unwrap();
1356 match response_message {
1357 ControllerMessage::RemoteFunctionFailed {
1358 seq,
1359 error: WorkerError { backtrace: msg, .. },
1360 } => {
1361 assert_eq!(seq, 0.into());
1362 assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1363 }
1364 _ => panic!("unexpected response {:#?}", response_message),
1365 }
1366
1367 Ok(())
1368 }
1369
1370 #[async_timed_test(timeout_secs = 60)]
1371 async fn mutated_refs_are_updated_with_error() -> Result<()> {
1372 test_setup()?;
1373
1374 let proc = Proc::local();
1375 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1376
1377 let worker_handle = proc
1378 .spawn::<WorkerActor>(
1379 "worker",
1380 WorkerParams {
1381 world_size: 1,
1382 rank: 0,
1383 device_index: None,
1384 controller_actor: controller_ref,
1385 },
1386 )
1387 .await
1388 .unwrap();
1389 worker_handle
1390 .command_group(
1391 &client,
1392 vec![
1393 WorkerMessage::CreateStream {
1394 id: 1.into(),
1395 stream_creation: StreamCreationMode::UseDefaultStream,
1396 },
1397 WorkerMessage::SetRefUnitTestsOnly {
1398 reference: 0.into(),
1399 value: WireValue::Int(1),
1400 stream: 1.into(),
1401 },
1402 WorkerMessage::CallFunction(CallFunctionParams {
1403 seq: 0.into(),
1404 results: vec![Some(Ref { id: 2 })],
1405 mutates: vec![0.into()],
1406 function: "i.dont.exist".into(),
1407 args: vec![],
1408 kwargs: HashMap::new(),
1409 stream: 1.into(),
1410 remote_process_groups: vec![],
1411 }),
1412 ],
1413 )
1414 .await
1415 .unwrap();
1416
1417 let result = worker_handle
1418 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1419 .await?;
1420
1421 worker_handle.drain_and_stop().unwrap();
1423 worker_handle.await;
1424
1425 let mutated_ref = result
1426 .context("no such ref")?
1427 .err()
1428 .context("expected error")?;
1429 assert!(mutated_ref.contains("failed to resolve function"));
1430
1431 let responses = controller_rx.drain();
1432 assert_eq!(
1433 responses.len(),
1434 1,
1435 "Expected one response, got: {:#?}",
1436 responses
1437 );
1438 Ok(())
1439 }
1440
1441 #[async_timed_test(timeout_secs = 60)]
1442 async fn accessing_errored_dependency() -> Result<()> {
1443 test_setup()?;
1444
1445 let proc = Proc::local();
1446 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1447
1448 let worker_handle = proc
1449 .spawn::<WorkerActor>(
1450 "worker",
1451 WorkerParams {
1452 world_size: 1,
1453 rank: 0,
1454 device_index: None,
1455 controller_actor: controller_ref,
1456 },
1457 )
1458 .await
1459 .unwrap();
1460 worker_handle
1461 .command_group(
1462 &client,
1463 vec![
1464 WorkerMessage::CreateStream {
1465 id: 1.into(),
1466 stream_creation: StreamCreationMode::UseDefaultStream,
1467 },
1468 WorkerMessage::CallFunction(CallFunctionParams {
1469 seq: 0.into(),
1470 results: vec![Some(0.into())],
1471 mutates: vec![],
1472 function: "i.dont.exist".into(),
1473 args: vec![],
1474 kwargs: HashMap::new(),
1475 stream: 1.into(),
1476 remote_process_groups: vec![],
1477 }),
1478 WorkerMessage::CallFunction(CallFunctionParams {
1479 seq: 1.into(),
1480 results: vec![Some(1.into())],
1481 mutates: vec![],
1482 function: "torch.ops.aten.sub_.Scalar".into(),
1483 args: vec![WireValue::Ref(0.into())],
1484 kwargs: HashMap::new(),
1485 stream: 1.into(),
1486 remote_process_groups: vec![],
1487 }),
1488 WorkerMessage::Exit { error: None },
1489 ],
1490 )
1491 .await
1492 .unwrap();
1493
1494 worker_handle.drain_and_stop().unwrap();
1495 worker_handle.await;
1496
1497 let responses = controller_rx.drain();
1498 assert_eq!(
1499 responses.len(),
1500 1,
1501 "Expected one response, got: {:#?}",
1502 responses
1503 );
1504
1505 match &responses[0] {
1506 ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1507 assert_eq!(seq, &0.into())
1508 }
1509 _ => panic!("unexpected response {:#?}", responses[0]),
1510 };
1511 Ok(())
1512 }
1513
1514 #[async_timed_test(timeout_secs = 60)]
1515 async fn py_remote_function_calls() -> Result<()> {
1516 test_setup()?;
1517
1518 let proc = Proc::local();
1519 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1520
1521 let worker_handle = proc
1522 .spawn::<WorkerActor>(
1523 "worker",
1524 WorkerParams {
1525 world_size: 1,
1526 rank: 0,
1527 device_index: None,
1528 controller_actor: controller_ref,
1529 },
1530 )
1531 .await
1532 .unwrap();
1533 let (split_arg, sort_list, mesh_ref, dim, layout, none, scalar, device, memory_format) =
1534 Python::with_gil(|py| {
1535 let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1536 .into_any()
1537 .try_into()?;
1538 let sort_list: PickledPyObject =
1539 PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1540 let mesh_ref: PickledPyObject = Ref { id: 5 }.into_bound_py_any(py)?.try_into()?;
1541 let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1542 let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1543 let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1544 let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1545 let device: PickledPyObject = py
1546 .import("torch")?
1547 .getattr("device")?
1548 .call1(("cuda:1",))?
1549 .try_into()?;
1550 let memory_format: PickledPyObject = py
1551 .import("torch")?
1552 .getattr("contiguous_format")?
1553 .try_into()?;
1554 PyResult::Ok((
1555 split_arg,
1556 sort_list,
1557 mesh_ref,
1558 dim,
1559 layout,
1560 none,
1561 scalar,
1562 device,
1563 memory_format,
1564 ))
1565 })?;
1566
1567 worker_handle
1568 .command_group(
1569 &client,
1570 vec![
1571 WorkerMessage::CreateStream {
1572 id: 1.into(),
1573 stream_creation: StreamCreationMode::UseDefaultStream,
1574 },
1575 WorkerMessage::CallFunction(CallFunctionParams {
1576 seq: 0.into(),
1577 results: vec![Some(0.into()), Some(Ref { id: 2 })],
1578 mutates: vec![],
1579 function: "os.path.split".into(),
1580 args: vec![split_arg.into()],
1581 kwargs: HashMap::new(),
1582 stream: 1.into(),
1583 remote_process_groups: vec![],
1584 }),
1585 WorkerMessage::CallFunction(CallFunctionParams {
1586 seq: 2.into(),
1587 results: vec![Some(4.into()), None, None, None, None],
1588 mutates: vec![],
1589 function: "builtins.sorted".into(),
1590 args: vec![sort_list.into()],
1591 kwargs: HashMap::new(),
1592 stream: 1.into(),
1593 remote_process_groups: vec![],
1594 }),
1595 WorkerMessage::CreateDeviceMesh {
1596 result: 5.into(),
1597 names: vec!["x".into()],
1598 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1599 },
1600 WorkerMessage::CallFunction(CallFunctionParams {
1601 seq: 2.into(),
1602 results: vec![Some(6.into())],
1603 mutates: vec![],
1604 function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1605 args: vec![mesh_ref.into(), dim.into()],
1606 kwargs: HashMap::new(),
1607 stream: 1.into(),
1608 remote_process_groups: vec![],
1609 }),
1610 WorkerMessage::CallFunction(CallFunctionParams {
1611 seq: 4.into(),
1612 results: vec![Some(7.into())],
1613 mutates: vec![],
1614 function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1615 .into(),
1616 args: vec![scalar.into()],
1617 kwargs: HashMap::new(),
1618 stream: 1.into(),
1619 remote_process_groups: vec![],
1620 }),
1621 WorkerMessage::CallFunction(CallFunctionParams {
1622 seq: 5.into(),
1623 results: vec![Some(8.into())],
1624 mutates: vec![],
1625 function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1626 args: vec![layout.into()],
1627 kwargs: HashMap::new(),
1628 stream: 1.into(),
1629 remote_process_groups: vec![],
1630 }),
1631 WorkerMessage::CallFunction(CallFunctionParams {
1632 seq: 6.into(),
1633 results: vec![Some(9.into())],
1634 mutates: vec![],
1635 function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1636 args: vec![none.into()],
1637 kwargs: HashMap::new(),
1638 stream: 1.into(),
1639 remote_process_groups: vec![],
1640 }),
1641 WorkerMessage::CallFunction(CallFunctionParams {
1644 seq: 7.into(),
1645 results: vec![None],
1646 mutates: vec![],
1647 function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1648 args: vec![],
1649 kwargs: HashMap::new(),
1650 stream: 1.into(),
1651 remote_process_groups: vec![],
1652 }),
1653 WorkerMessage::CallFunction(CallFunctionParams {
1654 seq: 8.into(),
1655 results: vec![Some(10.into())],
1656 mutates: vec![],
1657 function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1658 args: vec![device.into()],
1659 kwargs: HashMap::new(),
1660 stream: 1.into(),
1661 remote_process_groups: vec![],
1662 }),
1663 WorkerMessage::CallFunction(CallFunctionParams {
1664 seq: 9.into(),
1665 results: vec![Some(11.into())],
1666 mutates: vec![],
1667 function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1668 .into(),
1669 args: vec![memory_format.into()],
1670 kwargs: HashMap::new(),
1671 stream: 1.into(),
1672 remote_process_groups: vec![],
1673 }),
1674 WorkerMessage::CallFunction(CallFunctionParams {
1676 seq: 10.into(),
1677 results: vec![Some(12.into())],
1678 mutates: vec![],
1679 function: "torch.ops.aten.ones.default".into(),
1680 args: vec![WireValue::IntList(vec![2, 3])],
1681 kwargs: HashMap::new(),
1682 stream: 1.into(),
1683 remote_process_groups: vec![],
1684 }),
1685 WorkerMessage::CallFunction(CallFunctionParams {
1686 seq: 11.into(),
1687 results: vec![Some(13.into())],
1688 mutates: vec![],
1689 function: "torch.ops.aten.stack.default".into(),
1690 args: vec![WireValue::RefList(vec![12.into(), 12.into()])],
1691 kwargs: HashMap::new(),
1692 stream: 1.into(),
1693 remote_process_groups: vec![],
1694 }),
1695 ],
1696 )
1697 .await
1698 .unwrap();
1699
1700 let result1: String = worker_handle
1701 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1702 .await
1703 .unwrap()
1704 .unwrap()
1705 .unwrap()
1706 .try_into()
1707 .unwrap();
1708 let result2: String = worker_handle
1709 .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1710 .await
1711 .unwrap()
1712 .unwrap()
1713 .unwrap()
1714 .try_into()
1715 .unwrap();
1716 let result3: i64 = worker_handle
1717 .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1718 .await
1719 .unwrap()
1720 .unwrap()
1721 .unwrap()
1722 .try_into()
1723 .unwrap();
1724 let result4: i64 = worker_handle
1725 .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1726 .await
1727 .unwrap()
1728 .unwrap()
1729 .unwrap()
1730 .try_into()
1731 .unwrap();
1732 assert_eq!(
1733 ScalarType::Float,
1734 worker_handle
1735 .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1736 .await
1737 .unwrap()
1738 .unwrap()
1739 .unwrap()
1740 .try_into()
1741 .unwrap()
1742 );
1743 assert_eq!(
1744 Layout::Strided,
1745 worker_handle
1746 .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1747 .await
1748 .unwrap()
1749 .unwrap()
1750 .unwrap()
1751 .try_into()
1752 .unwrap()
1753 );
1754 assert_matches!(
1755 worker_handle
1756 .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1757 .await
1758 .unwrap()
1759 .unwrap()
1760 .unwrap(),
1761 WireValue::None(()),
1762 );
1763 let device: Device = CudaDevice::new(DeviceIndex(1)).into();
1764 assert_eq!(
1765 device,
1766 worker_handle
1767 .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1768 .await
1769 .unwrap()
1770 .unwrap()
1771 .unwrap()
1772 .try_into()
1773 .unwrap()
1774 );
1775 assert_matches!(
1776 worker_handle
1777 .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1778 .await
1779 .unwrap()
1780 .unwrap()
1781 .unwrap(),
1782 WireValue::MemoryFormat(MemoryFormat::Contiguous),
1783 );
1784
1785 worker_handle.drain_and_stop().unwrap();
1786 worker_handle.await;
1787 let error_responses = controller_rx.drain();
1788 assert!(
1789 error_responses.is_empty(),
1790 "Expected no error responses, got: {:#?}",
1791 error_responses
1792 );
1793
1794 assert_eq!(result1, "/fbs/fbc/foo");
1795 assert_eq!(result2, "bar");
1796 assert_eq!(result3, 1);
1797 assert_eq!(result4, 0);
1798
1799 Ok(())
1800 }
1801
1802 #[async_timed_test(timeout_secs = 60)]
1803 async fn delete_refs() -> Result<()> {
1804 test_setup()?;
1805
1806 let proc = Proc::local();
1807 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1808
1809 let worker_handle = proc
1810 .spawn::<WorkerActor>(
1811 "worker",
1812 WorkerParams {
1813 world_size: 1,
1814 rank: 0,
1815 device_index: None,
1816 controller_actor: controller_ref,
1817 },
1818 )
1819 .await
1820 .unwrap();
1821 worker_handle
1822 .command_group(
1823 &client,
1824 vec![
1825 WorkerMessage::CreateStream {
1826 id: 0.into(),
1827 stream_creation: StreamCreationMode::CreateNewStream,
1828 },
1829 WorkerMessage::CreateStream {
1830 id: 1.into(),
1831 stream_creation: StreamCreationMode::CreateNewStream,
1832 },
1833 WorkerMessage::SetRefUnitTestsOnly {
1834 reference: Ref { id: 2 },
1835 value: WireValue::Bool(false),
1836 stream: 0.into(),
1837 },
1838 WorkerMessage::SetRefUnitTestsOnly {
1839 reference: Ref { id: 3 },
1840 value: WireValue::Bool(true),
1841 stream: 0.into(),
1842 },
1843 WorkerMessage::SetRefUnitTestsOnly {
1844 reference: Ref { id: 4 },
1845 value: WireValue::Int(0),
1846 stream: 1.into(),
1847 },
1848 WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1849 ],
1850 )
1851 .await
1852 .unwrap();
1853
1854 let result: bool = worker_handle
1855 .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1856 .await
1857 .unwrap()
1858 .unwrap()
1859 .unwrap()
1860 .try_into()
1861 .unwrap();
1862 let fail_result = worker_handle
1863 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1864 .await
1865 .unwrap();
1866
1867 worker_handle.drain_and_stop().unwrap();
1868 worker_handle.await;
1869
1870 assert!(result, "should be able to get a non-deleted ref");
1871 assert!(fail_result.is_none(), "should fail to get a deleted ref");
1872
1873 Ok(())
1874 }
1875
1876 #[async_timed_test(timeout_secs = 60)]
1877 async fn request_status() -> Result<()> {
1878 test_setup()?;
1879
1880 let proc = Proc::local();
1881 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1882
1883 let worker_handle = proc
1884 .spawn::<WorkerActor>(
1885 "worker",
1886 WorkerParams {
1887 world_size: 1,
1888 rank: 0,
1889 device_index: None,
1890 controller_actor: controller_ref,
1891 },
1892 )
1893 .await
1894 .unwrap();
1895 worker_handle
1896 .command_group(
1897 &client,
1898 vec![
1899 WorkerMessage::CreateStream {
1900 id: 0.into(),
1901 stream_creation: StreamCreationMode::CreateNewStream,
1902 },
1903 WorkerMessage::CreateStream {
1904 id: 1.into(),
1905 stream_creation: StreamCreationMode::CreateNewStream,
1906 },
1907 ],
1908 )
1909 .await
1910 .unwrap();
1911
1912 for i in 0..100 {
1913 worker_handle
1915 .call_function(
1916 &client,
1917 CallFunctionParams {
1918 seq: i.into(),
1919 results: vec![Some(Ref { id: i + 2 })],
1920 mutates: vec![],
1921 function: "torch.ops.aten.ones.default".into(),
1922 args: vec![WireValue::IntList(vec![2, 3])],
1923 kwargs: HashMap::new(),
1924 stream: (i % 2).into(),
1925 remote_process_groups: vec![],
1926 },
1927 )
1928 .await
1929 .unwrap();
1930 }
1931
1932 worker_handle
1933 .request_status(&client, 100.into(), false)
1934 .await
1935 .unwrap();
1936
1937 worker_handle.drain_and_stop().unwrap();
1938 worker_handle.await;
1939
1940 let mut responses = controller_rx.drain();
1941 assert_eq!(
1942 responses.len(),
1943 1,
1944 "Expected one response, got: {:#?}",
1945 responses
1946 );
1947
1948 let response = responses.pop().unwrap();
1949 match response {
1950 ControllerMessage::Status { seq, .. } => {
1951 assert_eq!(seq, 101.into())
1952 }
1953 _ => panic!("unexpected response {:#?}", response),
1954 };
1955
1956 Ok(())
1957 }
1958
1959 #[async_timed_test(timeout_secs = 60)]
1960 async fn backend_network_init() {
1961 let proc = Proc::local();
1962 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1963
1964 let worker_handle1 = proc
1965 .spawn::<WorkerActor>(
1966 "worker0",
1967 WorkerParams {
1968 world_size: 2,
1969 rank: 0,
1970 device_index: Some(0),
1971 controller_actor: controller_ref.clone(),
1972 },
1973 )
1974 .await
1975 .unwrap();
1976 let worker_handle2 = proc
1977 .spawn::<WorkerActor>(
1978 "worker1",
1979 WorkerParams {
1980 world_size: 2,
1981 rank: 1,
1982 device_index: Some(1),
1983 controller_actor: controller_ref,
1984 },
1985 )
1986 .await
1987 .unwrap();
1988
1989 let unique_id = UniqueId::new().unwrap();
1990 worker_handle1
1991 .backend_network_init(&client, unique_id.clone())
1992 .await
1993 .unwrap();
1994 worker_handle2
1995 .backend_network_init(&client, unique_id)
1996 .await
1997 .unwrap();
1998
1999 worker_handle1.drain_and_stop().unwrap();
2000 worker_handle1.await;
2001 worker_handle2.drain_and_stop().unwrap();
2002 worker_handle2.await;
2003 }
2004
2005 #[async_timed_test(timeout_secs = 60)]
2006 async fn send_value() -> Result<()> {
2007 test_setup()?;
2008
2009 let proc = Proc::local();
2010 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2011
2012 let worker_handle = proc
2013 .spawn::<WorkerActor>(
2014 "worker",
2015 WorkerParams {
2016 world_size: 1,
2017 rank: 0,
2018 device_index: None,
2019 controller_actor: controller_ref,
2020 },
2021 )
2022 .await
2023 .unwrap();
2024 worker_handle
2025 .command_group(
2026 &client,
2027 vec![
2028 WorkerMessage::CreateStream {
2029 id: 1.into(),
2030 stream_creation: StreamCreationMode::UseDefaultStream,
2031 },
2032 WorkerMessage::CallFunction(CallFunctionParams {
2033 seq: 0.into(),
2034 results: vec![Some(0.into())],
2035 mutates: vec![],
2036 function: "torch.ops.aten.ones.default".into(),
2037 args: vec![WireValue::IntList(vec![2, 3])],
2038 kwargs: HashMap::new(),
2039 stream: 1.into(),
2040 remote_process_groups: vec![],
2041 }),
2042 WorkerMessage::SendValue {
2043 seq: 1.into(),
2044 destination: None,
2045 mutates: vec![],
2046 function: None,
2047 args: vec![WireValue::Ref(0.into())],
2048 kwargs: HashMap::new(),
2049 stream: 1.into(),
2050 },
2051 WorkerMessage::SendValue {
2052 seq: 2.into(),
2053 destination: None,
2054 mutates: vec![],
2055 function: Some("torch.ops.aten.var_mean.default".into()),
2056 args: vec![WireValue::Ref(0.into())],
2057 kwargs: HashMap::new(),
2058 stream: 1.into(),
2059 },
2060 WorkerMessage::Exit { error: None },
2061 ],
2062 )
2063 .await
2064 .unwrap();
2065
2066 worker_handle.drain_and_stop()?;
2067 assert_matches!(worker_handle.await, ActorStatus::Stopped);
2068
2069 let mut responses = controller_rx.drain();
2070 assert_eq!(
2071 responses.len(),
2072 3,
2073 "Expected one response, got: {:#?}",
2074 responses
2075 );
2076
2077 match responses.pop().unwrap() {
2078 ControllerMessage::FetchResult { seq, value } => {
2079 assert_eq!(seq, 2.into());
2080 let value = value.unwrap().deserialized::<PyTree<RValue>>().unwrap();
2081 assert_eq!(value.leaves().len(), 2);
2082 }
2083 resp => panic!("unexpected response {:#?}", resp),
2084 };
2085 match responses.pop().unwrap() {
2086 ControllerMessage::FetchResult { seq, .. } => {
2087 assert_eq!(seq, 1.into())
2088 }
2089 resp => panic!("unexpected response {:#?}", resp),
2090 };
2091 Ok(())
2092 }
2093
2094 #[async_timed_test(timeout_secs = 60)]
2095 async fn send_value_err_result() -> Result<()> {
2096 test_setup()?;
2097
2098 let proc = Proc::local();
2099 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2100
2101 let worker_handle = proc
2102 .spawn::<WorkerActor>(
2103 "worker",
2104 WorkerParams {
2105 world_size: 1,
2106 rank: 0,
2107 device_index: None,
2108 controller_actor: controller_ref,
2109 },
2110 )
2111 .await
2112 .unwrap();
2113
2114 let ref_arg: PickledPyObject =
2115 Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2116
2117 worker_handle
2118 .command_group(
2119 &client,
2120 vec![
2121 WorkerMessage::CreateStream {
2122 id: 1.into(),
2123 stream_creation: StreamCreationMode::UseDefaultStream,
2124 },
2125 WorkerMessage::SetRefUnitTestsOnly {
2126 reference: Ref { id: 2 },
2127 value: WireValue::Bool(false),
2128 stream: 1.into(),
2129 },
2130 WorkerMessage::SendValue {
2131 seq: 1.into(),
2132 destination: None,
2133 mutates: vec![Ref { id: 2 }],
2134 function: Some("non.existent.function".into()),
2135 args: vec![],
2136 kwargs: HashMap::new(),
2137 stream: 1.into(),
2138 },
2139 WorkerMessage::SendValue {
2140 seq: 2.into(),
2141 destination: None,
2142 mutates: vec![],
2143 function: None,
2144 args: vec![ref_arg.into()],
2145 kwargs: HashMap::new(),
2146 stream: 1.into(),
2147 },
2148 WorkerMessage::Exit { error: None },
2149 ],
2150 )
2151 .await
2152 .unwrap();
2153
2154 worker_handle.drain_and_stop()?;
2155 assert_matches!(worker_handle.await, ActorStatus::Stopped);
2156
2157 let mut responses = controller_rx.drain();
2158 assert_eq!(
2159 responses.len(),
2160 3,
2161 "Expected one response, got: {:#?}",
2162 responses
2163 );
2164 match responses.pop() {
2165 Some(ControllerMessage::FetchResult { seq, value }) => {
2166 assert_eq!(seq, 2.into());
2167 assert!(value.is_err());
2168 assert!(
2169 value
2170 .unwrap_err()
2171 .backtrace
2172 .contains("failed to resolve function")
2173 );
2174 }
2175 _ => panic!("unexpected response {:#?}", responses),
2176 }
2177 match responses.pop() {
2178 Some(ControllerMessage::FetchResult { seq, value }) => {
2179 assert_eq!(seq, 1.into());
2180 assert!(value.is_err());
2181 assert!(
2182 value
2183 .unwrap_err()
2184 .backtrace
2185 .contains("failed to resolve function")
2186 );
2187 }
2188 _ => panic!("unexpected response {:#?}", responses),
2189 }
2190 Ok(())
2191 }
2192
2193 #[async_timed_test(timeout_secs = 60)]
2194 async fn pipe_send_recv() -> Result<()> {
2195 test_setup()?;
2196
2197 let proc = Proc::local();
2198 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2199
2200 let handle = proc
2201 .spawn::<WorkerActor>(
2202 "worker",
2203 WorkerParams {
2204 world_size: 1,
2205 rank: 0,
2206 device_index: None,
2207 controller_actor: controller_ref,
2208 },
2209 )
2210 .await
2211 .unwrap();
2212 let (resolve_value_arg, torch_eq_arg1, torch_eq_arg2): (
2213 PickledPyObject,
2214 PickledPyObject,
2215 PickledPyObject,
2216 ) = Python::with_gil(|py| {
2217 PyResult::Ok((
2218 PyList::new(py, [2, 3])?.into_any().try_into()?,
2219 Ref { id: 2 }.into_bound_py_any(py)?.try_into()?,
2220 Ref { id: 4 }.into_bound_py_any(py)?.try_into()?,
2221 ))
2222 })?;
2223
2224 handle
2225 .command_group(
2226 &client,
2227 vec![
2228 WorkerMessage::CreateStream {
2229 id: 0.into(),
2230 stream_creation: StreamCreationMode::UseDefaultStream,
2231 },
2232 WorkerMessage::CreateDeviceMesh {
2233 result: 1.into(),
2234 names: vec!["x".into()],
2235 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2236 },
2237 WorkerMessage::CallFunction(CallFunctionParams {
2239 seq: 0.into(),
2240 results: vec![Some(2.into())],
2241 mutates: vec![],
2242 function: "torch.ops.aten.ones.default".into(),
2243 args: vec![WireValue::IntList(vec![2, 3])],
2244 kwargs: HashMap::new(),
2245 stream: 0.into(),
2246 remote_process_groups: vec![],
2247 }),
2248 WorkerMessage::CreatePipe {
2249 result: 3.into(),
2250 key: "unused".into(),
2251 function: "monarch.monarch_tensor_worker.test_utils.handler".into(),
2252 max_messages: 1,
2253 mesh: 1.into(),
2254 args: vec![],
2255 kwargs: HashMap::new(),
2256 },
2257 WorkerMessage::SendValue {
2258 seq: 1.into(),
2259 destination: Some(3.into()),
2260 mutates: vec![],
2261 function: Some(
2262 "monarch.monarch_tensor_worker.test_utils.resolve_value".into(),
2263 ),
2264 args: vec![resolve_value_arg.into()],
2265 kwargs: HashMap::new(),
2266 stream: 0.into(),
2267 },
2268 WorkerMessage::PipeRecv {
2269 seq: 2.into(),
2270 results: vec![Some(4.into())],
2271 pipe: 3.into(),
2272 stream: 0.into(),
2273 },
2274 WorkerMessage::CallFunction(CallFunctionParams {
2275 seq: 0.into(),
2276 results: vec![Some(5.into())],
2277 mutates: vec![],
2278 function: "torch.equal".into(),
2279 args: vec![torch_eq_arg1.into(), torch_eq_arg2.into()],
2280 kwargs: HashMap::new(),
2281 stream: 0.into(),
2282 remote_process_groups: vec![],
2283 }),
2284 ],
2285 )
2286 .await
2287 .unwrap();
2288
2289 let matches: bool = handle
2290 .get_ref_unit_tests_only(&client, 5.into(), 0.into())
2291 .await
2292 .unwrap()
2293 .unwrap()
2294 .unwrap()
2295 .try_into()
2296 .unwrap();
2297 assert!(matches);
2298
2299 handle.drain_and_stop()?;
2300 assert_matches!(handle.await, ActorStatus::Stopped);
2301
2302 let responses = controller_rx.drain();
2303 assert_eq!(
2304 responses.len(),
2305 0,
2306 "Expected one response, got: {:#?}",
2307 responses
2308 );
2309
2310 Ok(())
2311 }
2312
2313 fn get_random_channel_addr() -> ChannelAddr {
2314 let random_string = rand::thread_rng()
2315 .sample_iter(&Alphanumeric)
2316 .take(24)
2317 .map(char::from)
2318 .collect::<String>();
2319 format!("unix!@{random_string}").parse().unwrap()
2320 }
2321
2322 async fn ensure_world_ready(client: &Instance<()>, world: WorldId) -> Result<()> {
2323 tracing::info!("checking whether world {world} is ready");
2324 let retry_strategy = FixedInterval::from_millis(1000).take(100);
2325 Retry::spawn(retry_strategy, async || {
2326 let snapshot = SYSTEM_ACTOR_REF
2327 .snapshot(&client, SystemSnapshotFilter::default())
2328 .await?;
2329 let world_snapshot = snapshot.worlds.get(&world).ok_or(anyhow!("no world"))?;
2330 tracing::info!("world status: {:?}", world_snapshot.status);
2331 match world_snapshot.status {
2332 WorldStatus::Live => Ok(()),
2333 _ => Err(anyhow!("world is not live")),
2334 }
2335 })
2336 .await?;
2337 Ok(())
2338 }
2339
2340 #[async_timed_test(timeout_secs = 60)]
2341 async fn remote_process_group() -> Result<()> {
2342 test_setup()?;
2343
2344 let timeout: Duration = Duration::from_secs(10);
2346 let system_addr = get_random_channel_addr();
2347 let _system_handle = System::serve(system_addr.clone(), timeout, timeout).await?;
2348
2349 let client = System::new(system_addr.clone()).attach().await?;
2351 let (handle, mut controller_rx) = client.open_port::<ControllerMessage>();
2352 handle.bind_to(ControllerMessage::port());
2353 let controller_ref: ActorRef<ControllerActor> = ActorRef::attest(client.self_id().clone());
2354
2355 let world_size = 2;
2357 SYSTEM_ACTOR_REF
2358 .upsert_world(
2359 &client,
2360 id!(world),
2361 Shape::Definite(vec![world_size]),
2362 4,
2363 Environment::Local,
2364 HashMap::new(),
2365 )
2366 .await?;
2367
2368 let mut worker_process_handles = vec![];
2370 let mut worker_procs: Vec<ActorRef<ProcActor>> = vec![];
2371 for rank in 0..world_size {
2372 let world_id = "world".to_string();
2373 let proc_id = format!("{world_id}[{rank}]");
2374 worker_procs.push(ActorRef::attest(format!("world[{rank}].proc[0]").parse()?));
2375
2376 let mut handle =
2377 Command::new(std::env::var("MONARCH_TENSOR_WORKER_EXE").map_err(|e| {
2378 anyhow::anyhow!("could not get var MONARCH_TENSOR_WORKER_EXE: {}", e)
2379 })?)
2380 .arg("worker")
2381 .arg(format!("--bootstrap-addr={system_addr}"))
2382 .arg(format!("--world-id={world_id}"))
2383 .arg(format!("--proc-id={proc_id}"))
2384 .stdout(Stdio::piped())
2385 .stdin(Stdio::piped())
2386 .kill_on_drop(true)
2387 .spawn()?;
2388
2389 let out = handle.stdout.take().unwrap();
2390 tokio::spawn(async move {
2391 let mut reader = BufReader::new(out);
2392 tokio::io::copy(&mut reader, &mut tokio::io::stderr())
2393 .await
2394 .unwrap();
2395 });
2396 worker_process_handles.push(handle);
2397 }
2398
2399 ensure_world_ready(&client, id!(world)).await?;
2402
2403 let (spawned_port, mut spawned_receiver) = open_port(&client);
2405 for (rank, worker_proc) in worker_procs.iter().enumerate() {
2406 let params = WorkerParams {
2407 world_size,
2408 rank,
2409 device_index: Some(rank.try_into().unwrap()),
2410 controller_actor: controller_ref.clone(),
2411 };
2412 worker_proc
2413 .spawn(
2414 &client,
2415 "monarch_tensor_worker::WorkerActor".to_owned(),
2416 "worker".to_owned(),
2417 bincode::serialize(¶ms)?,
2418 spawned_port.bind(),
2419 )
2420 .await?;
2421 }
2422 let mut spawned = vec![];
2423 while spawned.len() < world_size {
2424 spawned.push(spawned_receiver.recv().await?);
2425 }
2426 tracing::info!("spawned {} worker actors", world_size);
2427 let workers: Vec<ActorRef<WorkerActor>> = (0..world_size)
2428 .map(|rank| format!("world[{rank}].worker[0]"))
2429 .map(|name| ActorRef::attest(name.parse().unwrap()))
2430 .collect();
2431
2432 let remote_proc_grp_ref: PickledPyObject =
2433 Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2434
2435 let unique_id = UniqueId::new()?;
2436 let messages = vec![
2437 WorkerMessage::CreateStream {
2438 id: 0.into(),
2439 stream_creation: StreamCreationMode::UseDefaultStream,
2440 },
2441 WorkerMessage::BackendNetworkInit(unique_id.clone()),
2442 WorkerMessage::CreateDeviceMesh {
2443 result: 1.into(),
2444 names: vec!["x".into()],
2445 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2446 },
2447 WorkerMessage::CreateRemoteProcessGroup {
2448 result: 2.into(),
2449 device_mesh: 1.into(),
2450 dims: vec!["x".into()],
2451 },
2452 WorkerMessage::SplitCommForProcessGroup {
2453 remote_process_group: 2.into(),
2454 stream: 0.into(),
2455 config: None,
2456 },
2457 WorkerMessage::CallFunction(CallFunctionParams {
2458 seq: 0.into(),
2459 results: vec![Some(3.into())],
2460 mutates: vec![],
2461 function: "monarch.monarch_tensor_worker.test_utils.test_remote_process_group"
2462 .into(),
2463 args: vec![remote_proc_grp_ref.into()],
2464 kwargs: HashMap::new(),
2465 stream: 0.into(),
2466 remote_process_groups: vec![2.into()],
2467 }),
2468 ];
2469
2470 workers[0].command_group(&client, messages.clone()).await?;
2471 workers[1].command_group(&client, messages).await?;
2472
2473 let _ = workers[0]
2474 .get_ref_unit_tests_only(&client, 3.into(), 0.into())
2475 .await?
2476 .unwrap()
2477 .unwrap();
2478
2479 let error_responses = controller_rx.drain();
2480 assert!(
2481 error_responses.is_empty(),
2482 "Expected no error responses, got: {:#?}",
2483 error_responses
2484 );
2485
2486 Ok(())
2487 }
2488}