1#![feature(assert_matches)]
10#![feature(duration_constructors)]
11#![feature(exit_status_error)]
12#![allow(unsafe_op_in_unsafe_fn)]
15
16mod borrow;
31mod comm;
32pub mod device_mesh;
33pub mod stream;
34pub mod test_util;
35
36use std::collections::HashMap;
37use std::collections::HashSet;
38use std::collections::hash_map::Entry;
39use std::sync::Arc;
40
41use anyhow::Context;
42use anyhow::Result;
43use anyhow::anyhow;
44use anyhow::bail;
45use anyhow::ensure;
46use async_trait::async_trait;
47use borrow::Borrow;
48use comm::CommMessageClient;
49use comm::CommParams;
50use comm::NcclCommActor;
51use derive_more::TryInto;
52use device_mesh::DeviceMesh;
53use futures::future::try_join_all;
54use hyperactor::Actor;
55use hyperactor::ActorRef;
56use hyperactor::Bind;
57use hyperactor::Handler;
58use hyperactor::RemoteSpawn;
59use hyperactor::Unbind;
60use hyperactor::actor::ActorHandle;
61use hyperactor::context;
62use hyperactor::reference::ActorId;
63use hyperactor_mesh::comm::multicast::CastInfo;
64use itertools::Itertools;
65use monarch_hyperactor::shape::PyPoint;
66use monarch_messages::controller::ControllerActor;
67use monarch_messages::controller::ControllerMessageClient;
68use monarch_messages::controller::Seq;
69use monarch_messages::wire_value::WireValue;
70use monarch_messages::worker::ActorCallParams;
71use monarch_messages::worker::ActorMethodParams;
72use monarch_messages::worker::ArgsKwargs;
73use monarch_messages::worker::CallFunctionParams;
74use monarch_messages::worker::Factory;
75use monarch_messages::worker::Reduction;
76use monarch_messages::worker::Ref;
77use monarch_messages::worker::ResolvableFunction;
78use monarch_messages::worker::StreamCreationMode;
79use monarch_messages::worker::StreamRef;
80use monarch_messages::worker::WorkerMessage;
81use monarch_messages::worker::WorkerMessageHandler;
82use monarch_messages::worker::WorkerParams;
83use ndslice::Slice;
84use pyo3::Python;
85use pyo3::types::PyAnyMethods;
86use serde::Deserialize;
87use serde::Serialize;
88use sorted_vec::SortedVec;
89use stream::StreamActor;
90use stream::StreamMessageClient;
91use stream::StreamParams;
92use torch_sys_cuda::nccl::ReduceOp;
93use torch_sys_cuda::nccl::UniqueId;
94use torch_sys2::CudaDevice;
95use torch_sys2::DeviceIndex;
96use torch_sys2::Layout;
97use torch_sys2::ScalarType;
98use torch_sys2::TensorCell;
99use torch_sys2::factory_zeros;
100use typeuri::Named;
101
102#[derive(Debug)]
103struct RemoteProcessGroupState {
104 device_mesh_ref: Ref,
105 dims: SortedVec<String>,
106 comms: HashMap<StreamRef, Arc<ActorHandle<NcclCommActor>>>,
107}
108
109impl RemoteProcessGroupState {
110 fn new(device_mesh_ref: Ref, dims: SortedVec<String>) -> Self {
111 Self {
112 device_mesh_ref,
113 dims,
114 comms: HashMap::new(),
115 }
116 }
117}
118
119#[derive(Debug)]
120enum Recording {
121 PartialRecording {
124 last_index: usize,
126 commands: Vec<WorkerMessage>,
128 },
129
130 CompleteRecording {
132 streams: HashSet<StreamRef>,
134 },
135}
136
137#[derive(Debug)]
145#[hyperactor::export(
146 spawn = true,
147 handlers = [
148 WorkerMessage {cast = true},
149 AssignRankMessage {cast = true},
150 ],
151)]
152pub struct WorkerActor {
153 device: Option<CudaDevice>,
154 streams: HashMap<StreamRef, Arc<ActorHandle<StreamActor>>>,
155 device_meshes: HashMap<
158 Ref,
159 (
160 DeviceMesh,
161 HashMap<(StreamRef, SortedVec<String>), (usize, Arc<ActorHandle<NcclCommActor>>)>,
163 ),
164 >,
165 world_size: usize,
166 rank: usize,
167 borrows: HashMap<u64, Borrow>,
168 comm: Option<ActorHandle<NcclCommActor>>,
169 controller_actor: ActorRef<ControllerActor>,
170 remote_process_groups: HashMap<Ref, RemoteProcessGroupState>,
174 send_recv_comms: HashMap<(StreamRef, StreamRef), Arc<ActorHandle<NcclCommActor>>>,
176 recordings: HashMap<Ref, Recording>,
177 defining_recording: Option<Ref>,
178 respond_with_python_message: bool,
179}
180
181impl WorkerActor {
182 fn try_get_stream(&self, stream: StreamRef) -> Result<&Arc<ActorHandle<StreamActor>>> {
183 self.streams
184 .get(&stream)
185 .ok_or(anyhow::anyhow!("invalid stream id: {:#?}", stream))
186 }
187
188 async fn maybe_add_stream_to_recording(
189 &mut self,
190 cx: &impl context::Actor,
191 stream: StreamRef,
192 ) -> Result<()> {
193 if let Some(defining_recording) = self.defining_recording {
196 let recording = self.recordings.get_mut(&defining_recording).unwrap();
197 let fut = match recording {
198 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
199 Recording::CompleteRecording { streams } => {
200 streams.insert(stream).then(|| -> Result<_, anyhow::Error> {
201 Ok(self
202 .try_get_stream(stream)?
203 .define_recording(cx, defining_recording))
204 })
205 }
206 }
207 .transpose()?;
208 match fut {
209 Some(fut) => fut.await,
210 None => Ok(()),
211 }
212 } else {
213 Ok(())
214 }
215 }
216}
217
218impl Actor for WorkerActor {}
219
220#[async_trait]
221impl RemoteSpawn for WorkerActor {
222 type Params = WorkerParams;
223
224 async fn new(
225 WorkerParams {
226 world_size,
227 rank,
228 device_index,
229 controller_actor,
230 }: Self::Params,
231 ) -> Result<Self> {
232 Python::with_gil(|py| {
233 py.import("monarch.safe_torch").unwrap();
234 });
235 Ok(Self {
236 device: device_index.map(|i| CudaDevice::new(DeviceIndex(i))),
237 streams: HashMap::new(),
238 device_meshes: HashMap::new(),
239 world_size,
240 rank,
241 borrows: HashMap::new(),
242 comm: None,
243 controller_actor,
244 remote_process_groups: HashMap::new(),
245 send_recv_comms: HashMap::new(),
246 recordings: HashMap::new(),
247 defining_recording: None,
248 respond_with_python_message: false,
249 })
250 }
251
252 }
254
255#[async_trait]
256impl Handler<AssignRankMessage> for WorkerActor {
257 async fn handle(
258 &mut self,
259 cx: &hyperactor::Context<Self>,
260 _: AssignRankMessage,
261 ) -> anyhow::Result<()> {
262 let point = cx.cast_point();
263 self.rank = point.rank();
264 self.respond_with_python_message = true;
265 Python::with_gil(|py| {
266 let mesh_controller = py.import("monarch.mesh_controller").unwrap();
267 let p: PyPoint = point.into();
268 mesh_controller
269 .call_method1("_initialize_env", (p, cx.proc().proc_id().to_string()))
270 .unwrap();
271 });
272 Ok(())
273 }
274}
275
276#[derive(Handler, Clone, Serialize, Deserialize, Debug, Named, Bind, Unbind)]
279pub enum AssignRankMessage {
280 AssignRank(),
281}
282wirevalue::register_type!(AssignRankMessage);
283
284#[async_trait]
285impl Handler<WorkerMessage> for WorkerActor {
286 async fn handle(
287 &mut self,
288 cx: &hyperactor::Context<Self>,
289 message: WorkerMessage,
290 ) -> anyhow::Result<()> {
291 <Self as WorkerMessageHandler>::handle(self, cx, message).await
292 }
293}
294
295#[async_trait]
296impl WorkerMessageHandler for WorkerActor {
297 async fn backend_network_init(
298 &mut self,
299 cx: &hyperactor::Context<Self>,
300 unique_id: UniqueId,
301 ) -> Result<()> {
302 let device = self
303 .device
304 .expect("tried to init backend network on a non-CUDA worker");
305 let comm = NcclCommActor::new(CommParams::New {
306 device,
307 unique_id,
308 world_size: self.world_size.try_into().unwrap(),
309 rank: self.rank.try_into().unwrap(),
310 })
311 .await?
312 .spawn(cx)?;
313
314 let tensor = factory_zeros(&[1], ScalarType::Float, Layout::Strided, device.into());
315 let cell = TensorCell::new(tensor);
316
317 comm.all_reduce(
318 cx,
319 cell,
320 ReduceOp::Sum,
321 torch_sys_cuda::cuda::Stream::get_current_stream(),
322 )
323 .await?;
324
325 let sorted_streams = self
334 .streams
335 .iter()
336 .sorted_by_key(|(k, _)| *k)
337 .map(|(_, v)| v.as_ref());
338
339 let mut splits = Vec::new();
340 for _ in 0..sorted_streams.len() {
341 splits.push(comm.split_all(cx).await?);
344 }
345 let _: Vec<()> = try_join_all(
346 sorted_streams
347 .into_iter()
348 .zip(splits.into_iter())
349 .map(|(stream, split)| stream.init_comm(cx, split)),
350 )
351 .await?;
352
353 self.comm = Some(comm);
354
355 Ok(())
356 }
357
358 async fn backend_network_point_to_point_init(
359 &mut self,
360 cx: &hyperactor::Context<Self>,
361 from_stream: StreamRef,
362 to_stream: StreamRef,
363 ) -> Result<()> {
364 if !self.streams.contains_key(&from_stream) {
365 bail!("invalid from_stream id: {:#?}", from_stream);
366 }
367 if !self.streams.contains_key(&to_stream) {
368 bail!("invalid to_stream id: {:#?}", to_stream);
369 }
370 let global_comm = self
371 .comm
372 .as_ref()
373 .context("tried to call Reduce before BackendNetworkInit")?;
374 let comm = global_comm.split_all(cx).await?;
375 self.send_recv_comms
376 .insert((from_stream, to_stream), Arc::new(comm));
377 Ok(())
378 }
379
380 async fn call_function(
381 &mut self,
382 cx: &hyperactor::Context<Self>,
383 params: CallFunctionParams,
384 ) -> Result<()> {
385 let stream = self.try_get_stream(params.stream)?.clone();
386 self.maybe_add_stream_to_recording(cx, params.stream)
387 .await?;
388
389 let device_meshes = self
390 .device_meshes
391 .iter()
392 .map(|(k, v)| (k.clone(), v.0.clone()))
393 .collect();
394
395 let mut remote_process_groups = HashMap::new();
396 for remote_process_group_ref in ¶ms.remote_process_groups {
397 if let Some(state) = self.remote_process_groups.get(remote_process_group_ref) {
398 let dims_vec = state.dims.iter().cloned().collect();
399 let (device_mesh, _) = self
400 .device_meshes
401 .get(&state.device_mesh_ref)
402 .ok_or_else(|| {
403 anyhow::anyhow!("invalid device mesh id: {:#?}", state.device_mesh_ref)
404 })?
405 .clone();
406 let comm = state.comms
407 .get(¶ms.stream)
408 .ok_or_else(|| {
409 anyhow::anyhow!("no comm found for remote process group {remote_process_group_ref:#?} stream {stream:#?}")
410 })?
411 .clone();
412 remote_process_groups.insert(
413 remote_process_group_ref.clone(),
414 (device_mesh, dims_vec, comm),
415 );
416 }
417 }
418
419 stream
420 .call_function(cx, params, device_meshes, remote_process_groups)
421 .await?;
422
423 Ok(())
424 }
425
426 async fn command_group(
427 &mut self,
428 cx: &hyperactor::Context<Self>,
429 params: Vec<WorkerMessage>,
430 ) -> Result<()> {
431 for msg in params {
432 WorkerMessageHandler::handle(self, cx, msg).await?;
433 }
434 Ok(())
435 }
436
437 async fn create_stream(
438 &mut self,
439 cx: &hyperactor::Context<Self>,
440 result: StreamRef,
441 creation_mode: StreamCreationMode,
442 ) -> Result<()> {
443 let handle: ActorHandle<StreamActor> = StreamActor::new(StreamParams {
444 world_size: self.world_size,
445 rank: self.rank,
446 creation_mode,
447 id: result,
448 device: self.device,
449 controller_actor: self.controller_actor.clone(),
450 respond_with_python_message: self.respond_with_python_message,
451 })
452 .spawn(cx)?;
453 self.streams.insert(result, Arc::new(handle));
454 Ok(())
455 }
456
457 async fn create_device_mesh(
458 &mut self,
459 _cx: &hyperactor::Context<Self>,
460 result: Ref,
461 names: Vec<String>,
462 ranks: Slice,
463 ) -> Result<()> {
464 self.device_meshes.insert(
465 result,
466 (DeviceMesh::new(names, ranks, self.rank)?, HashMap::new()),
467 );
468 Ok(())
469 }
470
471 async fn create_remote_process_group(
472 &mut self,
473 _cx: &hyperactor::Context<Self>,
474 result: Ref,
475 device_mesh: Ref,
476 dims: Vec<String>,
477 ) -> Result<()> {
478 self.device_meshes
479 .get(&device_mesh)
480 .with_context(|| format!("invalid device mesh id: {:#?}", device_mesh))?;
481 match self.remote_process_groups.entry(result) {
482 Entry::Vacant(ent) => ent.insert(RemoteProcessGroupState::new(
483 device_mesh,
484 SortedVec::from_unsorted(dims),
485 )),
486 Entry::Occupied(ent) => bail!("remote process group {:?} already create", ent.key()),
487 };
488 Ok(())
489 }
490
491 async fn borrow_create(
492 &mut self,
493 cx: &hyperactor::Context<Self>,
494 result: Ref,
495 borrow_id: u64,
496 tensor_ref: Ref,
497 from_stream: StreamRef,
498 to_stream: StreamRef,
499 ) -> Result<()> {
500 self.maybe_add_stream_to_recording(cx, from_stream).await?;
501 self.maybe_add_stream_to_recording(cx, to_stream).await?;
502 let from_stream = self.try_get_stream(from_stream)?.clone();
503 let to_stream = self.try_get_stream(to_stream)?.clone();
504
505 let borrow =
506 Borrow::create(cx, borrow_id, tensor_ref, result, from_stream, to_stream).await?;
507 self.borrows.insert(borrow_id, borrow);
508 Ok(())
509 }
510
511 async fn borrow_first_use(
512 &mut self,
513 cx: &hyperactor::Context<Self>,
514 borrow: u64,
515 ) -> Result<()> {
516 let borrow = self
517 .borrows
518 .get_mut(&borrow)
519 .ok_or_else(|| anyhow!("invalid borrow id: {:#?}", borrow))?;
520
521 borrow.first_use(cx).await?;
522 Ok(())
523 }
524
525 async fn borrow_last_use(&mut self, cx: &hyperactor::Context<Self>, borrow: u64) -> Result<()> {
526 let borrow = self
527 .borrows
528 .get_mut(&borrow)
529 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow))?;
530
531 borrow.last_use(cx).await?;
532 Ok(())
533 }
534
535 async fn borrow_drop(&mut self, cx: &hyperactor::Context<Self>, borrow_id: u64) -> Result<()> {
536 let borrow = self
537 .borrows
538 .get_mut(&borrow_id)
539 .ok_or_else(|| anyhow::anyhow!("invalid borrow id: {:#?}", borrow_id))?;
540
541 borrow.drop(cx).await?;
542 self.borrows.remove(&borrow_id);
543 Ok(())
544 }
545
546 async fn delete_refs(&mut self, cx: &hyperactor::Context<Self>, refs: Vec<Ref>) -> Result<()> {
547 let _: Vec<()> = try_join_all(
554 self.streams
555 .values()
556 .map(|s| s.delete_refs(cx, refs.clone())),
557 )
558 .await?;
559 Ok(())
560 }
561
562 async fn request_status(
563 &mut self,
564 cx: &hyperactor::Context<Self>,
565 seq: Seq,
566 controller: bool,
567 ) -> Result<()> {
568 let _: Vec<()> = try_join_all(
573 self.streams
574 .values()
575 .map(|stream| stream.request_status(cx)),
576 )
577 .await?;
578
579 ControllerMessageClient::status(
580 &self.controller_actor,
581 cx,
582 seq.next(),
583 cx.self_id().clone(),
584 controller,
585 )
586 .await?;
587 Ok(())
588 }
589
590 async fn reduce(
591 &mut self,
592 cx: &hyperactor::Context<Self>,
593 result: Ref,
594 local_tensor: Ref,
595 factory: Factory,
596 source_mesh: Ref,
597 stream_ref: StreamRef,
598 dims: Vec<String>,
599 reduction: Reduction,
600 scatter: bool,
601 in_place: bool,
602 out: Option<Ref>,
603 ) -> Result<()> {
604 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
605
606 let dims = SortedVec::from_unsorted(dims);
608 let stream = self.try_get_stream(stream_ref)?.clone();
609
610 let (_, comm_map) = self
611 .device_meshes
612 .get_mut(&source_mesh)
613 .ok_or_else(|| anyhow::anyhow!("invalid device mesh id: {:#?}", source_mesh))?;
614
615 let (size, comm) = comm_map
616 .get(&(stream_ref, dims.clone()))
617 .ok_or_else(|| anyhow::anyhow!("no comm found for stream {stream:#?}, dims {dims:#?}"))?
618 .clone();
619
620 stream
621 .reduce(
622 cx,
623 comm,
624 size.try_into()?,
625 result,
626 local_tensor,
627 factory,
628 reduction,
629 scatter,
630 in_place,
631 out,
632 )
633 .await?;
634
635 Ok(())
636 }
637
638 async fn send_tensor(
639 &mut self,
640 cx: &hyperactor::Context<Self>,
641 result: Ref,
642 from_ranks: Slice,
643 to_ranks: Slice,
644 tensor: Ref,
645 factory: Factory,
646 from_stream: StreamRef,
647 to_stream: StreamRef,
648 ) -> Result<()> {
649 let comm = self
650 .send_recv_comms
651 .get(&(from_stream, to_stream))
652 .ok_or_else(|| {
653 anyhow::anyhow!(
654 "could not find stream to stream comm for: {:#?}",
655 (from_stream, to_stream)
656 )
657 })?
658 .clone();
659
660 let to_rank = from_ranks
661 .index(self.rank)
662 .map(|index| to_ranks.get(index).ok())
663 .ok()
664 .flatten();
665 let from_rank = to_ranks
666 .index(self.rank)
667 .map(|index| from_ranks.get(index).ok())
668 .ok()
669 .flatten();
670
671 let (stream, stream_ref) = if to_rank.is_none() {
672 (self.try_get_stream(to_stream)?.clone(), to_stream)
673 } else if from_rank.is_none() || from_stream == to_stream {
674 (self.try_get_stream(from_stream)?.clone(), from_stream)
675 } else {
676 unimplemented!(
677 "We haven't implemented to_mesh between streams if a rank participates as both a sender and receiver. \
678 It is possible, but would require the recv stream to send the output buffer tensor to the send stream and sync. \
679 Then the send stream would do the nccl op, and then sync with sending stream again."
680 );
681 };
682
683 self.maybe_add_stream_to_recording(cx, stream_ref).await?;
684
685 stream
686 .send_tensor(cx, result, from_rank, to_rank, tensor, factory, comm)
687 .await?;
688
689 Ok(())
690 }
691
692 async fn exit(
693 &mut self,
694 cx: &hyperactor::Context<Self>,
695 error: Option<(Option<ActorId>, String)>,
696 ) -> Result<()> {
697 for (_, stream) in self.streams.drain() {
698 stream.drain_and_stop()?;
699 Arc::into_inner(stream)
700 .expect("there should be no owners of this stream handle except the worker stream table")
701 .await;
702 }
703
704 let self_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_SELF_ERROR_EXIT_CODE")
705 .ok()
706 .and_then(|val| val.parse::<i32>().ok())
707 .unwrap_or(1);
708 let peer_error_exit_code = std::env::var("MONARCH_TENSOR_WORKER_PEER_ERROR_EXIT_CODE")
709 .ok()
710 .and_then(|val| val.parse::<i32>().ok())
711 .unwrap_or(1);
712
713 let exit_code = match error {
715 Some((Some(actor_id), reason)) => {
716 tracing::error!(
717 "stopping the worker, actor {} failed with error: {}",
718 actor_id,
719 reason
720 );
721 if *cx.self_id() == actor_id {
722 self_error_exit_code
723 } else {
724 peer_error_exit_code
725 }
726 }
727 Some((None, reason)) => {
728 tracing::error!("stopping the worker, reason: {}", reason);
729 1
730 }
731 None => 0,
732 };
733
734 if exit_code != 0 {
735 tracing::info!("stopping the worker process, exit code: {}", exit_code);
736 std::process::exit(exit_code);
737 }
738 cx.stop()?;
739 Ok(())
740 }
741
742 async fn send_value(
743 &mut self,
744 cx: &hyperactor::Context<Self>,
745 seq: Seq,
746 destination: Option<Ref>,
747 mutates: Vec<Ref>,
748 function: Option<ResolvableFunction>,
749 args_kwargs: ArgsKwargs,
750 stream: StreamRef,
751 ) -> Result<()> {
752 let stream = self.try_get_stream(stream)?;
754
755 let device_meshes = if function.is_none() {
756 HashMap::new()
757 } else {
758 self.device_meshes
759 .iter()
760 .map(|(k, v)| (k.clone(), v.0.clone()))
761 .collect()
762 };
763
764 if destination.is_some() {
765 panic!("send_value with pipe destination is no longer implemented")
766 }
767
768 stream
770 .send_value(
771 cx,
772 seq,
773 cx.self_id().clone(),
774 mutates,
775 function,
776 args_kwargs,
777 device_meshes,
778 )
779 .await
780 }
781
782 async fn send_result_of_actor_call(
783 &mut self,
784 cx: &hyperactor::Context<Self>,
785 params: ActorCallParams,
786 ) -> Result<()> {
787 let stream = self.try_get_stream(params.stream)?;
788 stream.send_result_of_actor_call(cx, params).await?;
789 Ok(())
790 }
791 async fn call_actor_method(
792 &mut self,
793 cx: &hyperactor::Context<Self>,
794 params: ActorMethodParams,
795 ) -> Result<()> {
796 let stream = self.try_get_stream(params.call.stream)?;
797 stream.call_actor_method(cx, params).await?;
798 Ok(())
799 }
800 async fn split_comm(
801 &mut self,
802 cx: &hyperactor::Context<Self>,
803 dims: Vec<String>,
804 device_mesh: Ref,
805 stream_ref: StreamRef,
806 ) -> Result<()> {
807 let global_comm = self
808 .comm
809 .as_ref()
810 .context("tried to call SplitComm before BackendNetworkInit")?;
811 match self.device_meshes.get_mut(&device_mesh) {
812 Some((device_mesh, comm_map)) => {
813 let stream = self
816 .streams
817 .get(&stream_ref)
818 .ok_or_else(|| anyhow::anyhow!("invalid stream id: {:#?}", stream_ref))?;
819
820 let dims = SortedVec::from_unsorted(dims);
821
822 anyhow::ensure!(
823 !comm_map.contains_key(&(stream_ref, dims.clone())),
824 "comm already exists for stream {stream:#?}, dims {dims:#?}"
825 );
826 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&dims)?;
827 let size = ranks_for_group.len();
828 let split_comm = global_comm
829 .split_from(
830 cx,
831 ranks_for_group
832 .into_iter()
833 .map(|v| v.clone().try_into())
834 .collect::<Result<Vec<_>, _>>()?,
835 )
836 .await?
837 .context("split comm should include self rank")?;
838 comm_map.insert((stream_ref, dims), (size, Arc::new(split_comm)));
839 }
840 None => {
841 global_comm.split_from(cx, vec![]).await?;
844 }
845 }
846 Ok(())
847 }
848
849 async fn split_comm_for_process_group(
850 &mut self,
851 cx: &hyperactor::Context<Self>,
852 remote_process_group_ref: Ref,
853 stream_ref: StreamRef,
854 ) -> Result<()> {
855 ensure!(
856 self.streams.contains_key(&stream_ref),
857 "invalid stream id: {:#?}",
858 stream_ref
859 );
860 let global_comm = self
861 .comm
862 .as_ref()
863 .context("tried to call SplitComm before BackendNetworkInit")?;
864 let state = self
865 .remote_process_groups
866 .get_mut(&remote_process_group_ref)
867 .with_context(|| format!("invalid remote process group id: {:#?}", stream_ref))?;
868 match self.device_meshes.get_mut(&state.device_mesh_ref) {
869 Some((device_mesh, _)) => {
870 let entry = match state.comms.entry(stream_ref) {
873 Entry::Vacant(entry) => entry,
874 Entry::Occupied(_) => bail!(
875 "comm already exists for remote process group {:#?} on stream {:#?}",
876 remote_process_group_ref,
877 stream_ref,
878 ),
879 };
880 let ranks_for_group = device_mesh.get_ranks_for_dim_slice(&state.dims)?;
881 let split_comm = global_comm
882 .split_from(
883 cx,
884 ranks_for_group
885 .into_iter()
886 .map(|v| v.clone().try_into())
887 .collect::<Result<Vec<_>, _>>()?,
888 )
889 .await?
890 .context("split comm should include self rank")?;
891 entry.insert(Arc::new(split_comm));
892 }
893 None => {
894 global_comm.split_from(cx, vec![]).await?;
897 }
898 }
899 Ok(())
900 }
901
902 async fn pipe_recv(
903 &mut self,
904 _cx: &hyperactor::Context<Self>,
905 _seq: Seq,
906 _results: Vec<Option<Ref>>,
907 _pipe: Ref,
908 _stream: StreamRef,
909 ) -> Result<()> {
910 panic!("pipe_recv is no longer implemented")
911 }
912
913 async fn set_ref_unit_tests_only(
914 &mut self,
915 cx: &hyperactor::Context<Self>,
916 reference: Ref,
917 value: WireValue,
918 stream: StreamRef,
919 ) -> Result<()> {
920 let stream = self.try_get_stream(stream)?;
921
922 stream.set_ref_unit_tests_only(cx, reference, value).await
923 }
924
925 async fn get_ref_unit_tests_only(
926 &mut self,
927 cx: &hyperactor::Context<Self>,
928 ref_id: Ref,
929 stream: StreamRef,
930 ) -> Result<Option<Result<WireValue, String>>> {
931 let stream = self.try_get_stream(stream)?;
932 Ok(stream.get_ref_unit_tests_only(cx, ref_id.clone()).await?)
933 }
934
935 async fn define_recording(
936 &mut self,
937 cx: &hyperactor::Context<Self>,
938 result: Ref,
939 _nresults: usize,
940 _nformals: usize,
941 commands: Vec<WorkerMessage>,
942 ntotal_messages: usize,
943 index: usize,
944 ) -> Result<()> {
945 if self.defining_recording.is_some() && self.defining_recording.unwrap() != result {
946 bail!("already defining a different recording");
947 }
948 self.defining_recording = Some(result);
949
950 match self.recordings.entry(result) {
951 Entry::Vacant(entry) => {
952 ensure!(
953 index == 0,
954 "got DefineRecording message with (index = {:?}) > 0 for previously unseen recording",
955 index
956 );
957 entry.insert(Recording::PartialRecording {
958 last_index: 0,
959 commands,
960 });
961 }
962 Entry::Occupied(mut entry) => match entry.get_mut() {
963 Recording::CompleteRecording { .. } => {
964 bail!("got DefineRecording message for already complete recording")
965 }
966 Recording::PartialRecording {
967 last_index,
968 commands: existing_commands,
969 } => {
970 ensure!(
971 index == *last_index + 1,
972 "Got DefineRecording message with index = {:?}, but \
973 last seen index for recording is {:?}",
974 index,
975 last_index
976 );
977 *last_index = index;
978 existing_commands.extend(commands);
979 }
980 },
981 };
982
983 if index < ntotal_messages - 1 {
984 return Ok(());
985 }
986 let commands = match self.recordings.remove(&result).unwrap() {
987 Recording::CompleteRecording { .. } => panic!("unreachable, in theory"),
988 Recording::PartialRecording { commands, .. } => {
989 self.recordings.insert(
990 result,
991 Recording::CompleteRecording {
992 streams: HashSet::new(),
993 },
994 );
995 commands
996 }
997 };
998
999 for command in commands {
1000 WorkerMessageHandler::handle(self, cx, command).await?;
1001 }
1002
1003 match self.recordings.get(&result).unwrap() {
1004 Recording::PartialRecording { .. } => panic!("unreachable, in theory"),
1005 Recording::CompleteRecording { streams, .. } => {
1006 for stream in streams {
1007 self.try_get_stream(*stream)?
1008 .finalize_recording(cx, result)
1009 .await?;
1010 }
1011 }
1012 }
1013
1014 self.defining_recording = None;
1015 Ok(())
1016 }
1017
1018 async fn recording_formal(
1019 &mut self,
1020 cx: &hyperactor::Context<Self>,
1021 result: Ref,
1022 argument_index: usize,
1023 stream: StreamRef,
1024 ) -> Result<()> {
1025 ensure!(self.defining_recording.is_some());
1026 self.maybe_add_stream_to_recording(cx, stream).await?;
1027 self.try_get_stream(stream)?
1028 .recording_formal(cx, result, argument_index)
1029 .await
1030 }
1031
1032 async fn recording_result(
1033 &mut self,
1034 cx: &hyperactor::Context<Self>,
1035 result: Ref,
1036 output_index: usize,
1037 stream: StreamRef,
1038 ) -> Result<()> {
1039 ensure!(self.defining_recording.is_some());
1040 self.maybe_add_stream_to_recording(cx, stream).await?;
1041 self.try_get_stream(stream)?
1042 .recording_result(cx, result, output_index)
1043 .await
1044 }
1045
1046 async fn call_recording(
1047 &mut self,
1048 cx: &hyperactor::Context<Self>,
1049 seq: Seq,
1050 recording: Ref,
1051 results: Vec<Ref>,
1052 actuals: Vec<Ref>,
1053 ) -> Result<()> {
1054 ensure!(self.defining_recording.is_none());
1055 let recording_ref = recording;
1056 let recording = self.recordings.get(&recording).ok_or(anyhow::anyhow!(
1057 "could not find recording: {:#?}",
1058 recording
1059 ))?;
1060 match recording {
1061 Recording::PartialRecording { .. } => {
1062 bail!("cannot call recording because it is incomplete")
1063 }
1064 Recording::CompleteRecording { streams } => try_join_all(
1065 streams
1066 .iter()
1067 .map(|stream| self.try_get_stream(*stream))
1068 .collect::<Result<Vec<_>>>()?
1069 .into_iter()
1070 .map(|stream| {
1071 stream.call_recording(
1072 cx,
1073 seq,
1074 recording_ref,
1075 results.clone(),
1076 actuals.clone(),
1077 )
1078 }),
1079 )
1080 .await
1081 .map(|_| ()),
1082 }
1083 }
1084}
1085
1086#[cfg(test)]
1087mod tests {
1088 use std::assert_matches::assert_matches;
1089
1090 use anyhow::Result;
1091 use hyperactor::RemoteSpawn;
1092 use hyperactor::channel::ChannelAddr;
1093 use hyperactor::proc::Proc;
1094 use monarch_messages::controller::ControllerMessage;
1095 use monarch_messages::controller::WorkerError;
1096 use monarch_messages::worker::WorkerMessageClient;
1097 use monarch_types::PickledPyObject;
1098 use pyo3::Python;
1099 use pyo3::prelude::*;
1100 use pyo3::types::PyList;
1101 use pyo3::types::PyString;
1102 use rand::Rng;
1103 use rand::distributions::Alphanumeric;
1104 use timed_test::async_timed_test;
1105
1106 use super::*;
1107 use crate::test_util::test_setup;
1108
1109 #[async_timed_test(timeout_secs = 60)]
1110 async fn basic_worker() -> Result<()> {
1111 test_setup()?;
1112
1113 let proc = Proc::local();
1114 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1115
1116 let worker_handle = proc
1117 .spawn(
1118 "worker",
1119 WorkerActor::new(WorkerParams {
1120 world_size: 1,
1121 rank: 0,
1122 device_index: None,
1123 controller_actor: controller_ref,
1124 })
1125 .await
1126 .unwrap(),
1127 )
1128 .unwrap();
1129 worker_handle
1130 .command_group(
1131 &client,
1132 vec![
1133 WorkerMessage::CreateStream {
1134 id: 1.into(),
1135 stream_creation: StreamCreationMode::UseDefaultStream,
1136 },
1137 WorkerMessage::CallFunction(CallFunctionParams {
1138 seq: 0.into(),
1139 results: vec![Some(0.into())],
1140 mutates: vec![],
1141 function: "torch.ops.aten.ones.default".into(),
1142 args_kwargs: ArgsKwargs::from_wire_values(
1143 vec![WireValue::IntList(vec![2, 3])],
1144 HashMap::new(),
1145 )
1146 .unwrap(),
1147 stream: 1.into(),
1148 remote_process_groups: vec![],
1149 }),
1150 WorkerMessage::CallFunction(CallFunctionParams {
1151 seq: 2.into(),
1152 results: vec![Some(Ref { id: 2 })],
1153 mutates: vec![0.into()],
1154 function: "torch.ops.aten.sub_.Scalar".into(),
1155 args_kwargs: ArgsKwargs::from_wire_values(
1156 vec![WireValue::Ref(0.into()), WireValue::Int(1)],
1157 HashMap::new(),
1158 )
1159 .unwrap(),
1160 stream: 1.into(),
1161 remote_process_groups: vec![],
1162 }),
1163 WorkerMessage::CallFunction(CallFunctionParams {
1164 seq: 3.into(),
1165 results: vec![Some(Ref { id: 3 })],
1166 mutates: vec![],
1167 function: "torch.ops.aten.zeros.default".into(),
1168 args_kwargs: ArgsKwargs::from_wire_values(
1169 vec![WireValue::IntList(vec![2, 3])],
1170 HashMap::new(),
1171 )
1172 .unwrap(),
1173 stream: 1.into(),
1174 remote_process_groups: vec![],
1175 }),
1176 WorkerMessage::CallFunction(CallFunctionParams {
1177 seq: 4.into(),
1178 results: vec![Some(Ref { id: 4 })],
1179 mutates: vec![],
1180 function: "torch.ops.aten.allclose.default".into(),
1181 args_kwargs: ArgsKwargs::from_wire_values(
1182 vec![WireValue::Ref(0.into()), WireValue::Ref(Ref { id: 3 })],
1183 HashMap::new(),
1184 )
1185 .unwrap(),
1186 stream: 1.into(),
1187 remote_process_groups: vec![],
1188 }),
1189 ],
1190 )
1191 .await
1192 .unwrap();
1193
1194 let result: bool = worker_handle
1195 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1196 .await
1197 .unwrap()
1198 .unwrap()
1199 .unwrap()
1200 .try_into()
1201 .unwrap();
1202 worker_handle.drain_and_stop().unwrap();
1203 worker_handle.await;
1204 let error_responses = controller_rx.drain();
1205 assert!(
1206 error_responses.is_empty(),
1207 "Expected no error responses, got: {:#?}",
1208 error_responses
1209 );
1210 assert!(result);
1211
1212 Ok(())
1213 }
1214
1215 #[async_timed_test(timeout_secs = 60)]
1216 async fn error_sends_response() -> Result<()> {
1217 test_setup()?;
1218
1219 let proc = Proc::local();
1220 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1221
1222 let worker_handle = proc
1223 .spawn(
1224 "worker",
1225 WorkerActor::new(WorkerParams {
1226 world_size: 1,
1227 rank: 0,
1228 device_index: None,
1229 controller_actor: controller_ref,
1230 })
1231 .await
1232 .unwrap(),
1233 )
1234 .unwrap();
1235 worker_handle
1236 .command_group(
1237 &client,
1238 vec![
1239 WorkerMessage::CreateStream {
1240 id: 1.into(),
1241 stream_creation: StreamCreationMode::UseDefaultStream,
1242 },
1243 WorkerMessage::CallFunction(CallFunctionParams {
1244 seq: 0.into(),
1245 results: vec![Some(0.into())],
1246 mutates: vec![],
1247 function: "torch.ops.aten.rand.default".into(),
1248 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1249 stream: 1.into(),
1250 remote_process_groups: vec![],
1251 }),
1252 WorkerMessage::Exit { error: None },
1253 ],
1254 )
1255 .await
1256 .unwrap();
1257
1258 worker_handle.drain_and_stop().unwrap();
1259 worker_handle.await;
1260 let response_message = controller_rx.recv().await.unwrap();
1261 match response_message {
1262 ControllerMessage::RemoteFunctionFailed {
1263 seq,
1264 error: WorkerError { backtrace: msg, .. },
1265 } => {
1266 assert_eq!(seq, 0.into());
1267 assert!(msg.contains("aten::rand() is missing value for argument 'size'"))
1268 }
1269 _ => panic!("unexpected response {:#?}", response_message),
1270 }
1271
1272 Ok(())
1273 }
1274
1275 #[async_timed_test(timeout_secs = 60)]
1276 async fn mutated_refs_are_updated_with_error() -> Result<()> {
1277 test_setup()?;
1278
1279 let proc = Proc::local();
1280 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1281
1282 let worker_handle = proc
1283 .spawn(
1284 "worker",
1285 WorkerActor::new(WorkerParams {
1286 world_size: 1,
1287 rank: 0,
1288 device_index: None,
1289 controller_actor: controller_ref,
1290 })
1291 .await
1292 .unwrap(),
1293 )
1294 .unwrap();
1295 worker_handle
1296 .command_group(
1297 &client,
1298 vec![
1299 WorkerMessage::CreateStream {
1300 id: 1.into(),
1301 stream_creation: StreamCreationMode::UseDefaultStream,
1302 },
1303 WorkerMessage::SetRefUnitTestsOnly {
1304 reference: 0.into(),
1305 value: WireValue::Int(1),
1306 stream: 1.into(),
1307 },
1308 WorkerMessage::CallFunction(CallFunctionParams {
1309 seq: 0.into(),
1310 results: vec![Some(Ref { id: 2 })],
1311 mutates: vec![0.into()],
1312 function: "i.dont.exist".into(),
1313 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1314 stream: 1.into(),
1315 remote_process_groups: vec![],
1316 }),
1317 ],
1318 )
1319 .await
1320 .unwrap();
1321
1322 let result = worker_handle
1323 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1324 .await?;
1325
1326 worker_handle.drain_and_stop().unwrap();
1328 worker_handle.await;
1329
1330 let mutated_ref = result
1331 .context("no such ref")?
1332 .err()
1333 .context("expected error")?;
1334 assert!(mutated_ref.contains("failed to resolve function"));
1335
1336 let responses = controller_rx.drain();
1337 assert_eq!(
1338 responses.len(),
1339 1,
1340 "Expected one response, got: {:#?}",
1341 responses
1342 );
1343 Ok(())
1344 }
1345
1346 #[async_timed_test(timeout_secs = 60)]
1347 async fn accessing_errored_dependency() -> Result<()> {
1348 test_setup()?;
1349
1350 let proc = Proc::local();
1351 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1352
1353 let worker_handle = proc
1354 .spawn(
1355 "worker",
1356 WorkerActor::new(WorkerParams {
1357 world_size: 1,
1358 rank: 0,
1359 device_index: None,
1360 controller_actor: controller_ref,
1361 })
1362 .await
1363 .unwrap(),
1364 )
1365 .unwrap();
1366 worker_handle
1367 .command_group(
1368 &client,
1369 vec![
1370 WorkerMessage::CreateStream {
1371 id: 1.into(),
1372 stream_creation: StreamCreationMode::UseDefaultStream,
1373 },
1374 WorkerMessage::CallFunction(CallFunctionParams {
1375 seq: 0.into(),
1376 results: vec![Some(0.into())],
1377 mutates: vec![],
1378 function: "i.dont.exist".into(),
1379 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1380 stream: 1.into(),
1381 remote_process_groups: vec![],
1382 }),
1383 WorkerMessage::CallFunction(CallFunctionParams {
1384 seq: 1.into(),
1385 results: vec![Some(1.into())],
1386 mutates: vec![],
1387 function: "torch.ops.aten.sub_.Scalar".into(),
1388 args_kwargs: ArgsKwargs::from_wire_values(
1389 vec![WireValue::Ref(0.into())],
1390 HashMap::new(),
1391 )
1392 .unwrap(),
1393 stream: 1.into(),
1394 remote_process_groups: vec![],
1395 }),
1396 WorkerMessage::Exit { error: None },
1397 ],
1398 )
1399 .await
1400 .unwrap();
1401
1402 worker_handle.drain_and_stop().unwrap();
1403 worker_handle.await;
1404
1405 let responses = controller_rx.drain();
1406 assert_eq!(
1407 responses.len(),
1408 1,
1409 "Expected one response, got: {:#?}",
1410 responses
1411 );
1412
1413 match &responses[0] {
1414 ControllerMessage::RemoteFunctionFailed { seq, .. } => {
1415 assert_eq!(seq, &0.into())
1416 }
1417 _ => panic!("unexpected response {:#?}", responses[0]),
1418 };
1419 Ok(())
1420 }
1421
1422 #[async_timed_test(timeout_secs = 60)]
1423 async fn py_remote_function_calls() -> Result<()> {
1424 test_setup()?;
1425
1426 let proc = Proc::local();
1427 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1428
1429 let worker_handle = proc
1430 .spawn(
1431 "worker",
1432 WorkerActor::new(WorkerParams {
1433 world_size: 1,
1434 rank: 0,
1435 device_index: None,
1436 controller_actor: controller_ref,
1437 })
1438 .await
1439 .unwrap(),
1440 )
1441 .unwrap();
1442 let (split_arg, sort_list, dim, layout, none, scalar, device, memory_format) =
1443 Python::with_gil(|py| {
1444 let split_arg: PickledPyObject = PyString::new(py, "/fbs/fbc/foo/bar")
1445 .into_any()
1446 .try_into()?;
1447 let sort_list: PickledPyObject =
1448 PyList::new(py, [65, 34, 79, 1, 5])?.into_any().try_into()?;
1449 let dim: PickledPyObject = PyString::new(py, "x").into_any().try_into()?;
1450 let layout: PickledPyObject = py.import("torch")?.getattr("strided")?.try_into()?;
1451 let none: PickledPyObject = py.None().into_any().into_bound(py).try_into()?;
1452 let scalar: PickledPyObject = py.import("torch")?.getattr("float32")?.try_into()?;
1453 let device: PickledPyObject = py
1454 .import("torch")?
1455 .getattr("device")?
1456 .call1(("cuda:1",))?
1457 .try_into()?;
1458 let memory_format: PickledPyObject = py
1459 .import("torch")?
1460 .getattr("contiguous_format")?
1461 .try_into()?;
1462 PyResult::Ok((
1463 split_arg,
1464 sort_list,
1465 dim,
1466 layout,
1467 none,
1468 scalar,
1469 device,
1470 memory_format,
1471 ))
1472 })?;
1473
1474 worker_handle
1475 .command_group(
1476 &client,
1477 vec![
1478 WorkerMessage::CreateStream {
1479 id: 1.into(),
1480 stream_creation: StreamCreationMode::UseDefaultStream,
1481 },
1482 WorkerMessage::CallFunction(CallFunctionParams {
1483 seq: 0.into(),
1484 results: vec![Some(0.into()), Some(Ref { id: 2 })],
1485 mutates: vec![],
1486 function: "os.path.split".into(),
1487 args_kwargs: ArgsKwargs::from_wire_values(
1488 vec![split_arg.into()],
1489 HashMap::new(),
1490 )
1491 .unwrap(),
1492 stream: 1.into(),
1493 remote_process_groups: vec![],
1494 }),
1495 WorkerMessage::CallFunction(CallFunctionParams {
1496 seq: 2.into(),
1497 results: vec![Some(4.into()), None, None, None, None],
1498 mutates: vec![],
1499 function: "builtins.sorted".into(),
1500 args_kwargs: ArgsKwargs::from_wire_values(
1501 vec![sort_list.into()],
1502 HashMap::new(),
1503 )
1504 .unwrap(),
1505 stream: 1.into(),
1506 remote_process_groups: vec![],
1507 }),
1508 WorkerMessage::CreateDeviceMesh {
1509 result: 5.into(),
1510 names: vec!["x".into()],
1511 ranks: Slice::new(0, vec![2], vec![1]).unwrap(),
1512 },
1513 WorkerMessage::CallFunction(CallFunctionParams {
1514 seq: 2.into(),
1515 results: vec![Some(6.into())],
1516 mutates: vec![],
1517 function: "monarch.monarch_tensor_worker.test_utils.mesh_rank".into(),
1518 args_kwargs: ArgsKwargs::from_wire_values(
1519 vec![WireValue::Ref(Ref { id: 5 }), dim.into()],
1520 HashMap::new(),
1521 )
1522 .unwrap(),
1523 stream: 1.into(),
1524 remote_process_groups: vec![],
1525 }),
1526 WorkerMessage::CallFunction(CallFunctionParams {
1527 seq: 4.into(),
1528 results: vec![Some(7.into())],
1529 mutates: vec![],
1530 function: "monarch.monarch_tensor_worker.test_utils.test_scalar_type"
1531 .into(),
1532 args_kwargs: ArgsKwargs::from_wire_values(
1533 vec![scalar.into()],
1534 HashMap::new(),
1535 )
1536 .unwrap(),
1537 stream: 1.into(),
1538 remote_process_groups: vec![],
1539 }),
1540 WorkerMessage::CallFunction(CallFunctionParams {
1541 seq: 5.into(),
1542 results: vec![Some(8.into())],
1543 mutates: vec![],
1544 function: "monarch.monarch_tensor_worker.test_utils.test_layout".into(),
1545 args_kwargs: ArgsKwargs::from_wire_values(
1546 vec![layout.into()],
1547 HashMap::new(),
1548 )
1549 .unwrap(),
1550 stream: 1.into(),
1551 remote_process_groups: vec![],
1552 }),
1553 WorkerMessage::CallFunction(CallFunctionParams {
1554 seq: 6.into(),
1555 results: vec![Some(9.into())],
1556 mutates: vec![],
1557 function: "monarch.monarch_tensor_worker.test_utils.test_none".into(),
1558 args_kwargs: ArgsKwargs::from_wire_values(
1559 vec![none.into()],
1560 HashMap::new(),
1561 )
1562 .unwrap(),
1563 stream: 1.into(),
1564 remote_process_groups: vec![],
1565 }),
1566 WorkerMessage::CallFunction(CallFunctionParams {
1569 seq: 7.into(),
1570 results: vec![None],
1571 mutates: vec![],
1572 function: "monarch.monarch_tensor_worker.test_utils.none".into(),
1573 args_kwargs: ArgsKwargs::from_wire_values(vec![], HashMap::new()).unwrap(),
1574 stream: 1.into(),
1575 remote_process_groups: vec![],
1576 }),
1577 WorkerMessage::CallFunction(CallFunctionParams {
1578 seq: 8.into(),
1579 results: vec![Some(10.into())],
1580 mutates: vec![],
1581 function: "monarch.monarch_tensor_worker.test_utils.test_device".into(),
1582 args_kwargs: ArgsKwargs::from_wire_values(
1583 vec![device.into()],
1584 HashMap::new(),
1585 )
1586 .unwrap(),
1587 stream: 1.into(),
1588 remote_process_groups: vec![],
1589 }),
1590 WorkerMessage::CallFunction(CallFunctionParams {
1591 seq: 9.into(),
1592 results: vec![Some(11.into())],
1593 mutates: vec![],
1594 function: "monarch.monarch_tensor_worker.test_utils.test_memory_format"
1595 .into(),
1596 args_kwargs: ArgsKwargs::from_wire_values(
1597 vec![memory_format.into()],
1598 HashMap::new(),
1599 )
1600 .unwrap(),
1601 stream: 1.into(),
1602 remote_process_groups: vec![],
1603 }),
1604 WorkerMessage::CallFunction(CallFunctionParams {
1606 seq: 10.into(),
1607 results: vec![Some(12.into())],
1608 mutates: vec![],
1609 function: "torch.ops.aten.ones.default".into(),
1610 args_kwargs: ArgsKwargs::from_wire_values(
1611 vec![WireValue::IntList(vec![2, 3])],
1612 HashMap::new(),
1613 )
1614 .unwrap(),
1615 stream: 1.into(),
1616 remote_process_groups: vec![],
1617 }),
1618 WorkerMessage::CallFunction(CallFunctionParams {
1619 seq: 11.into(),
1620 results: vec![Some(13.into())],
1621 mutates: vec![],
1622 function: "torch.ops.aten.stack.default".into(),
1623 args_kwargs: ArgsKwargs::from_wire_values(
1624 vec![WireValue::RefList(vec![12.into(), 12.into()])],
1625 HashMap::new(),
1626 )
1627 .unwrap(),
1628 stream: 1.into(),
1629 remote_process_groups: vec![],
1630 }),
1631 ],
1632 )
1633 .await
1634 .unwrap();
1635
1636 let result1: String = worker_handle
1637 .get_ref_unit_tests_only(&client, 0.into(), 1.into())
1638 .await
1639 .unwrap()
1640 .unwrap()
1641 .unwrap()
1642 .try_into()
1643 .unwrap();
1644 let result2: String = worker_handle
1645 .get_ref_unit_tests_only(&client, 2.into(), 1.into())
1646 .await
1647 .unwrap()
1648 .unwrap()
1649 .unwrap()
1650 .try_into()
1651 .unwrap();
1652 let result3: i64 = worker_handle
1653 .get_ref_unit_tests_only(&client, 4.into(), 1.into())
1654 .await
1655 .unwrap()
1656 .unwrap()
1657 .unwrap()
1658 .try_into()
1659 .unwrap();
1660 let result4: i64 = worker_handle
1661 .get_ref_unit_tests_only(&client, 6.into(), 1.into())
1662 .await
1663 .unwrap()
1664 .unwrap()
1665 .unwrap()
1666 .try_into()
1667 .unwrap();
1668 worker_handle
1669 .get_ref_unit_tests_only(&client, 7.into(), 1.into())
1670 .await
1671 .unwrap()
1672 .unwrap()
1673 .unwrap();
1674
1675 worker_handle
1676 .get_ref_unit_tests_only(&client, 8.into(), 1.into())
1677 .await
1678 .unwrap()
1679 .unwrap()
1680 .unwrap();
1681
1682 assert_matches!(
1683 worker_handle
1684 .get_ref_unit_tests_only(&client, 9.into(), 1.into())
1685 .await
1686 .unwrap()
1687 .unwrap()
1688 .unwrap(),
1689 WireValue::None(()),
1690 );
1691 worker_handle
1692 .get_ref_unit_tests_only(&client, 10.into(), 1.into())
1693 .await
1694 .unwrap()
1695 .unwrap()
1696 .unwrap();
1697 worker_handle
1698 .get_ref_unit_tests_only(&client, 11.into(), 1.into())
1699 .await
1700 .unwrap()
1701 .unwrap()
1702 .unwrap();
1703
1704 worker_handle.drain_and_stop().unwrap();
1705 worker_handle.await;
1706 let error_responses = controller_rx.drain();
1707 assert!(
1708 error_responses.is_empty(),
1709 "Expected no error responses, got: {:#?}",
1710 error_responses
1711 );
1712
1713 assert_eq!(result1, "/fbs/fbc/foo");
1714 assert_eq!(result2, "bar");
1715 assert_eq!(result3, 1);
1716 assert_eq!(result4, 0);
1717
1718 Ok(())
1719 }
1720
1721 #[async_timed_test(timeout_secs = 60)]
1722 async fn delete_refs() -> Result<()> {
1723 test_setup()?;
1724
1725 let proc = Proc::local();
1726 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1727
1728 let worker_handle = proc
1729 .spawn(
1730 "worker",
1731 WorkerActor::new(WorkerParams {
1732 world_size: 1,
1733 rank: 0,
1734 device_index: None,
1735 controller_actor: controller_ref,
1736 })
1737 .await
1738 .unwrap(),
1739 )
1740 .unwrap();
1741 worker_handle
1742 .command_group(
1743 &client,
1744 vec![
1745 WorkerMessage::CreateStream {
1746 id: 0.into(),
1747 stream_creation: StreamCreationMode::CreateNewStream,
1748 },
1749 WorkerMessage::CreateStream {
1750 id: 1.into(),
1751 stream_creation: StreamCreationMode::CreateNewStream,
1752 },
1753 WorkerMessage::SetRefUnitTestsOnly {
1754 reference: Ref { id: 2 },
1755 value: WireValue::Bool(false),
1756 stream: 0.into(),
1757 },
1758 WorkerMessage::SetRefUnitTestsOnly {
1759 reference: Ref { id: 3 },
1760 value: WireValue::Bool(true),
1761 stream: 0.into(),
1762 },
1763 WorkerMessage::SetRefUnitTestsOnly {
1764 reference: Ref { id: 4 },
1765 value: WireValue::Int(0),
1766 stream: 1.into(),
1767 },
1768 WorkerMessage::DeleteRefs(vec![Ref { id: 2 }, Ref { id: 4 }]),
1769 ],
1770 )
1771 .await
1772 .unwrap();
1773
1774 let result: bool = worker_handle
1775 .get_ref_unit_tests_only(&client, Ref { id: 3 }, 0.into())
1776 .await
1777 .unwrap()
1778 .unwrap()
1779 .unwrap()
1780 .try_into()
1781 .unwrap();
1782 let fail_result = worker_handle
1783 .get_ref_unit_tests_only(&client, Ref { id: 4 }, 1.into())
1784 .await
1785 .unwrap();
1786
1787 worker_handle.drain_and_stop().unwrap();
1788 worker_handle.await;
1789
1790 assert!(result, "should be able to get a non-deleted ref");
1791 assert!(fail_result.is_none(), "should fail to get a deleted ref");
1792
1793 Ok(())
1794 }
1795
1796 #[async_timed_test(timeout_secs = 60)]
1797 async fn request_status() -> Result<()> {
1798 test_setup()?;
1799
1800 let proc = Proc::local();
1801 let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller").unwrap();
1802
1803 let worker_handle = proc
1804 .spawn(
1805 "worker",
1806 WorkerActor::new(WorkerParams {
1807 world_size: 1,
1808 rank: 0,
1809 device_index: None,
1810 controller_actor: controller_ref,
1811 })
1812 .await
1813 .unwrap(),
1814 )
1815 .unwrap();
1816 worker_handle
1817 .command_group(
1818 &client,
1819 vec![
1820 WorkerMessage::CreateStream {
1821 id: 0.into(),
1822 stream_creation: StreamCreationMode::CreateNewStream,
1823 },
1824 WorkerMessage::CreateStream {
1825 id: 1.into(),
1826 stream_creation: StreamCreationMode::CreateNewStream,
1827 },
1828 ],
1829 )
1830 .await
1831 .unwrap();
1832
1833 for i in 0..100 {
1834 worker_handle
1836 .call_function(
1837 &client,
1838 CallFunctionParams {
1839 seq: i.into(),
1840 results: vec![Some(Ref { id: i + 2 })],
1841 mutates: vec![],
1842 function: "torch.ops.aten.ones.default".into(),
1843 args_kwargs: ArgsKwargs::from_wire_values(
1844 vec![WireValue::IntList(vec![2, 3])],
1845 HashMap::new(),
1846 )
1847 .unwrap(),
1848 stream: (i % 2).into(),
1849 remote_process_groups: vec![],
1850 },
1851 )
1852 .await
1853 .unwrap();
1854 }
1855
1856 worker_handle
1857 .request_status(&client, 100.into(), false)
1858 .await
1859 .unwrap();
1860
1861 worker_handle.drain_and_stop().unwrap();
1862 worker_handle.await;
1863
1864 let mut responses = controller_rx.drain();
1865 assert_eq!(
1866 responses.len(),
1867 1,
1868 "Expected one response, got: {:#?}",
1869 responses
1870 );
1871
1872 let response = responses.pop().unwrap();
1873 match response {
1874 ControllerMessage::Status { seq, .. } => {
1875 assert_eq!(seq, 101.into())
1876 }
1877 _ => panic!("unexpected response {:#?}", response),
1878 };
1879
1880 Ok(())
1881 }
1882
1883 #[async_timed_test(timeout_secs = 60)]
1884 async fn backend_network_init() {
1885 test_setup().unwrap();
1886 let proc = Proc::local();
1887 let (client, controller_ref, _) = proc.attach_actor("controller").unwrap();
1888
1889 let worker_handle1 = proc
1890 .spawn(
1891 "worker0",
1892 WorkerActor::new(WorkerParams {
1893 world_size: 2,
1894 rank: 0,
1895 device_index: Some(0),
1896 controller_actor: controller_ref.clone(),
1897 })
1898 .await
1899 .unwrap(),
1900 )
1901 .unwrap();
1902 let worker_handle2 = proc
1903 .spawn(
1904 "worker1",
1905 WorkerActor::new(WorkerParams {
1906 world_size: 2,
1907 rank: 1,
1908 device_index: Some(1),
1909 controller_actor: controller_ref,
1910 })
1911 .await
1912 .unwrap(),
1913 )
1914 .unwrap();
1915
1916 let unique_id = UniqueId::new().unwrap();
1917 worker_handle1
1918 .backend_network_init(&client, unique_id.clone())
1919 .await
1920 .unwrap();
1921 worker_handle2
1922 .backend_network_init(&client, unique_id)
1923 .await
1924 .unwrap();
1925
1926 worker_handle1.drain_and_stop().unwrap();
1927 worker_handle1.await;
1928 worker_handle2.drain_and_stop().unwrap();
1929 worker_handle2.await;
1930 }
1931
1932 #[allow(dead_code)]
1933 fn get_random_channel_addr() -> ChannelAddr {
1934 let random_string = rand::thread_rng()
1935 .sample_iter(&Alphanumeric)
1936 .take(24)
1937 .map(char::from)
1938 .collect::<String>();
1939 format!("unix!@{random_string}").parse().unwrap()
1940 }
1941}