hyperactor_mesh/
shared_cell.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt::Debug;
10use std::ops::Deref;
11use std::sync::Arc;
12use std::sync::atomic::AtomicUsize;
13use std::sync::atomic::Ordering;
14
15use async_trait::async_trait;
16use dashmap::DashMap;
17use futures::future::join_all;
18use futures::future::try_join_all;
19use preempt_rwlock::OwnedPreemptibleRwLockReadGuard;
20use preempt_rwlock::PreemptibleRwLock;
21use tokio::sync::TryLockError;
22
23#[derive(thiserror::Error, Debug)]
24pub struct EmptyCellError {}
25
26impl From<TryLockError> for EmptyCellError {
27    fn from(_err: TryLockError) -> Self {
28        Self {}
29    }
30}
31
32impl std::fmt::Display for EmptyCellError {
33    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34        write!(f, "already taken")
35    }
36}
37
38#[derive(thiserror::Error, Debug)]
39pub enum TryTakeError {
40    #[error("already taken")]
41    Empty,
42    #[error("cannot lock: {0}")]
43    TryLockError(#[from] TryLockError),
44}
45
46struct PoolRef {
47    map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
48    key: usize,
49}
50
51impl Debug for PoolRef {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("PoolRef").field("key", &self.key).finish()
54    }
55}
56
57#[derive(Debug)]
58struct Inner<T> {
59    value: Option<T>,
60    pool: Option<PoolRef>,
61}
62
63impl<T> Drop for Inner<T> {
64    fn drop(&mut self) {
65        if let Some(pool) = &self.pool {
66            pool.map.remove(&pool.key);
67        }
68    }
69}
70
71/// A wrapper class that facilitates sharing an item across different users, supporting:
72/// - Ability grab a reference-counted reference to the item
73/// - Ability to consume the item, leaving the cell in an unusable state
74#[derive(Debug)]
75pub struct SharedCell<T> {
76    inner: Arc<PreemptibleRwLock<Inner<T>>>,
77}
78
79impl<T> Clone for SharedCell<T> {
80    fn clone(&self) -> Self {
81        Self {
82            inner: self.inner.clone(),
83        }
84    }
85}
86
87impl<T> From<T> for SharedCell<T> {
88    fn from(value: T) -> Self {
89        Self {
90            inner: Arc::new(PreemptibleRwLock::new(Inner {
91                value: Some(value),
92                pool: None,
93            })),
94        }
95    }
96}
97
98impl<T> SharedCell<T> {
99    fn with_pool(value: T, pool: PoolRef) -> Self {
100        Self {
101            inner: Arc::new(PreemptibleRwLock::new(Inner {
102                value: Some(value),
103                pool: Some(pool),
104            })),
105        }
106    }
107}
108
109pub struct SharedCellRef<T, U = T> {
110    guard: OwnedPreemptibleRwLockReadGuard<Inner<T>, U>,
111}
112
113impl<T> SharedCellRef<T> {
114    fn from(guard: OwnedPreemptibleRwLockReadGuard<Inner<T>>) -> Result<Self, EmptyCellError> {
115        if guard.value.is_none() {
116            return Err(EmptyCellError {});
117        }
118        Ok(Self {
119            guard: OwnedPreemptibleRwLockReadGuard::map(guard, |guard| {
120                guard.value.as_ref().unwrap()
121            }),
122        })
123    }
124
125    pub fn map<F, U>(self, f: F) -> SharedCellRef<T, U>
126    where
127        F: FnOnce(&T) -> &U,
128    {
129        SharedCellRef {
130            guard: OwnedPreemptibleRwLockReadGuard::map(self.guard, f),
131        }
132    }
133
134    pub async fn preempted(&self) {
135        self.guard.preempted().await
136    }
137}
138
139impl<T, U: Debug> Debug for SharedCellRef<T, U> {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        Debug::fmt(&**self, f)
142    }
143}
144
145impl<T, U> Deref for SharedCellRef<T, U> {
146    type Target = U;
147
148    fn deref(&self) -> &Self::Target {
149        &self.guard
150    }
151}
152
153impl<T> SharedCell<T> {
154    /// Borrow the cell, returning a reference to the item. If the cell is empty, returns an error.
155    /// While references are held, the cell cannot be taken below.
156    pub fn borrow(&self) -> Result<SharedCellRef<T>, EmptyCellError> {
157        SharedCellRef::from(self.inner.clone().try_read_owned()?)
158    }
159
160    /// Execute given closure with write access to the underlying data. If the cell is empty, returns an error.
161    pub async fn with_mut<F, R>(&self, f: F) -> Result<R, EmptyCellError>
162    where
163        F: FnOnce(&mut T) -> R,
164    {
165        let mut inner = self.inner.write(true).await;
166        let value = inner.value.as_mut().ok_or(EmptyCellError {})?;
167        Ok(f(value))
168    }
169
170    /// Take the item out of the cell, leaving it in an unusable state.
171    pub async fn take(&self) -> Result<T, EmptyCellError> {
172        let mut inner = self.inner.write(true).await;
173        inner.value.take().ok_or(EmptyCellError {})
174    }
175
176    pub fn blocking_take(&self) -> Result<T, EmptyCellError> {
177        let mut inner = self.inner.blocking_write(true);
178        inner.value.take().ok_or(EmptyCellError {})
179    }
180
181    pub fn try_take(&self) -> Result<T, TryTakeError> {
182        let mut inner = self.inner.try_write(true)?;
183        inner.value.take().ok_or(TryTakeError::Empty)
184    }
185}
186
187/// A pool of `SharedCell`s which can be used to mass `take()` and discard them all at once.
188pub struct SharedCellPool {
189    map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
190    token: AtomicUsize,
191}
192
193impl SharedCellPool {
194    pub fn new() -> Self {
195        Self {
196            map: Arc::new(DashMap::new()),
197            token: AtomicUsize::new(0),
198        }
199    }
200
201    pub fn insert<T>(&self, value: T) -> SharedCell<T>
202    where
203        T: Send + Sync + 'static,
204    {
205        let map = self.map.clone();
206        let key = self.token.fetch_add(1, Ordering::Relaxed);
207        let pool = PoolRef { map, key };
208        let value: SharedCell<_> = SharedCell::with_pool(value, pool);
209        self.map.entry(key).insert(Arc::new(value.clone()));
210        value
211    }
212
213    /// Run `take` on all cells in the pool and immediately drop them.
214    pub async fn discard_all(self) -> Result<(), EmptyCellError> {
215        try_join_all(
216            self.map
217                .iter()
218                .map(|r| async move { r.value().discard().await }),
219        )
220        .await?;
221        Ok(())
222    }
223
224    /// Run `take` on all cells in the pool and immediately drop them or produce an error if the cell has already been taken
225    pub async fn discard_or_error_all(self) -> Vec<Result<(), EmptyCellError>> {
226        join_all(
227            self.map
228                .iter()
229                .map(|r| async move { r.value().discard().await }),
230        )
231        .await
232    }
233}
234
235/// Trait to facilitate storing `SharedCell`s of different types in a single pool.
236#[async_trait]
237pub trait SharedCellDiscard {
238    async fn discard(&self) -> Result<(), EmptyCellError>;
239    fn blocking_discard(&self) -> Result<(), EmptyCellError>;
240    fn try_discard(&self) -> Result<(), TryTakeError>;
241}
242
243#[async_trait]
244impl<T: Send + Sync> SharedCellDiscard for SharedCell<T> {
245    fn try_discard(&self) -> Result<(), TryTakeError> {
246        self.try_take()?;
247        Ok(())
248    }
249
250    async fn discard(&self) -> Result<(), EmptyCellError> {
251        self.take().await?;
252        Ok(())
253    }
254
255    fn blocking_discard(&self) -> Result<(), EmptyCellError> {
256        self.blocking_take()?;
257        Ok(())
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use anyhow::Result;
264
265    use super::*;
266
267    #[tokio::test]
268    async fn borrow_after_take() -> Result<()> {
269        let cell = SharedCell::from(0);
270        let _ = cell.take().await;
271        assert!(cell.borrow().is_err());
272        Ok(())
273    }
274
275    #[tokio::test]
276    async fn take_after_borrow() -> Result<()> {
277        let cell = SharedCell::from(0);
278        let b = cell.borrow()?;
279        assert!(cell.try_take().is_err());
280        std::mem::drop(b);
281        cell.try_take()?;
282        Ok(())
283    }
284}