1#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12#![allow(unsafe_op_in_unsafe_fn)]
15
16mod borrow;
31mod comm;
32pub mod device_mesh;
33pub mod stream;
34pub mod test_util;
35
36use std::collections::HashMap;
37use std::collections::HashSet;
38use std::collections::hash_map::Entry;
39use std::sync::Arc;
40
41use anyhow::Context;
42use anyhow::Result;
43use anyhow::anyhow;
44use anyhow::bail;
45use anyhow::ensure;
46use async_trait::async_trait;
47use borrow::Borrow;
48use comm::CommMessageClient;
49use comm::CommParams;
50use comm::NcclCommActor;
51use derive_more::TryInto;
52use device_mesh::DeviceMesh;
53use futures::future::try_join_all;
54use hyperactor::Actor;
55use hyperactor::ActorRef;
56use hyperactor::Bind;
57use hyperactor::Handler;
58use hyperactor::Named;
59use hyperactor::Unbind;
60use hyperactor::actor::ActorHandle;
61use hyperactor::context;
62use hyperactor::reference::ActorId;
63use hyperactor_mesh::comm::multicast::CastInfo;
64use itertools::Itertools;
65use monarch_hyperactor::shape::PyPoint;
66use monarch_messages::controller::ControllerActor;
67use monarch_messages::controller::ControllerMessageClient;
68use monarch_messages::controller::Seq;
69use monarch_messages::wire_value::WireValue;
70use monarch_messages::worker::ActorCallParams;
71use monarch_messages::worker::ActorMethodParams;
72use monarch_messages::worker::CallFunctionParams;
73use monarch_messages::worker::Factory;
74use monarch_messages::worker::Reduction;
75use monarch_messages::worker::Ref;
76use monarch_messages::worker::ResolvableFunction;
77use monarch_messages::worker::StreamCreationMode;
78use monarch_messages::worker::StreamRef;
79use monarch_messages::worker::WorkerMessage;
80use monarch_messages::worker::WorkerMessageHandler;
81use monarch_messages::worker::WorkerParams;
82use monarch_types::PyTree;
83use ndslice::Slice;
84use pyo3::Python;
85use pyo3::types::PyAnyMethods;
86use serde::Deserialize;
87use serde::Serialize;
88use sorted_vec::SortedVec;
89use stream::StreamActor;
90use stream::StreamMessageClient;
91use stream::StreamParams;
92use torch_sys::CudaDevice;
93use torch_sys::DeviceIndex;
94use torch_sys::Layout;
95use torch_sys::RValue;
96use torch_sys::ScalarType;
97use torch_sys::TensorCell;
98use torch_sys::factory_zeros;
99use torch_sys_cuda::nccl::NcclConfig;
100use torch_sys_cuda::nccl::ReduceOp;
101use torch_sys_cuda::nccl::UniqueId;
102
103#[derive(Debug)]
104struct RemoteProcessGroupState {
105 device_mesh_ref: Ref,
106 dims: SortedVec<String>,
107 comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
108}
109
110impl RemoteProcessGroupState {
111 fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
112 Self {
113 device_mesh_ref,
114 dims,
115 comms: HashMap::new(),
116 }
117 }
118}
119
120#[derive(Debug)]
121enum Recording {
122 PartialRecording {
125 last_index: usize,
127 commands: Vec<WorkerMessage>,
129 },
130
131 CompleteRecording {
133 streams: HashSet<StreamRef>,
135 },
136}
137
138#[derive(Debug)]
146#[hyperactor::export(
147 spawn = true,
148 handlers = [
149 WorkerMessage {cast = true},
150 AssignRankMessage {cast = true},
151 ],
152)]
153pub struct WorkerActor {
154 device: Option<CudaDevice>,
155 streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
156 device_meshes: HashMap<
159 Ref,
160 (
161 DeviceMesh,
162 HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
164 ),
165 >,
166 world_size: usize,
167 rank: usize,
168 borrows: HashMap<u64, Borrow>,
169 comm: Option<ActorHandle<NcclCommActor>>,
170 controller_actor: ActorRef<ControllerActor>,
171 remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
175 send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
177 recordings: HashMap<Ref, Recording>,
178 defining_recording: Option<Ref>,
179 respond_with_python_message: bool,
180}
181
182impl WorkerActor {
183 fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
184 self.streams
185 .get(&stream)
186 .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
187 }
188
189 async fn maybe_add_stream_to_recording(
190 &mut self,
191 cx: &impl context::Actor,
192 stream: StreamRef,
193 ) -> Result<()> {
194 if let Some(defining_recording) = self.defining_recording {
197 let recording = self.recordings.get_mut(&defining_recording).unwrap();
198 let fut = match recording {
199 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
200 Recording::CompleteRecording { streams } => {
201 streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
202 Ok(self
203 .try_get_stream(stream)?
204 .define_recording(cx, defining_recording))
205 })
206 }
207 }
208 .transpose()?;
209 match fut {
210 Some(fut) => fut.await,
211 None => Ok(()),
212 }
213 } else {
214 Ok(())
215 }
216 }
217}
218
219#[async_trait]
220impl Actor for WorkerActor {
221 type Params = WorkerParams;
222
223 async fn new(
224 WorkerParams {
225 world_size,
226 rank,
227 device_index,
228 controller_actor,
229 }: Self::Params,
230 ) -> Result<Self> {
231 Ok(Self {
232 device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
233 streams: HashMap::new(),
234 device_meshes: HashMap::new(),
235 world_size,
236 rank,
237 borrows: HashMap::new(),
238 comm: None,
239 controller_actor,
240 remote_process_groups: HashMap::new(),
241 send_recv_comms: HashMap::new(),
242 recordings: HashMap::new(),
243 defining_recording: None,
244 respond_with_python_message: false,
245 })
246 }
247
248 }
250
251#[async_trait]
252impl Handler<AssignRankMessage> for WorkerActor {
253 async fn handle(
254 &mut self,
255 cx: &hyperactor::Context<Self>,
256 _: AssignRankMessage,
257 ) -> anyhow::Result<()> {
258 let point = cx.cast_point();
259 self.rank = point.rank();
260 self.respond_with_python_message = true;
261 Python::with_gil(|py| {
262 let mesh_controller = py.import("monarch.mesh_controller").unwrap();
263 let p: PyPoint = point.into();
264 mesh_controller
265 .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
266 .unwrap();
267 });
268 Ok(())
269 }
270}
271
272#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
275pub enum AssignRankMessage {
276 AssignRank(),
277}
278
279#[async_trait]
280impl Handler<WorkerMessage> for WorkerActor {
281 async fn handle(
282 &mut self,
283 cx: &hyperactor::Context<Self>,
284 message: WorkerMessage,
285 ) -> anyhow::Result<()> {
286 <Self as WorkerMessageHandler>::handle(self, cx, message).await
287 }
288}
289
290#[async_trait]
291impl WorkerMessageHandler for WorkerActor {
292 async fn backend_network_init(
293 &mut self,
294 cx: &hyperactor::Context<Self>,
295 unique_id: UniqueId,
296 ) -> Result<()> {
297 let device = self
298 .device
299 .expect("tried to init backend network on a non-CUDA worker");
300 let comm = NcclCommActor::spawn(
301 cx,
302 CommParams::New {
303 device,
304 unique_id,
305 world_size: self.world_size.try_into().unwrap(),
306 rank: self.rank.try_into().unwrap(),
307 },
308 )
309 .await?;
310
311 let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
312 let cell = TensorCell::new(tensor);
313
314 comm.all_reduce(
315 cx,
316 cell,
317 ReduceOp::Sum,
318 torch_sys_cuda::cuda::Stream::get_current_stream(),
319 )
320 .await?;
321
322 let sorted_streams = self
331 .streams
332 .iter()
333 .sorted_by_key(|(k, _)| *k)
334 .map(|(_, v)| v.as_ref());
335
336 let mut splits = Vec::new();
337 for _ in 0..sorted_streams.len() {
338 splits.push(comm.split_all(cx, None).await?);
341 }
342 let _: Vec<()> = try_join_all(
343 sorted_streams
344 .into_iter()
345 .zip(splits.into_iter())
346 .map(|(stream, split)| stream.init_comm(cx, split)),
347 )
348 .await?;
349
350 self.comm = Some(comm);
351
352 Ok(())
353 }
354
355 async fn backend_network_point_to_point_init(
356 &mut self,
357 cx: &hyperactor::Context<Self>,
358 from_stream: StreamRef,
359 to_stream: StreamRef,
360 ) -> Result<()> {
361 if !self.streams.contains_key(&from_stream) {
362 bail!("invalid from_stream id: {:#?}", from_stream);
363 }
364 if !self.streams.contains_key(&to_stream) {
365 bail!("invalid to_stream id: {:#?}", to_stream);
366 }
367 let global_comm = self
368 .comm
369 .as_ref()
370 .context("tried to call Reduce before BackendNetworkInit")?;
371 let comm = global_comm.split_all(cx, None).await?;
372 self.send_recv_comms
373 .insert((from_stream, to_stream), Arc::new(comm));
374 Ok(())
375 }
376
377 async fn call_function(
378 &mut self,
379 cx: &hyperactor::Context<Self>,
380 params: CallFunctionParams,
381 ) -> Result<()> {
382 let stream = self.try_get_stream(params.stream)?.clone();
383 self.maybe_add_stream_to_recording(cx, params.stream)
384 .await?;
385
386 let device_meshes = if params.function.as_torch_op().is_some() {
387 HashMap::new()
388 } else {
389 self.device_meshes
390 .iter()
391 .map(|(k, v)| (k.clone(), v.0.clone()))
392 .collect()
393 };
394
395 let mut remote_process_groups = HashMap::new();
396 for remote_process_group_ref in ¶ms.remote_process_groups {
397 if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
398 let dims_vec = state.dims.iter().cloned().collect();
399 let (device_mesh, _) = self
400 .device_meshes
401 .get(&state.device_mesh_ref)
402 .ok_or_else(|| {
403 anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
404 })?
405 .clone();
406 let comm = state.comms
407 .get(¶ms.stream)
408 .ok_or_else(|| {
409 anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
410 })?
411 .clone();
412 remote_process_groups.insert(
413 remote_process_group_ref.clone(),
414 (device_mesh, dims_vec, comm),
415 );
416 }
417 }
418
419 stream
420 .call_function(cx, params, device_meshes, remote_process_groups)
421 .await?;
422
423 Ok(())
424 }
425
426 async fn command_group(
427 &mut self,
428 cx: &hyperactor::Context<Self>,
429 params: Vec<WorkerMessage>,
430 ) -> Result<()> {
431 for msg in params {
432 WorkerMessageHandler::handle(self, cx, msg).await?;
433 }
434 Ok(())
435 }
436
437 async fn create_stream(
438 &mut self,
439 cx: &hyperactor::Context<Self>,
440 result: StreamRef,
441 creation_mode: StreamCreationMode,
442 ) -> Result<()> {
443 let handle: ActorHandle<StreamActor> = StreamActor::spawn(
444 cx,
445 StreamParams {
446 world_size: self.world_size,
447 rank: self.rank,
448 creation_mode,
449 id: result,
450 device: self.device,
451 controller_actor: self.controller_actor.clone(),
452 respond_with_python_message: self.respond_with_python_message,
453 },
454 )
455 .await?;
456 self.streams.insert(result, Arc::new(handle));
457 Ok(())
458 }
459
460 async fn create_device_mesh(
461 &mut self,
462 _cx: &hyperactor::Context<Self>,
463 result: Ref,
464 names: Vec<String>,
465 ranks: Slice,
466 ) -> Result<()> {
467 self.device_meshes.insert(
468 result,
469 (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
470 );
471 Ok(())
472 }
473
474 async fn create_remote_process_group(
475 &mut self,
476 _cx: &hyperactor::Context<Self>,
477 result: Ref,
478 device_mesh: Ref,
479 dims: Vec<String>,
480 ) -> Result<()> {
481 self.device_meshes
482 .get(&device_mesh)
483 .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
484 match self.remote_process_groups.entry(result) {
485 Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
486 device_mesh,
487 SortedVec::from_unsorted(dims),
488 )),
489 Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
490 };
491 Ok(())
492 }
493
494 async fn borrow_create(
495 &mut self,
496 cx: &hyperactor::Context<Self>,
497 result: Ref,
498 borrow_id: u64,
499 tensor_ref: Ref,
500 from_stream: StreamRef,
501 to_stream: StreamRef,
502 ) -> Result<()> {
503 self.maybe_add_stream_to_recording(cx, from_stream).await?;
504 self.maybe_add_stream_to_recording(cx, to_stream).await?;
505 let from_stream = self.try_get_stream(from_stream)?.clone();
506 let to_stream = self.try_get_stream(to_stream)?.clone();
507
508 let borrow =
509 Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
510 self.borrows.insert(borrow_id, borrow);
511 Ok(())
512 }
513
514 async fn borrow_first_use(
515 &mut self,
516 cx: &hyperactor::Context<Self>,
517 borrow: u64,
518 ) -> Result<()> {
519 let borrow = self
520 .borrows
521 .get_mut(&borrow)
522 .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
523
524 borrow.first_use(cx).await?;
525 Ok(())
526 }
527
528 async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
529 let borrow = self
530 .borrows
531 .get_mut(&borrow)
532 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
533
534 borrow.last_use(cx).await?;
535 Ok(())
536 }
537
538 async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
539 let borrow = self
540 .borrows
541 .get_mut(&borrow_id)
542 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
543
544 borrow.drop(cx).await?;
545 self.borrows.remove(&borrow_id);
546 Ok(())
547 }
548
549 async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
550 let _: Vec<()> = try_join_all(
557 self.streams
558 .values()
559 .map(|s| s.delete_refs(cx, refs.clone())),
560 )
561 .await?;
562 Ok(())
563 }
564
565 async fn request_status(
566 &mut self,
567 cx: &hyperactor::Context<Self>,
568 seq: Seq,
569 controller: bool,
570 ) -> Result<()> {
571 let _: Vec<()> = try_join_all(
576 self.streams
577 .values()
578 .map(|stream| stream.request_status(cx)),
579 )
580 .await?;
581
582 ControllerMessageClient::status(
583 &self.controller_actor,
584 cx,
585 seq.next(),
586 cx.self_id().clone(),
587 controller,
588 )
589 .await?;
590 Ok(())
591 }
592
593 async fn reduce(
594 &mut self,
595 cx: &hyperactor::Context<Self>,
596 result: Ref,
597 local_tensor: Ref,
598 factory: Factory,
599 source_mesh: Ref,
600 stream_ref: StreamRef,
601 dims: Vec<String>,
602 reduction: Reduction,
603 scatter: bool,
604 in_place: bool,
605 out: Option<Ref>,
606 ) -> Result<()> {
607 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
608
609 let dims = SortedVec::from_unsorted(dims);
611 let stream = self.try_get_stream(stream_ref)?.clone();
612
613 let (_, comm_map) = self
614 .device_meshes
615 .get_mut(&source_mesh)
616 .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
617
618 let (size, comm) = comm_map
619 .get(&(stream_ref, dims.clone()))
620 .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
621 .clone();
622
623 stream
624 .reduce(
625 cx,
626 comm,
627 size.try_into()?,
628 result,
629 local_tensor,
630 factory,
631 reduction,
632 scatter,
633 in_place,
634 out,
635 )
636 .await?;
637
638 Ok(())
639 }
640
641 async fn create_pipe(
642 &mut self,
643 _cx: &hyperactor::Context<Self>,
644 _result: Ref,
645 _key: String,
648 _function: ResolvableFunction,
649 _max_messages: i64,
650 _device_mesh: Ref,
651 _args: Vec<WireValue>,
652 _kwargs: HashMap<String, WireValue>,
653 ) -> Result<()> {
654 panic!("create_pipe is no longer implemented")
655 }
656
657 async fn send_tensor(
658 &mut self,
659 cx: &hyperactor::Context<Self>,
660 result: Ref,
661 from_ranks: Slice,
662 to_ranks: Slice,
663 tensor: Ref,
664 factory: Factory,
665 from_stream: StreamRef,
666 to_stream: StreamRef,
667 ) -> Result<()> {
668 let comm = self
669 .send_recv_comms
670 .get(&(from_stream, to_stream))
671 .ok_or_else(|| {
672 anyhow::anyhow!(
673 "could not find stream to stream comm for: {:#?}",
674 (from_stream, to_stream)
675 )
676 })?
677 .clone();
678
679 let to_rank = from_ranks
680 .index(self.rank)
681 .map(|index| to_ranks.get(index).ok())
682 .ok()
683 .flatten();
684 let from_rank = to_ranks
685 .index(self.rank)
686 .map(|index| from_ranks.get(index).ok())
687 .ok()
688 .flatten();
689
690 let (stream, stream_ref) = if to_rank.is_none() {
691 (self.try_get_stream(to_stream)?.clone(), to_stream)
692 } else if from_rank.is_none() || from_stream == to_stream {
693 (self.try_get_stream(from_stream)?.clone(), from_stream)
694 } else {
695 unimplemented!(
696 "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
697 It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
698 Then the send stream would do the nccl op, and then sync with sending stream again."
699 );
700 };
701
702 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
703
704 stream
705 .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
706 .await?;
707
708 Ok(())
709 }
710
711 async fn exit(
712 &mut self,
713 cx: &hyperactor::Context<Self>,
714 error: Option<(Option<ActorId>, String)>,
715 ) -> Result<()> {
716 for (_, stream) in self.streams.drain() {
717 stream.drain_and_stop()?;
718 Arc::into_inner(stream)
719 .expect("there should be no owners of this stream handle except the worker stream table")
720 .await;
721 }
722
723 let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
724 .ok()
725 .and_then(|val| val.parse::<i32>().ok())
726 .unwrap_or(1);
727 let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
728 .ok()
729 .and_then(|val| val.parse::<i32>().ok())
730 .unwrap_or(1);
731
732 let exit_code = match error {
734 Some((Some(actor_id), reason)) => {
735 tracing::error!(
736 "stopping the worker, actor {} failed with error: {}",
737 actor_id,
738 reason
739 );
740 if *cx.self_id() == actor_id {
741 self_error_exit_code
742 } else {
743 peer_error_exit_code
744 }
745 }
746 Some((None, reason)) => {
747 tracing::error!("stopping the worker, reason: {}", reason);
748 1
749 }
750 None => 0,
751 };
752
753 if exit_code != 0 {
754 tracing::info!("stopping the worker process, exit code: {}", exit_code);
755 std::process::exit(exit_code);
756 }
757 cx.stop()?;
758 Ok(())
759 }
760
761 async fn send_value(
762 &mut self,
763 cx: &hyperactor::Context<Self>,
764 seq: Seq,
765 destination: Option<Ref>,
766 mutates: Vec<Ref>,
767 function: Option<ResolvableFunction>,
768 args: Vec<WireValue>,
769 kwargs: HashMap<String, WireValue>,
770 stream: StreamRef,
771 ) -> Result<()> {
772 let stream = self.try_get_stream(stream)?;
774
775 let device_meshes = if function.as_ref().is_none_or(|f| f.as_torch_op().is_some()) {
776 HashMap::new()
777 } else {
778 self.device_meshes
779 .iter()
780 .map(|(k, v)| (k.clone(), v.0.clone()))
781 .collect()
782 };
783
784 if destination.is_some() {
785 panic!("send_value with pipe destination is no longer implemented")
786 }
787
788 stream
790 .send_value(
791 cx,
792 seq,
793 cx.self_id().clone(),
794 mutates,
795 function,
796 args,
797 kwargs,
798 device_meshes,
799 )
800 .await
801 }
802
803 async fn send_result_of_actor_call(
804 &mut self,
805 cx: &hyperactor::Context<Self>,
806 params: ActorCallParams,
807 ) -> Result<()> {
808 let stream = self.try_get_stream(params.stream)?;
809 stream.send_result_of_actor_call(cx, params).await?;
810 Ok(())
811 }
812 async fn call_actor_method(
813 &mut self,
814 cx: &hyperactor::Context<Self>,
815 params: ActorMethodParams,
816 ) -> Result<()> {
817 let stream = self.try_get_stream(params.call.stream)?;
818 stream.call_actor_method(cx, params).await?;
819 Ok(())
820 }
821 async fn split_comm(
822 &mut self,
823 cx: &hyperactor::Context<Self>,
824 dims: Vec<String>,
825 device_mesh: Ref,
826 stream_ref: StreamRef,
827 config: Option<NcclConfig>,
828 ) -> Result<()> {
829 let global_comm = self
830 .comm
831 .as_ref()
832 .context("tried to call SplitComm before BackendNetworkInit")?;
833 match self.device_meshes.get_mut(&device_mesh) {
834 Some((device_mesh, comm_map)) => {
835 let stream = self
838 .streams
839 .get(&stream_ref)
840 .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
841
842 let dims = SortedVec::from_unsorted(dims);
843
844 anyhow::ensure!(
845 !comm_map.contains_key(&(stream_ref, dims.clone())),
846 "comm already exists for stream {stream:#?}, dims {dims:#?}"
847 );
848 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
849 let size = ranks_for_group.len();
850 let split_comm = global_comm
851 .split_from(
852 cx,
853 ranks_for_group
854 .into_iter()
855 .map(|v| v.clone().try_into())
856 .collect::<Result<Vec<_>, _>>()?,
857 config,
858 )
859 .await?
860 .context("split comm should include self rank")?;
861 comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
862 }
863 None => {
864 global_comm.split_from(cx, vec![], config).await?;
867 }
868 }
869 Ok(())
870 }
871
872 async fn split_comm_for_process_group(
873 &mut self,
874 cx: &hyperactor::Context<Self>,
875 remote_process_group_ref: Ref,
876 stream_ref: StreamRef,
877 config: Option<NcclConfig>,
878 ) -> Result<()> {
879 ensure!(
880 self.streams.contains_key(&stream_ref),
881 "invalid stream id: {:#?}",
882 stream_ref
883 );
884 let global_comm = self
885 .comm
886 .as_ref()
887 .context("tried to call SplitComm before BackendNetworkInit")?;
888 let state = self
889 .remote_process_groups
890 .get_mut(&remote_process_group_ref)
891 .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
892 match self.device_meshes.get_mut(&state.device_mesh_ref) {
893 Some((device_mesh, _)) => {
894 let entry = match state.comms.entry(stream_ref) {
897 Entry::Vacant(entry) => entry,
898 Entry::Occupied(_) => bail!(
899 "comm already exists for remote process group {:#?} on stream {:#?}",
900 remote_process_group_ref,
901 stream_ref,
902 ),
903 };
904 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
905 let split_comm = global_comm
906 .split_from(
907 cx,
908 ranks_for_group
909 .into_iter()
910 .map(|v| v.clone().try_into())
911 .collect::<Result<Vec<_>, _>>()?,
912 config,
913 )
914 .await?
915 .context("split comm should include self rank")?;
916 entry.insert(Arc::new(split_comm));
917 }
918 None => {
919 global_comm.split_from(cx, vec![], config).await?;
922 }
923 }
924 Ok(())
925 }
926
927 async fn pipe_recv(
928 &mut self,
929 _cx: &hyperactor::Context<Self>,
930 _seq: Seq,
931 _results: Vec<Option<Ref>>,
932 _pipe: Ref,
933 _stream: StreamRef,
934 ) -> Result<()> {
935 panic!("pipe_recv is no longer implemented")
936 }
937
938 async fn set_ref_unit_tests_only(
939 &mut self,
940 cx: &hyperactor::Context<Self>,
941 reference: Ref,
942 value: WireValue,
943 stream: StreamRef,
944 ) -> Result<()> {
945 let stream = self.try_get_stream(stream)?;
946
947 stream.set_ref_unit_tests_only(cx, reference, value).await
948 }
949
950 async fn get_ref_unit_tests_only(
951 &mut self,
952 cx: &hyperactor::Context<Self>,
953 ref_id: Ref,
954 stream: StreamRef,
955 ) -> Result<Option<Result<WireValue, String>>> {
956 let stream = self.try_get_stream(stream)?;
957 Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
958 }
959
960 async fn define_recording(
961 &mut self,
962 cx: &hyperactor::Context<Self>,
963 result: Ref,
964 _nresults: usize,
965 _nformals: usize,
966 commands: Vec<WorkerMessage>,
967 ntotal_messages: usize,
968 index: usize,
969 ) -> Result<()> {
970 if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
971 bail!("already defining a different recording");
972 }
973 self.defining_recording = Some(result);
974
975 match self.recordings.entry(result) {
976 Entry::Vacant(entry) => {
977 ensure!(
978 index == 0,
979 "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
980 index
981 );
982 entry.insert(Recording::PartialRecording {
983 last_index: 0,
984 commands,
985 });
986 }
987 Entry::Occupied(mut entry) => match entry.get_mut() {
988 Recording::CompleteRecording { .. } => {
989 bail!("got DefineRecording message for already complete recording")
990 }
991 Recording::PartialRecording {
992 last_index,
993 commands: existing_commands,
994 } => {
995 ensure!(
996 index == *last_index + 1,
997 "Got DefineRecording message with index = {:?}, but \
998 last seen index for recording is {:?}",
999 index,
1000 last_index
1001 );
1002 *last_index = index;
1003 existing_commands.extend(commands.into_iter());
1004 }
1005 },
1006 };
1007
1008 if index < ntotal_messages - 1 {
1009 return Ok(());
1010 }
1011 let commands = match self.recordings.remove(&result).unwrap() {
1012 Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
1013 Recording::PartialRecording { commands, .. } => {
1014 self.recordings.insert(
1015 result,
1016 Recording::CompleteRecording {
1017 streams: HashSet::new(),
1018 },
1019 );
1020 commands
1021 }
1022 };
1023
1024 for command in commands {
1025 WorkerMessageHandler::handle(self, cx, command).await?;
1026 }
1027
1028 match self.recordings.get(&result).unwrap() {
1029 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1030 Recording::CompleteRecording { streams, .. } => {
1031 for stream in streams {
1032 self.try_get_stream(*stream)?
1033 .finalize_recording(cx, result)
1034 .await?;
1035 }
1036 }
1037 }
1038
1039 self.defining_recording = None;
1040 Ok(())
1041 }
1042
1043 async fn recording_formal(
1044 &mut self,
1045 cx: &hyperactor::Context<Self>,
1046 result: Ref,
1047 argument_index: usize,
1048 stream: StreamRef,
1049 ) -> Result<()> {
1050 ensure!(self.defining_recording.is_some());
1051 self.maybe_add_stream_to_recording(cx, stream).await?;
1052 self.try_get_stream(stream)?
1053 .recording_formal(cx, result, argument_index)
1054 .await
1055 }
1056
1057 async fn recording_result(
1058 &mut self,
1059 cx: &hyperactor::Context<Self>,
1060 result: Ref,
1061 output_index: usize,
1062 stream: StreamRef,
1063 ) -> Result<()> {
1064 ensure!(self.defining_recording.is_some());
1065 self.maybe_add_stream_to_recording(cx, stream).await?;
1066 self.try_get_stream(stream)?
1067 .recording_result(cx, result, output_index)
1068 .await
1069 }
1070
1071 async fn call_recording(
1072 &mut self,
1073 cx: &hyperactor::Context<Self>,
1074 seq: Seq,
1075 recording: Ref,
1076 results: Vec<Ref>,
1077 actuals: Vec<Ref>,
1078 ) -> Result<()> {
1079 ensure!(self.defining_recording.is_none());
1080 let recording_ref = recording;
1081 let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1082 "could not find recording: {:#?}",
1083 recording
1084 ))?;
1085 match recording {
1086 Recording::PartialRecording { .. } => {
1087 bail!("cannot call recording because it is incomplete")
1088 }
1089 Recording::CompleteRecording { streams } => try_join_all(
1090 streams
1091 .iter()
1092 .map(|stream| self.try_get_stream(*stream))
1093 .collect::<Result<Vec<_>>>()?
1094 .into_iter()
1095 .map(|stream| {
1096 stream.call_recording(
1097 cx,
1098 seq,
1099 recording_ref,
1100 results.clone(),
1101 actuals.clone(),
1102 )
1103 }),
1104 )
1105 .await
1106 .map(|_| ()),
1107 }
1108 }
1109}
1110
1111#[cfg(test)]
1112mod tests {
1113 use std::assert_matches::assert_matches;
1114 use std::process::Stdio;
1115
1116 use anyhow::Result;
1117 use hyperactor::Instance;
1118 use hyperactor::WorldId;
1119 use hyperactor::actor::ActorStatus;
1120 use hyperactor::channel::ChannelAddr;
1121 use hyperactor::id;
1122 use hyperactor::mailbox::open_port;
1123 use hyperactor::proc::Proc;
1124 use hyperactor_multiprocess::System;
1125 use hyperactor_multiprocess::proc_actor::Environment;
1126 use hyperactor_multiprocess::proc_actor::ProcActor;
1127 use hyperactor_multiprocess::proc_actor::ProcMessageClient;
1128 use hyperactor_multiprocess::system_actor::SYSTEM_ACTOR_REF;
1129 use hyperactor_multiprocess::system_actor::Shape;
1130 use hyperactor_multiprocess::system_actor::SystemMessageClient;
1131 use hyperactor_multiprocess::system_actor::SystemSnapshotFilter;
1132 use hyperactor_multiprocess::system_actor::WorldStatus;
1133 use monarch_messages::controller::ControllerMessage;
1134 use monarch_messages::controller::WorkerError;
1135 use monarch_messages::worker::WorkerMessageClient;
1136 use monarch_types::PickledPyObject;
1137 use monarch_types::PyTree;
1138 use pyo3::IntoPyObjectExt;
1139 use pyo3::Python;
1140 use pyo3::prelude::*;
1141 use pyo3::types::PyList;
1142 use pyo3::types::PyString;
1143 use rand::Rng;
1144 use rand::distributions::Alphanumeric;
1145 use timed_test::async_timed_test;
1146 use tokio::io::BufReader;
1147 use tokio::process::Command;
1148 use tokio_retry::Retry;
1149 use tokio_retry::strategy::FixedInterval;
1150 use torch_sys::Device;
1151 use torch_sys::DeviceIndex;
1152 use torch_sys::MemoryFormat;
1153
1154 use super::*;
1155 use crate::test_util::test_setup;
1156
1157 #[async_timed_test(timeout_secs = 60)]
1158 async fn basic_worker() -> Result<()> {
1159 test_setup()?;
1160
1161 let proc = Proc::local();
1162 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1163
1164 let worker_handle = proc
1165 .spawn::<WorkerActor>(
1166 "worker",
1167 WorkerParams {
1168 world_size: 1,
1169 rank: 0,
1170 device_index: None,
1171 controller_actor: controller_ref,
1172 },
1173 )
1174 .await
1175 .unwrap();
1176 worker_handle
1177 .command_group(
1178 &client,
1179 vec![
1180 WorkerMessage::CreateStream {
1181 id: 1.into(),
1182 stream_creation: StreamCreationMode::UseDefaultStream,
1183 },
1184 WorkerMessage::CallFunction(CallFunctionParams {
1185 seq: 0.into(),
1186 results: vec![Some(0.into())],
1187 mutates: vec![],
1188 function: "torch.ops.aten.ones.default".into(),
1189 args: vec![WireValue::IntList(vec![2, 3])],
1190 kwargs: HashMap::new(),
1191 stream: 1.into(),
1192 remote_process_groups: vec![],
1193 }),
1194 WorkerMessage::CallFunction(CallFunctionParams {
1195 seq: 2.into(),
1196 results: vec![Some(Ref { id: 2 })],
1197 mutates: vec![0.into()],
1198 function: "torch.ops.aten.sub_.Scalar".into(),
1199 args: vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1200 kwargs: HashMap::new(),
1201 stream: 1.into(),
1202 remote_process_groups: vec![],
1203 }),
1204 WorkerMessage::CallFunction(CallFunctionParams {
1205 seq: 3.into(),
1206 results: vec![Some(Ref { id: 3 })],
1207 mutates: vec![],
1208 function: "torch.ops.aten.zeros.default".into(),
1209 args: vec![WireValue::IntList(vec![2, 3])],
1210 kwargs: HashMap::new(),
1211 stream: 1.into(),
1212 remote_process_groups: vec![],
1213 }),
1214 WorkerMessage::CallFunction(CallFunctionParams {
1215 seq: 4.into(),
1216 results: vec![Some(Ref { id: 4 })],
1217 mutates: vec![],
1218 function: "torch.ops.aten.allclose.default".into(),
1219 args: vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1220 kwargs: HashMap::new(),
1221 stream: 1.into(),
1222 remote_process_groups: vec![],
1223 }),
1224 ],
1225 )
1226 .await
1227 .unwrap();
1228
1229 let result: bool = worker_handle
1230 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1231 .await
1232 .unwrap()
1233 .unwrap()
1234 .unwrap()
1235 .try_into()
1236 .unwrap();
1237 worker_handle.drain_and_stop().unwrap();
1238 worker_handle.await;
1239 let error_responses = controller_rx.drain();
1240 assert!(
1241 error_responses.is_empty(),
1242 "Expected no error responses, got: {:#?}",
1243 error_responses
1244 );
1245 assert!(result);
1246
1247 Ok(())
1248 }
1249
1250 #[async_timed_test(timeout_secs = 60)]
1251 async fn error_sends_response() -> Result<()> {
1252 test_setup()?;
1253
1254 let proc = Proc::local();
1255 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1256
1257 let worker_handle = proc
1258 .spawn::<WorkerActor>(
1259 "worker",
1260 WorkerParams {
1261 world_size: 1,
1262 rank: 0,
1263 device_index: None,
1264 controller_actor: controller_ref,
1265 },
1266 )
1267 .await
1268 .unwrap();
1269 worker_handle
1270 .command_group(
1271 &client,
1272 vec![
1273 WorkerMessage::CreateStream {
1274 id: 1.into(),
1275 stream_creation: StreamCreationMode::UseDefaultStream,
1276 },
1277 WorkerMessage::CallFunction(CallFunctionParams {
1278 seq: 0.into(),
1279 results: vec![Some(0.into())],
1280 mutates: vec![],
1281 function: "torch.ops.aten.rand.default".into(),
1282 args: vec![],
1283 kwargs: HashMap::new(),
1284 stream: 1.into(),
1285 remote_process_groups: vec![],
1286 }),
1287 WorkerMessage::Exit { error: None },
1288 ],
1289 )
1290 .await
1291 .unwrap();
1292
1293 worker_handle.drain_and_stop().unwrap();
1294 worker_handle.await;
1295 let response_message = controller_rx.recv().await.unwrap();
1296 match response_message {
1297 ControllerMessage::RemoteFunctionFailed {
1298 seq,
1299 error: WorkerError { backtrace: msg, .. },
1300 } => {
1301 assert_eq!(seq, 0.into());
1302 assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1303 }
1304 _ => panic!("unexpected response {:#?}", response_message),
1305 }
1306
1307 Ok(())
1308 }
1309
1310 #[async_timed_test(timeout_secs = 60)]
1311 async fn mutated_refs_are_updated_with_error() -> Result<()> {
1312 test_setup()?;
1313
1314 let proc = Proc::local();
1315 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1316
1317 let worker_handle = proc
1318 .spawn::<WorkerActor>(
1319 "worker",
1320 WorkerParams {
1321 world_size: 1,
1322 rank: 0,
1323 device_index: None,
1324 controller_actor: controller_ref,
1325 },
1326 )
1327 .await
1328 .unwrap();
1329 worker_handle
1330 .command_group(
1331 &client,
1332 vec![
1333 WorkerMessage::CreateStream {
1334 id: 1.into(),
1335 stream_creation: StreamCreationMode::UseDefaultStream,
1336 },
1337 WorkerMessage::SetRefUnitTestsOnly {
1338 reference: 0.into(),
1339 value: WireValue::Int(1),
1340 stream: 1.into(),
1341 },
1342 WorkerMessage::CallFunction(CallFunctionParams {
1343 seq: 0.into(),
1344 results: vec![Some(Ref { id: 2 })],
1345 mutates: vec![0.into()],
1346 function: "i.dont.exist".into(),
1347 args: vec![],
1348 kwargs: HashMap::new(),
1349 stream: 1.into(),
1350 remote_process_groups: vec![],
1351 }),
1352 ],
1353 )
1354 .await
1355 .unwrap();
1356
1357 let result = worker_handle
1358 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1359 .await?;
1360
1361 worker_handle.drain_and_stop().unwrap();
1363 worker_handle.await;
1364
1365 let mutated_ref = result
1366 .context("no such ref")?
1367 .err()
1368 .context("expected error")?;
1369 assert!(mutated_ref.contains("failed to resolve function"));
1370
1371 let responses = controller_rx.drain();
1372 assert_eq!(
1373 responses.len(),
1374 1,
1375 "Expected one response, got: {:#?}",
1376 responses
1377 );
1378 Ok(())
1379 }
1380
1381 #[async_timed_test(timeout_secs = 60)]
1382 async fn accessing_errored_dependency() -> Result<()> {
1383 test_setup()?;
1384
1385 let proc = Proc::local();
1386 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1387
1388 let worker_handle = proc
1389 .spawn::<WorkerActor>(
1390 "worker",
1391 WorkerParams {
1392 world_size: 1,
1393 rank: 0,
1394 device_index: None,
1395 controller_actor: controller_ref,
1396 },
1397 )
1398 .await
1399 .unwrap();
1400 worker_handle
1401 .command_group(
1402 &client,
1403 vec![
1404 WorkerMessage::CreateStream {
1405 id: 1.into(),
1406 stream_creation: StreamCreationMode::UseDefaultStream,
1407 },
1408 WorkerMessage::CallFunction(CallFunctionParams {
1409 seq: 0.into(),
1410 results: vec![Some(0.into())],
1411 mutates: vec![],
1412 function: "i.dont.exist".into(),
1413 args: vec![],
1414 kwargs: HashMap::new(),
1415 stream: 1.into(),
1416 remote_process_groups: vec![],
1417 }),
1418 WorkerMessage::CallFunction(CallFunctionParams {
1419 seq: 1.into(),
1420 results: vec![Some(1.into())],
1421 mutates: vec![],
1422 function: "torch.ops.aten.sub_.Scalar".into(),
1423 args: vec![WireValue::Ref(0.into())],
1424 kwargs: HashMap::new(),
1425 stream: 1.into(),
1426 remote_process_groups: vec![],
1427 }),
1428 WorkerMessage::Exit { error: None },
1429 ],
1430 )
1431 .await
1432 .unwrap();
1433
1434 worker_handle.drain_and_stop().unwrap();
1435 worker_handle.await;
1436
1437 let responses = controller_rx.drain();
1438 assert_eq!(
1439 responses.len(),
1440 1,
1441 "Expected one response, got: {:#?}",
1442 responses
1443 );
1444
1445 match &responses[0] {
1446 ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1447 assert_eq!(seq, &0.into())
1448 }
1449 _ => panic!("unexpected response {:#?}", responses[0]),
1450 };
1451 Ok(())
1452 }
1453
1454 #[async_timed_test(timeout_secs = 60)]
1455 async fn py_remote_function_calls() -> Result<()> {
1456 test_setup()?;
1457
1458 let proc = Proc::local();
1459 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1460
1461 let worker_handle = proc
1462 .spawn::<WorkerActor>(
1463 "worker",
1464 WorkerParams {
1465 world_size: 1,
1466 rank: 0,
1467 device_index: None,
1468 controller_actor: controller_ref,
1469 },
1470 )
1471 .await
1472 .unwrap();
1473 let (split_arg, sort_list, mesh_ref, dim, layout, none, scalar, device, memory_format) =
1474 Python::with_gil(|py| {
1475 let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1476 .into_any()
1477 .try_into()?;
1478 let sort_list: PickledPyObject =
1479 PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1480 let mesh_ref: PickledPyObject = Ref { id: 5 }.into_bound_py_any(py)?.try_into()?;
1481 let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1482 let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1483 let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1484 let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1485 let device: PickledPyObject = py
1486 .import("torch")?
1487 .getattr("device")?
1488 .call1(("cuda:1",))?
1489 .try_into()?;
1490 let memory_format: PickledPyObject = py
1491 .import("torch")?
1492 .getattr("contiguous_format")?
1493 .try_into()?;
1494 PyResult::Ok((
1495 split_arg,
1496 sort_list,
1497 mesh_ref,
1498 dim,
1499 layout,
1500 none,
1501 scalar,
1502 device,
1503 memory_format,
1504 ))
1505 })?;
1506
1507 worker_handle
1508 .command_group(
1509 &client,
1510 vec![
1511 WorkerMessage::CreateStream {
1512 id: 1.into(),
1513 stream_creation: StreamCreationMode::UseDefaultStream,
1514 },
1515 WorkerMessage::CallFunction(CallFunctionParams {
1516 seq: 0.into(),
1517 results: vec![Some(0.into()), Some(Ref { id: 2 })],
1518 mutates: vec![],
1519 function: "os.path.split".into(),
1520 args: vec![split_arg.into()],
1521 kwargs: HashMap::new(),
1522 stream: 1.into(),
1523 remote_process_groups: vec![],
1524 }),
1525 WorkerMessage::CallFunction(CallFunctionParams {
1526 seq: 2.into(),
1527 results: vec![Some(4.into()), None, None, None, None],
1528 mutates: vec![],
1529 function: "builtins.sorted".into(),
1530 args: vec![sort_list.into()],
1531 kwargs: HashMap::new(),
1532 stream: 1.into(),
1533 remote_process_groups: vec![],
1534 }),
1535 WorkerMessage::CreateDeviceMesh {
1536 result: 5.into(),
1537 names: vec!["x".into()],
1538 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1539 },
1540 WorkerMessage::CallFunction(CallFunctionParams {
1541 seq: 2.into(),
1542 results: vec![Some(6.into())],
1543 mutates: vec![],
1544 function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1545 args: vec![mesh_ref.into(), dim.into()],
1546 kwargs: HashMap::new(),
1547 stream: 1.into(),
1548 remote_process_groups: vec![],
1549 }),
1550 WorkerMessage::CallFunction(CallFunctionParams {
1551 seq: 4.into(),
1552 results: vec![Some(7.into())],
1553 mutates: vec![],
1554 function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1555 .into(),
1556 args: vec![scalar.into()],
1557 kwargs: HashMap::new(),
1558 stream: 1.into(),
1559 remote_process_groups: vec![],
1560 }),
1561 WorkerMessage::CallFunction(CallFunctionParams {
1562 seq: 5.into(),
1563 results: vec![Some(8.into())],
1564 mutates: vec![],
1565 function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1566 args: vec![layout.into()],
1567 kwargs: HashMap::new(),
1568 stream: 1.into(),
1569 remote_process_groups: vec![],
1570 }),
1571 WorkerMessage::CallFunction(CallFunctionParams {
1572 seq: 6.into(),
1573 results: vec![Some(9.into())],
1574 mutates: vec![],
1575 function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1576 args: vec![none.into()],
1577 kwargs: HashMap::new(),
1578 stream: 1.into(),
1579 remote_process_groups: vec![],
1580 }),
1581 WorkerMessage::CallFunction(CallFunctionParams {
1584 seq: 7.into(),
1585 results: vec![None],
1586 mutates: vec![],
1587 function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1588 args: vec![],
1589 kwargs: HashMap::new(),
1590 stream: 1.into(),
1591 remote_process_groups: vec![],
1592 }),
1593 WorkerMessage::CallFunction(CallFunctionParams {
1594 seq: 8.into(),
1595 results: vec![Some(10.into())],
1596 mutates: vec![],
1597 function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1598 args: vec![device.into()],
1599 kwargs: HashMap::new(),
1600 stream: 1.into(),
1601 remote_process_groups: vec![],
1602 }),
1603 WorkerMessage::CallFunction(CallFunctionParams {
1604 seq: 9.into(),
1605 results: vec![Some(11.into())],
1606 mutates: vec![],
1607 function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1608 .into(),
1609 args: vec![memory_format.into()],
1610 kwargs: HashMap::new(),
1611 stream: 1.into(),
1612 remote_process_groups: vec![],
1613 }),
1614 WorkerMessage::CallFunction(CallFunctionParams {
1616 seq: 10.into(),
1617 results: vec![Some(12.into())],
1618 mutates: vec![],
1619 function: "torch.ops.aten.ones.default".into(),
1620 args: vec![WireValue::IntList(vec![2, 3])],
1621 kwargs: HashMap::new(),
1622 stream: 1.into(),
1623 remote_process_groups: vec![],
1624 }),
1625 WorkerMessage::CallFunction(CallFunctionParams {
1626 seq: 11.into(),
1627 results: vec![Some(13.into())],
1628 mutates: vec![],
1629 function: "torch.ops.aten.stack.default".into(),
1630 args: vec![WireValue::RefList(vec![12.into(), 12.into()])],
1631 kwargs: HashMap::new(),
1632 stream: 1.into(),
1633 remote_process_groups: vec![],
1634 }),
1635 ],
1636 )
1637 .await
1638 .unwrap();
1639
1640 let result1: String = worker_handle
1641 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1642 .await
1643 .unwrap()
1644 .unwrap()
1645 .unwrap()
1646 .try_into()
1647 .unwrap();
1648 let result2: String = worker_handle
1649 .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1650 .await
1651 .unwrap()
1652 .unwrap()
1653 .unwrap()
1654 .try_into()
1655 .unwrap();
1656 let result3: i64 = worker_handle
1657 .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1658 .await
1659 .unwrap()
1660 .unwrap()
1661 .unwrap()
1662 .try_into()
1663 .unwrap();
1664 let result4: i64 = worker_handle
1665 .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1666 .await
1667 .unwrap()
1668 .unwrap()
1669 .unwrap()
1670 .try_into()
1671 .unwrap();
1672 assert_eq!(
1673 ScalarType::Float,
1674 worker_handle
1675 .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1676 .await
1677 .unwrap()
1678 .unwrap()
1679 .unwrap()
1680 .try_into()
1681 .unwrap()
1682 );
1683 assert_eq!(
1684 Layout::Strided,
1685 worker_handle
1686 .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1687 .await
1688 .unwrap()
1689 .unwrap()
1690 .unwrap()
1691 .try_into()
1692 .unwrap()
1693 );
1694 assert_matches!(
1695 worker_handle
1696 .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1697 .await
1698 .unwrap()
1699 .unwrap()
1700 .unwrap(),
1701 WireValue::None(()),
1702 );
1703 let device: Device = CudaDevice::new(DeviceIndex(1)).into();
1704 assert_eq!(
1705 device,
1706 worker_handle
1707 .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1708 .await
1709 .unwrap()
1710 .unwrap()
1711 .unwrap()
1712 .try_into()
1713 .unwrap()
1714 );
1715 assert_matches!(
1716 worker_handle
1717 .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1718 .await
1719 .unwrap()
1720 .unwrap()
1721 .unwrap(),
1722 WireValue::MemoryFormat(MemoryFormat::Contiguous),
1723 );
1724
1725 worker_handle.drain_and_stop().unwrap();
1726 worker_handle.await;
1727 let error_responses = controller_rx.drain();
1728 assert!(
1729 error_responses.is_empty(),
1730 "Expected no error responses, got: {:#?}",
1731 error_responses
1732 );
1733
1734 assert_eq!(result1, "/fbs/fbc/foo");
1735 assert_eq!(result2, "bar");
1736 assert_eq!(result3, 1);
1737 assert_eq!(result4, 0);
1738
1739 Ok(())
1740 }
1741
1742 #[async_timed_test(timeout_secs = 60)]
1743 async fn delete_refs() -> Result<()> {
1744 test_setup()?;
1745
1746 let proc = Proc::local();
1747 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1748
1749 let worker_handle = proc
1750 .spawn::<WorkerActor>(
1751 "worker",
1752 WorkerParams {
1753 world_size: 1,
1754 rank: 0,
1755 device_index: None,
1756 controller_actor: controller_ref,
1757 },
1758 )
1759 .await
1760 .unwrap();
1761 worker_handle
1762 .command_group(
1763 &client,
1764 vec![
1765 WorkerMessage::CreateStream {
1766 id: 0.into(),
1767 stream_creation: StreamCreationMode::CreateNewStream,
1768 },
1769 WorkerMessage::CreateStream {
1770 id: 1.into(),
1771 stream_creation: StreamCreationMode::CreateNewStream,
1772 },
1773 WorkerMessage::SetRefUnitTestsOnly {
1774 reference: Ref { id: 2 },
1775 value: WireValue::Bool(false),
1776 stream: 0.into(),
1777 },
1778 WorkerMessage::SetRefUnitTestsOnly {
1779 reference: Ref { id: 3 },
1780 value: WireValue::Bool(true),
1781 stream: 0.into(),
1782 },
1783 WorkerMessage::SetRefUnitTestsOnly {
1784 reference: Ref { id: 4 },
1785 value: WireValue::Int(0),
1786 stream: 1.into(),
1787 },
1788 WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1789 ],
1790 )
1791 .await
1792 .unwrap();
1793
1794 let result: bool = worker_handle
1795 .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1796 .await
1797 .unwrap()
1798 .unwrap()
1799 .unwrap()
1800 .try_into()
1801 .unwrap();
1802 let fail_result = worker_handle
1803 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1804 .await
1805 .unwrap();
1806
1807 worker_handle.drain_and_stop().unwrap();
1808 worker_handle.await;
1809
1810 assert!(result, "should be able to get a non-deleted ref");
1811 assert!(fail_result.is_none(), "should fail to get a deleted ref");
1812
1813 Ok(())
1814 }
1815
1816 #[async_timed_test(timeout_secs = 60)]
1817 async fn request_status() -> Result<()> {
1818 test_setup()?;
1819
1820 let proc = Proc::local();
1821 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1822
1823 let worker_handle = proc
1824 .spawn::<WorkerActor>(
1825 "worker",
1826 WorkerParams {
1827 world_size: 1,
1828 rank: 0,
1829 device_index: None,
1830 controller_actor: controller_ref,
1831 },
1832 )
1833 .await
1834 .unwrap();
1835 worker_handle
1836 .command_group(
1837 &client,
1838 vec![
1839 WorkerMessage::CreateStream {
1840 id: 0.into(),
1841 stream_creation: StreamCreationMode::CreateNewStream,
1842 },
1843 WorkerMessage::CreateStream {
1844 id: 1.into(),
1845 stream_creation: StreamCreationMode::CreateNewStream,
1846 },
1847 ],
1848 )
1849 .await
1850 .unwrap();
1851
1852 for i in 0..100 {
1853 worker_handle
1855 .call_function(
1856 &client,
1857 CallFunctionParams {
1858 seq: i.into(),
1859 results: vec![Some(Ref { id: i + 2 })],
1860 mutates: vec![],
1861 function: "torch.ops.aten.ones.default".into(),
1862 args: vec![WireValue::IntList(vec![2, 3])],
1863 kwargs: HashMap::new(),
1864 stream: (i % 2).into(),
1865 remote_process_groups: vec![],
1866 },
1867 )
1868 .await
1869 .unwrap();
1870 }
1871
1872 worker_handle
1873 .request_status(&client, 100.into(), false)
1874 .await
1875 .unwrap();
1876
1877 worker_handle.drain_and_stop().unwrap();
1878 worker_handle.await;
1879
1880 let mut responses = controller_rx.drain();
1881 assert_eq!(
1882 responses.len(),
1883 1,
1884 "Expected one response, got: {:#?}",
1885 responses
1886 );
1887
1888 let response = responses.pop().unwrap();
1889 match response {
1890 ControllerMessage::Status { seq, .. } => {
1891 assert_eq!(seq, 101.into())
1892 }
1893 _ => panic!("unexpected response {:#?}", response),
1894 };
1895
1896 Ok(())
1897 }
1898
1899 #[async_timed_test(timeout_secs = 60)]
1900 async fn backend_network_init() {
1901 let proc = Proc::local();
1902 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1903
1904 let worker_handle1 = proc
1905 .spawn::<WorkerActor>(
1906 "worker0",
1907 WorkerParams {
1908 world_size: 2,
1909 rank: 0,
1910 device_index: Some(0),
1911 controller_actor: controller_ref.clone(),
1912 },
1913 )
1914 .await
1915 .unwrap();
1916 let worker_handle2 = proc
1917 .spawn::<WorkerActor>(
1918 "worker1",
1919 WorkerParams {
1920 world_size: 2,
1921 rank: 1,
1922 device_index: Some(1),
1923 controller_actor: controller_ref,
1924 },
1925 )
1926 .await
1927 .unwrap();
1928
1929 let unique_id = UniqueId::new().unwrap();
1930 worker_handle1
1931 .backend_network_init(&client, unique_id.clone())
1932 .await
1933 .unwrap();
1934 worker_handle2
1935 .backend_network_init(&client, unique_id)
1936 .await
1937 .unwrap();
1938
1939 worker_handle1.drain_and_stop().unwrap();
1940 worker_handle1.await;
1941 worker_handle2.drain_and_stop().unwrap();
1942 worker_handle2.await;
1943 }
1944
1945 #[async_timed_test(timeout_secs = 60)]
1946 async fn send_value() -> Result<()> {
1947 test_setup()?;
1948
1949 let proc = Proc::local();
1950 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1951
1952 let worker_handle = proc
1953 .spawn::<WorkerActor>(
1954 "worker",
1955 WorkerParams {
1956 world_size: 1,
1957 rank: 0,
1958 device_index: None,
1959 controller_actor: controller_ref,
1960 },
1961 )
1962 .await
1963 .unwrap();
1964 worker_handle
1965 .command_group(
1966 &client,
1967 vec![
1968 WorkerMessage::CreateStream {
1969 id: 1.into(),
1970 stream_creation: StreamCreationMode::UseDefaultStream,
1971 },
1972 WorkerMessage::CallFunction(CallFunctionParams {
1973 seq: 0.into(),
1974 results: vec![Some(0.into())],
1975 mutates: vec![],
1976 function: "torch.ops.aten.ones.default".into(),
1977 args: vec![WireValue::IntList(vec![2, 3])],
1978 kwargs: HashMap::new(),
1979 stream: 1.into(),
1980 remote_process_groups: vec![],
1981 }),
1982 WorkerMessage::SendValue {
1983 seq: 1.into(),
1984 destination: None,
1985 mutates: vec![],
1986 function: None,
1987 args: vec![WireValue::Ref(0.into())],
1988 kwargs: HashMap::new(),
1989 stream: 1.into(),
1990 },
1991 WorkerMessage::SendValue {
1992 seq: 2.into(),
1993 destination: None,
1994 mutates: vec![],
1995 function: Some("torch.ops.aten.var_mean.default".into()),
1996 args: vec![WireValue::Ref(0.into())],
1997 kwargs: HashMap::new(),
1998 stream: 1.into(),
1999 },
2000 WorkerMessage::Exit { error: None },
2001 ],
2002 )
2003 .await
2004 .unwrap();
2005
2006 worker_handle.drain_and_stop()?;
2007 assert_matches!(worker_handle.await, ActorStatus::Stopped);
2008
2009 let mut responses = controller_rx.drain();
2010 assert_eq!(
2011 responses.len(),
2012 3,
2013 "Expected one response, got: {:#?}",
2014 responses
2015 );
2016
2017 match responses.pop().unwrap() {
2018 ControllerMessage::FetchResult { seq, value } => {
2019 assert_eq!(seq, 2.into());
2020 let value = value.unwrap().deserialized::<PyTree<RValue>>().unwrap();
2021 assert_eq!(value.leaves().len(), 2);
2022 }
2023 resp => panic!("unexpected response {:#?}", resp),
2024 };
2025 match responses.pop().unwrap() {
2026 ControllerMessage::FetchResult { seq, .. } => {
2027 assert_eq!(seq, 1.into())
2028 }
2029 resp => panic!("unexpected response {:#?}", resp),
2030 };
2031 Ok(())
2032 }
2033
2034 #[async_timed_test(timeout_secs = 60)]
2035 async fn send_value_err_result() -> Result<()> {
2036 test_setup()?;
2037
2038 let proc = Proc::local();
2039 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
2040
2041 let worker_handle = proc
2042 .spawn::<WorkerActor>(
2043 "worker",
2044 WorkerParams {
2045 world_size: 1,
2046 rank: 0,
2047 device_index: None,
2048 controller_actor: controller_ref,
2049 },
2050 )
2051 .await
2052 .unwrap();
2053
2054 let ref_arg: PickledPyObject =
2055 Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2056
2057 worker_handle
2058 .command_group(
2059 &client,
2060 vec![
2061 WorkerMessage::CreateStream {
2062 id: 1.into(),
2063 stream_creation: StreamCreationMode::UseDefaultStream,
2064 },
2065 WorkerMessage::SetRefUnitTestsOnly {
2066 reference: Ref { id: 2 },
2067 value: WireValue::Bool(false),
2068 stream: 1.into(),
2069 },
2070 WorkerMessage::SendValue {
2071 seq: 1.into(),
2072 destination: None,
2073 mutates: vec![Ref { id: 2 }],
2074 function: Some("non.existent.function".into()),
2075 args: vec![],
2076 kwargs: HashMap::new(),
2077 stream: 1.into(),
2078 },
2079 WorkerMessage::SendValue {
2080 seq: 2.into(),
2081 destination: None,
2082 mutates: vec![],
2083 function: None,
2084 args: vec![ref_arg.into()],
2085 kwargs: HashMap::new(),
2086 stream: 1.into(),
2087 },
2088 WorkerMessage::Exit { error: None },
2089 ],
2090 )
2091 .await
2092 .unwrap();
2093
2094 worker_handle.drain_and_stop()?;
2095 assert_matches!(worker_handle.await, ActorStatus::Stopped);
2096
2097 let mut responses = controller_rx.drain();
2098 assert_eq!(
2099 responses.len(),
2100 3,
2101 "Expected one response, got: {:#?}",
2102 responses
2103 );
2104 match responses.pop() {
2105 Some(ControllerMessage::FetchResult { seq, value }) => {
2106 assert_eq!(seq, 2.into());
2107 assert!(value.is_err());
2108 assert!(
2109 value
2110 .unwrap_err()
2111 .backtrace
2112 .contains("failed to resolve function")
2113 );
2114 }
2115 _ => panic!("unexpected response {:#?}", responses),
2116 }
2117 match responses.pop() {
2118 Some(ControllerMessage::FetchResult { seq, value }) => {
2119 assert_eq!(seq, 1.into());
2120 assert!(value.is_err());
2121 assert!(
2122 value
2123 .unwrap_err()
2124 .backtrace
2125 .contains("failed to resolve function")
2126 );
2127 }
2128 _ => panic!("unexpected response {:#?}", responses),
2129 }
2130 Ok(())
2131 }
2132
2133 fn get_random_channel_addr() -> ChannelAddr {
2134 let random_string = rand::thread_rng()
2135 .sample_iter(&Alphanumeric)
2136 .take(24)
2137 .map(char::from)
2138 .collect::<String>();
2139 format!("unix!@{random_string}").parse().unwrap()
2140 }
2141
2142 async fn ensure_world_ready(client: &Instance<()>, world: WorldId) -> Result<()> {
2143 tracing::info!("checking whether world {world} is ready");
2144 let retry_strategy = FixedInterval::from_millis(1000).take(100);
2145 Retry::spawn(retry_strategy, async || {
2146 let snapshot = SYSTEM_ACTOR_REF
2147 .snapshot(&client, SystemSnapshotFilter::default())
2148 .await?;
2149 let world_snapshot = snapshot.worlds.get(&world).ok_or(anyhow!("no world"))?;
2150 tracing::info!("world status: {:?}", world_snapshot.status);
2151 match world_snapshot.status {
2152 WorldStatus::Live => Ok(()),
2153 _ => Err(anyhow!("world is not live")),
2154 }
2155 })
2156 .await?;
2157 Ok(())
2158 }
2159
2160 #[async_timed_test(timeout_secs = 60)]
2161 async fn remote_process_group() -> Result<()> {
2162 test_setup()?;
2163
2164 let timeout: Duration = Duration::from_secs(10);
2166 let system_addr = get_random_channel_addr();
2167 let _system_handle = System::serve(system_addr.clone(), timeout, timeout).await?;
2168
2169 let client = System::new(system_addr.clone()).attach().await?;
2171 let (_handle, mut controller_rx) = client.bind_actor_port::<ControllerMessage>();
2172 let controller_ref: ActorRef<ControllerActor> = ActorRef::attest(client.self_id().clone());
2173
2174 let world_size = 2;
2176 SYSTEM_ACTOR_REF
2177 .upsert_world(
2178 &client,
2179 id!(world),
2180 Shape::Definite(vec![world_size]),
2181 4,
2182 Environment::Local,
2183 HashMap::new(),
2184 )
2185 .await?;
2186
2187 let mut worker_process_handles = vec![];
2189 let mut worker_procs: Vec<ActorRef<ProcActor>> = vec![];
2190 for rank in 0..world_size {
2191 let world_id = "world".to_string();
2192 let proc_id = format!("{world_id}[{rank}]");
2193 worker_procs.push(ActorRef::attest(format!("world[{rank}].proc[0]").parse()?));
2194
2195 let mut handle =
2196 Command::new(std::env::var("MONARCH_TENSOR_WORKER_EXE").map_err(|e| {
2197 anyhow::anyhow!("could not get var MONARCH_TENSOR_WORKER_EXE: {}", e)
2198 })?)
2199 .arg("worker")
2200 .arg(format!("--bootstrap-addr={system_addr}"))
2201 .arg(format!("--world-id={world_id}"))
2202 .arg(format!("--proc-id={proc_id}"))
2203 .stdout(Stdio::piped())
2204 .stdin(Stdio::piped())
2205 .kill_on_drop(true)
2206 .spawn()?;
2207
2208 let out = handle.stdout.take().unwrap();
2209 tokio::spawn(async move {
2210 let mut reader = BufReader::new(out);
2211 tokio::io::copy(&mut reader, &mut tokio::io::stderr())
2212 .await
2213 .unwrap();
2214 });
2215 worker_process_handles.push(handle);
2216 }
2217
2218 ensure_world_ready(&client, id!(world)).await?;
2221
2222 let (spawned_port, mut spawned_receiver) = open_port(&client);
2224 for (rank, worker_proc) in worker_procs.iter().enumerate() {
2225 let params = WorkerParams {
2226 world_size,
2227 rank,
2228 device_index: Some(rank.try_into().unwrap()),
2229 controller_actor: controller_ref.clone(),
2230 };
2231 worker_proc
2232 .spawn(
2233 &client,
2234 "monarch_tensor_worker::WorkerActor".to_owned(),
2235 "worker".to_owned(),
2236 bincode::serialize(¶ms)?,
2237 spawned_port.bind(),
2238 )
2239 .await?;
2240 }
2241 let mut spawned = vec![];
2242 while spawned.len() < world_size {
2243 spawned.push(spawned_receiver.recv().await?);
2244 }
2245 tracing::info!("spawned {} worker actors", world_size);
2246 let workers: Vec<ActorRef<WorkerActor>> = (0..world_size)
2247 .map(|rank| format!("world[{rank}].worker[0]"))
2248 .map(|name| ActorRef::attest(name.parse().unwrap()))
2249 .collect();
2250
2251 let remote_proc_grp_ref: PickledPyObject =
2252 Python::with_gil(|py| Ref { id: 2 }.into_bound_py_any(py)?.try_into())?;
2253
2254 let unique_id = UniqueId::new()?;
2255 let messages = vec![
2256 WorkerMessage::CreateStream {
2257 id: 0.into(),
2258 stream_creation: StreamCreationMode::UseDefaultStream,
2259 },
2260 WorkerMessage::BackendNetworkInit(unique_id.clone()),
2261 WorkerMessage::CreateDeviceMesh {
2262 result: 1.into(),
2263 names: vec!["x".into()],
2264 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
2265 },
2266 WorkerMessage::CreateRemoteProcessGroup {
2267 result: 2.into(),
2268 device_mesh: 1.into(),
2269 dims: vec!["x".into()],
2270 },
2271 WorkerMessage::SplitCommForProcessGroup {
2272 remote_process_group: 2.into(),
2273 stream: 0.into(),
2274 config: None,
2275 },
2276 WorkerMessage::CallFunction(CallFunctionParams {
2277 seq: 0.into(),
2278 results: vec![Some(3.into())],
2279 mutates: vec![],
2280 function: "monarch.monarch_tensor_worker.test_utils.test_remote_process_group"
2281 .into(),
2282 args: vec![remote_proc_grp_ref.into()],
2283 kwargs: HashMap::new(),
2284 stream: 0.into(),
2285 remote_process_groups: vec![2.into()],
2286 }),
2287 ];
2288
2289 workers[0].command_group(&client, messages.clone()).await?;
2290 workers[1].command_group(&client, messages).await?;
2291
2292 let _ = workers[0]
2293 .get_ref_unit_tests_only(&client, 3.into(), 0.into())
2294 .await?
2295 .unwrap()
2296 .unwrap();
2297
2298 let error_responses = controller_rx.drain();
2299 assert!(
2300 error_responses.is_empty(),
2301 "Expected no error responses, got: {:#?}",
2302 error_responses
2303 );
2304
2305 Ok(())
2306 }
2307}