hyperactor_mesh/
shared_cell.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt::Debug;
10use std::ops::Deref;
11use std::sync::Arc;
12use std::sync::atomic::AtomicUsize;
13use std::sync::atomic::Ordering;
14
15use async_trait::async_trait;
16use dashmap::DashMap;
17use futures::future::join_all;
18use futures::future::try_join_all;
19use preempt_rwlock::OwnedPreemptibleRwLockReadGuard;
20use preempt_rwlock::PreemptibleRwLock;
21use tokio::sync::TryLockError;
22
23#[derive(thiserror::Error, Debug)]
24pub struct EmptyCellError {}
25
26impl From<TryLockError> for EmptyCellError {
27    fn from(_err: TryLockError) -> Self {
28        Self {}
29    }
30}
31
32impl std::fmt::Display for EmptyCellError {
33    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34        write!(f, "already taken")
35    }
36}
37
38#[derive(thiserror::Error, Debug)]
39pub enum TryTakeError {
40    #[error("already taken")]
41    Empty,
42    #[error("cannot lock: {0}")]
43    TryLockError(#[from] TryLockError),
44}
45
46struct PoolRef {
47    map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
48    key: usize,
49}
50
51impl Debug for PoolRef {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("PoolRef").field("key", &self.key).finish()
54    }
55}
56
57#[derive(Debug)]
58struct Inner<T> {
59    value: Option<T>,
60    pool: Option<PoolRef>,
61}
62
63impl<T> Drop for Inner<T> {
64    fn drop(&mut self) {
65        if let Some(pool) = &self.pool {
66            pool.map.remove(&pool.key);
67        }
68    }
69}
70
71/// A wrapper class that facilitates sharing an item across different users, supporting:
72/// - Ability grab a reference-counted reference to the item
73/// - Ability to consume the item, leaving the cell in an unusable state
74#[derive(Debug)]
75pub struct SharedCell<T> {
76    inner: Arc<PreemptibleRwLock<Inner<T>>>,
77}
78
79impl<T> Clone for SharedCell<T> {
80    fn clone(&self) -> Self {
81        Self {
82            inner: self.inner.clone(),
83        }
84    }
85}
86
87impl<T> From<T> for SharedCell<T> {
88    fn from(value: T) -> Self {
89        Self {
90            inner: Arc::new(PreemptibleRwLock::new(Inner {
91                value: Some(value),
92                pool: None,
93            })),
94        }
95    }
96}
97
98impl<T> SharedCell<T> {
99    fn with_pool(value: T, pool: PoolRef) -> Self {
100        Self {
101            inner: Arc::new(PreemptibleRwLock::new(Inner {
102                value: Some(value),
103                pool: Some(pool),
104            })),
105        }
106    }
107}
108
109pub struct SharedCellRef<T, U = T> {
110    guard: OwnedPreemptibleRwLockReadGuard<Inner<T>, U>,
111}
112
113impl<T> SharedCellRef<T> {
114    fn from(guard: OwnedPreemptibleRwLockReadGuard<Inner<T>>) -> Result<Self, EmptyCellError> {
115        if guard.value.is_none() {
116            return Err(EmptyCellError {});
117        }
118        Ok(Self {
119            guard: OwnedPreemptibleRwLockReadGuard::map(guard, |guard| {
120                guard.value.as_ref().unwrap()
121            }),
122        })
123    }
124
125    pub fn map<F, U>(self, f: F) -> SharedCellRef<T, U>
126    where
127        F: FnOnce(&T) -> &U,
128    {
129        SharedCellRef {
130            guard: OwnedPreemptibleRwLockReadGuard::map(self.guard, f),
131        }
132    }
133
134    pub async fn preempted(&self) {
135        self.guard.preempted().await
136    }
137}
138
139impl<T, U: Debug> Debug for SharedCellRef<T, U> {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        Debug::fmt(&**self, f)
142    }
143}
144
145impl<T, U> Deref for SharedCellRef<T, U> {
146    type Target = U;
147
148    fn deref(&self) -> &Self::Target {
149        &self.guard
150    }
151}
152
153impl<T> SharedCell<T> {
154    /// Borrow the cell, returning a reference to the item. If the cell is empty, returns an error.
155    /// While references are held, the cell cannot be taken below.
156    pub fn borrow(&self) -> Result<SharedCellRef<T>, EmptyCellError> {
157        SharedCellRef::from(self.inner.clone().try_read_owned()?)
158    }
159
160    /// Execute given closure with write access to the underlying data. If the cell is empty, returns an error.
161    pub async fn with_mut<F, R>(&self, f: F) -> Result<R, EmptyCellError>
162    where
163        F: FnOnce(&mut T) -> R,
164    {
165        let mut inner = self.inner.write(true).await;
166        let value = inner.value.as_mut().ok_or(EmptyCellError {})?;
167        Ok(f(value))
168    }
169
170    /// Non-async variant of [`with_mut`](Self::with_mut). Returns
171    /// `Err` if the write lock cannot be acquired immediately or if
172    /// the cell has already been taken.
173    pub fn try_with_mut<F, R>(&self, f: F) -> Result<R, EmptyCellError>
174    where
175        F: FnOnce(&mut T) -> R,
176    {
177        let mut inner = self.inner.try_write(false).map_err(|_| EmptyCellError {})?;
178        let value = inner.value.as_mut().ok_or(EmptyCellError {})?;
179        Ok(f(value))
180    }
181
182    /// Take the item out of the cell, leaving it in an unusable state.
183    pub async fn take(&self) -> Result<T, EmptyCellError> {
184        let mut inner = self.inner.write(true).await;
185        inner.value.take().ok_or(EmptyCellError {})
186    }
187
188    pub fn blocking_take(&self) -> Result<T, EmptyCellError> {
189        let mut inner = self.inner.blocking_write(true);
190        inner.value.take().ok_or(EmptyCellError {})
191    }
192
193    pub fn try_take(&self) -> Result<T, TryTakeError> {
194        let mut inner = self.inner.try_write(true)?;
195        inner.value.take().ok_or(TryTakeError::Empty)
196    }
197}
198
199/// A pool of `SharedCell`s which can be used to mass `take()` and discard them all at once.
200pub struct SharedCellPool {
201    map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
202    token: AtomicUsize,
203}
204
205impl Default for SharedCellPool {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211impl SharedCellPool {
212    pub fn new() -> Self {
213        Self {
214            map: Arc::new(DashMap::new()),
215            token: AtomicUsize::new(0),
216        }
217    }
218
219    pub fn insert<T>(&self, value: T) -> SharedCell<T>
220    where
221        T: Send + Sync + 'static,
222    {
223        let map = self.map.clone();
224        let key = self.token.fetch_add(1, Ordering::Relaxed);
225        let pool = PoolRef { map, key };
226        let value: SharedCell<_> = SharedCell::with_pool(value, pool);
227        self.map.entry(key).insert(Arc::new(value.clone()));
228        value
229    }
230
231    /// Run `take` on all cells in the pool and immediately drop them.
232    pub async fn discard_all(self) -> Result<(), EmptyCellError> {
233        try_join_all(
234            self.map
235                .iter()
236                .map(|r| async move { r.value().discard().await }),
237        )
238        .await?;
239        Ok(())
240    }
241
242    /// Run `take` on all cells in the pool and immediately drop them or produce an error if the cell has already been taken
243    pub async fn discard_or_error_all(self) -> Vec<Result<(), EmptyCellError>> {
244        join_all(
245            self.map
246                .iter()
247                .map(|r| async move { r.value().discard().await }),
248        )
249        .await
250    }
251}
252
253/// Trait to facilitate storing `SharedCell`s of different types in a single pool.
254#[async_trait]
255pub trait SharedCellDiscard {
256    async fn discard(&self) -> Result<(), EmptyCellError>;
257    fn blocking_discard(&self) -> Result<(), EmptyCellError>;
258    fn try_discard(&self) -> Result<(), TryTakeError>;
259}
260
261#[async_trait]
262impl<T: Send + Sync> SharedCellDiscard for SharedCell<T> {
263    fn try_discard(&self) -> Result<(), TryTakeError> {
264        self.try_take()?;
265        Ok(())
266    }
267
268    async fn discard(&self) -> Result<(), EmptyCellError> {
269        self.take().await?;
270        Ok(())
271    }
272
273    fn blocking_discard(&self) -> Result<(), EmptyCellError> {
274        self.blocking_take()?;
275        Ok(())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use anyhow::Result;
282
283    use super::*;
284
285    #[tokio::test]
286    async fn borrow_after_take() -> Result<()> {
287        let cell = SharedCell::from(0);
288        let _ = cell.take().await;
289        assert!(cell.borrow().is_err());
290        Ok(())
291    }
292
293    #[tokio::test]
294    async fn take_after_borrow() -> Result<()> {
295        let cell = SharedCell::from(0);
296        let b = cell.borrow()?;
297        assert!(cell.try_take().is_err());
298        std::mem::drop(b);
299        cell.try_take()?;
300        Ok(())
301    }
302}