1#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12#![allow(unsafe_op_in_unsafe_fn)]
15
16mod borrow;
31mod comm;
32pub mod device_mesh;
33pub mod stream;
34pub mod test_util;
35
36use std::collections::HashMap;
37use std::collections::HashSet;
38use std::collections::hash_map::Entry;
39use std::sync::Arc;
40
41use anyhow::Context;
42use anyhow::Result;
43use anyhow::anyhow;
44use anyhow::bail;
45use anyhow::ensure;
46use async_trait::async_trait;
47use borrow::Borrow;
48use comm::CommMessageClient;
49use comm::CommParams;
50use comm::NcclCommActor;
51use derive_more::TryInto;
52use device_mesh::DeviceMesh;
53use futures::future::try_join_all;
54use hyperactor as reference;
55use hyperactor::Actor;
56use hyperactor::Bind;
57use hyperactor::Handler;
58use hyperactor::RemoteSpawn;
59use hyperactor::Unbind;
60use hyperactor::actor::ActorHandle;
61use hyperactor::context;
62use hyperactor_config::Flattrs;
63use hyperactor_mesh::comm::multicast::CastInfo;
64use itertools::Itertools;
65use monarch_hyperactor::shape::PyPoint;
66use monarch_messages::controller::ControllerActor;
67use monarch_messages::controller::ControllerMessageClient;
68use monarch_messages::controller::Seq;
69use monarch_messages::wire_value::WireValue;
70use monarch_messages::worker::ActorCallParams;
71use monarch_messages::worker::ActorMethodParams;
72use monarch_messages::worker::ArgsKwargs;
73use monarch_messages::worker::CallFunctionParams;
74use monarch_messages::worker::Factory;
75use monarch_messages::worker::Reduction;
76use monarch_messages::worker::Ref;
77use monarch_messages::worker::ResolvableFunction;
78use monarch_messages::worker::StreamCreationMode;
79use monarch_messages::worker::StreamRef;
80use monarch_messages::worker::WorkerMessage;
81use monarch_messages::worker::WorkerMessageHandler;
82use monarch_messages::worker::WorkerParams;
83use monarch_types::ReduceOp;
84use monarch_types::UniqueId;
85use ndslice::Slice;
86use pyo3::Python;
87use pyo3::types::PyAnyMethods;
88use serde::Deserialize;
89use serde::Serialize;
90use sorted_vec::SortedVec;
91use stream::StreamActor;
92use stream::StreamMessageClient;
93use stream::StreamParams;
94use torch_sys2::CudaDevice;
95use torch_sys2::DeviceIndex;
96use torch_sys2::Layout;
97use torch_sys2::ScalarType;
98use torch_sys2::TensorCell;
99use torch_sys2::factory_zeros;
100use typeuri::Named;
101
102#[derive(Debug)]
103struct RemoteProcessGroupState {
104 device_mesh_ref: Ref,
105 dims: SortedVec<String>,
106 comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
107}
108
109impl RemoteProcessGroupState {
110 fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
111 Self {
112 device_mesh_ref,
113 dims,
114 comms: HashMap::new(),
115 }
116 }
117}
118
119#[derive(Debug)]
120enum Recording {
121 PartialRecording {
124 last_index: usize,
126 commands: Vec<WorkerMessage>,
128 },
129
130 CompleteRecording {
132 streams: HashSet<StreamRef>,
134 },
135}
136
137#[derive(Debug)]
145#[hyperactor::spawnable]
146#[hyperactor::export(
147 handlers = [
148 WorkerMessage {cast = true},
149 AssignRankMessage {cast = true},
150 ],
151)]
152pub struct WorkerActor {
153 device: Option<CudaDevice>,
154 streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
155 device_meshes: HashMap<
158 Ref,
159 (
160 DeviceMesh,
161 HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
163 ),
164 >,
165 world_size: usize,
166 rank: usize,
167 borrows: HashMap<u64, Borrow>,
168 comm: Option<ActorHandle<NcclCommActor>>,
169 controller_actor: reference::ActorRef<ControllerActor>,
170 remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
174 send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
176 recordings: HashMap<Ref, Recording>,
177 defining_recording: Option<Ref>,
178 respond_with_python_message: bool,
179}
180
181impl WorkerActor {
182 fn runtime_has_cuda() -> bool {
183 Python::attach(|py| {
184 py.import("torch")
185 .expect("torch must be importable in a worker")
186 .getattr("cuda")
187 .expect("torch.cuda attribute must exist")
188 .call_method0("is_available")
189 .expect("torch.cuda.is_available() must be callable")
190 .extract::<bool>()
191 .expect("torch.cuda.is_available() must return bool")
192 })
193 }
194
195 fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
196 self.streams
197 .get(&stream)
198 .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
199 }
200
201 async fn maybe_add_stream_to_recording(
202 &mut self,
203 cx: &impl context::Actor,
204 stream: StreamRef,
205 ) -> Result<()> {
206 if let Some(defining_recording) = self.defining_recording {
209 let recording = self.recordings.get_mut(&defining_recording).unwrap();
210 let fut = match recording {
211 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
212 Recording::CompleteRecording { streams } => {
213 streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
214 Ok(self
215 .try_get_stream(stream)?
216 .define_recording(cx, defining_recording))
217 })
218 }
219 }
220 .transpose()?;
221 match fut {
222 Some(fut) => fut.await,
223 None => Ok(()),
224 }
225 } else {
226 Ok(())
227 }
228 }
229}
230
231impl Actor for WorkerActor {}
232
233#[async_trait]
234impl RemoteSpawn for WorkerActor {
235 type Params = WorkerParams;
236
237 async fn new(
238 WorkerParams {
239 world_size,
240 rank,
241 device_index,
242 controller_actor,
243 }: Self::Params,
244 _environment: Flattrs,
245 ) -> Result<Self> {
246 Python::attach(|py| {
247 py.import("monarch.safe_torch").unwrap();
248 });
249 let device = if Self::runtime_has_cuda() {
250 device_index.map(|i| CudaDevice::new(DeviceIndex(i)))
251 } else {
252 None
253 };
254 Ok(Self {
255 device,
256 streams: HashMap::new(),
257 device_meshes: HashMap::new(),
258 world_size,
259 rank,
260 borrows: HashMap::new(),
261 comm: None,
262 controller_actor,
263 remote_process_groups: HashMap::new(),
264 send_recv_comms: HashMap::new(),
265 recordings: HashMap::new(),
266 defining_recording: None,
267 respond_with_python_message: false,
268 })
269 }
270
271 }
273
274#[async_trait]
275impl Handler<AssignRankMessage> for WorkerActor {
276 async fn handle(
277 &mut self,
278 cx: &hyperactor::Context<Self>,
279 _: AssignRankMessage,
280 ) -> anyhow::Result<()> {
281 let point = cx.cast_point();
282 self.rank = point.rank();
283 self.respond_with_python_message = true;
284 Python::attach(|py| {
285 let mesh_controller = py.import("monarch.mesh_controller").unwrap();
286 let p: PyPoint = point.into();
287 mesh_controller
288 .call_method1("_initialize_env", (p, cx.proc().proc_addr().to_string()))
289 .unwrap();
290 });
291 Ok(())
292 }
293}
294
295#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
298pub enum AssignRankMessage {
299 AssignRank(),
300}
301wirevalue::register_type!(AssignRankMessage);
302
303#[async_trait]
304impl Handler<WorkerMessage> for WorkerActor {
305 async fn handle(
306 &mut self,
307 cx: &hyperactor::Context<Self>,
308 message: WorkerMessage,
309 ) -> anyhow::Result<()> {
310 <Self as WorkerMessageHandler>::handle(self, cx, message).await
311 }
312}
313
314#[async_trait]
315impl WorkerMessageHandler for WorkerActor {
316 async fn backend_network_init(
317 &mut self,
318 cx: &hyperactor::Context<Self>,
319 unique_id: UniqueId,
320 ) -> Result<()> {
321 let Some(device) = self.device else {
322 return Ok(());
323 };
324 let comm = NcclCommActor::new(CommParams::New {
325 device,
326 unique_id,
327 world_size: self.world_size.try_into().unwrap(),
328 rank: self.rank.try_into().unwrap(),
329 })
330 .await?
331 .spawn(cx)?;
332
333 let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
334 let cell = TensorCell::new(tensor);
335
336 comm.all_reduce(
337 cx,
338 cell,
339 ReduceOp::Sum,
340 torch_sys_cuda::cuda::Stream::get_current_stream(),
341 )
342 .await?;
343
344 let sorted_streams = self
353 .streams
354 .iter()
355 .sorted_by_key(|(k, _)| *k)
356 .map(|(_, v)| v.as_ref());
357
358 let mut splits = Vec::new();
359 for _ in 0..sorted_streams.len() {
360 splits.push(comm.split_all(cx).await?);
363 }
364 let _: Vec<()> = try_join_all(
365 sorted_streams
366 .into_iter()
367 .zip(splits)
368 .map(|(stream, split)| stream.init_comm(cx, split)),
369 )
370 .await?;
371
372 self.comm = Some(comm);
373
374 Ok(())
375 }
376
377 async fn backend_network_point_to_point_init(
378 &mut self,
379 cx: &hyperactor::Context<Self>,
380 from_stream: StreamRef,
381 to_stream: StreamRef,
382 ) -> Result<()> {
383 if !self.streams.contains_key(&from_stream) {
384 bail!("invalid from_stream id: {:#?}", from_stream);
385 }
386 if !self.streams.contains_key(&to_stream) {
387 bail!("invalid to_stream id: {:#?}", to_stream);
388 }
389 let Some(global_comm) = self.comm.as_ref() else {
390 return Ok(());
391 };
392 let comm = global_comm.split_all(cx).await?;
393 self.send_recv_comms
394 .insert((from_stream, to_stream), Arc::new(comm));
395 Ok(())
396 }
397
398 async fn call_function(
399 &mut self,
400 cx: &hyperactor::Context<Self>,
401 params: CallFunctionParams,
402 ) -> Result<()> {
403 let stream = self.try_get_stream(params.stream)?.clone();
404 self.maybe_add_stream_to_recording(cx, params.stream)
405 .await?;
406
407 let device_meshes = self
408 .device_meshes
409 .iter()
410 .map(|(k, v)| (k.clone(), v.0.clone()))
411 .collect();
412
413 let mut remote_process_groups = HashMap::new();
414 for remote_process_group_ref in ¶ms.remote_process_groups {
415 if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
416 let dims_vec = state.dims.iter().cloned().collect();
417 let (device_mesh, _) = self
418 .device_meshes
419 .get(&state.device_mesh_ref)
420 .ok_or_else(|| {
421 anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
422 })?
423 .clone();
424 let comm = state.comms
425 .get(¶ms.stream)
426 .ok_or_else(|| {
427 anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
428 })?
429 .clone();
430 remote_process_groups.insert(
431 remote_process_group_ref.clone(),
432 (device_mesh, dims_vec, comm),
433 );
434 }
435 }
436
437 stream
438 .call_function(cx, params, device_meshes, remote_process_groups)
439 .await?;
440
441 Ok(())
442 }
443
444 async fn command_group(
445 &mut self,
446 cx: &hyperactor::Context<Self>,
447 params: Vec<WorkerMessage>,
448 ) -> Result<()> {
449 for msg in params {
450 WorkerMessageHandler::handle(self, cx, msg).await?;
451 }
452 Ok(())
453 }
454
455 async fn create_stream(
456 &mut self,
457 cx: &hyperactor::Context<Self>,
458 result: StreamRef,
459 creation_mode: StreamCreationMode,
460 ) -> Result<()> {
461 let handle: ActorHandle<StreamActor> = StreamActor::new(StreamParams {
462 world_size: self.world_size,
463 rank: self.rank,
464 creation_mode,
465 id: result,
466 device: self.device,
467 controller_actor: self.controller_actor.clone(),
468 respond_with_python_message: self.respond_with_python_message,
469 })
470 .spawn(cx)?;
471 self.streams.insert(result, Arc::new(handle));
472 Ok(())
473 }
474
475 async fn create_device_mesh(
476 &mut self,
477 _cx: &hyperactor::Context<Self>,
478 result: Ref,
479 names: Vec<String>,
480 ranks: Slice,
481 ) -> Result<()> {
482 self.device_meshes.insert(
483 result,
484 (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
485 );
486 Ok(())
487 }
488
489 async fn create_remote_process_group(
490 &mut self,
491 _cx: &hyperactor::Context<Self>,
492 result: Ref,
493 device_mesh: Ref,
494 dims: Vec<String>,
495 ) -> Result<()> {
496 self.device_meshes
497 .get(&device_mesh)
498 .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
499 match self.remote_process_groups.entry(result) {
500 Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
501 device_mesh,
502 SortedVec::from_unsorted(dims),
503 )),
504 Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
505 };
506 Ok(())
507 }
508
509 async fn borrow_create(
510 &mut self,
511 cx: &hyperactor::Context<Self>,
512 result: Ref,
513 borrow_id: u64,
514 tensor_ref: Ref,
515 from_stream: StreamRef,
516 to_stream: StreamRef,
517 ) -> Result<()> {
518 self.maybe_add_stream_to_recording(cx, from_stream).await?;
519 self.maybe_add_stream_to_recording(cx, to_stream).await?;
520 let from_stream = self.try_get_stream(from_stream)?.clone();
521 let to_stream = self.try_get_stream(to_stream)?.clone();
522
523 let borrow =
524 Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
525 self.borrows.insert(borrow_id, borrow);
526 Ok(())
527 }
528
529 async fn borrow_first_use(
530 &mut self,
531 cx: &hyperactor::Context<Self>,
532 borrow: u64,
533 ) -> Result<()> {
534 let borrow = self
535 .borrows
536 .get_mut(&borrow)
537 .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
538
539 borrow.first_use(cx).await?;
540 Ok(())
541 }
542
543 async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
544 let borrow = self
545 .borrows
546 .get_mut(&borrow)
547 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
548
549 borrow.last_use(cx).await?;
550 Ok(())
551 }
552
553 async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
554 let borrow = self
555 .borrows
556 .get_mut(&borrow_id)
557 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
558
559 borrow.drop(cx).await?;
560 self.borrows.remove(&borrow_id);
561 Ok(())
562 }
563
564 async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
565 let _: Vec<()> = try_join_all(
572 self.streams
573 .values()
574 .map(|s| s.delete_refs(cx, refs.clone())),
575 )
576 .await?;
577 Ok(())
578 }
579
580 async fn request_status(
581 &mut self,
582 cx: &hyperactor::Context<Self>,
583 seq: Seq,
584 controller: bool,
585 ) -> Result<()> {
586 let _: Vec<()> = try_join_all(
591 self.streams
592 .values()
593 .map(|stream| stream.request_status(cx)),
594 )
595 .await?;
596
597 ControllerMessageClient::status(
598 &self.controller_actor,
599 cx,
600 seq.next(),
601 cx.self_addr().clone(),
602 controller,
603 )
604 .await?;
605 Ok(())
606 }
607
608 async fn reduce(
609 &mut self,
610 cx: &hyperactor::Context<Self>,
611 result: Ref,
612 local_tensor: Ref,
613 factory: Factory,
614 source_mesh: Ref,
615 stream_ref: StreamRef,
616 dims: Vec<String>,
617 reduction: Reduction,
618 scatter: bool,
619 in_place: bool,
620 out: Option<Ref>,
621 ) -> Result<()> {
622 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
623
624 let dims = SortedVec::from_unsorted(dims);
626 let stream = self.try_get_stream(stream_ref)?.clone();
627
628 let (_, comm_map) = self
629 .device_meshes
630 .get_mut(&source_mesh)
631 .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
632
633 let (size, comm) = comm_map
634 .get(&(stream_ref, dims.clone()))
635 .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
636 .clone();
637
638 stream
639 .reduce(
640 cx,
641 comm,
642 size.try_into()?,
643 result,
644 local_tensor,
645 factory,
646 reduction,
647 scatter,
648 in_place,
649 out,
650 )
651 .await?;
652
653 Ok(())
654 }
655
656 async fn send_tensor(
657 &mut self,
658 cx: &hyperactor::Context<Self>,
659 result: Ref,
660 from_ranks: Slice,
661 to_ranks: Slice,
662 tensor: Ref,
663 factory: Factory,
664 from_stream: StreamRef,
665 to_stream: StreamRef,
666 ) -> Result<()> {
667 let to_rank = from_ranks
668 .index(self.rank)
669 .map(|index| to_ranks.get(index).ok())
670 .ok()
671 .flatten();
672 let from_rank = to_ranks
673 .index(self.rank)
674 .map(|index| from_ranks.get(index).ok())
675 .ok()
676 .flatten();
677
678 let (stream, stream_ref) = if to_rank.is_none() {
679 (self.try_get_stream(to_stream)?.clone(), to_stream)
680 } else if from_rank.is_none() || from_stream == to_stream {
681 (self.try_get_stream(from_stream)?.clone(), from_stream)
682 } else {
683 unimplemented!(
684 "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
685 It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
686 Then the send stream would do the nccl op, and then sync with sending stream again."
687 );
688 };
689 let comm = if from_rank == to_rank {
690 None
691 } else {
692 Some(
693 self.send_recv_comms
694 .get(&(from_stream, to_stream))
695 .ok_or_else(|| {
696 anyhow::anyhow!(
697 "could not find stream to stream comm for: {:#?}",
698 (from_stream, to_stream)
699 )
700 })?
701 .clone(),
702 )
703 };
704
705 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
706
707 stream
708 .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
709 .await?;
710
711 Ok(())
712 }
713
714 async fn exit(
715 &mut self,
716 cx: &hyperactor::Context<Self>,
717 error: Option<(Option<reference::ActorAddr>, String)>,
718 ) -> Result<()> {
719 for (_, stream) in self.streams.drain() {
720 stream.drain_and_stop("tensor worker exit cleanup")?;
721 Arc::into_inner(stream)
722 .expect("there should be no owners of this stream handle except the worker stream table")
723 .await;
724 }
725
726 let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
727 .ok()
728 .and_then(|val| val.parse::<i32>().ok())
729 .unwrap_or(1);
730 let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
731 .ok()
732 .and_then(|val| val.parse::<i32>().ok())
733 .unwrap_or(1);
734
735 let exit_code = match error {
737 Some((Some(actor_id), reason)) => {
738 tracing::error!(
739 "stopping the worker, actor {} failed with error: {}",
740 actor_id,
741 reason
742 );
743 if cx.self_addr() == &actor_id {
744 self_error_exit_code
745 } else {
746 peer_error_exit_code
747 }
748 }
749 Some((None, reason)) => {
750 tracing::error!("stopping the worker, reason: {}", reason);
751 1
752 }
753 None => 0,
754 };
755
756 if exit_code != 0 {
757 tracing::info!("stopping the worker process, exit code: {}", exit_code);
758 std::process::exit(exit_code);
759 }
760 cx.stop("tensor worker exit")?;
761 Ok(())
762 }
763
764 async fn send_value(
765 &mut self,
766 cx: &hyperactor::Context<Self>,
767 seq: Seq,
768 destination: Option<Ref>,
769 mutates: Vec<Ref>,
770 function: Option<ResolvableFunction>,
771 args_kwargs: ArgsKwargs,
772 stream: StreamRef,
773 ) -> Result<()> {
774 let stream = self.try_get_stream(stream)?;
776
777 let device_meshes = if function.is_none() {
778 HashMap::new()
779 } else {
780 self.device_meshes
781 .iter()
782 .map(|(k, v)| (k.clone(), v.0.clone()))
783 .collect()
784 };
785
786 if destination.is_some() {
787 panic!("send_value with pipe destination is no longer implemented")
788 }
789
790 stream
792 .send_value(
793 cx,
794 seq,
795 cx.self_addr().clone(),
796 mutates,
797 function,
798 args_kwargs,
799 device_meshes,
800 )
801 .await
802 }
803
804 async fn send_result_of_actor_call(
805 &mut self,
806 cx: &hyperactor::Context<Self>,
807 params: ActorCallParams,
808 ) -> Result<()> {
809 let stream = self.try_get_stream(params.stream)?;
810 stream.send_result_of_actor_call(cx, params).await?;
811 Ok(())
812 }
813 async fn call_actor_method(
814 &mut self,
815 cx: &hyperactor::Context<Self>,
816 params: ActorMethodParams,
817 ) -> Result<()> {
818 let stream = self.try_get_stream(params.call.stream)?;
819 stream.call_actor_method(cx, params).await?;
820 Ok(())
821 }
822 async fn split_comm(
823 &mut self,
824 cx: &hyperactor::Context<Self>,
825 dims: Vec<String>,
826 device_mesh: Ref,
827 stream_ref: StreamRef,
828 ) -> Result<()> {
829 let Some(global_comm) = self.comm.as_ref() else {
830 return Ok(());
831 };
832 match self.device_meshes.get_mut(&device_mesh) {
833 Some((device_mesh, comm_map)) => {
834 let stream = self
837 .streams
838 .get(&stream_ref)
839 .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
840
841 let dims = SortedVec::from_unsorted(dims);
842
843 anyhow::ensure!(
844 !comm_map.contains_key(&(stream_ref, dims.clone())),
845 "comm already exists for stream {stream:#?}, dims {dims:#?}"
846 );
847 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
848 let size = ranks_for_group.len();
849 let split_comm = global_comm
850 .split_from(
851 cx,
852 ranks_for_group
853 .into_iter()
854 .map(|v| v.clone().try_into())
855 .collect::<Result<Vec<_>, _>>()?,
856 )
857 .await?
858 .context("split comm should include self rank")?;
859 comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
860 }
861 None => {
862 global_comm.split_from(cx, vec![]).await?;
865 }
866 }
867 Ok(())
868 }
869
870 async fn split_comm_for_process_group(
871 &mut self,
872 cx: &hyperactor::Context<Self>,
873 remote_process_group_ref: Ref,
874 stream_ref: StreamRef,
875 ) -> Result<()> {
876 ensure!(
877 self.streams.contains_key(&stream_ref),
878 "invalid stream id: {:#?}",
879 stream_ref
880 );
881 let Some(global_comm) = self.comm.as_ref() else {
882 return Ok(());
883 };
884 let state = self
885 .remote_process_groups
886 .get_mut(&remote_process_group_ref)
887 .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
888 match self.device_meshes.get_mut(&state.device_mesh_ref) {
889 Some((device_mesh, _)) => {
890 let entry = match state.comms.entry(stream_ref) {
893 Entry::Vacant(entry) => entry,
894 Entry::Occupied(_) => bail!(
895 "comm already exists for remote process group {:#?} on stream {:#?}",
896 remote_process_group_ref,
897 stream_ref,
898 ),
899 };
900 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
901 let split_comm = global_comm
902 .split_from(
903 cx,
904 ranks_for_group
905 .into_iter()
906 .map(|v| v.clone().try_into())
907 .collect::<Result<Vec<_>, _>>()?,
908 )
909 .await?
910 .context("split comm should include self rank")?;
911 entry.insert(Arc::new(split_comm));
912 }
913 None => {
914 global_comm.split_from(cx, vec![]).await?;
917 }
918 }
919 Ok(())
920 }
921
922 async fn pipe_recv(
923 &mut self,
924 _cx: &hyperactor::Context<Self>,
925 _seq: Seq,
926 _results: Vec<Option<Ref>>,
927 _pipe: Ref,
928 _stream: StreamRef,
929 ) -> Result<()> {
930 panic!("pipe_recv is no longer implemented")
931 }
932
933 async fn set_ref_unit_tests_only(
934 &mut self,
935 cx: &hyperactor::Context<Self>,
936 reference: Ref,
937 value: WireValue,
938 stream: StreamRef,
939 ) -> Result<()> {
940 let stream = self.try_get_stream(stream)?;
941
942 stream.set_ref_unit_tests_only(cx, reference, value).await
943 }
944
945 async fn get_ref_unit_tests_only(
946 &mut self,
947 cx: &hyperactor::Context<Self>,
948 ref_id: Ref,
949 stream: StreamRef,
950 ) -> Result<Option<Result<WireValue, String>>> {
951 let stream = self.try_get_stream(stream)?;
952 Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
953 }
954
955 async fn define_recording(
956 &mut self,
957 cx: &hyperactor::Context<Self>,
958 result: Ref,
959 _nresults: usize,
960 _nformals: usize,
961 commands: Vec<WorkerMessage>,
962 ntotal_messages: usize,
963 index: usize,
964 ) -> Result<()> {
965 if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
966 bail!("already defining a different recording");
967 }
968 self.defining_recording = Some(result);
969
970 match self.recordings.entry(result) {
971 Entry::Vacant(entry) => {
972 ensure!(
973 index == 0,
974 "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
975 index
976 );
977 entry.insert(Recording::PartialRecording {
978 last_index: 0,
979 commands,
980 });
981 }
982 Entry::Occupied(mut entry) => match entry.get_mut() {
983 Recording::CompleteRecording { .. } => {
984 bail!("got DefineRecording message for already complete recording")
985 }
986 Recording::PartialRecording {
987 last_index,
988 commands: existing_commands,
989 } => {
990 ensure!(
991 index == *last_index + 1,
992 "Got DefineRecording message with index = {:?}, but \
993 last seen index for recording is {:?}",
994 index,
995 last_index
996 );
997 *last_index = index;
998 existing_commands.extend(commands);
999 }
1000 },
1001 };
1002
1003 if index < ntotal_messages - 1 {
1004 return Ok(());
1005 }
1006 let commands = match self.recordings.remove(&result).unwrap() {
1007 Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
1008 Recording::PartialRecording { commands, .. } => {
1009 self.recordings.insert(
1010 result,
1011 Recording::CompleteRecording {
1012 streams: HashSet::new(),
1013 },
1014 );
1015 commands
1016 }
1017 };
1018
1019 for command in commands {
1020 WorkerMessageHandler::handle(self, cx, command).await?;
1021 }
1022
1023 match self.recordings.get(&result).unwrap() {
1024 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1025 Recording::CompleteRecording { streams, .. } => {
1026 for stream in streams {
1027 self.try_get_stream(*stream)?
1028 .finalize_recording(cx, result)
1029 .await?;
1030 }
1031 }
1032 }
1033
1034 self.defining_recording = None;
1035 Ok(())
1036 }
1037
1038 async fn recording_formal(
1039 &mut self,
1040 cx: &hyperactor::Context<Self>,
1041 result: Ref,
1042 argument_index: usize,
1043 stream: StreamRef,
1044 ) -> Result<()> {
1045 ensure!(self.defining_recording.is_some());
1046 self.maybe_add_stream_to_recording(cx, stream).await?;
1047 self.try_get_stream(stream)?
1048 .recording_formal(cx, result, argument_index)
1049 .await
1050 }
1051
1052 async fn recording_result(
1053 &mut self,
1054 cx: &hyperactor::Context<Self>,
1055 result: Ref,
1056 output_index: usize,
1057 stream: StreamRef,
1058 ) -> Result<()> {
1059 ensure!(self.defining_recording.is_some());
1060 self.maybe_add_stream_to_recording(cx, stream).await?;
1061 self.try_get_stream(stream)?
1062 .recording_result(cx, result, output_index)
1063 .await
1064 }
1065
1066 async fn call_recording(
1067 &mut self,
1068 cx: &hyperactor::Context<Self>,
1069 seq: Seq,
1070 recording: Ref,
1071 results: Vec<Ref>,
1072 actuals: Vec<Ref>,
1073 ) -> Result<()> {
1074 ensure!(self.defining_recording.is_none());
1075 let recording_ref = recording;
1076 let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1077 "could not find recording: {:#?}",
1078 recording
1079 ))?;
1080 match recording {
1081 Recording::PartialRecording { .. } => {
1082 bail!("cannot call recording because it is incomplete")
1083 }
1084 Recording::CompleteRecording { streams } => try_join_all(
1085 streams
1086 .iter()
1087 .map(|stream| self.try_get_stream(*stream))
1088 .collect::<Result<Vec<_>>>()?
1089 .into_iter()
1090 .map(|stream| {
1091 stream.call_recording(
1092 cx,
1093 seq,
1094 recording_ref,
1095 results.clone(),
1096 actuals.clone(),
1097 )
1098 }),
1099 )
1100 .await
1101 .map(|_| ()),
1102 }
1103 }
1104}
1105
1106#[cfg(all(test, fbcode_build))]
1107mod tests {
1108 use std::assert_matches;
1109
1110 use anyhow::Result;
1111 use hyperactor::RemoteSpawn;
1112 use hyperactor::channel::ChannelAddr;
1113 use hyperactor::proc::Proc;
1114 use monarch_messages::controller::ControllerMessage;
1115 use monarch_messages::controller::WorkerError;
1116 use monarch_messages::worker::WorkerMessageClient;
1117 use monarch_types::PickledPyObject;
1118 use pyo3::Python;
1119 use pyo3::prelude::*;
1120 use pyo3::types::PyList;
1121 use pyo3::types::PyString;
1122 use rand::RngExt as _;
1123 use rand::distr::Alphanumeric;
1124 use timed_test::async_timed_test;
1125 use torch_sys_cuda::nccl::UniqueIdExt;
1126
1127 use super::*;
1128 use crate::test_util::test_setup;
1129
1130 #[async_timed_test(timeout_secs = 60)]
1131 async fn basic_worker() -> Result<()> {
1132 test_setup()?;
1133
1134 let proc = Proc::isolated();
1135 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1136
1137 let worker_handle = proc
1138 .spawn(
1139 "worker",
1140 WorkerActor::new(
1141 WorkerParams {
1142 world_size: 1,
1143 rank: 0,
1144 device_index: None,
1145 controller_actor: controller_ref,
1146 },
1147 Flattrs::default(),
1148 )
1149 .await
1150 .unwrap(),
1151 )
1152 .unwrap();
1153 worker_handle
1154 .command_group(
1155 &client,
1156 vec![
1157 WorkerMessage::CreateStream {
1158 id: 1.into(),
1159 stream_creation: StreamCreationMode::UseDefaultStream,
1160 },
1161 WorkerMessage::CallFunction(CallFunctionParams {
1162 seq: 0.into(),
1163 results: vec![Some(0.into())],
1164 mutates: vec![],
1165 function: "torch.ops.aten.ones.default".into(),
1166 args_kwargs: ArgsKwargs::from_wire_values(
1167 vec![WireValue::IntList(vec![2, 3])],
1168 HashMap::new(),
1169 )
1170 .unwrap(),
1171 stream: 1.into(),
1172 remote_process_groups: vec![],
1173 }),
1174 WorkerMessage::CallFunction(CallFunctionParams {
1175 seq: 2.into(),
1176 results: vec![Some(Ref { id: 2 })],
1177 mutates: vec![0.into()],
1178 function: "torch.ops.aten.sub_.Scalar".into(),
1179 args_kwargs: ArgsKwargs::from_wire_values(
1180 vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1181 HashMap::new(),
1182 )
1183 .unwrap(),
1184 stream: 1.into(),
1185 remote_process_groups: vec![],
1186 }),
1187 WorkerMessage::CallFunction(CallFunctionParams {
1188 seq: 3.into(),
1189 results: vec![Some(Ref { id: 3 })],
1190 mutates: vec![],
1191 function: "torch.ops.aten.zeros.default".into(),
1192 args_kwargs: ArgsKwargs::from_wire_values(
1193 vec![WireValue::IntList(vec![2, 3])],
1194 HashMap::new(),
1195 )
1196 .unwrap(),
1197 stream: 1.into(),
1198 remote_process_groups: vec![],
1199 }),
1200 WorkerMessage::CallFunction(CallFunctionParams {
1201 seq: 4.into(),
1202 results: vec![Some(Ref { id: 4 })],
1203 mutates: vec![],
1204 function: "torch.ops.aten.allclose.default".into(),
1205 args_kwargs: ArgsKwargs::from_wire_values(
1206 vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1207 HashMap::new(),
1208 )
1209 .unwrap(),
1210 stream: 1.into(),
1211 remote_process_groups: vec![],
1212 }),
1213 ],
1214 )
1215 .await
1216 .unwrap();
1217
1218 let result: bool = worker_handle
1219 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1220 .await
1221 .unwrap()
1222 .unwrap()
1223 .unwrap()
1224 .try_into()
1225 .unwrap();
1226 worker_handle.drain_and_stop("test").unwrap();
1227 worker_handle.await;
1228 let error_responses = controller_rx.drain();
1229 assert!(
1230 error_responses.is_empty(),
1231 "Expected no error responses, got: {:#?}",
1232 error_responses
1233 );
1234 assert!(result);
1235
1236 Ok(())
1237 }
1238
1239 #[async_timed_test(timeout_secs = 60)]
1240 async fn error_sends_response() -> Result<()> {
1241 test_setup()?;
1242
1243 let proc = Proc::isolated();
1244 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1245
1246 let worker_handle = proc
1247 .spawn(
1248 "worker",
1249 WorkerActor::new(
1250 WorkerParams {
1251 world_size: 1,
1252 rank: 0,
1253 device_index: None,
1254 controller_actor: controller_ref,
1255 },
1256 Flattrs::default(),
1257 )
1258 .await
1259 .unwrap(),
1260 )
1261 .unwrap();
1262 worker_handle
1263 .command_group(
1264 &client,
1265 vec![
1266 WorkerMessage::CreateStream {
1267 id: 1.into(),
1268 stream_creation: StreamCreationMode::UseDefaultStream,
1269 },
1270 WorkerMessage::CallFunction(CallFunctionParams {
1271 seq: 0.into(),
1272 results: vec![Some(0.into())],
1273 mutates: vec![],
1274 function: "torch.ops.aten.rand.default".into(),
1275 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1276 stream: 1.into(),
1277 remote_process_groups: vec![],
1278 }),
1279 WorkerMessage::Exit { error: None },
1280 ],
1281 )
1282 .await
1283 .unwrap();
1284
1285 worker_handle.drain_and_stop("test").unwrap();
1286 worker_handle.await;
1287 let response_message = controller_rx.recv().await.unwrap();
1288 match response_message {
1289 ControllerMessage::RemoteFunctionFailed {
1290 seq,
1291 error: WorkerError { backtrace: msg, .. },
1292 } => {
1293 assert_eq!(seq, 0.into());
1294 assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1295 }
1296 _ => panic!("unexpected response {:#?}", response_message),
1297 }
1298
1299 Ok(())
1300 }
1301
1302 #[async_timed_test(timeout_secs = 60)]
1303 async fn mutated_refs_are_updated_with_error() -> Result<()> {
1304 test_setup()?;
1305
1306 let proc = Proc::isolated();
1307 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1308
1309 let worker_handle = proc
1310 .spawn(
1311 "worker",
1312 WorkerActor::new(
1313 WorkerParams {
1314 world_size: 1,
1315 rank: 0,
1316 device_index: None,
1317 controller_actor: controller_ref,
1318 },
1319 Flattrs::default(),
1320 )
1321 .await
1322 .unwrap(),
1323 )
1324 .unwrap();
1325 worker_handle
1326 .command_group(
1327 &client,
1328 vec![
1329 WorkerMessage::CreateStream {
1330 id: 1.into(),
1331 stream_creation: StreamCreationMode::UseDefaultStream,
1332 },
1333 WorkerMessage::SetRefUnitTestsOnly {
1334 reference: 0.into(),
1335 value: WireValue::Int(1),
1336 stream: 1.into(),
1337 },
1338 WorkerMessage::CallFunction(CallFunctionParams {
1339 seq: 0.into(),
1340 results: vec![Some(Ref { id: 2 })],
1341 mutates: vec![0.into()],
1342 function: "i.dont.exist".into(),
1343 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1344 stream: 1.into(),
1345 remote_process_groups: vec![],
1346 }),
1347 ],
1348 )
1349 .await
1350 .unwrap();
1351
1352 let result = worker_handle
1353 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1354 .await?;
1355
1356 worker_handle.drain_and_stop("test").unwrap();
1358 worker_handle.await;
1359
1360 let mutated_ref = result
1361 .context("no such ref")?
1362 .err()
1363 .context("expected error")?;
1364 assert!(mutated_ref.contains("failed to resolve function"));
1365
1366 let responses = controller_rx.drain();
1367 assert_eq!(
1368 responses.len(),
1369 1,
1370 "Expected one response, got: {:#?}",
1371 responses
1372 );
1373 Ok(())
1374 }
1375
1376 #[async_timed_test(timeout_secs = 60)]
1377 async fn accessing_errored_dependency() -> Result<()> {
1378 test_setup()?;
1379
1380 let proc = Proc::isolated();
1381 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1382
1383 let worker_handle = proc
1384 .spawn(
1385 "worker",
1386 WorkerActor::new(
1387 WorkerParams {
1388 world_size: 1,
1389 rank: 0,
1390 device_index: None,
1391 controller_actor: controller_ref,
1392 },
1393 Flattrs::default(),
1394 )
1395 .await
1396 .unwrap(),
1397 )
1398 .unwrap();
1399 worker_handle
1400 .command_group(
1401 &client,
1402 vec![
1403 WorkerMessage::CreateStream {
1404 id: 1.into(),
1405 stream_creation: StreamCreationMode::UseDefaultStream,
1406 },
1407 WorkerMessage::CallFunction(CallFunctionParams {
1408 seq: 0.into(),
1409 results: vec![Some(0.into())],
1410 mutates: vec![],
1411 function: "i.dont.exist".into(),
1412 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1413 stream: 1.into(),
1414 remote_process_groups: vec![],
1415 }),
1416 WorkerMessage::CallFunction(CallFunctionParams {
1417 seq: 1.into(),
1418 results: vec![Some(1.into())],
1419 mutates: vec![],
1420 function: "torch.ops.aten.sub_.Scalar".into(),
1421 args_kwargs: ArgsKwargs::from_wire_values(
1422 vec![WireValue::Ref(0.into())],
1423 HashMap::new(),
1424 )
1425 .unwrap(),
1426 stream: 1.into(),
1427 remote_process_groups: vec![],
1428 }),
1429 WorkerMessage::Exit { error: None },
1430 ],
1431 )
1432 .await
1433 .unwrap();
1434
1435 worker_handle.drain_and_stop("test").unwrap();
1436 worker_handle.await;
1437
1438 let responses = controller_rx.drain();
1439 assert_eq!(
1440 responses.len(),
1441 1,
1442 "Expected one response, got: {:#?}",
1443 responses
1444 );
1445
1446 match &responses[0] {
1447 ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1448 assert_eq!(seq, &0.into())
1449 }
1450 _ => panic!("unexpected response {:#?}", responses[0]),
1451 };
1452 Ok(())
1453 }
1454
1455 #[async_timed_test(timeout_secs = 60)]
1456 async fn py_remote_function_calls() -> Result<()> {
1457 test_setup()?;
1458
1459 let proc = Proc::isolated();
1460 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1461
1462 let worker_handle = proc
1463 .spawn(
1464 "worker",
1465 WorkerActor::new(
1466 WorkerParams {
1467 world_size: 1,
1468 rank: 0,
1469 device_index: None,
1470 controller_actor: controller_ref,
1471 },
1472 Flattrs::default(),
1473 )
1474 .await
1475 .unwrap(),
1476 )
1477 .unwrap();
1478 let (split_arg, sort_list, dim, layout, none, scalar, device, memory_format) =
1479 Python::attach(|py| {
1480 let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1481 .into_any()
1482 .try_into()?;
1483 let sort_list: PickledPyObject =
1484 PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1485 let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1486 let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1487 let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1488 let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1489 let device: PickledPyObject = py
1490 .import("torch")?
1491 .getattr("device")?
1492 .call1(("cuda:1",))?
1493 .try_into()?;
1494 let memory_format: PickledPyObject = py
1495 .import("torch")?
1496 .getattr("contiguous_format")?
1497 .try_into()?;
1498 PyResult::Ok((
1499 split_arg,
1500 sort_list,
1501 dim,
1502 layout,
1503 none,
1504 scalar,
1505 device,
1506 memory_format,
1507 ))
1508 })?;
1509
1510 worker_handle
1511 .command_group(
1512 &client,
1513 vec![
1514 WorkerMessage::CreateStream {
1515 id: 1.into(),
1516 stream_creation: StreamCreationMode::UseDefaultStream,
1517 },
1518 WorkerMessage::CallFunction(CallFunctionParams {
1519 seq: 0.into(),
1520 results: vec![Some(0.into()), Some(Ref { id: 2 })],
1521 mutates: vec![],
1522 function: "os.path.split".into(),
1523 args_kwargs: ArgsKwargs::from_wire_values(
1524 vec![split_arg.into()],
1525 HashMap::new(),
1526 )
1527 .unwrap(),
1528 stream: 1.into(),
1529 remote_process_groups: vec![],
1530 }),
1531 WorkerMessage::CallFunction(CallFunctionParams {
1532 seq: 2.into(),
1533 results: vec![Some(4.into()), None, None, None, None],
1534 mutates: vec![],
1535 function: "builtins.sorted".into(),
1536 args_kwargs: ArgsKwargs::from_wire_values(
1537 vec![sort_list.into()],
1538 HashMap::new(),
1539 )
1540 .unwrap(),
1541 stream: 1.into(),
1542 remote_process_groups: vec![],
1543 }),
1544 WorkerMessage::CreateDeviceMesh {
1545 result: 5.into(),
1546 names: vec!["x".into()],
1547 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1548 },
1549 WorkerMessage::CallFunction(CallFunctionParams {
1550 seq: 2.into(),
1551 results: vec![Some(6.into())],
1552 mutates: vec![],
1553 function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1554 args_kwargs: ArgsKwargs::from_wire_values(
1555 vec![WireValue::Ref(Ref { id: 5 }), dim.into()],
1556 HashMap::new(),
1557 )
1558 .unwrap(),
1559 stream: 1.into(),
1560 remote_process_groups: vec![],
1561 }),
1562 WorkerMessage::CallFunction(CallFunctionParams {
1563 seq: 4.into(),
1564 results: vec![Some(7.into())],
1565 mutates: vec![],
1566 function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1567 .into(),
1568 args_kwargs: ArgsKwargs::from_wire_values(
1569 vec![scalar.into()],
1570 HashMap::new(),
1571 )
1572 .unwrap(),
1573 stream: 1.into(),
1574 remote_process_groups: vec![],
1575 }),
1576 WorkerMessage::CallFunction(CallFunctionParams {
1577 seq: 5.into(),
1578 results: vec![Some(8.into())],
1579 mutates: vec![],
1580 function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1581 args_kwargs: ArgsKwargs::from_wire_values(
1582 vec![layout.into()],
1583 HashMap::new(),
1584 )
1585 .unwrap(),
1586 stream: 1.into(),
1587 remote_process_groups: vec![],
1588 }),
1589 WorkerMessage::CallFunction(CallFunctionParams {
1590 seq: 6.into(),
1591 results: vec![Some(9.into())],
1592 mutates: vec![],
1593 function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1594 args_kwargs: ArgsKwargs::from_wire_values(
1595 vec![none.into()],
1596 HashMap::new(),
1597 )
1598 .unwrap(),
1599 stream: 1.into(),
1600 remote_process_groups: vec![],
1601 }),
1602 WorkerMessage::CallFunction(CallFunctionParams {
1605 seq: 7.into(),
1606 results: vec![None],
1607 mutates: vec![],
1608 function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1609 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1610 stream: 1.into(),
1611 remote_process_groups: vec![],
1612 }),
1613 WorkerMessage::CallFunction(CallFunctionParams {
1614 seq: 8.into(),
1615 results: vec![Some(10.into())],
1616 mutates: vec![],
1617 function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1618 args_kwargs: ArgsKwargs::from_wire_values(
1619 vec![device.into()],
1620 HashMap::new(),
1621 )
1622 .unwrap(),
1623 stream: 1.into(),
1624 remote_process_groups: vec![],
1625 }),
1626 WorkerMessage::CallFunction(CallFunctionParams {
1627 seq: 9.into(),
1628 results: vec![Some(11.into())],
1629 mutates: vec![],
1630 function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1631 .into(),
1632 args_kwargs: ArgsKwargs::from_wire_values(
1633 vec![memory_format.into()],
1634 HashMap::new(),
1635 )
1636 .unwrap(),
1637 stream: 1.into(),
1638 remote_process_groups: vec![],
1639 }),
1640 WorkerMessage::CallFunction(CallFunctionParams {
1642 seq: 10.into(),
1643 results: vec![Some(12.into())],
1644 mutates: vec![],
1645 function: "torch.ops.aten.ones.default".into(),
1646 args_kwargs: ArgsKwargs::from_wire_values(
1647 vec![WireValue::IntList(vec![2, 3])],
1648 HashMap::new(),
1649 )
1650 .unwrap(),
1651 stream: 1.into(),
1652 remote_process_groups: vec![],
1653 }),
1654 WorkerMessage::CallFunction(CallFunctionParams {
1655 seq: 11.into(),
1656 results: vec![Some(13.into())],
1657 mutates: vec![],
1658 function: "torch.ops.aten.stack.default".into(),
1659 args_kwargs: ArgsKwargs::from_wire_values(
1660 vec![WireValue::RefList(vec![12.into(), 12.into()])],
1661 HashMap::new(),
1662 )
1663 .unwrap(),
1664 stream: 1.into(),
1665 remote_process_groups: vec![],
1666 }),
1667 ],
1668 )
1669 .await
1670 .unwrap();
1671
1672 let result1: String = worker_handle
1673 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1674 .await
1675 .unwrap()
1676 .unwrap()
1677 .unwrap()
1678 .try_into()
1679 .unwrap();
1680 let result2: String = worker_handle
1681 .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1682 .await
1683 .unwrap()
1684 .unwrap()
1685 .unwrap()
1686 .try_into()
1687 .unwrap();
1688 let result3: i64 = worker_handle
1689 .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1690 .await
1691 .unwrap()
1692 .unwrap()
1693 .unwrap()
1694 .try_into()
1695 .unwrap();
1696 let result4: i64 = worker_handle
1697 .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1698 .await
1699 .unwrap()
1700 .unwrap()
1701 .unwrap()
1702 .try_into()
1703 .unwrap();
1704 worker_handle
1705 .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1706 .await
1707 .unwrap()
1708 .unwrap()
1709 .unwrap();
1710
1711 worker_handle
1712 .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1713 .await
1714 .unwrap()
1715 .unwrap()
1716 .unwrap();
1717
1718 assert_matches!(
1719 worker_handle
1720 .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1721 .await
1722 .unwrap()
1723 .unwrap()
1724 .unwrap(),
1725 WireValue::None(()),
1726 );
1727 worker_handle
1728 .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1729 .await
1730 .unwrap()
1731 .unwrap()
1732 .unwrap();
1733 worker_handle
1734 .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1735 .await
1736 .unwrap()
1737 .unwrap()
1738 .unwrap();
1739
1740 worker_handle.drain_and_stop("test").unwrap();
1741 worker_handle.await;
1742 let error_responses = controller_rx.drain();
1743 assert!(
1744 error_responses.is_empty(),
1745 "Expected no error responses, got: {:#?}",
1746 error_responses
1747 );
1748
1749 assert_eq!(result1, "/fbs/fbc/foo");
1750 assert_eq!(result2, "bar");
1751 assert_eq!(result3, 1);
1752 assert_eq!(result4, 0);
1753
1754 Ok(())
1755 }
1756
1757 #[async_timed_test(timeout_secs = 60)]
1758 async fn delete_refs() -> Result<()> {
1759 test_setup()?;
1760
1761 let proc = Proc::isolated();
1762 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1763
1764 let worker_handle = proc
1765 .spawn(
1766 "worker",
1767 WorkerActor::new(
1768 WorkerParams {
1769 world_size: 1,
1770 rank: 0,
1771 device_index: None,
1772 controller_actor: controller_ref,
1773 },
1774 Flattrs::default(),
1775 )
1776 .await
1777 .unwrap(),
1778 )
1779 .unwrap();
1780 worker_handle
1781 .command_group(
1782 &client,
1783 vec![
1784 WorkerMessage::CreateStream {
1785 id: 0.into(),
1786 stream_creation: StreamCreationMode::CreateNewStream,
1787 },
1788 WorkerMessage::CreateStream {
1789 id: 1.into(),
1790 stream_creation: StreamCreationMode::CreateNewStream,
1791 },
1792 WorkerMessage::SetRefUnitTestsOnly {
1793 reference: Ref { id: 2 },
1794 value: WireValue::Bool(false),
1795 stream: 0.into(),
1796 },
1797 WorkerMessage::SetRefUnitTestsOnly {
1798 reference: Ref { id: 3 },
1799 value: WireValue::Bool(true),
1800 stream: 0.into(),
1801 },
1802 WorkerMessage::SetRefUnitTestsOnly {
1803 reference: Ref { id: 4 },
1804 value: WireValue::Int(0),
1805 stream: 1.into(),
1806 },
1807 WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1808 ],
1809 )
1810 .await
1811 .unwrap();
1812
1813 let result: bool = worker_handle
1814 .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1815 .await
1816 .unwrap()
1817 .unwrap()
1818 .unwrap()
1819 .try_into()
1820 .unwrap();
1821 let fail_result = worker_handle
1822 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1823 .await
1824 .unwrap();
1825
1826 worker_handle.drain_and_stop("test").unwrap();
1827 worker_handle.await;
1828
1829 assert!(result, "should be able to get a non-deleted ref");
1830 assert!(fail_result.is_none(), "should fail to get a deleted ref");
1831
1832 Ok(())
1833 }
1834
1835 #[async_timed_test(timeout_secs = 60)]
1836 async fn request_status() -> Result<()> {
1837 test_setup()?;
1838
1839 let proc = Proc::isolated();
1840 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1841
1842 let worker_handle = proc
1843 .spawn(
1844 "worker",
1845 WorkerActor::new(
1846 WorkerParams {
1847 world_size: 1,
1848 rank: 0,
1849 device_index: None,
1850 controller_actor: controller_ref,
1851 },
1852 Flattrs::default(),
1853 )
1854 .await
1855 .unwrap(),
1856 )
1857 .unwrap();
1858 worker_handle
1859 .command_group(
1860 &client,
1861 vec![
1862 WorkerMessage::CreateStream {
1863 id: 0.into(),
1864 stream_creation: StreamCreationMode::CreateNewStream,
1865 },
1866 WorkerMessage::CreateStream {
1867 id: 1.into(),
1868 stream_creation: StreamCreationMode::CreateNewStream,
1869 },
1870 ],
1871 )
1872 .await
1873 .unwrap();
1874
1875 for i in 0..100 {
1876 worker_handle
1878 .call_function(
1879 &client,
1880 CallFunctionParams {
1881 seq: i.into(),
1882 results: vec![Some(Ref { id: i + 2 })],
1883 mutates: vec![],
1884 function: "torch.ops.aten.ones.default".into(),
1885 args_kwargs: ArgsKwargs::from_wire_values(
1886 vec![WireValue::IntList(vec![2, 3])],
1887 HashMap::new(),
1888 )
1889 .unwrap(),
1890 stream: (i % 2).into(),
1891 remote_process_groups: vec![],
1892 },
1893 )
1894 .await
1895 .unwrap();
1896 }
1897
1898 worker_handle
1899 .request_status(&client, 100.into(), false)
1900 .await
1901 .unwrap();
1902
1903 worker_handle.drain_and_stop("test").unwrap();
1904 worker_handle.await;
1905
1906 let mut responses = controller_rx.drain();
1907 assert_eq!(
1908 responses.len(),
1909 1,
1910 "Expected one response, got: {:#?}",
1911 responses
1912 );
1913
1914 let response = responses.pop().unwrap();
1915 match response {
1916 ControllerMessage::Status { seq, .. } => {
1917 assert_eq!(seq, 101.into())
1918 }
1919 _ => panic!("unexpected response {:#?}", response),
1920 };
1921
1922 Ok(())
1923 }
1924
1925 #[async_timed_test(timeout_secs = 60)]
1926 async fn backend_network_init() {
1927 test_setup().unwrap();
1928 let proc = Proc::isolated();
1929 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1930
1931 let worker_handle1 = proc
1932 .spawn(
1933 "worker0",
1934 WorkerActor::new(
1935 WorkerParams {
1936 world_size: 2,
1937 rank: 0,
1938 device_index: Some(0),
1939 controller_actor: controller_ref.clone(),
1940 },
1941 Flattrs::default(),
1942 )
1943 .await
1944 .unwrap(),
1945 )
1946 .unwrap();
1947 let worker_handle2 = proc
1948 .spawn(
1949 "worker1",
1950 WorkerActor::new(
1951 WorkerParams {
1952 world_size: 2,
1953 rank: 1,
1954 device_index: Some(1),
1955 controller_actor: controller_ref,
1956 },
1957 Flattrs::default(),
1958 )
1959 .await
1960 .unwrap(),
1961 )
1962 .unwrap();
1963
1964 let unique_id = UniqueId::new_nccl().unwrap();
1965 worker_handle1
1966 .backend_network_init(&client, unique_id.clone())
1967 .await
1968 .unwrap();
1969 worker_handle2
1970 .backend_network_init(&client, unique_id)
1971 .await
1972 .unwrap();
1973
1974 worker_handle1.drain_and_stop("test").unwrap();
1975 worker_handle1.await;
1976 worker_handle2.drain_and_stop("test").unwrap();
1977 worker_handle2.await;
1978 }
1979
1980 #[allow(dead_code)]
1981 fn get_random_channel_addr() -> ChannelAddr {
1982 let random_string = rand::rng()
1983 .sample_iter(&Alphanumeric)
1984 .take(24)
1985 .map(char::from)
1986 .collect::<String>();
1987 format!("unix!@{random_string}").parse().unwrap()
1988 }
1989}