1#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12#![allow(unsafe_op_in_unsafe_fn)]
15
16mod borrow;
31mod comm;
32pub mod device_mesh;
33pub mod stream;
34pub mod test_util;
35
36use std::collections::HashMap;
37use std::collections::HashSet;
38use std::collections::hash_map::Entry;
39use std::sync::Arc;
40
41use anyhow::Context;
42use anyhow::Result;
43use anyhow::anyhow;
44use anyhow::bail;
45use anyhow::ensure;
46use async_trait::async_trait;
47use borrow::Borrow;
48use comm::CommMessageClient;
49use comm::CommParams;
50use comm::NcclCommActor;
51use derive_more::TryInto;
52use device_mesh::DeviceMesh;
53use futures::future::try_join_all;
54use hyperactor::Actor;
55use hyperactor::Bind;
56use hyperactor::Handler;
57use hyperactor::RemoteSpawn;
58use hyperactor::Unbind;
59use hyperactor::actor::ActorHandle;
60use hyperactor::context;
61use hyperactor::reference;
62use hyperactor_config::Flattrs;
63use hyperactor_mesh::comm::multicast::CastInfo;
64use itertools::Itertools;
65use monarch_hyperactor::shape::PyPoint;
66use monarch_messages::controller::ControllerActor;
67use monarch_messages::controller::ControllerMessageClient;
68use monarch_messages::controller::Seq;
69use monarch_messages::wire_value::WireValue;
70use monarch_messages::worker::ActorCallParams;
71use monarch_messages::worker::ActorMethodParams;
72use monarch_messages::worker::ArgsKwargs;
73use monarch_messages::worker::CallFunctionParams;
74use monarch_messages::worker::Factory;
75use monarch_messages::worker::Reduction;
76use monarch_messages::worker::Ref;
77use monarch_messages::worker::ResolvableFunction;
78use monarch_messages::worker::StreamCreationMode;
79use monarch_messages::worker::StreamRef;
80use monarch_messages::worker::WorkerMessage;
81use monarch_messages::worker::WorkerMessageHandler;
82use monarch_messages::worker::WorkerParams;
83use ndslice::Slice;
84use pyo3::Python;
85use pyo3::types::PyAnyMethods;
86use serde::Deserialize;
87use serde::Serialize;
88use sorted_vec::SortedVec;
89use stream::StreamActor;
90use stream::StreamMessageClient;
91use stream::StreamParams;
92use torch_sys_cuda::nccl::ReduceOp;
93use torch_sys_cuda::nccl::UniqueId;
94use torch_sys2::CudaDevice;
95use torch_sys2::DeviceIndex;
96use torch_sys2::Layout;
97use torch_sys2::ScalarType;
98use torch_sys2::TensorCell;
99use torch_sys2::factory_zeros;
100use typeuri::Named;
101
102#[derive(Debug)]
103struct RemoteProcessGroupState {
104 device_mesh_ref: Ref,
105 dims: SortedVec<String>,
106 comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
107}
108
109impl RemoteProcessGroupState {
110 fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
111 Self {
112 device_mesh_ref,
113 dims,
114 comms: HashMap::new(),
115 }
116 }
117}
118
119#[derive(Debug)]
120enum Recording {
121 PartialRecording {
124 last_index: usize,
126 commands: Vec<WorkerMessage>,
128 },
129
130 CompleteRecording {
132 streams: HashSet<StreamRef>,
134 },
135}
136
137#[derive(Debug)]
145#[hyperactor::export(
146 spawn = true,
147 handlers = [
148 WorkerMessage {cast = true},
149 AssignRankMessage {cast = true},
150 ],
151)]
152pub struct WorkerActor {
153 device: Option<CudaDevice>,
154 streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
155 device_meshes: HashMap<
158 Ref,
159 (
160 DeviceMesh,
161 HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
163 ),
164 >,
165 world_size: usize,
166 rank: usize,
167 borrows: HashMap<u64, Borrow>,
168 comm: Option<ActorHandle<NcclCommActor>>,
169 controller_actor: reference::ActorRef<ControllerActor>,
170 remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
174 send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
176 recordings: HashMap<Ref, Recording>,
177 defining_recording: Option<Ref>,
178 respond_with_python_message: bool,
179}
180
181impl WorkerActor {
182 fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
183 self.streams
184 .get(&stream)
185 .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
186 }
187
188 async fn maybe_add_stream_to_recording(
189 &mut self,
190 cx: &impl context::Actor,
191 stream: StreamRef,
192 ) -> Result<()> {
193 if let Some(defining_recording) = self.defining_recording {
196 let recording = self.recordings.get_mut(&defining_recording).unwrap();
197 let fut = match recording {
198 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
199 Recording::CompleteRecording { streams } => {
200 streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
201 Ok(self
202 .try_get_stream(stream)?
203 .define_recording(cx, defining_recording))
204 })
205 }
206 }
207 .transpose()?;
208 match fut {
209 Some(fut) => fut.await,
210 None => Ok(()),
211 }
212 } else {
213 Ok(())
214 }
215 }
216}
217
218impl Actor for WorkerActor {}
219
220#[async_trait]
221impl RemoteSpawn for WorkerActor {
222 type Params = WorkerParams;
223
224 async fn new(
225 WorkerParams {
226 world_size,
227 rank,
228 device_index,
229 controller_actor,
230 }: Self::Params,
231 _environment: Flattrs,
232 ) -> Result<Self> {
233 Python::attach(|py| {
234 py.import("monarch.safe_torch").unwrap();
235 });
236 Ok(Self {
237 device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
238 streams: HashMap::new(),
239 device_meshes: HashMap::new(),
240 world_size,
241 rank,
242 borrows: HashMap::new(),
243 comm: None,
244 controller_actor,
245 remote_process_groups: HashMap::new(),
246 send_recv_comms: HashMap::new(),
247 recordings: HashMap::new(),
248 defining_recording: None,
249 respond_with_python_message: false,
250 })
251 }
252
253 }
255
256#[async_trait]
257impl Handler<AssignRankMessage> for WorkerActor {
258 async fn handle(
259 &mut self,
260 cx: &hyperactor::Context<Self>,
261 _: AssignRankMessage,
262 ) -> anyhow::Result<()> {
263 let point = cx.cast_point();
264 self.rank = point.rank();
265 self.respond_with_python_message = true;
266 Python::attach(|py| {
267 let mesh_controller = py.import("monarch.mesh_controller").unwrap();
268 let p: PyPoint = point.into();
269 mesh_controller
270 .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
271 .unwrap();
272 });
273 Ok(())
274 }
275}
276
277#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
280pub enum AssignRankMessage {
281 AssignRank(),
282}
283wirevalue::register_type!(AssignRankMessage);
284
285#[async_trait]
286impl Handler<WorkerMessage> for WorkerActor {
287 async fn handle(
288 &mut self,
289 cx: &hyperactor::Context<Self>,
290 message: WorkerMessage,
291 ) -> anyhow::Result<()> {
292 <Self as WorkerMessageHandler>::handle(self, cx, message).await
293 }
294}
295
296#[async_trait]
297impl WorkerMessageHandler for WorkerActor {
298 async fn backend_network_init(
299 &mut self,
300 cx: &hyperactor::Context<Self>,
301 unique_id: UniqueId,
302 ) -> Result<()> {
303 let device = self
304 .device
305 .expect("tried to init backend network on a non-CUDA worker");
306 let comm = NcclCommActor::new(CommParams::New {
307 device,
308 unique_id,
309 world_size: self.world_size.try_into().unwrap(),
310 rank: self.rank.try_into().unwrap(),
311 })
312 .await?
313 .spawn(cx)?;
314
315 let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
316 let cell = TensorCell::new(tensor);
317
318 comm.all_reduce(
319 cx,
320 cell,
321 ReduceOp::Sum,
322 torch_sys_cuda::cuda::Stream::get_current_stream(),
323 )
324 .await?;
325
326 let sorted_streams = self
335 .streams
336 .iter()
337 .sorted_by_key(|(k, _)| *k)
338 .map(|(_, v)| v.as_ref());
339
340 let mut splits = Vec::new();
341 for _ in 0..sorted_streams.len() {
342 splits.push(comm.split_all(cx).await?);
345 }
346 let _: Vec<()> = try_join_all(
347 sorted_streams
348 .into_iter()
349 .zip(splits.into_iter())
350 .map(|(stream, split)| stream.init_comm(cx, split)),
351 )
352 .await?;
353
354 self.comm = Some(comm);
355
356 Ok(())
357 }
358
359 async fn backend_network_point_to_point_init(
360 &mut self,
361 cx: &hyperactor::Context<Self>,
362 from_stream: StreamRef,
363 to_stream: StreamRef,
364 ) -> Result<()> {
365 if !self.streams.contains_key(&from_stream) {
366 bail!("invalid from_stream id: {:#?}", from_stream);
367 }
368 if !self.streams.contains_key(&to_stream) {
369 bail!("invalid to_stream id: {:#?}", to_stream);
370 }
371 let global_comm = self
372 .comm
373 .as_ref()
374 .context("tried to call Reduce before BackendNetworkInit")?;
375 let comm = global_comm.split_all(cx).await?;
376 self.send_recv_comms
377 .insert((from_stream, to_stream), Arc::new(comm));
378 Ok(())
379 }
380
381 async fn call_function(
382 &mut self,
383 cx: &hyperactor::Context<Self>,
384 params: CallFunctionParams,
385 ) -> Result<()> {
386 let stream = self.try_get_stream(params.stream)?.clone();
387 self.maybe_add_stream_to_recording(cx, params.stream)
388 .await?;
389
390 let device_meshes = self
391 .device_meshes
392 .iter()
393 .map(|(k, v)| (k.clone(), v.0.clone()))
394 .collect();
395
396 let mut remote_process_groups = HashMap::new();
397 for remote_process_group_ref in ¶ms.remote_process_groups {
398 if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
399 let dims_vec = state.dims.iter().cloned().collect();
400 let (device_mesh, _) = self
401 .device_meshes
402 .get(&state.device_mesh_ref)
403 .ok_or_else(|| {
404 anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
405 })?
406 .clone();
407 let comm = state.comms
408 .get(¶ms.stream)
409 .ok_or_else(|| {
410 anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
411 })?
412 .clone();
413 remote_process_groups.insert(
414 remote_process_group_ref.clone(),
415 (device_mesh, dims_vec, comm),
416 );
417 }
418 }
419
420 stream
421 .call_function(cx, params, device_meshes, remote_process_groups)
422 .await?;
423
424 Ok(())
425 }
426
427 async fn command_group(
428 &mut self,
429 cx: &hyperactor::Context<Self>,
430 params: Vec<WorkerMessage>,
431 ) -> Result<()> {
432 for msg in params {
433 WorkerMessageHandler::handle(self, cx, msg).await?;
434 }
435 Ok(())
436 }
437
438 async fn create_stream(
439 &mut self,
440 cx: &hyperactor::Context<Self>,
441 result: StreamRef,
442 creation_mode: StreamCreationMode,
443 ) -> Result<()> {
444 let handle: ActorHandle<StreamActor> = StreamActor::new(StreamParams {
445 world_size: self.world_size,
446 rank: self.rank,
447 creation_mode,
448 id: result,
449 device: self.device,
450 controller_actor: self.controller_actor.clone(),
451 respond_with_python_message: self.respond_with_python_message,
452 })
453 .spawn(cx)?;
454 self.streams.insert(result, Arc::new(handle));
455 Ok(())
456 }
457
458 async fn create_device_mesh(
459 &mut self,
460 _cx: &hyperactor::Context<Self>,
461 result: Ref,
462 names: Vec<String>,
463 ranks: Slice,
464 ) -> Result<()> {
465 self.device_meshes.insert(
466 result,
467 (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
468 );
469 Ok(())
470 }
471
472 async fn create_remote_process_group(
473 &mut self,
474 _cx: &hyperactor::Context<Self>,
475 result: Ref,
476 device_mesh: Ref,
477 dims: Vec<String>,
478 ) -> Result<()> {
479 self.device_meshes
480 .get(&device_mesh)
481 .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
482 match self.remote_process_groups.entry(result) {
483 Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
484 device_mesh,
485 SortedVec::from_unsorted(dims),
486 )),
487 Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
488 };
489 Ok(())
490 }
491
492 async fn borrow_create(
493 &mut self,
494 cx: &hyperactor::Context<Self>,
495 result: Ref,
496 borrow_id: u64,
497 tensor_ref: Ref,
498 from_stream: StreamRef,
499 to_stream: StreamRef,
500 ) -> Result<()> {
501 self.maybe_add_stream_to_recording(cx, from_stream).await?;
502 self.maybe_add_stream_to_recording(cx, to_stream).await?;
503 let from_stream = self.try_get_stream(from_stream)?.clone();
504 let to_stream = self.try_get_stream(to_stream)?.clone();
505
506 let borrow =
507 Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
508 self.borrows.insert(borrow_id, borrow);
509 Ok(())
510 }
511
512 async fn borrow_first_use(
513 &mut self,
514 cx: &hyperactor::Context<Self>,
515 borrow: u64,
516 ) -> Result<()> {
517 let borrow = self
518 .borrows
519 .get_mut(&borrow)
520 .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
521
522 borrow.first_use(cx).await?;
523 Ok(())
524 }
525
526 async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
527 let borrow = self
528 .borrows
529 .get_mut(&borrow)
530 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
531
532 borrow.last_use(cx).await?;
533 Ok(())
534 }
535
536 async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
537 let borrow = self
538 .borrows
539 .get_mut(&borrow_id)
540 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
541
542 borrow.drop(cx).await?;
543 self.borrows.remove(&borrow_id);
544 Ok(())
545 }
546
547 async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
548 let _: Vec<()> = try_join_all(
555 self.streams
556 .values()
557 .map(|s| s.delete_refs(cx, refs.clone())),
558 )
559 .await?;
560 Ok(())
561 }
562
563 async fn request_status(
564 &mut self,
565 cx: &hyperactor::Context<Self>,
566 seq: Seq,
567 controller: bool,
568 ) -> Result<()> {
569 let _: Vec<()> = try_join_all(
574 self.streams
575 .values()
576 .map(|stream| stream.request_status(cx)),
577 )
578 .await?;
579
580 ControllerMessageClient::status(
581 &self.controller_actor,
582 cx,
583 seq.next(),
584 cx.self_id().clone(),
585 controller,
586 )
587 .await?;
588 Ok(())
589 }
590
591 async fn reduce(
592 &mut self,
593 cx: &hyperactor::Context<Self>,
594 result: Ref,
595 local_tensor: Ref,
596 factory: Factory,
597 source_mesh: Ref,
598 stream_ref: StreamRef,
599 dims: Vec<String>,
600 reduction: Reduction,
601 scatter: bool,
602 in_place: bool,
603 out: Option<Ref>,
604 ) -> Result<()> {
605 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
606
607 let dims = SortedVec::from_unsorted(dims);
609 let stream = self.try_get_stream(stream_ref)?.clone();
610
611 let (_, comm_map) = self
612 .device_meshes
613 .get_mut(&source_mesh)
614 .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
615
616 let (size, comm) = comm_map
617 .get(&(stream_ref, dims.clone()))
618 .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
619 .clone();
620
621 stream
622 .reduce(
623 cx,
624 comm,
625 size.try_into()?,
626 result,
627 local_tensor,
628 factory,
629 reduction,
630 scatter,
631 in_place,
632 out,
633 )
634 .await?;
635
636 Ok(())
637 }
638
639 async fn send_tensor(
640 &mut self,
641 cx: &hyperactor::Context<Self>,
642 result: Ref,
643 from_ranks: Slice,
644 to_ranks: Slice,
645 tensor: Ref,
646 factory: Factory,
647 from_stream: StreamRef,
648 to_stream: StreamRef,
649 ) -> Result<()> {
650 let comm = self
651 .send_recv_comms
652 .get(&(from_stream, to_stream))
653 .ok_or_else(|| {
654 anyhow::anyhow!(
655 "could not find stream to stream comm for: {:#?}",
656 (from_stream, to_stream)
657 )
658 })?
659 .clone();
660
661 let to_rank = from_ranks
662 .index(self.rank)
663 .map(|index| to_ranks.get(index).ok())
664 .ok()
665 .flatten();
666 let from_rank = to_ranks
667 .index(self.rank)
668 .map(|index| from_ranks.get(index).ok())
669 .ok()
670 .flatten();
671
672 let (stream, stream_ref) = if to_rank.is_none() {
673 (self.try_get_stream(to_stream)?.clone(), to_stream)
674 } else if from_rank.is_none() || from_stream == to_stream {
675 (self.try_get_stream(from_stream)?.clone(), from_stream)
676 } else {
677 unimplemented!(
678 "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
679 It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
680 Then the send stream would do the nccl op, and then sync with sending stream again."
681 );
682 };
683
684 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
685
686 stream
687 .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
688 .await?;
689
690 Ok(())
691 }
692
693 async fn exit(
694 &mut self,
695 cx: &hyperactor::Context<Self>,
696 error: Option<(Option<reference::ActorId>, String)>,
697 ) -> Result<()> {
698 for (_, stream) in self.streams.drain() {
699 stream.drain_and_stop("tensor worker exit cleanup")?;
700 Arc::into_inner(stream)
701 .expect("there should be no owners of this stream handle except the worker stream table")
702 .await;
703 }
704
705 let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
706 .ok()
707 .and_then(|val| val.parse::<i32>().ok())
708 .unwrap_or(1);
709 let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
710 .ok()
711 .and_then(|val| val.parse::<i32>().ok())
712 .unwrap_or(1);
713
714 let exit_code = match error {
716 Some((Some(actor_id), reason)) => {
717 tracing::error!(
718 "stopping the worker, actor {} failed with error: {}",
719 actor_id,
720 reason
721 );
722 if *cx.self_id() == actor_id {
723 self_error_exit_code
724 } else {
725 peer_error_exit_code
726 }
727 }
728 Some((None, reason)) => {
729 tracing::error!("stopping the worker, reason: {}", reason);
730 1
731 }
732 None => 0,
733 };
734
735 if exit_code != 0 {
736 tracing::info!("stopping the worker process, exit code: {}", exit_code);
737 std::process::exit(exit_code);
738 }
739 cx.stop("tensor worker exit")?;
740 Ok(())
741 }
742
743 async fn send_value(
744 &mut self,
745 cx: &hyperactor::Context<Self>,
746 seq: Seq,
747 destination: Option<Ref>,
748 mutates: Vec<Ref>,
749 function: Option<ResolvableFunction>,
750 args_kwargs: ArgsKwargs,
751 stream: StreamRef,
752 ) -> Result<()> {
753 let stream = self.try_get_stream(stream)?;
755
756 let device_meshes = if function.is_none() {
757 HashMap::new()
758 } else {
759 self.device_meshes
760 .iter()
761 .map(|(k, v)| (k.clone(), v.0.clone()))
762 .collect()
763 };
764
765 if destination.is_some() {
766 panic!("send_value with pipe destination is no longer implemented")
767 }
768
769 stream
771 .send_value(
772 cx,
773 seq,
774 cx.self_id().clone(),
775 mutates,
776 function,
777 args_kwargs,
778 device_meshes,
779 )
780 .await
781 }
782
783 async fn send_result_of_actor_call(
784 &mut self,
785 cx: &hyperactor::Context<Self>,
786 params: ActorCallParams,
787 ) -> Result<()> {
788 let stream = self.try_get_stream(params.stream)?;
789 stream.send_result_of_actor_call(cx, params).await?;
790 Ok(())
791 }
792 async fn call_actor_method(
793 &mut self,
794 cx: &hyperactor::Context<Self>,
795 params: ActorMethodParams,
796 ) -> Result<()> {
797 let stream = self.try_get_stream(params.call.stream)?;
798 stream.call_actor_method(cx, params).await?;
799 Ok(())
800 }
801 async fn split_comm(
802 &mut self,
803 cx: &hyperactor::Context<Self>,
804 dims: Vec<String>,
805 device_mesh: Ref,
806 stream_ref: StreamRef,
807 ) -> Result<()> {
808 let global_comm = self
809 .comm
810 .as_ref()
811 .context("tried to call SplitComm before BackendNetworkInit")?;
812 match self.device_meshes.get_mut(&device_mesh) {
813 Some((device_mesh, comm_map)) => {
814 let stream = self
817 .streams
818 .get(&stream_ref)
819 .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
820
821 let dims = SortedVec::from_unsorted(dims);
822
823 anyhow::ensure!(
824 !comm_map.contains_key(&(stream_ref, dims.clone())),
825 "comm already exists for stream {stream:#?}, dims {dims:#?}"
826 );
827 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
828 let size = ranks_for_group.len();
829 let split_comm = global_comm
830 .split_from(
831 cx,
832 ranks_for_group
833 .into_iter()
834 .map(|v| v.clone().try_into())
835 .collect::<Result<Vec<_>, _>>()?,
836 )
837 .await?
838 .context("split comm should include self rank")?;
839 comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
840 }
841 None => {
842 global_comm.split_from(cx, vec![]).await?;
845 }
846 }
847 Ok(())
848 }
849
850 async fn split_comm_for_process_group(
851 &mut self,
852 cx: &hyperactor::Context<Self>,
853 remote_process_group_ref: Ref,
854 stream_ref: StreamRef,
855 ) -> Result<()> {
856 ensure!(
857 self.streams.contains_key(&stream_ref),
858 "invalid stream id: {:#?}",
859 stream_ref
860 );
861 let global_comm = self
862 .comm
863 .as_ref()
864 .context("tried to call SplitComm before BackendNetworkInit")?;
865 let state = self
866 .remote_process_groups
867 .get_mut(&remote_process_group_ref)
868 .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
869 match self.device_meshes.get_mut(&state.device_mesh_ref) {
870 Some((device_mesh, _)) => {
871 let entry = match state.comms.entry(stream_ref) {
874 Entry::Vacant(entry) => entry,
875 Entry::Occupied(_) => bail!(
876 "comm already exists for remote process group {:#?} on stream {:#?}",
877 remote_process_group_ref,
878 stream_ref,
879 ),
880 };
881 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
882 let split_comm = global_comm
883 .split_from(
884 cx,
885 ranks_for_group
886 .into_iter()
887 .map(|v| v.clone().try_into())
888 .collect::<Result<Vec<_>, _>>()?,
889 )
890 .await?
891 .context("split comm should include self rank")?;
892 entry.insert(Arc::new(split_comm));
893 }
894 None => {
895 global_comm.split_from(cx, vec![]).await?;
898 }
899 }
900 Ok(())
901 }
902
903 async fn pipe_recv(
904 &mut self,
905 _cx: &hyperactor::Context<Self>,
906 _seq: Seq,
907 _results: Vec<Option<Ref>>,
908 _pipe: Ref,
909 _stream: StreamRef,
910 ) -> Result<()> {
911 panic!("pipe_recv is no longer implemented")
912 }
913
914 async fn set_ref_unit_tests_only(
915 &mut self,
916 cx: &hyperactor::Context<Self>,
917 reference: Ref,
918 value: WireValue,
919 stream: StreamRef,
920 ) -> Result<()> {
921 let stream = self.try_get_stream(stream)?;
922
923 stream.set_ref_unit_tests_only(cx, reference, value).await
924 }
925
926 async fn get_ref_unit_tests_only(
927 &mut self,
928 cx: &hyperactor::Context<Self>,
929 ref_id: Ref,
930 stream: StreamRef,
931 ) -> Result<Option<Result<WireValue, String>>> {
932 let stream = self.try_get_stream(stream)?;
933 Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
934 }
935
936 async fn define_recording(
937 &mut self,
938 cx: &hyperactor::Context<Self>,
939 result: Ref,
940 _nresults: usize,
941 _nformals: usize,
942 commands: Vec<WorkerMessage>,
943 ntotal_messages: usize,
944 index: usize,
945 ) -> Result<()> {
946 if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
947 bail!("already defining a different recording");
948 }
949 self.defining_recording = Some(result);
950
951 match self.recordings.entry(result) {
952 Entry::Vacant(entry) => {
953 ensure!(
954 index == 0,
955 "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
956 index
957 );
958 entry.insert(Recording::PartialRecording {
959 last_index: 0,
960 commands,
961 });
962 }
963 Entry::Occupied(mut entry) => match entry.get_mut() {
964 Recording::CompleteRecording { .. } => {
965 bail!("got DefineRecording message for already complete recording")
966 }
967 Recording::PartialRecording {
968 last_index,
969 commands: existing_commands,
970 } => {
971 ensure!(
972 index == *last_index + 1,
973 "Got DefineRecording message with index = {:?}, but \
974 last seen index for recording is {:?}",
975 index,
976 last_index
977 );
978 *last_index = index;
979 existing_commands.extend(commands);
980 }
981 },
982 };
983
984 if index < ntotal_messages - 1 {
985 return Ok(());
986 }
987 let commands = match self.recordings.remove(&result).unwrap() {
988 Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
989 Recording::PartialRecording { commands, .. } => {
990 self.recordings.insert(
991 result,
992 Recording::CompleteRecording {
993 streams: HashSet::new(),
994 },
995 );
996 commands
997 }
998 };
999
1000 for command in commands {
1001 WorkerMessageHandler::handle(self, cx, command).await?;
1002 }
1003
1004 match self.recordings.get(&result).unwrap() {
1005 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1006 Recording::CompleteRecording { streams, .. } => {
1007 for stream in streams {
1008 self.try_get_stream(*stream)?
1009 .finalize_recording(cx, result)
1010 .await?;
1011 }
1012 }
1013 }
1014
1015 self.defining_recording = None;
1016 Ok(())
1017 }
1018
1019 async fn recording_formal(
1020 &mut self,
1021 cx: &hyperactor::Context<Self>,
1022 result: Ref,
1023 argument_index: usize,
1024 stream: StreamRef,
1025 ) -> Result<()> {
1026 ensure!(self.defining_recording.is_some());
1027 self.maybe_add_stream_to_recording(cx, stream).await?;
1028 self.try_get_stream(stream)?
1029 .recording_formal(cx, result, argument_index)
1030 .await
1031 }
1032
1033 async fn recording_result(
1034 &mut self,
1035 cx: &hyperactor::Context<Self>,
1036 result: Ref,
1037 output_index: usize,
1038 stream: StreamRef,
1039 ) -> Result<()> {
1040 ensure!(self.defining_recording.is_some());
1041 self.maybe_add_stream_to_recording(cx, stream).await?;
1042 self.try_get_stream(stream)?
1043 .recording_result(cx, result, output_index)
1044 .await
1045 }
1046
1047 async fn call_recording(
1048 &mut self,
1049 cx: &hyperactor::Context<Self>,
1050 seq: Seq,
1051 recording: Ref,
1052 results: Vec<Ref>,
1053 actuals: Vec<Ref>,
1054 ) -> Result<()> {
1055 ensure!(self.defining_recording.is_none());
1056 let recording_ref = recording;
1057 let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1058 "could not find recording: {:#?}",
1059 recording
1060 ))?;
1061 match recording {
1062 Recording::PartialRecording { .. } => {
1063 bail!("cannot call recording because it is incomplete")
1064 }
1065 Recording::CompleteRecording { streams } => try_join_all(
1066 streams
1067 .iter()
1068 .map(|stream| self.try_get_stream(*stream))
1069 .collect::<Result<Vec<_>>>()?
1070 .into_iter()
1071 .map(|stream| {
1072 stream.call_recording(
1073 cx,
1074 seq,
1075 recording_ref,
1076 results.clone(),
1077 actuals.clone(),
1078 )
1079 }),
1080 )
1081 .await
1082 .map(|_| ()),
1083 }
1084 }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089 use std::assert_matches::assert_matches;
1090
1091 use anyhow::Result;
1092 use hyperactor::RemoteSpawn;
1093 use hyperactor::channel::ChannelAddr;
1094 use hyperactor::proc::Proc;
1095 use monarch_messages::controller::ControllerMessage;
1096 use monarch_messages::controller::WorkerError;
1097 use monarch_messages::worker::WorkerMessageClient;
1098 use monarch_types::PickledPyObject;
1099 use pyo3::Python;
1100 use pyo3::prelude::*;
1101 use pyo3::types::PyList;
1102 use pyo3::types::PyString;
1103 use rand::Rng;
1104 use rand::distr::Alphanumeric;
1105 use timed_test::async_timed_test;
1106
1107 use super::*;
1108 use crate::test_util::test_setup;
1109
1110 #[async_timed_test(timeout_secs = 60)]
1111 async fn basic_worker() -> Result<()> {
1112 test_setup()?;
1113
1114 let proc = Proc::local();
1115 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1116
1117 let worker_handle = proc
1118 .spawn(
1119 "worker",
1120 WorkerActor::new(
1121 WorkerParams {
1122 world_size: 1,
1123 rank: 0,
1124 device_index: None,
1125 controller_actor: controller_ref,
1126 },
1127 Flattrs::default(),
1128 )
1129 .await
1130 .unwrap(),
1131 )
1132 .unwrap();
1133 worker_handle
1134 .command_group(
1135 &client,
1136 vec![
1137 WorkerMessage::CreateStream {
1138 id: 1.into(),
1139 stream_creation: StreamCreationMode::UseDefaultStream,
1140 },
1141 WorkerMessage::CallFunction(CallFunctionParams {
1142 seq: 0.into(),
1143 results: vec![Some(0.into())],
1144 mutates: vec![],
1145 function: "torch.ops.aten.ones.default".into(),
1146 args_kwargs: ArgsKwargs::from_wire_values(
1147 vec![WireValue::IntList(vec![2, 3])],
1148 HashMap::new(),
1149 )
1150 .unwrap(),
1151 stream: 1.into(),
1152 remote_process_groups: vec![],
1153 }),
1154 WorkerMessage::CallFunction(CallFunctionParams {
1155 seq: 2.into(),
1156 results: vec![Some(Ref { id: 2 })],
1157 mutates: vec![0.into()],
1158 function: "torch.ops.aten.sub_.Scalar".into(),
1159 args_kwargs: ArgsKwargs::from_wire_values(
1160 vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1161 HashMap::new(),
1162 )
1163 .unwrap(),
1164 stream: 1.into(),
1165 remote_process_groups: vec![],
1166 }),
1167 WorkerMessage::CallFunction(CallFunctionParams {
1168 seq: 3.into(),
1169 results: vec![Some(Ref { id: 3 })],
1170 mutates: vec![],
1171 function: "torch.ops.aten.zeros.default".into(),
1172 args_kwargs: ArgsKwargs::from_wire_values(
1173 vec![WireValue::IntList(vec![2, 3])],
1174 HashMap::new(),
1175 )
1176 .unwrap(),
1177 stream: 1.into(),
1178 remote_process_groups: vec![],
1179 }),
1180 WorkerMessage::CallFunction(CallFunctionParams {
1181 seq: 4.into(),
1182 results: vec![Some(Ref { id: 4 })],
1183 mutates: vec![],
1184 function: "torch.ops.aten.allclose.default".into(),
1185 args_kwargs: ArgsKwargs::from_wire_values(
1186 vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1187 HashMap::new(),
1188 )
1189 .unwrap(),
1190 stream: 1.into(),
1191 remote_process_groups: vec![],
1192 }),
1193 ],
1194 )
1195 .await
1196 .unwrap();
1197
1198 let result: bool = worker_handle
1199 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1200 .await
1201 .unwrap()
1202 .unwrap()
1203 .unwrap()
1204 .try_into()
1205 .unwrap();
1206 worker_handle.drain_and_stop("test").unwrap();
1207 worker_handle.await;
1208 let error_responses = controller_rx.drain();
1209 assert!(
1210 error_responses.is_empty(),
1211 "Expected no error responses, got: {:#?}",
1212 error_responses
1213 );
1214 assert!(result);
1215
1216 Ok(())
1217 }
1218
1219 #[async_timed_test(timeout_secs = 60)]
1220 async fn error_sends_response() -> Result<()> {
1221 test_setup()?;
1222
1223 let proc = Proc::local();
1224 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1225
1226 let worker_handle = proc
1227 .spawn(
1228 "worker",
1229 WorkerActor::new(
1230 WorkerParams {
1231 world_size: 1,
1232 rank: 0,
1233 device_index: None,
1234 controller_actor: controller_ref,
1235 },
1236 Flattrs::default(),
1237 )
1238 .await
1239 .unwrap(),
1240 )
1241 .unwrap();
1242 worker_handle
1243 .command_group(
1244 &client,
1245 vec![
1246 WorkerMessage::CreateStream {
1247 id: 1.into(),
1248 stream_creation: StreamCreationMode::UseDefaultStream,
1249 },
1250 WorkerMessage::CallFunction(CallFunctionParams {
1251 seq: 0.into(),
1252 results: vec![Some(0.into())],
1253 mutates: vec![],
1254 function: "torch.ops.aten.rand.default".into(),
1255 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1256 stream: 1.into(),
1257 remote_process_groups: vec![],
1258 }),
1259 WorkerMessage::Exit { error: None },
1260 ],
1261 )
1262 .await
1263 .unwrap();
1264
1265 worker_handle.drain_and_stop("test").unwrap();
1266 worker_handle.await;
1267 let response_message = controller_rx.recv().await.unwrap();
1268 match response_message {
1269 ControllerMessage::RemoteFunctionFailed {
1270 seq,
1271 error: WorkerError { backtrace: msg, .. },
1272 } => {
1273 assert_eq!(seq, 0.into());
1274 assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1275 }
1276 _ => panic!("unexpected response {:#?}", response_message),
1277 }
1278
1279 Ok(())
1280 }
1281
1282 #[async_timed_test(timeout_secs = 60)]
1283 async fn mutated_refs_are_updated_with_error() -> Result<()> {
1284 test_setup()?;
1285
1286 let proc = Proc::local();
1287 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1288
1289 let worker_handle = proc
1290 .spawn(
1291 "worker",
1292 WorkerActor::new(
1293 WorkerParams {
1294 world_size: 1,
1295 rank: 0,
1296 device_index: None,
1297 controller_actor: controller_ref,
1298 },
1299 Flattrs::default(),
1300 )
1301 .await
1302 .unwrap(),
1303 )
1304 .unwrap();
1305 worker_handle
1306 .command_group(
1307 &client,
1308 vec![
1309 WorkerMessage::CreateStream {
1310 id: 1.into(),
1311 stream_creation: StreamCreationMode::UseDefaultStream,
1312 },
1313 WorkerMessage::SetRefUnitTestsOnly {
1314 reference: 0.into(),
1315 value: WireValue::Int(1),
1316 stream: 1.into(),
1317 },
1318 WorkerMessage::CallFunction(CallFunctionParams {
1319 seq: 0.into(),
1320 results: vec![Some(Ref { id: 2 })],
1321 mutates: vec![0.into()],
1322 function: "i.dont.exist".into(),
1323 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1324 stream: 1.into(),
1325 remote_process_groups: vec![],
1326 }),
1327 ],
1328 )
1329 .await
1330 .unwrap();
1331
1332 let result = worker_handle
1333 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1334 .await?;
1335
1336 worker_handle.drain_and_stop("test").unwrap();
1338 worker_handle.await;
1339
1340 let mutated_ref = result
1341 .context("no such ref")?
1342 .err()
1343 .context("expected error")?;
1344 assert!(mutated_ref.contains("failed to resolve function"));
1345
1346 let responses = controller_rx.drain();
1347 assert_eq!(
1348 responses.len(),
1349 1,
1350 "Expected one response, got: {:#?}",
1351 responses
1352 );
1353 Ok(())
1354 }
1355
1356 #[async_timed_test(timeout_secs = 60)]
1357 async fn accessing_errored_dependency() -> Result<()> {
1358 test_setup()?;
1359
1360 let proc = Proc::local();
1361 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1362
1363 let worker_handle = proc
1364 .spawn(
1365 "worker",
1366 WorkerActor::new(
1367 WorkerParams {
1368 world_size: 1,
1369 rank: 0,
1370 device_index: None,
1371 controller_actor: controller_ref,
1372 },
1373 Flattrs::default(),
1374 )
1375 .await
1376 .unwrap(),
1377 )
1378 .unwrap();
1379 worker_handle
1380 .command_group(
1381 &client,
1382 vec![
1383 WorkerMessage::CreateStream {
1384 id: 1.into(),
1385 stream_creation: StreamCreationMode::UseDefaultStream,
1386 },
1387 WorkerMessage::CallFunction(CallFunctionParams {
1388 seq: 0.into(),
1389 results: vec![Some(0.into())],
1390 mutates: vec![],
1391 function: "i.dont.exist".into(),
1392 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1393 stream: 1.into(),
1394 remote_process_groups: vec![],
1395 }),
1396 WorkerMessage::CallFunction(CallFunctionParams {
1397 seq: 1.into(),
1398 results: vec![Some(1.into())],
1399 mutates: vec![],
1400 function: "torch.ops.aten.sub_.Scalar".into(),
1401 args_kwargs: ArgsKwargs::from_wire_values(
1402 vec![WireValue::Ref(0.into())],
1403 HashMap::new(),
1404 )
1405 .unwrap(),
1406 stream: 1.into(),
1407 remote_process_groups: vec![],
1408 }),
1409 WorkerMessage::Exit { error: None },
1410 ],
1411 )
1412 .await
1413 .unwrap();
1414
1415 worker_handle.drain_and_stop("test").unwrap();
1416 worker_handle.await;
1417
1418 let responses = controller_rx.drain();
1419 assert_eq!(
1420 responses.len(),
1421 1,
1422 "Expected one response, got: {:#?}",
1423 responses
1424 );
1425
1426 match &responses[0] {
1427 ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1428 assert_eq!(seq, &0.into())
1429 }
1430 _ => panic!("unexpected response {:#?}", responses[0]),
1431 };
1432 Ok(())
1433 }
1434
1435 #[async_timed_test(timeout_secs = 60)]
1436 async fn py_remote_function_calls() -> Result<()> {
1437 test_setup()?;
1438
1439 let proc = Proc::local();
1440 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1441
1442 let worker_handle = proc
1443 .spawn(
1444 "worker",
1445 WorkerActor::new(
1446 WorkerParams {
1447 world_size: 1,
1448 rank: 0,
1449 device_index: None,
1450 controller_actor: controller_ref,
1451 },
1452 Flattrs::default(),
1453 )
1454 .await
1455 .unwrap(),
1456 )
1457 .unwrap();
1458 let (split_arg, sort_list, dim, layout, none, scalar, device, memory_format) =
1459 Python::attach(|py| {
1460 let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1461 .into_any()
1462 .try_into()?;
1463 let sort_list: PickledPyObject =
1464 PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1465 let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1466 let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1467 let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1468 let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1469 let device: PickledPyObject = py
1470 .import("torch")?
1471 .getattr("device")?
1472 .call1(("cuda:1",))?
1473 .try_into()?;
1474 let memory_format: PickledPyObject = py
1475 .import("torch")?
1476 .getattr("contiguous_format")?
1477 .try_into()?;
1478 PyResult::Ok((
1479 split_arg,
1480 sort_list,
1481 dim,
1482 layout,
1483 none,
1484 scalar,
1485 device,
1486 memory_format,
1487 ))
1488 })?;
1489
1490 worker_handle
1491 .command_group(
1492 &client,
1493 vec![
1494 WorkerMessage::CreateStream {
1495 id: 1.into(),
1496 stream_creation: StreamCreationMode::UseDefaultStream,
1497 },
1498 WorkerMessage::CallFunction(CallFunctionParams {
1499 seq: 0.into(),
1500 results: vec![Some(0.into()), Some(Ref { id: 2 })],
1501 mutates: vec![],
1502 function: "os.path.split".into(),
1503 args_kwargs: ArgsKwargs::from_wire_values(
1504 vec![split_arg.into()],
1505 HashMap::new(),
1506 )
1507 .unwrap(),
1508 stream: 1.into(),
1509 remote_process_groups: vec![],
1510 }),
1511 WorkerMessage::CallFunction(CallFunctionParams {
1512 seq: 2.into(),
1513 results: vec![Some(4.into()), None, None, None, None],
1514 mutates: vec![],
1515 function: "builtins.sorted".into(),
1516 args_kwargs: ArgsKwargs::from_wire_values(
1517 vec![sort_list.into()],
1518 HashMap::new(),
1519 )
1520 .unwrap(),
1521 stream: 1.into(),
1522 remote_process_groups: vec![],
1523 }),
1524 WorkerMessage::CreateDeviceMesh {
1525 result: 5.into(),
1526 names: vec!["x".into()],
1527 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1528 },
1529 WorkerMessage::CallFunction(CallFunctionParams {
1530 seq: 2.into(),
1531 results: vec![Some(6.into())],
1532 mutates: vec![],
1533 function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1534 args_kwargs: ArgsKwargs::from_wire_values(
1535 vec![WireValue::Ref(Ref { id: 5 }), dim.into()],
1536 HashMap::new(),
1537 )
1538 .unwrap(),
1539 stream: 1.into(),
1540 remote_process_groups: vec![],
1541 }),
1542 WorkerMessage::CallFunction(CallFunctionParams {
1543 seq: 4.into(),
1544 results: vec![Some(7.into())],
1545 mutates: vec![],
1546 function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1547 .into(),
1548 args_kwargs: ArgsKwargs::from_wire_values(
1549 vec![scalar.into()],
1550 HashMap::new(),
1551 )
1552 .unwrap(),
1553 stream: 1.into(),
1554 remote_process_groups: vec![],
1555 }),
1556 WorkerMessage::CallFunction(CallFunctionParams {
1557 seq: 5.into(),
1558 results: vec![Some(8.into())],
1559 mutates: vec![],
1560 function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1561 args_kwargs: ArgsKwargs::from_wire_values(
1562 vec![layout.into()],
1563 HashMap::new(),
1564 )
1565 .unwrap(),
1566 stream: 1.into(),
1567 remote_process_groups: vec![],
1568 }),
1569 WorkerMessage::CallFunction(CallFunctionParams {
1570 seq: 6.into(),
1571 results: vec![Some(9.into())],
1572 mutates: vec![],
1573 function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1574 args_kwargs: ArgsKwargs::from_wire_values(
1575 vec![none.into()],
1576 HashMap::new(),
1577 )
1578 .unwrap(),
1579 stream: 1.into(),
1580 remote_process_groups: vec![],
1581 }),
1582 WorkerMessage::CallFunction(CallFunctionParams {
1585 seq: 7.into(),
1586 results: vec![None],
1587 mutates: vec![],
1588 function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1589 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1590 stream: 1.into(),
1591 remote_process_groups: vec![],
1592 }),
1593 WorkerMessage::CallFunction(CallFunctionParams {
1594 seq: 8.into(),
1595 results: vec![Some(10.into())],
1596 mutates: vec![],
1597 function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1598 args_kwargs: ArgsKwargs::from_wire_values(
1599 vec![device.into()],
1600 HashMap::new(),
1601 )
1602 .unwrap(),
1603 stream: 1.into(),
1604 remote_process_groups: vec![],
1605 }),
1606 WorkerMessage::CallFunction(CallFunctionParams {
1607 seq: 9.into(),
1608 results: vec![Some(11.into())],
1609 mutates: vec![],
1610 function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1611 .into(),
1612 args_kwargs: ArgsKwargs::from_wire_values(
1613 vec![memory_format.into()],
1614 HashMap::new(),
1615 )
1616 .unwrap(),
1617 stream: 1.into(),
1618 remote_process_groups: vec![],
1619 }),
1620 WorkerMessage::CallFunction(CallFunctionParams {
1622 seq: 10.into(),
1623 results: vec![Some(12.into())],
1624 mutates: vec![],
1625 function: "torch.ops.aten.ones.default".into(),
1626 args_kwargs: ArgsKwargs::from_wire_values(
1627 vec![WireValue::IntList(vec![2, 3])],
1628 HashMap::new(),
1629 )
1630 .unwrap(),
1631 stream: 1.into(),
1632 remote_process_groups: vec![],
1633 }),
1634 WorkerMessage::CallFunction(CallFunctionParams {
1635 seq: 11.into(),
1636 results: vec![Some(13.into())],
1637 mutates: vec![],
1638 function: "torch.ops.aten.stack.default".into(),
1639 args_kwargs: ArgsKwargs::from_wire_values(
1640 vec![WireValue::RefList(vec![12.into(), 12.into()])],
1641 HashMap::new(),
1642 )
1643 .unwrap(),
1644 stream: 1.into(),
1645 remote_process_groups: vec![],
1646 }),
1647 ],
1648 )
1649 .await
1650 .unwrap();
1651
1652 let result1: String = worker_handle
1653 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1654 .await
1655 .unwrap()
1656 .unwrap()
1657 .unwrap()
1658 .try_into()
1659 .unwrap();
1660 let result2: String = worker_handle
1661 .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1662 .await
1663 .unwrap()
1664 .unwrap()
1665 .unwrap()
1666 .try_into()
1667 .unwrap();
1668 let result3: i64 = worker_handle
1669 .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1670 .await
1671 .unwrap()
1672 .unwrap()
1673 .unwrap()
1674 .try_into()
1675 .unwrap();
1676 let result4: i64 = worker_handle
1677 .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1678 .await
1679 .unwrap()
1680 .unwrap()
1681 .unwrap()
1682 .try_into()
1683 .unwrap();
1684 worker_handle
1685 .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1686 .await
1687 .unwrap()
1688 .unwrap()
1689 .unwrap();
1690
1691 worker_handle
1692 .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1693 .await
1694 .unwrap()
1695 .unwrap()
1696 .unwrap();
1697
1698 assert_matches!(
1699 worker_handle
1700 .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1701 .await
1702 .unwrap()
1703 .unwrap()
1704 .unwrap(),
1705 WireValue::None(()),
1706 );
1707 worker_handle
1708 .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1709 .await
1710 .unwrap()
1711 .unwrap()
1712 .unwrap();
1713 worker_handle
1714 .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1715 .await
1716 .unwrap()
1717 .unwrap()
1718 .unwrap();
1719
1720 worker_handle.drain_and_stop("test").unwrap();
1721 worker_handle.await;
1722 let error_responses = controller_rx.drain();
1723 assert!(
1724 error_responses.is_empty(),
1725 "Expected no error responses, got: {:#?}",
1726 error_responses
1727 );
1728
1729 assert_eq!(result1, "/fbs/fbc/foo");
1730 assert_eq!(result2, "bar");
1731 assert_eq!(result3, 1);
1732 assert_eq!(result4, 0);
1733
1734 Ok(())
1735 }
1736
1737 #[async_timed_test(timeout_secs = 60)]
1738 async fn delete_refs() -> Result<()> {
1739 test_setup()?;
1740
1741 let proc = Proc::local();
1742 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1743
1744 let worker_handle = proc
1745 .spawn(
1746 "worker",
1747 WorkerActor::new(
1748 WorkerParams {
1749 world_size: 1,
1750 rank: 0,
1751 device_index: None,
1752 controller_actor: controller_ref,
1753 },
1754 Flattrs::default(),
1755 )
1756 .await
1757 .unwrap(),
1758 )
1759 .unwrap();
1760 worker_handle
1761 .command_group(
1762 &client,
1763 vec![
1764 WorkerMessage::CreateStream {
1765 id: 0.into(),
1766 stream_creation: StreamCreationMode::CreateNewStream,
1767 },
1768 WorkerMessage::CreateStream {
1769 id: 1.into(),
1770 stream_creation: StreamCreationMode::CreateNewStream,
1771 },
1772 WorkerMessage::SetRefUnitTestsOnly {
1773 reference: Ref { id: 2 },
1774 value: WireValue::Bool(false),
1775 stream: 0.into(),
1776 },
1777 WorkerMessage::SetRefUnitTestsOnly {
1778 reference: Ref { id: 3 },
1779 value: WireValue::Bool(true),
1780 stream: 0.into(),
1781 },
1782 WorkerMessage::SetRefUnitTestsOnly {
1783 reference: Ref { id: 4 },
1784 value: WireValue::Int(0),
1785 stream: 1.into(),
1786 },
1787 WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1788 ],
1789 )
1790 .await
1791 .unwrap();
1792
1793 let result: bool = worker_handle
1794 .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1795 .await
1796 .unwrap()
1797 .unwrap()
1798 .unwrap()
1799 .try_into()
1800 .unwrap();
1801 let fail_result = worker_handle
1802 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1803 .await
1804 .unwrap();
1805
1806 worker_handle.drain_and_stop("test").unwrap();
1807 worker_handle.await;
1808
1809 assert!(result, "should be able to get a non-deleted ref");
1810 assert!(fail_result.is_none(), "should fail to get a deleted ref");
1811
1812 Ok(())
1813 }
1814
1815 #[async_timed_test(timeout_secs = 60)]
1816 async fn request_status() -> Result<()> {
1817 test_setup()?;
1818
1819 let proc = Proc::local();
1820 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1821
1822 let worker_handle = proc
1823 .spawn(
1824 "worker",
1825 WorkerActor::new(
1826 WorkerParams {
1827 world_size: 1,
1828 rank: 0,
1829 device_index: None,
1830 controller_actor: controller_ref,
1831 },
1832 Flattrs::default(),
1833 )
1834 .await
1835 .unwrap(),
1836 )
1837 .unwrap();
1838 worker_handle
1839 .command_group(
1840 &client,
1841 vec![
1842 WorkerMessage::CreateStream {
1843 id: 0.into(),
1844 stream_creation: StreamCreationMode::CreateNewStream,
1845 },
1846 WorkerMessage::CreateStream {
1847 id: 1.into(),
1848 stream_creation: StreamCreationMode::CreateNewStream,
1849 },
1850 ],
1851 )
1852 .await
1853 .unwrap();
1854
1855 for i in 0..100 {
1856 worker_handle
1858 .call_function(
1859 &client,
1860 CallFunctionParams {
1861 seq: i.into(),
1862 results: vec![Some(Ref { id: i + 2 })],
1863 mutates: vec![],
1864 function: "torch.ops.aten.ones.default".into(),
1865 args_kwargs: ArgsKwargs::from_wire_values(
1866 vec![WireValue::IntList(vec![2, 3])],
1867 HashMap::new(),
1868 )
1869 .unwrap(),
1870 stream: (i % 2).into(),
1871 remote_process_groups: vec![],
1872 },
1873 )
1874 .await
1875 .unwrap();
1876 }
1877
1878 worker_handle
1879 .request_status(&client, 100.into(), false)
1880 .await
1881 .unwrap();
1882
1883 worker_handle.drain_and_stop("test").unwrap();
1884 worker_handle.await;
1885
1886 let mut responses = controller_rx.drain();
1887 assert_eq!(
1888 responses.len(),
1889 1,
1890 "Expected one response, got: {:#?}",
1891 responses
1892 );
1893
1894 let response = responses.pop().unwrap();
1895 match response {
1896 ControllerMessage::Status { seq, .. } => {
1897 assert_eq!(seq, 101.into())
1898 }
1899 _ => panic!("unexpected response {:#?}", response),
1900 };
1901
1902 Ok(())
1903 }
1904
1905 #[async_timed_test(timeout_secs = 60)]
1906 async fn backend_network_init() {
1907 test_setup().unwrap();
1908 let proc = Proc::local();
1909 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1910
1911 let worker_handle1 = proc
1912 .spawn(
1913 "worker0",
1914 WorkerActor::new(
1915 WorkerParams {
1916 world_size: 2,
1917 rank: 0,
1918 device_index: Some(0),
1919 controller_actor: controller_ref.clone(),
1920 },
1921 Flattrs::default(),
1922 )
1923 .await
1924 .unwrap(),
1925 )
1926 .unwrap();
1927 let worker_handle2 = proc
1928 .spawn(
1929 "worker1",
1930 WorkerActor::new(
1931 WorkerParams {
1932 world_size: 2,
1933 rank: 1,
1934 device_index: Some(1),
1935 controller_actor: controller_ref,
1936 },
1937 Flattrs::default(),
1938 )
1939 .await
1940 .unwrap(),
1941 )
1942 .unwrap();
1943
1944 let unique_id = UniqueId::new().unwrap();
1945 worker_handle1
1946 .backend_network_init(&client, unique_id.clone())
1947 .await
1948 .unwrap();
1949 worker_handle2
1950 .backend_network_init(&client, unique_id)
1951 .await
1952 .unwrap();
1953
1954 worker_handle1.drain_and_stop("test").unwrap();
1955 worker_handle1.await;
1956 worker_handle2.drain_and_stop("test").unwrap();
1957 worker_handle2.await;
1958 }
1959
1960 #[allow(dead_code)]
1961 fn get_random_channel_addr() -> ChannelAddr {
1962 let random_string = rand::rng()
1963 .sample_iter(&Alphanumeric)
1964 .take(24)
1965 .map(char::from)
1966 .collect::<String>();
1967 format!("unix!@{random_string}").parse().unwrap()
1968 }
1969}