hyperactor_mesh/
shared_cell.rs1use std::fmt::Debug;
10use std::ops::Deref;
11use std::sync::Arc;
12use std::sync::atomic::AtomicUsize;
13use std::sync::atomic::Ordering;
14
15use async_trait::async_trait;
16use dashmap::DashMap;
17use futures::future::join_all;
18use futures::future::try_join_all;
19use preempt_rwlock::OwnedPreemptibleRwLockReadGuard;
20use preempt_rwlock::PreemptibleRwLock;
21use tokio::sync::TryLockError;
22
23#[derive(thiserror::Error, Debug)]
24pub struct EmptyCellError {}
25
26impl From<TryLockError> for EmptyCellError {
27 fn from(_err: TryLockError) -> Self {
28 Self {}
29 }
30}
31
32impl std::fmt::Display for EmptyCellError {
33 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34 write!(f, "already taken")
35 }
36}
37
38#[derive(thiserror::Error, Debug)]
39pub enum TryTakeError {
40 #[error("already taken")]
41 Empty,
42 #[error("cannot lock: {0}")]
43 TryLockError(#[from] TryLockError),
44}
45
46struct PoolRef {
47 map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
48 key: usize,
49}
50
51impl Debug for PoolRef {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("PoolRef").field("key", &self.key).finish()
54 }
55}
56
57#[derive(Debug)]
58struct Inner<T> {
59 value: Option<T>,
60 pool: Option<PoolRef>,
61}
62
63impl<T> Drop for Inner<T> {
64 fn drop(&mut self) {
65 if let Some(pool) = &self.pool {
66 pool.map.remove(&pool.key);
67 }
68 }
69}
70
71#[derive(Debug)]
75pub struct SharedCell<T> {
76 inner: Arc<PreemptibleRwLock<Inner<T>>>,
77}
78
79impl<T> Clone for SharedCell<T> {
80 fn clone(&self) -> Self {
81 Self {
82 inner: self.inner.clone(),
83 }
84 }
85}
86
87impl<T> From<T> for SharedCell<T> {
88 fn from(value: T) -> Self {
89 Self {
90 inner: Arc::new(PreemptibleRwLock::new(Inner {
91 value: Some(value),
92 pool: None,
93 })),
94 }
95 }
96}
97
98impl<T> SharedCell<T> {
99 fn with_pool(value: T, pool: PoolRef) -> Self {
100 Self {
101 inner: Arc::new(PreemptibleRwLock::new(Inner {
102 value: Some(value),
103 pool: Some(pool),
104 })),
105 }
106 }
107}
108
109pub struct SharedCellRef<T, U = T> {
110 guard: OwnedPreemptibleRwLockReadGuard<Inner<T>, U>,
111}
112
113impl<T> SharedCellRef<T> {
114 fn from(guard: OwnedPreemptibleRwLockReadGuard<Inner<T>>) -> Result<Self, EmptyCellError> {
115 if guard.value.is_none() {
116 return Err(EmptyCellError {});
117 }
118 Ok(Self {
119 guard: OwnedPreemptibleRwLockReadGuard::map(guard, |guard| {
120 guard.value.as_ref().unwrap()
121 }),
122 })
123 }
124
125 pub fn map<F, U>(self, f: F) -> SharedCellRef<T, U>
126 where
127 F: FnOnce(&T) -> &U,
128 {
129 SharedCellRef {
130 guard: OwnedPreemptibleRwLockReadGuard::map(self.guard, f),
131 }
132 }
133
134 pub async fn preempted(&self) {
135 self.guard.preempted().await
136 }
137}
138
139impl<T, U: Debug> Debug for SharedCellRef<T, U> {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 Debug::fmt(&**self, f)
142 }
143}
144
145impl<T, U> Deref for SharedCellRef<T, U> {
146 type Target = U;
147
148 fn deref(&self) -> &Self::Target {
149 &self.guard
150 }
151}
152
153impl<T> SharedCell<T> {
154 pub fn borrow(&self) -> Result<SharedCellRef<T>, EmptyCellError> {
157 SharedCellRef::from(self.inner.clone().try_read_owned()?)
158 }
159
160 pub async fn with_mut<F, R>(&self, f: F) -> Result<R, EmptyCellError>
162 where
163 F: FnOnce(&mut T) -> R,
164 {
165 let mut inner = self.inner.write(true).await;
166 let value = inner.value.as_mut().ok_or(EmptyCellError {})?;
167 Ok(f(value))
168 }
169
170 pub fn try_with_mut<F, R>(&self, f: F) -> Result<R, EmptyCellError>
174 where
175 F: FnOnce(&mut T) -> R,
176 {
177 let mut inner = self.inner.try_write(false).map_err(|_| EmptyCellError {})?;
178 let value = inner.value.as_mut().ok_or(EmptyCellError {})?;
179 Ok(f(value))
180 }
181
182 pub async fn take(&self) -> Result<T, EmptyCellError> {
184 let mut inner = self.inner.write(true).await;
185 inner.value.take().ok_or(EmptyCellError {})
186 }
187
188 pub fn blocking_take(&self) -> Result<T, EmptyCellError> {
189 let mut inner = self.inner.blocking_write(true);
190 inner.value.take().ok_or(EmptyCellError {})
191 }
192
193 pub fn try_take(&self) -> Result<T, TryTakeError> {
194 let mut inner = self.inner.try_write(true)?;
195 inner.value.take().ok_or(TryTakeError::Empty)
196 }
197}
198
199pub struct SharedCellPool {
201 map: Arc<DashMap<usize, Arc<dyn SharedCellDiscard + Send + Sync>>>,
202 token: AtomicUsize,
203}
204
205impl Default for SharedCellPool {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211impl SharedCellPool {
212 pub fn new() -> Self {
213 Self {
214 map: Arc::new(DashMap::new()),
215 token: AtomicUsize::new(0),
216 }
217 }
218
219 pub fn insert<T>(&self, value: T) -> SharedCell<T>
220 where
221 T: Send + Sync + 'static,
222 {
223 let map = self.map.clone();
224 let key = self.token.fetch_add(1, Ordering::Relaxed);
225 let pool = PoolRef { map, key };
226 let value: SharedCell<_> = SharedCell::with_pool(value, pool);
227 self.map.entry(key).insert(Arc::new(value.clone()));
228 value
229 }
230
231 pub async fn discard_all(self) -> Result<(), EmptyCellError> {
233 try_join_all(
234 self.map
235 .iter()
236 .map(|r| async move { r.value().discard().await }),
237 )
238 .await?;
239 Ok(())
240 }
241
242 pub async fn discard_or_error_all(self) -> Vec<Result<(), EmptyCellError>> {
244 join_all(
245 self.map
246 .iter()
247 .map(|r| async move { r.value().discard().await }),
248 )
249 .await
250 }
251}
252
253#[async_trait]
255pub trait SharedCellDiscard {
256 async fn discard(&self) -> Result<(), EmptyCellError>;
257 fn blocking_discard(&self) -> Result<(), EmptyCellError>;
258 fn try_discard(&self) -> Result<(), TryTakeError>;
259}
260
261#[async_trait]
262impl<T: Send + Sync> SharedCellDiscard for SharedCell<T> {
263 fn try_discard(&self) -> Result<(), TryTakeError> {
264 self.try_take()?;
265 Ok(())
266 }
267
268 async fn discard(&self) -> Result<(), EmptyCellError> {
269 self.take().await?;
270 Ok(())
271 }
272
273 fn blocking_discard(&self) -> Result<(), EmptyCellError> {
274 self.blocking_take()?;
275 Ok(())
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use anyhow::Result;
282
283 use super::*;
284
285 #[tokio::test]
286 async fn borrow_after_take() -> Result<()> {
287 let cell = SharedCell::from(0);
288 let _ = cell.take().await;
289 assert!(cell.borrow().is_err());
290 Ok(())
291 }
292
293 #[tokio::test]
294 async fn take_after_borrow() -> Result<()> {
295 let cell = SharedCell::from(0);
296 let b = cell.borrow()?;
297 assert!(cell.try_take().is_err());
298 std::mem::drop(b);
299 cell.try_take()?;
300 Ok(())
301 }
302}