hyperactor_mesh/
shared_cell.rs1use std::fmt::Debug;
10use std::ops::Deref;
11use std::sync::Arc;
12use std::sync::atomic::AtomicUsize;
13use std::sync::atomic::Ordering;
14
15use async_trait::async_trait;
16use dashmap::DashMap;
17use futures::future::join_all;
18use futures::future::try_join_all;
19use preempt_rwlock::OwnedPreemptibleRwLockReadGuard;
20use preempt_rwlock::PreemptibleRwLock;
21use tokio::sync::TryLockError;
22
23#[derive(thiserror::Error, Debug)]
24pub struct EmptyCellError {}
25
26impl From<TryLockError> for EmptyCellError {
27 fn from(_err: TryLockError) -> Self {
28 Self {}
29 }
30}
31
32impl std::fmt::Display for EmptyCellError {
33 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34 write!(f, "already taken")
35 }
36}
37
38#[derive(thiserror::Error, Debug)]
39pub enum TryTakeError {
40 #[error("already taken")]
41 Empty,
42 #[error("cannot lock: {0}")]
43 TryLockError(#[from] TryLockError),
44}
45
46struct PoolRef {
47 map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
48 key: usize,
49}
50
51impl Debug for PoolRef {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("PoolRef").field("key", &self.key).finish()
54 }
55}
56
57#[derive(Debug)]
58struct Inner<T> {
59 value: Option<T>,
60 pool: Option<PoolRef>,
61}
62
63impl<T> Drop for Inner<T> {
64 fn drop(&mut self) {
65 if let Some(pool) = &self.pool {
66 pool.map.remove(&pool.key);
67 }
68 }
69}
70
71#[derive(Debug)]
75pub struct SharedCell<T> {
76 inner: Arc<PreemptibleRwLock<Inner<T>>>,
77}
78
79impl<T> Clone for SharedCell<T> {
80 fn clone(&self) -> Self {
81 Self {
82 inner: self.inner.clone(),
83 }
84 }
85}
86
87impl<T> From<T> for SharedCell<T> {
88 fn from(value: T) -> Self {
89 Self {
90 inner: Arc::new(PreemptibleRwLock::new(Inner {
91 value: Some(value),
92 pool: None,
93 })),
94 }
95 }
96}
97
98impl<T> SharedCell<T> {
99 fn with_pool(value: T, pool: PoolRef) -> Self {
100 Self {
101 inner: Arc::new(PreemptibleRwLock::new(Inner {
102 value: Some(value),
103 pool: Some(pool),
104 })),
105 }
106 }
107}
108
109pub struct SharedCellRef<T, U = T> {
110 guard: OwnedPreemptibleRwLockReadGuard<Inner<T>, U>,
111}
112
113impl<T> SharedCellRef<T> {
114 fn from(guard: OwnedPreemptibleRwLockReadGuard<Inner<T>>) -> Result<Self, EmptyCellError> {
115 if guard.value.is_none() {
116 return Err(EmptyCellError {});
117 }
118 Ok(Self {
119 guard: OwnedPreemptibleRwLockReadGuard::map(guard, |guard| {
120 guard.value.as_ref().unwrap()
121 }),
122 })
123 }
124
125 pub fn map<F, U>(self, f: F) -> SharedCellRef<T, U>
126 where
127 F: FnOnce(&T) -> &U,
128 {
129 SharedCellRef {
130 guard: OwnedPreemptibleRwLockReadGuard::map(self.guard, f),
131 }
132 }
133
134 pub async fn preempted(&self) {
135 self.guard.preempted().await
136 }
137}
138
139impl<T, U: Debug> Debug for SharedCellRef<T, U> {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 Debug::fmt(&**self, f)
142 }
143}
144
145impl<T, U> Deref for SharedCellRef<T, U> {
146 type Target = U;
147
148 fn deref(&self) -> &Self::Target {
149 &self.guard
150 }
151}
152
153impl<T> SharedCell<T> {
154 pub fn borrow(&self) -> Result<SharedCellRef<T>, EmptyCellError> {
157 SharedCellRef::from(self.inner.clone().try_read_owned()?)
158 }
159
160 pub async fn with_mut<F, R>(&self, f: F) -> Result<R, EmptyCellError>
162 where
163 F: FnOnce(&mut T) -> R,
164 {
165 let mut inner = self.inner.write(true).await;
166 let value = inner.value.as_mut().ok_or(EmptyCellError {})?;
167 Ok(f(value))
168 }
169
170 pub async fn take(&self) -> Result<T, EmptyCellError> {
172 let mut inner = self.inner.write(true).await;
173 inner.value.take().ok_or(EmptyCellError {})
174 }
175
176 pub fn blocking_take(&self) -> Result<T, EmptyCellError> {
177 let mut inner = self.inner.blocking_write(true);
178 inner.value.take().ok_or(EmptyCellError {})
179 }
180
181 pub fn try_take(&self) -> Result<T, TryTakeError> {
182 let mut inner = self.inner.try_write(true)?;
183 inner.value.take().ok_or(TryTakeError::Empty)
184 }
185}
186
187pub struct SharedCellPool {
189 map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
190 token: AtomicUsize,
191}
192
193impl SharedCellPool {
194 pub fn new() -> Self {
195 Self {
196 map: Arc::new(DashMap::new()),
197 token: AtomicUsize::new(0),
198 }
199 }
200
201 pub fn insert<T>(&self, value: T) -> SharedCell<T>
202 where
203 T: Send + Sync + 'static,
204 {
205 let map = self.map.clone();
206 let key = self.token.fetch_add(1, Ordering::Relaxed);
207 let pool = PoolRef { map, key };
208 let value: SharedCell<_> = SharedCell::with_pool(value, pool);
209 self.map.entry(key).insert(Arc::new(value.clone()));
210 value
211 }
212
213 pub async fn discard_all(self) -> Result<(), EmptyCellError> {
215 try_join_all(
216 self.map
217 .iter()
218 .map(|r| async move { r.value().discard().await }),
219 )
220 .await?;
221 Ok(())
222 }
223
224 pub async fn discard_or_error_all(self) -> Vec<Result<(), EmptyCellError>> {
226 join_all(
227 self.map
228 .iter()
229 .map(|r| async move { r.value().discard().await }),
230 )
231 .await
232 }
233}
234
235#[async_trait]
237pub trait SharedCellDiscard {
238 async fn discard(&self) -> Result<(), EmptyCellError>;
239 fn blocking_discard(&self) -> Result<(), EmptyCellError>;
240 fn try_discard(&self) -> Result<(), TryTakeError>;
241}
242
243#[async_trait]
244impl<T: Send + Sync> SharedCellDiscard for SharedCell<T> {
245 fn try_discard(&self) -> Result<(), TryTakeError> {
246 self.try_take()?;
247 Ok(())
248 }
249
250 async fn discard(&self) -> Result<(), EmptyCellError> {
251 self.take().await?;
252 Ok(())
253 }
254
255 fn blocking_discard(&self) -> Result<(), EmptyCellError> {
256 self.blocking_take()?;
257 Ok(())
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use anyhow::Result;
264
265 use super::*;
266
267 #[tokio::test]
268 async fn borrow_after_take() -> Result<()> {
269 let cell = SharedCell::from(0);
270 let _ = cell.take().await;
271 assert!(cell.borrow().is_err());
272 Ok(())
273 }
274
275 #[tokio::test]
276 async fn take_after_borrow() -> Result<()> {
277 let cell = SharedCell::from(0);
278 let b = cell.borrow()?;
279 assert!(cell.try_take().is_err());
280 std::mem::drop(b);
281 cell.try_take()?;
282 Ok(())
283 }
284}