1#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12#![allow(unsafe_op_in_unsafe_fn)]
15
16pub mod bootstrap;
31mod borrow;
32mod comm;
33pub mod device_mesh;
34pub mod pipe;
35pub mod py_pipe;
36pub mod stream;
37pub mod test_util;
38
39use std::collections::HashMap;
40use std::collections::HashSet;
41use std::collections::hash_map::Entry;
42use std::sync::Arc;
43
44use anyhow::Context;
45use anyhow::Result;
46use anyhow::anyhow;
47use anyhow::bail;
48use anyhow::ensure;
49use async_trait::async_trait;
50use borrow::Borrow;
51use comm::CommMessageClient;
52use comm::CommParams;
53use comm::NcclCommActor;
54use derive_more::TryInto;
55use device_mesh::DeviceMesh;
56use futures::future::try_join_all;
57use hyperactor::Actor;
58use hyperactor::ActorRef;
59use hyperactor::Bind;
60use hyperactor::Handler;
61use hyperactor::Named;
62use hyperactor::Unbind;
63use hyperactor::actor::ActorHandle;
64use hyperactor::cap;
65use hyperactor::reference::ActorId;
66use hyperactor_mesh::comm::multicast::CastInfo;
67use itertools::Itertools;
68use monarch_hyperactor::shape::PyPoint;
69use monarch_hyperactor::shape::PyShape;
70use monarch_messages::controller::ControllerActor;
71use monarch_messages::controller::ControllerMessageClient;
72use monarch_messages::controller::Seq;
73use monarch_messages::wire_value::WireValue;
74use monarch_messages::worker::ActorCallParams;
75use monarch_messages::worker::ActorMethodParams;
76use monarch_messages::worker::CallFunctionError;
77use monarch_messages::worker::CallFunctionParams;
78use monarch_messages::worker::Factory;
79use monarch_messages::worker::Reduction;
80use monarch_messages::worker::Ref;
81use monarch_messages::worker::ResolvableFunction;
82use monarch_messages::worker::StreamCreationMode;
83use monarch_messages::worker::StreamRef;
84use monarch_messages::worker::WorkerMessage;
85use monarch_messages::worker::WorkerMessageHandler;
86use monarch_messages::worker::WorkerParams;
87use monarch_types::PyTree;
88use ndslice::Slice;
89use pipe::PipeActor;
90use pipe::PipeParams;
91use pyo3::Py;
92use pyo3::Python;
93use pyo3::types::PyAnyMethods;
94use serde::Deserialize;
95use serde::Serialize;
96use sorted_vec::SortedVec;
97use stream::StreamActor;
98use stream::StreamMessageClient;
99use stream::StreamParams;
100use torch_sys::CudaDevice;
101use torch_sys::DeviceIndex;
102use torch_sys::Layout;
103use torch_sys::RValue;
104use torch_sys::ScalarType;
105use torch_sys::TensorCell;
106use torch_sys::factory_zeros;
107use torch_sys_cuda::nccl::NcclConfig;
108use torch_sys_cuda::nccl::ReduceOp;
109use torch_sys_cuda::nccl::UniqueId;
110
111#[derive(Debug)]
112struct RemoteProcessGroupState {
113 device_mesh_ref: Ref,
114 dims: SortedVec<String>,
115 comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
116}
117
118impl RemoteProcessGroupState {
119 fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
120 Self {
121 device_mesh_ref,
122 dims,
123 comms: HashMap::new(),
124 }
125 }
126}
127
128#[derive(Debug)]
129enum Recording {
130 PartialRecording {
133 last_index: usize,
135 commands: Vec<WorkerMessage>,
137 },
138
139 CompleteRecording {
141 streams: HashSet<StreamRef>,
143 },
144}
145
146#[derive(Debug)]
154#[hyperactor::export(
155 spawn = true,
156 handlers = [
157 WorkerMessage {cast = true},
158 AssignRankMessage {cast = true},
159 ],
160)]
161pub struct WorkerActor {
162 device: Option<CudaDevice>,
163 streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
164 device_meshes: HashMap<
167 Ref,
168 (
169 DeviceMesh,
170 HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
172 ),
173 >,
174 world_size: usize,
175 rank: usize,
176 borrows: HashMap<u64, Borrow>,
177 comm: Option<ActorHandle<NcclCommActor>>,
178 controller_actor: ActorRef<ControllerActor>,
179 pipes: HashMap<Ref, ActorHandle<PipeActor>>,
181 remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
185 send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
187 recordings: HashMap<Ref, Recording>,
188 defining_recording: Option<Ref>,
189 respond_with_python_message: bool,
190}
191
192impl WorkerActor {
193 fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
194 self.streams
195 .get(&stream)
196 .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
197 }
198
199 async fn maybe_add_stream_to_recording(
200 &mut self,
201 caps: &impl cap::CanSend,
202 stream: StreamRef,
203 ) -> Result<()> {
204 if let Some(defining_recording) = self.defining_recording {
207 let recording = self.recordings.get_mut(&defining_recording).unwrap();
208 let fut = match recording {
209 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
210 Recording::CompleteRecording { streams } => {
211 streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
212 Ok(self
213 .try_get_stream(stream)?
214 .define_recording(caps, defining_recording))
215 })
216 }
217 }
218 .transpose()?;
219 match fut {
220 Some(fut) => fut.await,
221 None => Ok(()),
222 }
223 } else {
224 Ok(())
225 }
226 }
227}
228
229#[async_trait]
230impl Actor for WorkerActor {
231 type Params = WorkerParams;
232
233 async fn new(
234 WorkerParams {
235 world_size,
236 rank,
237 device_index,
238 controller_actor,
239 }: Self::Params,
240 ) -> Result<Self> {
241 Ok(Self {
242 device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
243 streams: HashMap::new(),
244 device_meshes: HashMap::new(),
245 world_size,
246 rank,
247 borrows: HashMap::new(),
248 comm: None,
249 controller_actor,
250 pipes: HashMap::new(),
251 remote_process_groups: HashMap::new(),
252 send_recv_comms: HashMap::new(),
253 recordings: HashMap::new(),
254 defining_recording: None,
255 respond_with_python_message: false,
256 })
257 }
258
259 }
261
262#[async_trait]
263impl Handler<AssignRankMessage> for WorkerActor {
264 async fn handle(
265 &mut self,
266 cx: &hyperactor::Context<Self>,
267 _: AssignRankMessage,
268 ) -> anyhow::Result<()> {
269 let (rank, shape) = cx.cast_info();
270 self.rank = rank;
271 self.respond_with_python_message = true;
272 Python::with_gil(|py| {
273 let mesh_controller = py.import("monarch.mesh_controller").unwrap();
274 let shape: PyShape = shape.into();
275 let shape: Py<PyShape> = Py::new(py, shape).unwrap();
276 let p: PyPoint = PyPoint::new(rank, shape);
277 mesh_controller
278 .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
279 .unwrap();
280 });
281 Ok(())
282 }
283}
284
285#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
288pub enum AssignRankMessage {
289 AssignRank(),
290}
291
292#[async_trait]
293impl Handler<WorkerMessage> for WorkerActor {
294 async fn handle(
295 &mut self,
296 cx: &hyperactor::Context<Self>,
297 message: WorkerMessage,
298 ) -> anyhow::Result<()> {
299 <Self as WorkerMessageHandler>::handle(self, cx, message).await
300 }
301}
302
303#[async_trait]
304impl WorkerMessageHandler for WorkerActor {
305 async fn backend_network_init(
306 &mut self,
307 cx: &hyperactor::Context<Self>,
308 unique_id: UniqueId,
309 ) -> Result<()> {
310 let device = self
311 .device
312 .expect("tried to init backend network on a non-CUDA worker");
313 let comm = NcclCommActor::spawn(
314 cx,
315 CommParams::New {
316 device,
317 unique_id,
318 world_size: self.world_size.try_into().unwrap(),
319 rank: self.rank.try_into().unwrap(),
320 },
321 )
322 .await?;
323
324 let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
325 let cell = TensorCell::new(tensor);
326
327 comm.all_reduce(
328 cx,
329 cell,
330 ReduceOp::Sum,
331 torch_sys_cuda::cuda::Stream::get_current_stream(),
332 )
333 .await?;
334
335 let sorted_streams = self
344 .streams
345 .iter()
346 .sorted_by_key(|(k, _)| *k)
347 .map(|(_, v)| v.as_ref());
348
349 let mut splits = Vec::new();
350 for _ in 0..sorted_streams.len() {
351 splits.push(comm.split_all(cx, None).await?);
354 }
355 let _: Vec<()> = try_join_all(
356 sorted_streams
357 .into_iter()
358 .zip(splits.into_iter())
359 .map(|(stream, split)| stream.init_comm(cx, split)),
360 )
361 .await?;
362
363 self.comm = Some(comm);
364
365 Ok(())
366 }
367
368 async fn backend_network_point_to_point_init(
369 &mut self,
370 cx: &hyperactor::Context<Self>,
371 from_stream: StreamRef,
372 to_stream: StreamRef,
373 ) -> Result<()> {
374 if !self.streams.contains_key(&from_stream) {
375 bail!("invalid from_stream id: {:#?}", from_stream);
376 }
377 if !self.streams.contains_key(&to_stream) {
378 bail!("invalid to_stream id: {:#?}", to_stream);
379 }
380 let global_comm = self
381 .comm
382 .as_ref()
383 .context("tried to call Reduce before BackendNetworkInit")?;
384 let comm = global_comm.split_all(cx, None).await?;
385 self.send_recv_comms
386 .insert((from_stream, to_stream), Arc::new(comm));
387 Ok(())
388 }
389
390 async fn call_function(
391 &mut self,
392 cx: &hyperactor::Context<Self>,
393 params: CallFunctionParams,
394 ) -> Result<()> {
395 let stream = self.try_get_stream(params.stream)?.clone();
396 self.maybe_add_stream_to_recording(cx, params.stream)
397 .await?;
398
399 let device_meshes = if params.function.as_torch_op().is_some() {
400 HashMap::new()
401 } else {
402 self.device_meshes
403 .iter()
404 .map(|(k, v)| (k.clone(), v.0.clone()))
405 .collect()
406 };
407
408 let mut remote_process_groups = HashMap::new();
409 for remote_process_group_ref in ¶ms.remote_process_groups {
410 if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
411 let dims_vec = state.dims.iter().cloned().collect();
412 let (device_mesh, _) = self
413 .device_meshes
414 .get(&state.device_mesh_ref)
415 .ok_or_else(|| {
416 anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
417 })?
418 .clone();
419 let comm = state.comms
420 .get(¶ms.stream)
421 .ok_or_else(|| {
422 anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
423 })?
424 .clone();
425 remote_process_groups.insert(
426 remote_process_group_ref.clone(),
427 (device_mesh, dims_vec, comm),
428 );
429 }
430 }
431
432 stream
433 .call_function(cx, params, device_meshes, remote_process_groups)
434 .await?;
435
436 Ok(())
437 }
438
439 async fn command_group(
440 &mut self,
441 cx: &hyperactor::Context<Self>,
442 params: Vec<WorkerMessage>,
443 ) -> Result<()> {
444 for msg in params {
445 WorkerMessageHandler::handle(self, cx, msg).await?;
446 }
447 Ok(())
448 }
449
450 async fn create_stream(
451 &mut self,
452 cx: &hyperactor::Context<Self>,
453 result: StreamRef,
454 creation_mode: StreamCreationMode,
455 ) -> Result<()> {
456 let handle: ActorHandle<StreamActor> = StreamActor::spawn(
457 cx,
458 StreamParams {
459 world_size: self.world_size,
460 rank: self.rank,
461 creation_mode,
462 id: result,
463 device: self.device,
464 controller_actor: self.controller_actor.clone(),
465 respond_with_python_message: self.respond_with_python_message,
466 },
467 )
468 .await?;
469 self.streams.insert(result, Arc::new(handle));
470 Ok(())
471 }
472
473 async fn create_device_mesh(
474 &mut self,
475 _cx: &hyperactor::Context<Self>,
476 result: Ref,
477 names: Vec<String>,
478 ranks: Slice,
479 ) -> Result<()> {
480 self.device_meshes.insert(
481 result,
482 (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
483 );
484 Ok(())
485 }
486
487 async fn create_remote_process_group(
488 &mut self,
489 _cx: &hyperactor::Context<Self>,
490 result: Ref,
491 device_mesh: Ref,
492 dims: Vec<String>,
493 ) -> Result<()> {
494 self.device_meshes
495 .get(&device_mesh)
496 .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
497 match self.remote_process_groups.entry(result) {
498 Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
499 device_mesh,
500 SortedVec::from_unsorted(dims),
501 )),
502 Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
503 };
504 Ok(())
505 }
506
507 async fn borrow_create(
508 &mut self,
509 cx: &hyperactor::Context<Self>,
510 result: Ref,
511 borrow_id: u64,
512 tensor_ref: Ref,
513 from_stream: StreamRef,
514 to_stream: StreamRef,
515 ) -> Result<()> {
516 self.maybe_add_stream_to_recording(cx, from_stream).await?;
517 self.maybe_add_stream_to_recording(cx, to_stream).await?;
518 let from_stream = self.try_get_stream(from_stream)?.clone();
519 let to_stream = self.try_get_stream(to_stream)?.clone();
520
521 let borrow =
522 Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
523 self.borrows.insert(borrow_id, borrow);
524 Ok(())
525 }
526
527 async fn borrow_first_use(
528 &mut self,
529 cx: &hyperactor::Context<Self>,
530 borrow: u64,
531 ) -> Result<()> {
532 let borrow = self
533 .borrows
534 .get_mut(&borrow)
535 .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
536
537 borrow.first_use(cx).await?;
538 Ok(())
539 }
540
541 async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
542 let borrow = self
543 .borrows
544 .get_mut(&borrow)
545 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
546
547 borrow.last_use(cx).await?;
548 Ok(())
549 }
550
551 async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
552 let borrow = self
553 .borrows
554 .get_mut(&borrow_id)
555 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
556
557 borrow.drop(cx).await?;
558 self.borrows.remove(&borrow_id);
559 Ok(())
560 }
561
562 async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
563 let _: Vec<()> = try_join_all(
570 self.streams
571 .values()
572 .map(|s| s.delete_refs(cx, refs.clone())),
573 )
574 .await?;
575 Ok(())
576 }
577
578 async fn request_status(
579 &mut self,
580 cx: &hyperactor::Context<Self>,
581 seq: Seq,
582 controller: bool,
583 ) -> Result<()> {
584 let _: Vec<()> = try_join_all(
589 self.streams
590 .values()
591 .map(|stream| stream.request_status(cx)),
592 )
593 .await?;
594
595 ControllerMessageClient::status(
596 &self.controller_actor,
597 cx,
598 seq.next(),
599 cx.self_id().clone(),
600 controller,
601 )
602 .await?;
603 Ok(())
604 }
605
606 async fn reduce(
607 &mut self,
608 cx: &hyperactor::Context<Self>,
609 result: Ref,
610 local_tensor: Ref,
611 factory: Factory,
612 source_mesh: Ref,
613 stream_ref: StreamRef,
614 dims: Vec<String>,
615 reduction: Reduction,
616 scatter: bool,
617 in_place: bool,
618 out: Option<Ref>,
619 ) -> Result<()> {
620 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
621
622 let dims = SortedVec::from_unsorted(dims);
624 let stream = self.try_get_stream(stream_ref)?.clone();
625
626 let (_, comm_map) = self
627 .device_meshes
628 .get_mut(&source_mesh)
629 .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
630
631 let (size, comm) = comm_map
632 .get(&(stream_ref, dims.clone()))
633 .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
634 .clone();
635
636 stream
637 .reduce(
638 cx,
639 comm,
640 size.try_into()?,
641 result,
642 local_tensor,
643 factory,
644 reduction,
645 scatter,
646 in_place,
647 out,
648 )
649 .await?;
650
651 Ok(())
652 }
653
654 async fn create_pipe(
655 &mut self,
656 cx: &hyperactor::Context<Self>,
657 result: Ref,
658 _key: String,
661 function: ResolvableFunction,
662 max_messages: i64,
663 device_mesh: Ref,
664 args: Vec<WireValue>,
665 kwargs: HashMap<String, WireValue>,
666 ) -> Result<()> {
667 println!("CREATE PIPE1 {}", result);
668 let args: Vec<PyTree<RValue>> = args
669 .into_iter()
670 .map(|object| RValue::PyObject(object.into_py_object().unwrap()).into())
671 .collect();
672 let kwargs: HashMap<_, PyTree<RValue>> = kwargs
673 .into_iter()
674 .map(|(k, object)| (k, RValue::PyObject(object.into_py_object().unwrap()).into()))
675 .collect();
676 let device_mesh = self.device_meshes.get(&device_mesh).ok_or_else(|| {
677 CallFunctionError::Error(anyhow::anyhow!("ref not found: {}", device_mesh))
678 })?;
679 println!("CREATE PIPE2 {}", result);
680 let pipe = PipeActor::spawn(
682 cx,
683 PipeParams {
684 function,
685 max_messages,
686 ranks: device_mesh.0.ranks(),
687 sizes: device_mesh.0.sizes(),
688 args,
689 kwargs,
690 },
691 )
692 .await?;
693 println!("AFTER CREATE PIPE {}", result);
694
695 self.pipes.insert(result, pipe);
696 Ok(())
697 }
698
699 async fn send_tensor(
700 &mut self,
701 cx: &hyperactor::Context<Self>,
702 result: Ref,
703 from_ranks: Slice,
704 to_ranks: Slice,
705 tensor: Ref,
706 factory: Factory,
707 from_stream: StreamRef,
708 to_stream: StreamRef,
709 ) -> Result<()> {
710 let comm = self
711 .send_recv_comms
712 .get(&(from_stream, to_stream))
713 .ok_or_else(|| {
714 anyhow::anyhow!(
715 "could not find stream to stream comm for: {:#?}",
716 (from_stream, to_stream)
717 )
718 })?
719 .clone();
720
721 let to_rank = from_ranks
722 .index(self.rank)
723 .map(|index| to_ranks.get(index).ok())
724 .ok()
725 .flatten();
726 let from_rank = to_ranks
727 .index(self.rank)
728 .map(|index| from_ranks.get(index).ok())
729 .ok()
730 .flatten();
731
732 let (stream, stream_ref) = if to_rank.is_none() {
733 (self.try_get_stream(to_stream)?.clone(), to_stream)
734 } else if from_rank.is_none() || from_stream == to_stream {
735 (self.try_get_stream(from_stream)?.clone(), from_stream)
736 } else {
737 unimplemented!(
738 "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
739 It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
740 Then the send stream would do the nccl op, and then sync with sending stream again."
741 );
742 };
743
744 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
745
746 stream
747 .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
748 .await?;
749
750 Ok(())
751 }
752
753 async fn exit(
754 &mut self,
755 cx: &hyperactor::Context<Self>,
756 error: Option<(Option<ActorId>, String)>,
757 ) -> Result<()> {
758 for (_, stream) in self.streams.drain() {
759 stream.drain_and_stop()?;
760 Arc::into_inner(stream)
761 .expect("there should be no owners of this stream handle except the worker stream table")
762 .await;
763 }
764
765 let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
766 .ok()
767 .and_then(|val| val.parse::<i32>().ok())
768 .unwrap_or(1);
769 let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
770 .ok()
771 .and_then(|val| val.parse::<i32>().ok())
772 .unwrap_or(1);
773
774 let exit_code = match error {
776 Some((Some(actor_id), reason)) => {
777 tracing::error!(
778 "stopping the worker, actor {} failed with error: {}",
779 actor_id,
780 reason
781 );
782 if *cx.self_id() == actor_id {
783 self_error_exit_code
784 } else {
785 peer_error_exit_code
786 }
787 }
788 Some((None, reason)) => {
789 tracing::error!("stopping the worker, reason: {}", reason);
790 1
791 }
792 None => 0,
793 };
794
795 if exit_code != 0 {
796 tracing::info!("stopping the worker process, exit code: {}", exit_code);
797 std::process::exit(exit_code);
798 }
799 cx.stop()?;
800 Ok(())
801 }
802
803 async fn send_value(
804 &mut self,
805 cx: &hyperactor::Context<Self>,
806 seq: Seq,
807 destination: Option<Ref>,
808 mutates: Vec<Ref>,
809 function: Option<ResolvableFunction>,
810 args: Vec<WireValue>,
811 kwargs: HashMap<String, WireValue>,
812 stream: StreamRef,
813 ) -> Result<()> {
814 let stream = self.try_get_stream(stream)?;
816
817 let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) {
818 HashMap::new()
819 } else {
820 self.device_meshes
821 .iter()
822 .map(|(k, v)| (k.clone(), v.0.clone()))
823 .collect()
824 };
825
826 let pipe = if let Some(destination) = destination {
827 let pipe = self
828 .pipes
829 .get(&destination)
830 .ok_or_else(|| anyhow::anyhow!("invalid pipe id: {:#?}", destination))?
831 .port();
832 Some(pipe)
833 } else {
834 None
835 };
836 stream
839 .send_value(
840 cx,
841 seq,
842 cx.self_id().clone(),
843 mutates,
844 function,
845 args,
846 kwargs,
847 device_meshes,
848 pipe,
849 )
850 .await
851 }
852
853 async fn send_result_of_actor_call(
854 &mut self,
855 cx: &hyperactor::Context<Self>,
856 params: ActorCallParams,
857 ) -> Result<()> {
858 let stream = self.try_get_stream(params.stream)?;
859 stream
860 .send_result_of_actor_call(cx, cx.self_id().clone(), params)
861 .await?;
862 Ok(())
863 }
864 async fn call_actor_method(
865 &mut self,
866 cx: &hyperactor::Context<Self>,
867 params: ActorMethodParams,
868 ) -> Result<()> {
869 let stream = self.try_get_stream(params.call.stream)?;
870 stream.call_actor_method(cx, params).await?;
871 Ok(())
872 }
873 async fn split_comm(
874 &mut self,
875 cx: &hyperactor::Context<Self>,
876 dims: Vec<String>,
877 device_mesh: Ref,
878 stream_ref: StreamRef,
879 config: Option<NcclConfig>,
880 ) -> Result<()> {
881 let global_comm = self
882 .comm
883 .as_ref()
884 .context("tried to call SplitComm before BackendNetworkInit")?;
885 match self.device_meshes.get_mut(&device_mesh) {
886 Some((device_mesh, comm_map)) => {
887 let stream = self
890 .streams
891 .get(&stream_ref)
892 .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
893
894 let dims = SortedVec::from_unsorted(dims);
895
896 anyhow::ensure!(
897 !comm_map.contains_key(&(stream_ref, dims.clone())),
898 "comm already exists for stream {stream:#?}, dims {dims:#?}"
899 );
900 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
901 let size = ranks_for_group.len();
902 let split_comm = global_comm
903 .split_from(
904 cx,
905 ranks_for_group
906 .into_iter()
907 .map(|v| v.clone().try_into())
908 .collect::<Result<Vec<_>, _>>()?,
909 config,
910 )
911 .await?
912 .context("split comm should include self rank")?;
913 comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
914 }
915 None => {
916 global_comm.split_from(cx, vec![], config).await?;
919 }
920 }
921 Ok(())
922 }
923
924 async fn split_comm_for_process_group(
925 &mut self,
926 cx: &hyperactor::Context<Self>,
927 remote_process_group_ref: Ref,
928 stream_ref: StreamRef,
929 config: Option<NcclConfig>,
930 ) -> Result<()> {
931 ensure!(
932 self.streams.contains_key(&stream_ref),
933 "invalid stream id: {:#?}",
934 stream_ref
935 );
936 let global_comm = self
937 .comm
938 .as_ref()
939 .context("tried to call SplitComm before BackendNetworkInit")?;
940 let state = self
941 .remote_process_groups
942 .get_mut(&remote_process_group_ref)
943 .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
944 match self.device_meshes.get_mut(&state.device_mesh_ref) {
945 Some((device_mesh, _)) => {
946 let entry = match state.comms.entry(stream_ref) {
949 Entry::Vacant(entry) => entry,
950 Entry::Occupied(_) => bail!(
951 "comm already exists for remote process group {:#?} on stream {:#?}",
952 remote_process_group_ref,
953 stream_ref,
954 ),
955 };
956 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
957 let split_comm = global_comm
958 .split_from(
959 cx,
960 ranks_for_group
961 .into_iter()
962 .map(|v| v.clone().try_into())
963 .collect::<Result<Vec<_>, _>>()?,
964 config,
965 )
966 .await?
967 .context("split comm should include self rank")?;
968 entry.insert(Arc::new(split_comm));
969 }
970 None => {
971 global_comm.split_from(cx, vec![], config).await?;
974 }
975 }
976 Ok(())
977 }
978
979 async fn pipe_recv(
980 &mut self,
981 cx: &hyperactor::Context<Self>,
982 seq: Seq,
983 results: Vec<Option<Ref>>,
984 pipe: Ref,
985 stream: StreamRef,
986 ) -> Result<()> {
987 self.maybe_add_stream_to_recording(cx, stream).await?;
988
989 let pipe = self
991 .pipes
992 .get(&pipe)
993 .ok_or_else(|| anyhow::anyhow!("ref not found: {}", pipe))?;
994 let pipe = pipe.port();
995 let stream = self.try_get_stream(stream)?;
997 stream.set_value(cx, seq, results, pipe).await
999 }
1000
1001 async fn set_ref_unit_tests_only(
1002 &mut self,
1003 cx: &hyperactor::Context<Self>,
1004 reference: Ref,
1005 value: WireValue,
1006 stream: StreamRef,
1007 ) -> Result<()> {
1008 let stream = self.try_get_stream(stream)?;
1009
1010 stream.set_ref_unit_tests_only(cx, reference, value).await
1011 }
1012
1013 async fn get_ref_unit_tests_only(
1014 &mut self,
1015 cx: &hyperactor::Context<Self>,
1016 ref_id: Ref,
1017 stream: StreamRef,
1018 ) -> Result<Option<Result<WireValue, String>>> {
1019 let stream = self.try_get_stream(stream)?;
1020 Ok(stream
1021 .get_ref_unit_tests_only(cx, ref_id.clone())
1022 .await?
1023 .map(|o| Ok(o?)))
1024 }
1025
1026 async fn define_recording(
1027 &mut self,
1028 cx: &hyperactor::Context<Self>,
1029 result: Ref,
1030 _nresults: usize,
1031 _nformals: usize,
1032 commands: Vec<WorkerMessage>,
1033 ntotal_messages: usize,
1034 index: usize,
1035 ) -> Result<()> {
1036 if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
1037 bail!("already defining a different recording");
1038 }
1039 self.defining_recording = Some(result);
1040
1041 match self.recordings.entry(result) {
1042 Entry::Vacant(entry) => {
1043 ensure!(
1044 index == 0,
1045 "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
1046 index
1047 );
1048 entry.insert(Recording::PartialRecording {
1049 last_index: 0,
1050 commands,
1051 });
1052 }
1053 Entry::Occupied(mut entry) => match entry.get_mut() {
1054 Recording::CompleteRecording { .. } => {
1055 bail!("got DefineRecording message for already complete recording")
1056 }
1057 Recording::PartialRecording {
1058 last_index,
1059 commands: existing_commands,
1060 } => {
1061 ensure!(
1062 index == *last_index + 1,
1063 "Got DefineRecording message with index = {:?}, but \
1064 last seen index for recording is {:?}",
1065 index,
1066 last_index
1067 );
1068 *last_index = index;
1069 existing_commands.extend(commands.into_iter());
1070 }
1071 },
1072 };
1073
1074 if index < ntotal_messages - 1 {
1075 return Ok(());
1076 }
1077 let commands = match self.recordings.remove(&result).unwrap() {
1078 Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
1079 Recording::PartialRecording { commands, .. } => {
1080 self.recordings.insert(
1081 result,
1082 Recording::CompleteRecording {
1083 streams: HashSet::new(),
1084 },
1085 );
1086 commands
1087 }
1088 };
1089
1090 for command in commands {
1091 WorkerMessageHandler::handle(self, cx, command).await?;
1092 }
1093
1094 match self.recordings.get(&result).unwrap() {
1095 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1096 Recording::CompleteRecording { streams, .. } => {
1097 for stream in streams {
1098 self.try_get_stream(*stream)?
1099 .finalize_recording(cx, result)
1100 .await?;
1101 }
1102 }
1103 }
1104
1105 self.defining_recording = None;
1106 Ok(())
1107 }
1108
1109 async fn recording_formal(
1110 &mut self,
1111 cx: &hyperactor::Context<Self>,
1112 result: Ref,
1113 argument_index: usize,
1114 stream: StreamRef,
1115 ) -> Result<()> {
1116 ensure!(self.defining_recording.is_some());
1117 self.maybe_add_stream_to_recording(cx, stream).await?;
1118 self.try_get_stream(stream)?
1119 .recording_formal(cx, result, argument_index)
1120 .await
1121 }
1122
1123 async fn recording_result(
1124 &mut self,
1125 cx: &hyperactor::Context<Self>,
1126 result: Ref,
1127 output_index: usize,
1128 stream: StreamRef,
1129 ) -> Result<()> {
1130 ensure!(self.defining_recording.is_some());
1131 self.maybe_add_stream_to_recording(cx, stream).await?;
1132 self.try_get_stream(stream)?
1133 .recording_result(cx, result, output_index)
1134 .await
1135 }
1136
1137 async fn call_recording(
1138 &mut self,
1139 cx: &hyperactor::Context<Self>,
1140 seq: Seq,
1141 recording: Ref,
1142 results: Vec<Ref>,
1143 actuals: Vec<Ref>,
1144 ) -> Result<()> {
1145 ensure!(self.defining_recording.is_none());
1146 let recording_ref = recording;
1147 let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1148 "could not find recording: {:#?}",
1149 recording
1150 ))?;
1151 match recording {
1152 Recording::PartialRecording { .. } => {
1153 bail!("cannot call recording because it is incomplete")
1154 }
1155 Recording::CompleteRecording { streams } => try_join_all(
1156 streams
1157 .iter()
1158 .map(|stream| self.try_get_stream(*stream))
1159 .collect::<Result<Vec<_>>>()?
1160 .into_iter()
1161 .map(|stream| {
1162 stream.call_recording(
1163 cx,
1164 seq,
1165 recording_ref,
1166 results.clone(),
1167 actuals.clone(),
1168 )
1169 }),
1170 )
1171 .await
1172 .map(|_| ()),
1173 }
1174 }
1175}
1176
1177#[cfg(test)]
1178mod tests {
1179 use std::assert_matches::assert_matches;
1180 use std::process::Stdio;
1181
1182 use anyhow::Result;
1183 use hyperactor::Mailbox;
1184 use hyperactor::Named;
1185 use hyperactor::WorldId;
1186 use hyperactor::actor::ActorStatus;
1187 use hyperactor::channel::ChannelAddr;
1188 use hyperactor::id;
1189 use hyperactor::mailbox::open_port;
1190 use hyperactor::proc::Proc;
1191 use hyperactor_multiprocess::System;
1192 use hyperactor_multiprocess::proc_actor::Environment;
1193 use hyperactor_multiprocess::proc_actor::ProcActor;
1194 use hyperactor_multiprocess::proc_actor::ProcMessageClient;
1195 use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
1196 use hyperactor_multiprocess::system_actor::Shape;
1197 use hyperactor_multiprocess::system_actor::SystemMessageClient;
1198 use hyperactor_multiprocess::system_actor::SystemSnapshotFilter;
1199 use hyperactor_multiprocess::system_actor::WorldStatus;
1200 use monarch_messages::controller::ControllerMessage;
1201 use monarch_messages::controller::WorkerError;
1202 use monarch_messages::worker::WorkerMessageClient;
1203 use monarch_types::PickledPyObject;
1204 use monarch_types::PyTree;
1205 use pyo3::IntoPyObjectExt;
1206 use pyo3::Python;
1207 use pyo3::prelude::*;
1208 use pyo3::types::PyList;
1209 use pyo3::types::PyString;
1210 use rand::Rng;
1211 use rand::distributions::Alphanumeric;
1212 use timed_test::async_timed_test;
1213 use tokio::io::BufReader;
1214 use tokio::process::Command;
1215 use tokio_retry::Retry;
1216 use tokio_retry::strategy::FixedInterval;
1217 use torch_sys::Device;
1218 use torch_sys::DeviceIndex;
1219 use torch_sys::MemoryFormat;
1220
1221 use super::*;
1222 use crate::test_util::test_setup;
1223
1224 #[async_timed_test(timeout_secs = 60)]
1225 async fn basic_worker() -> Result<()> {
1226 test_setup()?;
1227
1228 let proc = Proc::local();
1229 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1230
1231 let worker_handle = proc
1232 .spawn::<WorkerActor>(
1233 "worker",
1234 WorkerParams {
1235 world_size: 1,
1236 rank: 0,
1237 device_index: None,
1238 controller_actor: controller_ref,
1239 },
1240 )
1241 .await
1242 .unwrap();
1243 worker_handle
1244 .command_group(
1245 &client,
1246 vec![
1247 WorkerMessage::CreateStream {
1248 id: 1.into(),
1249 stream_creation: StreamCreationMode::UseDefaultStream,
1250 },
1251 WorkerMessage::CallFunction(CallFunctionParams {
1252 seq: 0.into(),
1253 results: vec![Some(0.into())],
1254 mutates: vec![],
1255 function: "torch.ops.aten.ones.default".into(),
1256 args: vec![WireValue::IntList(vec![2, 3])],
1257 kwargs: HashMap::new(),
1258 stream: 1.into(),
1259 remote_process_groups: vec![],
1260 }),
1261 WorkerMessage::CallFunction(CallFunctionParams {
1262 seq: 2.into(),
1263 results: vec![Some(Ref { id: 2 })],
1264 mutates: vec![0.into()],
1265 function: "torch.ops.aten.sub_.Scalar".into(),
1266 args: vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1267 kwargs: HashMap::new(),
1268 stream: 1.into(),
1269 remote_process_groups: vec![],
1270 }),
1271 WorkerMessage::CallFunction(CallFunctionParams {
1272 seq: 3.into(),
1273 results: vec![Some(Ref { id: 3 })],
1274 mutates: vec![],
1275 function: "torch.ops.aten.zeros.default".into(),
1276 args: vec![WireValue::IntList(vec![2, 3])],
1277 kwargs: HashMap::new(),
1278 stream: 1.into(),
1279 remote_process_groups: vec![],
1280 }),
1281 WorkerMessage::CallFunction(CallFunctionParams {
1282 seq: 4.into(),
1283 results: vec![Some(Ref { id: 4 })],
1284 mutates: vec![],
1285 function: "torch.ops.aten.allclose.default".into(),
1286 args: vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1287 kwargs: HashMap::new(),
1288 stream: 1.into(),
1289 remote_process_groups: vec![],
1290 }),
1291 ],
1292 )
1293 .await
1294 .unwrap();
1295
1296 let result: bool = worker_handle
1297 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1298 .await
1299 .unwrap()
1300 .unwrap()
1301 .unwrap()
1302 .try_into()
1303 .unwrap();
1304 worker_handle.drain_and_stop().unwrap();
1305 worker_handle.await;
1306 let error_responses = controller_rx.drain();
1307 assert!(
1308 error_responses.is_empty(),
1309 "Expected no error responses, got: {:#?}",
1310 error_responses
1311 );
1312 assert!(result);
1313
1314 Ok(())
1315 }
1316
1317 #[async_timed_test(timeout_secs = 60)]
1318 async fn error_sends_response() -> Result<()> {
1319 test_setup()?;
1320
1321 let proc = Proc::local();
1322 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1323
1324 let worker_handle = proc
1325 .spawn::<WorkerActor>(
1326 "worker",
1327 WorkerParams {
1328 world_size: 1,
1329 rank: 0,
1330 device_index: None,
1331 controller_actor: controller_ref,
1332 },
1333 )
1334 .await
1335 .unwrap();
1336 worker_handle
1337 .command_group(
1338 &client,
1339 vec![
1340 WorkerMessage::CreateStream {
1341 id: 1.into(),
1342 stream_creation: StreamCreationMode::UseDefaultStream,
1343 },
1344 WorkerMessage::CallFunction(CallFunctionParams {
1345 seq: 0.into(),
1346 results: vec![Some(0.into())],
1347 mutates: vec![],
1348 function: "torch.ops.aten.rand.default".into(),
1349 args: vec![],
1350 kwargs: HashMap::new(),
1351 stream: 1.into(),
1352 remote_process_groups: vec![],
1353 }),
1354 WorkerMessage::Exit { error: None },
1355 ],
1356 )
1357 .await
1358 .unwrap();
1359
1360 worker_handle.drain_and_stop().unwrap();
1361 worker_handle.await;
1362 let response_message = controller_rx.recv().await.unwrap();
1363 match response_message {
1364 ControllerMessage::RemoteFunctionFailed {
1365 seq,
1366 error: WorkerError { backtrace: msg, .. },
1367 } => {
1368 assert_eq!(seq, 0.into());
1369 assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1370 }
1371 _ => panic!("unexpected response {:#?}", response_message),
1372 }
1373
1374 Ok(())
1375 }
1376
1377 #[async_timed_test(timeout_secs = 60)]
1378 async fn mutated_refs_are_updated_with_error() -> Result<()> {
1379 test_setup()?;
1380
1381 let proc = Proc::local();
1382 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1383
1384 let worker_handle = proc
1385 .spawn::<WorkerActor>(
1386 "worker",
1387 WorkerParams {
1388 world_size: 1,
1389 rank: 0,
1390 device_index: None,
1391 controller_actor: controller_ref,
1392 },
1393 )
1394 .await
1395 .unwrap();
1396 worker_handle
1397 .command_group(
1398 &client,
1399 vec![
1400 WorkerMessage::CreateStream {
1401 id: 1.into(),
1402 stream_creation: StreamCreationMode::UseDefaultStream,
1403 },
1404 WorkerMessage::SetRefUnitTestsOnly {
1405 reference: 0.into(),
1406 value: WireValue::Int(1),
1407 stream: 1.into(),
1408 },
1409 WorkerMessage::CallFunction(CallFunctionParams {
1410 seq: 0.into(),
1411 results: vec![Some(Ref { id: 2 })],
1412 mutates: vec![0.into()],
1413 function: "i.dont.exist".into(),
1414 args: vec![],
1415 kwargs: HashMap::new(),
1416 stream: 1.into(),
1417 remote_process_groups: vec![],
1418 }),
1419 ],
1420 )
1421 .await
1422 .unwrap();
1423
1424 let result = worker_handle
1425 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1426 .await?;
1427
1428 worker_handle.drain_and_stop().unwrap();
1430 worker_handle.await;
1431
1432 let mutated_ref = result
1433 .context("no such ref")?
1434 .err()
1435 .context("expected error")?;
1436 assert!(mutated_ref.contains("failed to resolve function"));
1437
1438 let responses = controller_rx.drain();
1439 assert_eq!(
1440 responses.len(),
1441 1,
1442 "Expected one response, got: {:#?}",
1443 responses
1444 );
1445 Ok(())
1446 }
1447
1448 #[async_timed_test(timeout_secs = 60)]
1449 async fn accessing_errored_dependency() -> Result<()> {
1450 test_setup()?;
1451
1452 let proc = Proc::local();
1453 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1454
1455 let worker_handle = proc
1456 .spawn::<WorkerActor>(
1457 "worker",
1458 WorkerParams {
1459 world_size: 1,
1460 rank: 0,
1461 device_index: None,
1462 controller_actor: controller_ref,
1463 },
1464 )
1465 .await
1466 .unwrap();
1467 worker_handle
1468 .command_group(
1469 &client,
1470 vec![
1471 WorkerMessage::CreateStream {
1472 id: 1.into(),
1473 stream_creation: StreamCreationMode::UseDefaultStream,
1474 },
1475 WorkerMessage::CallFunction(CallFunctionParams {
1476 seq: 0.into(),
1477 results: vec![Some(0.into())],
1478 mutates: vec![],
1479 function: "i.dont.exist".into(),
1480 args: vec![],
1481 kwargs: HashMap::new(),
1482 stream: 1.into(),
1483 remote_process_groups: vec![],
1484 }),
1485 WorkerMessage::CallFunction(CallFunctionParams {
1486 seq: 1.into(),
1487 results: vec![Some(1.into())],
1488 mutates: vec![],
1489 function: "torch.ops.aten.sub_.Scalar".into(),
1490 args: vec![WireValue::Ref(0.into())],
1491 kwargs: HashMap::new(),
1492 stream: 1.into(),
1493 remote_process_groups: vec![],
1494 }),
1495 WorkerMessage::Exit { error: None },
1496 ],
1497 )
1498 .await
1499 .unwrap();
1500
1501 worker_handle.drain_and_stop().unwrap();
1502 worker_handle.await;
1503
1504 let responses = controller_rx.drain();
1505 assert_eq!(
1506 responses.len(),
1507 1,
1508 "Expected one response, got: {:#?}",
1509 responses
1510 );
1511
1512 match &responses[0] {
1513 ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1514 assert_eq!(seq, &0.into())
1515 }
1516 _ => panic!("unexpected response {:#?}", responses[0]),
1517 };
1518 Ok(())
1519 }
1520
1521 #[async_timed_test(timeout_secs = 60)]
1522 async fn py_remote_function_calls() -> Result<()> {
1523 test_setup()?;
1524
1525 let proc = Proc::local();
1526 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1527
1528 let worker_handle = proc
1529 .spawn::<WorkerActor>(
1530 "worker",
1531 WorkerParams {
1532 world_size: 1,
1533 rank: 0,
1534 device_index: None,
1535 controller_actor: controller_ref,
1536 },
1537 )
1538 .await
1539 .unwrap();
1540 let (split_arg, sort_list, mesh_ref, dim, layout, none, scalar, device, memory_format) =
1541 Python::with_gil(|py| {
1542 let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1543 .into_any()
1544 .try_into()?;
1545 let sort_list: PickledPyObject =
1546 PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1547 let mesh_ref: PickledPyObject = Ref { id: 5 }.into_bound_py_any(py)?.try_into()?;
1548 let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1549 let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1550 let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1551 let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1552 let device: PickledPyObject = py
1553 .import("torch")?
1554 .getattr("device")?
1555 .call1(("cuda:1",))?
1556 .try_into()?;
1557 let memory_format: PickledPyObject = py
1558 .import("torch")?
1559 .getattr("contiguous_format")?
1560 .try_into()?;
1561 PyResult::Ok((
1562 split_arg,
1563 sort_list,
1564 mesh_ref,
1565 dim,
1566 layout,
1567 none,
1568 scalar,
1569 device,
1570 memory_format,
1571 ))
1572 })?;
1573
1574 worker_handle
1575 .command_group(
1576 &client,
1577 vec![
1578 WorkerMessage::CreateStream {
1579 id: 1.into(),
1580 stream_creation: StreamCreationMode::UseDefaultStream,
1581 },
1582 WorkerMessage::CallFunction(CallFunctionParams {
1583 seq: 0.into(),
1584 results: vec![Some(0.into()), Some(Ref { id: 2 })],
1585 mutates: vec![],
1586 function: "os.path.split".into(),
1587 args: vec![split_arg.into()],
1588 kwargs: HashMap::new(),
1589 stream: 1.into(),
1590 remote_process_groups: vec![],
1591 }),
1592 WorkerMessage::CallFunction(CallFunctionParams {
1593 seq: 2.into(),
1594 results: vec![Some(4.into()), None, None, None, None],
1595 mutates: vec![],
1596 function: "builtins.sorted".into(),
1597 args: vec![sort_list.into()],
1598 kwargs: HashMap::new(),
1599 stream: 1.into(),
1600 remote_process_groups: vec![],
1601 }),
1602 WorkerMessage::CreateDeviceMesh {
1603 result: 5.into(),
1604 names: vec!["x".into()],
1605 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1606 },
1607 WorkerMessage::CallFunction(CallFunctionParams {
1608 seq: 2.into(),
1609 results: vec![Some(6.into())],
1610 mutates: vec![],
1611 function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1612 args: vec![mesh_ref.into(), dim.into()],
1613 kwargs: HashMap::new(),
1614 stream: 1.into(),
1615 remote_process_groups: vec![],
1616 }),
1617 WorkerMessage::CallFunction(CallFunctionParams {
1618 seq: 4.into(),
1619 results: vec![Some(7.into())],
1620 mutates: vec![],
1621 function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1622 .into(),
1623 args: vec![scalar.into()],
1624 kwargs: HashMap::new(),
1625 stream: 1.into(),
1626 remote_process_groups: vec![],
1627 }),
1628 WorkerMessage::CallFunction(CallFunctionParams {
1629 seq: 5.into(),
1630 results: vec![Some(8.into())],
1631 mutates: vec![],
1632 function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1633 args: vec![layout.into()],
1634 kwargs: HashMap::new(),
1635 stream: 1.into(),
1636 remote_process_groups: vec![],
1637 }),
1638 WorkerMessage::CallFunction(CallFunctionParams {
1639 seq: 6.into(),
1640 results: vec![Some(9.into())],
1641 mutates: vec![],
1642 function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1643 args: vec![none.into()],
1644 kwargs: HashMap::new(),
1645 stream: 1.into(),
1646 remote_process_groups: vec![],
1647 }),
1648 WorkerMessage::CallFunction(CallFunctionParams {
1651 seq: 7.into(),
1652 results: vec![None],
1653 mutates: vec![],
1654 function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1655 args: vec![],
1656 kwargs: HashMap::new(),
1657 stream: 1.into(),
1658 remote_process_groups: vec![],
1659 }),
1660 WorkerMessage::CallFunction(CallFunctionParams {
1661 seq: 8.into(),
1662 results: vec![Some(10.into())],
1663 mutates: vec![],
1664 function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1665 args: vec![device.into()],
1666 kwargs: HashMap::new(),
1667 stream: 1.into(),
1668 remote_process_groups: vec![],
1669 }),
1670 WorkerMessage::CallFunction(CallFunctionParams {
1671 seq: 9.into(),
1672 results: vec![Some(11.into())],
1673 mutates: vec![],
1674 function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1675 .into(),
1676 args: vec![memory_format.into()],
1677 kwargs: HashMap::new(),
1678 stream: 1.into(),
1679 remote_process_groups: vec![],
1680 }),
1681 WorkerMessage::CallFunction(CallFunctionParams {
1683 seq: 10.into(),
1684 results: vec![Some(12.into())],
1685 mutates: vec![],
1686 function: "torch.ops.aten.ones.default".into(),
1687 args: vec![WireValue::IntList(vec![2, 3])],
1688 kwargs: HashMap::new(),
1689 stream: 1.into(),
1690 remote_process_groups: vec![],
1691 }),
1692 WorkerMessage::CallFunction(CallFunctionParams {
1693 seq: 11.into(),
1694 results: vec![Some(13.into())],
1695 mutates: vec![],
1696 function: "torch.ops.aten.stack.default".into(),
1697 args: vec![WireValue::RefList(vec![12.into(), 12.into()])],
1698 kwargs: HashMap::new(),
1699 stream: 1.into(),
1700 remote_process_groups: vec![],
1701 }),
1702 ],
1703 )
1704 .await
1705 .unwrap();
1706
1707 let result1: String = worker_handle
1708 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1709 .await
1710 .unwrap()
1711 .unwrap()
1712 .unwrap()
1713 .try_into()
1714 .unwrap();
1715 let result2: String = worker_handle
1716 .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1717 .await
1718 .unwrap()
1719 .unwrap()
1720 .unwrap()
1721 .try_into()
1722 .unwrap();
1723 let result3: i64 = worker_handle
1724 .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1725 .await
1726 .unwrap()
1727 .unwrap()
1728 .unwrap()
1729 .try_into()
1730 .unwrap();
1731 let result4: i64 = worker_handle
1732 .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1733 .await
1734 .unwrap()
1735 .unwrap()
1736 .unwrap()
1737 .try_into()
1738 .unwrap();
1739 assert_eq!(
1740 ScalarType::Float,
1741 worker_handle
1742 .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1743 .await
1744 .unwrap()
1745 .unwrap()
1746 .unwrap()
1747 .try_into()
1748 .unwrap()
1749 );
1750 assert_eq!(
1751 Layout::Strided,
1752 worker_handle
1753 .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1754 .await
1755 .unwrap()
1756 .unwrap()
1757 .unwrap()
1758 .try_into()
1759 .unwrap()
1760 );
1761 assert_matches!(
1762 worker_handle
1763 .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1764 .await
1765 .unwrap()
1766 .unwrap()
1767 .unwrap(),
1768 WireValue::None(()),
1769 );
1770 let device: Device = CudaDevice::new(DeviceIndex(1)).into();
1771 assert_eq!(
1772 device,
1773 worker_handle
1774 .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1775 .await
1776 .unwrap()
1777 .unwrap()
1778 .unwrap()
1779 .try_into()
1780 .unwrap()
1781 );
1782 assert_matches!(
1783 worker_handle
1784 .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1785 .await
1786 .unwrap()
1787 .unwrap()
1788 .unwrap(),
1789 WireValue::MemoryFormat(MemoryFormat::Contiguous),
1790 );
1791
1792 worker_handle.drain_and_stop().unwrap();
1793 worker_handle.await;
1794 let error_responses = controller_rx.drain();
1795 assert!(
1796 error_responses.is_empty(),
1797 "Expected no error responses, got: {:#?}",
1798 error_responses
1799 );
1800
1801 assert_eq!(result1, "/fbs/fbc/foo");
1802 assert_eq!(result2, "bar");
1803 assert_eq!(result3, 1);
1804 assert_eq!(result4, 0);
1805
1806 Ok(())
1807 }
1808
1809 #[async_timed_test(timeout_secs = 60)]
1810 async fn delete_refs() -> Result<()> {
1811 test_setup()?;
1812
1813 let proc = Proc::local();
1814 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1815
1816 let worker_handle = proc
1817 .spawn::<WorkerActor>(
1818 "worker",
1819 WorkerParams {
1820 world_size: 1,
1821 rank: 0,
1822 device_index: None,
1823 controller_actor: controller_ref,
1824 },
1825 )
1826 .await
1827 .unwrap();
1828 worker_handle
1829 .command_group(
1830 &client,
1831 vec![
1832 WorkerMessage::CreateStream {
1833 id: 0.into(),
1834 stream_creation: StreamCreationMode::CreateNewStream,
1835 },
1836 WorkerMessage::CreateStream {
1837 id: 1.into(),
1838 stream_creation: StreamCreationMode::CreateNewStream,
1839 },
1840 WorkerMessage::SetRefUnitTestsOnly {
1841 reference: Ref { id: 2 },
1842 value: WireValue::Bool(false),
1843 stream: 0.into(),
1844 },
1845 WorkerMessage::SetRefUnitTestsOnly {
1846 reference: Ref { id: 3 },
1847 value: WireValue::Bool(true),
1848 stream: 0.into(),
1849 },
1850 WorkerMessage::SetRefUnitTestsOnly {
1851 reference: Ref { id: 4 },
1852 value: WireValue::Int(0),
1853 stream: 1.into(),
1854 },
1855 WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1856 ],
1857 )
1858 .await
1859 .unwrap();
1860
1861 let result: bool = worker_handle
1862 .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1863 .await
1864 .unwrap()
1865 .unwrap()
1866 .unwrap()
1867 .try_into()
1868 .unwrap();
1869 let fail_result = worker_handle
1870 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1871 .await
1872 .unwrap();
1873
1874 worker_handle.drain_and_stop().unwrap();
1875 worker_handle.await;
1876
1877 assert!(result, "should be able to get a non-deleted ref");
1878 assert!(fail_result.is_none(), "should fail to get a deleted ref");
1879
1880 Ok(())
1881 }
1882
1883 #[async_timed_test(timeout_secs = 60)]
1884 async fn request_status() -> Result<()> {
1885 test_setup()?;
1886
1887 let proc = Proc::local();
1888 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1889
1890 let worker_handle = proc
1891 .spawn::<WorkerActor>(
1892 "worker",
1893 WorkerParams {
1894 world_size: 1,
1895 rank: 0,
1896 device_index: None,
1897 controller_actor: controller_ref,
1898 },
1899 )
1900 .await
1901 .unwrap();
1902 worker_handle
1903 .command_group(
1904 &client,
1905 vec![
1906 WorkerMessage::CreateStream {
1907 id: 0.into(),
1908 stream_creation: StreamCreationMode::CreateNewStream,
1909 },
1910 WorkerMessage::CreateStream {
1911 id: 1.into(),
1912 stream_creation: StreamCreationMode::CreateNewStream,
1913 },
1914 ],
1915 )
1916 .await
1917 .unwrap();
1918
1919 for i in 0..100 {
1920 worker_handle
1922 .call_function(
1923 &client,
1924 CallFunctionParams {
1925 seq: i.into(),
1926 results: vec![Some(Ref { id: i + 2 })],
1927 mutates: vec![],
1928 function: "torch.ops.aten.ones.default".into(),
1929 args: vec![WireValue::IntList(vec![2, 3])],
1930 kwargs: HashMap::new(),
1931 stream: (i % 2).into(),
1932 remote_process_groups: vec![],
1933 },
1934 )
1935 .await
1936 .unwrap();
1937 }
1938
1939 worker_handle
1940 .request_status(&client, 100.into(), false)
1941 .await
1942 .unwrap();
1943
1944 worker_handle.drain_and_stop().unwrap();
1945 worker_handle.await;
1946
1947 let mut responses = controller_rx.drain();
1948 assert_eq!(
1949 responses.len(),
1950 1,
1951 "Expected one response, got: {:#?}",
1952 responses
1953 );
1954
1955 let response = responses.pop().unwrap();
1956 match response {
1957 ControllerMessage::Status { seq, .. } => {
1958 assert_eq!(seq, 101.into())
1959 }
1960 _ => panic!("unexpected response {:#?}", response),
1961 };
1962
1963 Ok(())
1964 }
1965
1966 #[async_timed_test(timeout_secs = 60)]
1967 async fn backend_network_init() {
1968 let proc = Proc::local();
1969 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1970
1971 let worker_handle1 = proc
1972 .spawn::<WorkerActor>(
1973 "worker0",
1974 WorkerParams {
1975 world_size: 2,
1976 rank: 0,
1977 device_index: Some(0),
1978 controller_actor: controller_ref.clone(),
1979 },
1980 )
1981 .await
1982 .unwrap();
1983 let worker_handle2 = proc
1984 .spawn::<WorkerActor>(
1985 "worker1",
1986 WorkerParams {
1987 world_size: 2,
1988 rank: 1,
1989 device_index: Some(1),
1990 controller_actor: controller_ref,
1991 },
1992 )
1993 .await
1994 .unwrap();
1995
1996 let unique_id = UniqueId::new().unwrap();
1997 worker_handle1
1998 .backend_network_init(&client, unique_id.clone())
1999 .await
2000 .unwrap();
2001 worker_handle2
2002 .backend_network_init(&client, unique_id)
2003 .await
2004 .unwrap();
2005
2006 worker_handle1.drain_and_stop().unwrap();
2007 worker_handle1.await;
2008 worker_handle2.drain_and_stop().unwrap();
2009 worker_handle2.await;
2010 }
2011
2012 #[async_timed_test(timeout_secs = 60)]
2013 async fn send_value() -> Result<()> {
2014 test_setup()?;
2015
2016 let proc = Proc::local();
2017 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2018
2019 let worker_handle = proc
2020 .spawn::<WorkerActor>(
2021 "worker",
2022 WorkerParams {
2023 world_size: 1,
2024 rank: 0,
2025 device_index: None,
2026 controller_actor: controller_ref,
2027 },
2028 )
2029 .await
2030 .unwrap();
2031 worker_handle
2032 .command_group(
2033 &client,
2034 vec![
2035 WorkerMessage::CreateStream {
2036 id: 1.into(),
2037 stream_creation: StreamCreationMode::UseDefaultStream,
2038 },
2039 WorkerMessage::CallFunction(CallFunctionParams {
2040 seq: 0.into(),
2041 results: vec![Some(0.into())],
2042 mutates: vec![],
2043 function: "torch.ops.aten.ones.default".into(),
2044 args: vec![WireValue::IntList(vec![2, 3])],
2045 kwargs: HashMap::new(),
2046 stream: 1.into(),
2047 remote_process_groups: vec![],
2048 }),
2049 WorkerMessage::SendValue {
2050 seq: 1.into(),
2051 destination: None,
2052 mutates: vec![],
2053 function: None,
2054 args: vec![WireValue::Ref(0.into())],
2055 kwargs: HashMap::new(),
2056 stream: 1.into(),
2057 },
2058 WorkerMessage::SendValue {
2059 seq: 2.into(),
2060 destination: None,
2061 mutates: vec![],
2062 function: Some("torch.ops.aten.var_mean.default".into()),
2063 args: vec![WireValue::Ref(0.into())],
2064 kwargs: HashMap::new(),
2065 stream: 1.into(),
2066 },
2067 WorkerMessage::Exit { error: None },
2068 ],
2069 )
2070 .await
2071 .unwrap();
2072
2073 worker_handle.drain_and_stop()?;
2074 assert_matches!(worker_handle.await, ActorStatus::Stopped);
2075
2076 let mut responses = controller_rx.drain();
2077 assert_eq!(
2078 responses.len(),
2079 3,
2080 "Expected one response, got: {:#?}",
2081 responses
2082 );
2083
2084 match responses.pop().unwrap() {
2085 ControllerMessage::FetchResult { seq, value } => {
2086 assert_eq!(seq, 2.into());
2087 let value = value.unwrap().deserialized::<PyTree<RValue>>().unwrap();
2088 assert_eq!(value.leaves().len(), 2);
2089 }
2090 resp => panic!("unexpected response {:#?}", resp),
2091 };
2092 match responses.pop().unwrap() {
2093 ControllerMessage::FetchResult { seq, .. } => {
2094 assert_eq!(seq, 1.into())
2095 }
2096 resp => panic!("unexpected response {:#?}", resp),
2097 };
2098 Ok(())
2099 }
2100
2101 #[async_timed_test(timeout_secs = 60)]
2102 async fn send_value_err_result() -> Result<()> {
2103 test_setup()?;
2104
2105 let proc = Proc::local();
2106 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2107
2108 let worker_handle = proc
2109 .spawn::<WorkerActor>(
2110 "worker",
2111 WorkerParams {
2112 world_size: 1,
2113 rank: 0,
2114 device_index: None,
2115 controller_actor: controller_ref,
2116 },
2117 )
2118 .await
2119 .unwrap();
2120
2121 let ref_arg: PickledPyObject =
2122 Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2123
2124 worker_handle
2125 .command_group(
2126 &client,
2127 vec![
2128 WorkerMessage::CreateStream {
2129 id: 1.into(),
2130 stream_creation: StreamCreationMode::UseDefaultStream,
2131 },
2132 WorkerMessage::SetRefUnitTestsOnly {
2133 reference: Ref { id: 2 },
2134 value: WireValue::Bool(false),
2135 stream: 1.into(),
2136 },
2137 WorkerMessage::SendValue {
2138 seq: 1.into(),
2139 destination: None,
2140 mutates: vec![Ref { id: 2 }],
2141 function: Some("non.existent.function".into()),
2142 args: vec![],
2143 kwargs: HashMap::new(),
2144 stream: 1.into(),
2145 },
2146 WorkerMessage::SendValue {
2147 seq: 2.into(),
2148 destination: None,
2149 mutates: vec![],
2150 function: None,
2151 args: vec![ref_arg.into()],
2152 kwargs: HashMap::new(),
2153 stream: 1.into(),
2154 },
2155 WorkerMessage::Exit { error: None },
2156 ],
2157 )
2158 .await
2159 .unwrap();
2160
2161 worker_handle.drain_and_stop()?;
2162 assert_matches!(worker_handle.await, ActorStatus::Stopped);
2163
2164 let mut responses = controller_rx.drain();
2165 assert_eq!(
2166 responses.len(),
2167 3,
2168 "Expected one response, got: {:#?}",
2169 responses
2170 );
2171 match responses.pop() {
2172 Some(ControllerMessage::FetchResult { seq, value }) => {
2173 assert_eq!(seq, 2.into());
2174 assert!(value.is_err());
2175 assert!(
2176 value
2177 .unwrap_err()
2178 .backtrace
2179 .contains("failed to resolve function")
2180 );
2181 }
2182 _ => panic!("unexpected response {:#?}", responses),
2183 }
2184 match responses.pop() {
2185 Some(ControllerMessage::FetchResult { seq, value }) => {
2186 assert_eq!(seq, 1.into());
2187 assert!(value.is_err());
2188 assert!(
2189 value
2190 .unwrap_err()
2191 .backtrace
2192 .contains("failed to resolve function")
2193 );
2194 }
2195 _ => panic!("unexpected response {:#?}", responses),
2196 }
2197 Ok(())
2198 }
2199
2200 #[async_timed_test(timeout_secs = 60)]
2201 async fn pipe_send_recv() -> Result<()> {
2202 test_setup()?;
2203
2204 let proc = Proc::local();
2205 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2206
2207 let handle = proc
2208 .spawn::<WorkerActor>(
2209 "worker",
2210 WorkerParams {
2211 world_size: 1,
2212 rank: 0,
2213 device_index: None,
2214 controller_actor: controller_ref,
2215 },
2216 )
2217 .await
2218 .unwrap();
2219 let (resolve_value_arg, torch_eq_arg1, torch_eq_arg2): (
2220 PickledPyObject,
2221 PickledPyObject,
2222 PickledPyObject,
2223 ) = Python::with_gil(|py| {
2224 PyResult::Ok((
2225 PyList::new(py, [2, 3])?.into_any().try_into()?,
2226 Ref { id: 2 }.into_bound_py_any(py)?.try_into()?,
2227 Ref { id: 4 }.into_bound_py_any(py)?.try_into()?,
2228 ))
2229 })?;
2230
2231 handle
2232 .command_group(
2233 &client,
2234 vec![
2235 WorkerMessage::CreateStream {
2236 id: 0.into(),
2237 stream_creation: StreamCreationMode::UseDefaultStream,
2238 },
2239 WorkerMessage::CreateDeviceMesh {
2240 result: 1.into(),
2241 names: vec!["x".into()],
2242 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2243 },
2244 WorkerMessage::CallFunction(CallFunctionParams {
2246 seq: 0.into(),
2247 results: vec![Some(2.into())],
2248 mutates: vec![],
2249 function: "torch.ops.aten.ones.default".into(),
2250 args: vec![WireValue::IntList(vec![2, 3])],
2251 kwargs: HashMap::new(),
2252 stream: 0.into(),
2253 remote_process_groups: vec![],
2254 }),
2255 WorkerMessage::CreatePipe {
2256 result: 3.into(),
2257 key: "unused".into(),
2258 function: "monarch.monarch_tensor_worker.test_utils.handler".into(),
2259 max_messages: 1,
2260 mesh: 1.into(),
2261 args: vec![],
2262 kwargs: HashMap::new(),
2263 },
2264 WorkerMessage::SendValue {
2265 seq: 1.into(),
2266 destination: Some(3.into()),
2267 mutates: vec![],
2268 function: Some(
2269 "monarch.monarch_tensor_worker.test_utils.resolve_value".into(),
2270 ),
2271 args: vec![resolve_value_arg.into()],
2272 kwargs: HashMap::new(),
2273 stream: 0.into(),
2274 },
2275 WorkerMessage::PipeRecv {
2276 seq: 2.into(),
2277 results: vec![Some(4.into())],
2278 pipe: 3.into(),
2279 stream: 0.into(),
2280 },
2281 WorkerMessage::CallFunction(CallFunctionParams {
2282 seq: 0.into(),
2283 results: vec![Some(5.into())],
2284 mutates: vec![],
2285 function: "torch.equal".into(),
2286 args: vec![torch_eq_arg1.into(), torch_eq_arg2.into()],
2287 kwargs: HashMap::new(),
2288 stream: 0.into(),
2289 remote_process_groups: vec![],
2290 }),
2291 ],
2292 )
2293 .await
2294 .unwrap();
2295
2296 let matches: bool = handle
2297 .get_ref_unit_tests_only(&client, 5.into(), 0.into())
2298 .await
2299 .unwrap()
2300 .unwrap()
2301 .unwrap()
2302 .try_into()
2303 .unwrap();
2304 assert!(matches);
2305
2306 handle.drain_and_stop()?;
2307 assert_matches!(handle.await, ActorStatus::Stopped);
2308
2309 let responses = controller_rx.drain();
2310 assert_eq!(
2311 responses.len(),
2312 0,
2313 "Expected one response, got: {:#?}",
2314 responses
2315 );
2316
2317 Ok(())
2318 }
2319
2320 fn get_random_channel_addr() -> ChannelAddr {
2321 let random_string = rand::thread_rng()
2322 .sample_iter(&Alphanumeric)
2323 .take(24)
2324 .map(char::from)
2325 .collect::<String>();
2326 format!("unix!@{random_string}").parse().unwrap()
2327 }
2328
2329 async fn ensure_world_ready(client: Mailbox, world: WorldId) -> Result<()> {
2330 tracing::info!("checking whether world {world} is ready");
2331 let retry_strategy = FixedInterval::from_millis(1000).take(100);
2332 Retry::spawn(retry_strategy, async || {
2333 let snapshot = SYSTEM_ACTOR_REF
2334 .snapshot(&client, SystemSnapshotFilter::default())
2335 .await?;
2336 let world_snapshot = snapshot.worlds.get(&world).ok_or(anyhow!("no world"))?;
2337 tracing::info!("world status: {:?}", world_snapshot.status);
2338 match world_snapshot.status {
2339 WorldStatus::Live => Ok(()),
2340 _ => Err(anyhow!("world is not live")),
2341 }
2342 })
2343 .await?;
2344 Ok(())
2345 }
2346
2347 #[async_timed_test(timeout_secs = 60)]
2348 async fn remote_process_group() -> Result<()> {
2349 test_setup()?;
2350
2351 let timeout: Duration = Duration::from_secs(10);
2353 let system_addr = get_random_channel_addr();
2354 let _system_handle = System::serve(system_addr.clone(), timeout, timeout).await?;
2355
2356 let client = System::new(system_addr.clone()).attach().await?;
2358 let (handle, mut controller_rx) = client.open_port::<ControllerMessage>();
2359 handle.bind_to(ControllerMessage::port());
2360 let controller_ref: ActorRef<ControllerActor> = ActorRef::attest(client.actor_id().clone());
2361
2362 let world_size = 2;
2364 SYSTEM_ACTOR_REF
2365 .upsert_world(
2366 &client,
2367 id!(world),
2368 Shape::Definite(vec![world_size]),
2369 4,
2370 Environment::Local,
2371 HashMap::new(),
2372 )
2373 .await?;
2374
2375 let mut worker_process_handles = vec![];
2377 let mut worker_procs: Vec<ActorRef<ProcActor>> = vec![];
2378 for rank in 0..world_size {
2379 let world_id = "world".to_string();
2380 let proc_id = format!("{world_id}[{rank}]");
2381 worker_procs.push(ActorRef::attest(format!("world[{rank}].proc[0]").parse()?));
2382
2383 let mut handle =
2384 Command::new(std::env::var("MONARCH_TENSOR_WORKER_EXE").map_err(|e| {
2385 anyhow::anyhow!("could not get var MONARCH_TENSOR_WORKER_EXE: {}", e)
2386 })?)
2387 .arg("worker")
2388 .arg(format!("--bootstrap-addr={system_addr}"))
2389 .arg(format!("--world-id={world_id}"))
2390 .arg(format!("--proc-id={proc_id}"))
2391 .stdout(Stdio::piped())
2392 .stdin(Stdio::piped())
2393 .kill_on_drop(true)
2394 .spawn()?;
2395
2396 let out = handle.stdout.take().unwrap();
2397 tokio::spawn(async move {
2398 let mut reader = BufReader::new(out);
2399 tokio::io::copy(&mut reader, &mut tokio::io::stderr())
2400 .await
2401 .unwrap();
2402 });
2403 worker_process_handles.push(handle);
2404 }
2405
2406 ensure_world_ready(client.clone(), id!(world)).await?;
2408
2409 let (spawned_port, mut spawned_receiver) = open_port(&client);
2411 for (rank, worker_proc) in worker_procs.iter().enumerate() {
2412 let params = WorkerParams {
2413 world_size,
2414 rank,
2415 device_index: Some(rank.try_into().unwrap()),
2416 controller_actor: controller_ref.clone(),
2417 };
2418 worker_proc
2419 .spawn(
2420 &client,
2421 "monarch_tensor_worker::WorkerActor".to_owned(),
2422 "worker".to_owned(),
2423 bincode::serialize(¶ms)?,
2424 spawned_port.bind(),
2425 )
2426 .await?;
2427 }
2428 let mut spawned = vec![];
2429 while spawned.len() < world_size {
2430 spawned.push(spawned_receiver.recv().await?);
2431 }
2432 tracing::info!("spawned {} worker actors", world_size);
2433 let workers: Vec<ActorRef<WorkerActor>> = (0..world_size)
2434 .map(|rank| format!("world[{rank}].worker[0]"))
2435 .map(|name| ActorRef::attest(name.parse().unwrap()))
2436 .collect();
2437
2438 let remote_proc_grp_ref: PickledPyObject =
2439 Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2440
2441 let unique_id = UniqueId::new()?;
2442 let messages = vec![
2443 WorkerMessage::CreateStream {
2444 id: 0.into(),
2445 stream_creation: StreamCreationMode::UseDefaultStream,
2446 },
2447 WorkerMessage::BackendNetworkInit(unique_id.clone()),
2448 WorkerMessage::CreateDeviceMesh {
2449 result: 1.into(),
2450 names: vec!["x".into()],
2451 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2452 },
2453 WorkerMessage::CreateRemoteProcessGroup {
2454 result: 2.into(),
2455 device_mesh: 1.into(),
2456 dims: vec!["x".into()],
2457 },
2458 WorkerMessage::SplitCommForProcessGroup {
2459 remote_process_group: 2.into(),
2460 stream: 0.into(),
2461 config: None,
2462 },
2463 WorkerMessage::CallFunction(CallFunctionParams {
2464 seq: 0.into(),
2465 results: vec![Some(3.into())],
2466 mutates: vec![],
2467 function: "monarch.monarch_tensor_worker.test_utils.test_remote_process_group"
2468 .into(),
2469 args: vec![remote_proc_grp_ref.into()],
2470 kwargs: HashMap::new(),
2471 stream: 0.into(),
2472 remote_process_groups: vec![2.into()],
2473 }),
2474 ];
2475
2476 workers[0].command_group(&client, messages.clone()).await?;
2477 workers[1].command_group(&client, messages).await?;
2478
2479 let _ = workers[0]
2480 .get_ref_unit_tests_only(&client, 3.into(), 0.into())
2481 .await?
2482 .unwrap()
2483 .unwrap();
2484
2485 let error_responses = controller_rx.drain();
2486 assert!(
2487 error_responses.is_empty(),
2488 "Expected no error responses, got: {:#?}",
2489 error_responses
2490 );
2491
2492 Ok(())
2493 }
2494}