1use std::cell::UnsafeCell;
10use std::fmt::Debug;
11use std::ops::Deref;
12use std::ops::DerefMut;
13use std::ptr::NonNull;
14use std::sync::Arc;
15
16use atomic_refcell::AtomicRef;
17use atomic_refcell::AtomicRefCell;
18use atomic_refcell::AtomicRefMut;
19use atomic_refcell::BorrowError;
20use atomic_refcell::BorrowMutError;
21use serde::Deserialize;
22use serde::Deserializer;
23use serde::Serialize;
24use serde::Serializer;
25
26pub struct AliasTrackingRefCell<T: ?Sized> {
69 alias_tracker: Arc<AtomicRefCell<()>>,
70 value: UnsafeCell<T>,
71}
72
73unsafe impl<T: ?Sized + Send> Send for AliasTrackingRefCell<T> {}
75unsafe impl<T: ?Sized + Send + Sync> Sync for AliasTrackingRefCell<T> {}
77
78impl<T> AliasTrackingRefCell<T> {
79 #[inline]
82 pub fn new(value: T) -> Self {
83 Self {
84 value: UnsafeCell::new(value),
85 alias_tracker: Arc::new(AtomicRefCell::new(())),
86 }
87 }
88
89 #[inline]
91 pub fn new_with_alias(value: T, alias: &Self) -> Self {
92 Self {
93 value: UnsafeCell::new(value),
94 alias_tracker: alias.alias_tracker.clone(),
95 }
96 }
97}
98
99impl<T: ?Sized> AliasTrackingRefCell<T> {
100 pub fn borrow(&self) -> AliasTrackingRef<T> {
107 match self.try_borrow() {
108 Ok(borrow) => borrow,
109 Err(e) => panic!("borrow failed: {:?}", e),
110 }
111 }
112
113 pub fn try_borrow(&self) -> Result<AliasTrackingRef<T>, BorrowError> {
119 Ok(AliasTrackingRef {
120 borrow: self.alias_tracker.try_borrow()?,
121 value: unsafe { NonNull::new_unchecked(self.value.get()) },
124 })
125 }
126
127 pub fn borrow_mut(&self) -> AliasTrackingRefMut<T> {
135 match self.try_borrow_mut() {
136 Ok(borrow) => borrow,
137 Err(e) => panic!("borrow_mut failed: {:?}", e),
138 }
139 }
140
141 pub fn try_borrow_mut(&self) -> Result<AliasTrackingRefMut<T>, BorrowMutError> {
147 Ok(AliasTrackingRefMut {
148 borrow: self.alias_tracker.try_borrow_mut()?,
149 value: unsafe { NonNull::new_unchecked(self.value.get()) },
151 })
152 }
153
154 pub fn aliases(&self, other: &Self) -> bool {
156 Arc::ptr_eq(&self.alias_tracker, &other.alias_tracker)
157 }
158
159 pub(crate) fn alias_ptr(&self) -> *const AtomicRefCell<()> {
163 Arc::as_ptr(&self.alias_tracker)
164 }
165
166 pub unsafe fn get_unchecked(&self) -> &T {
170 unsafe { self.value.get().as_ref().unwrap() }
172 }
173}
174
175impl<T: ?Sized + PartialEq> PartialEq for AliasTrackingRefCell<T> {
176 #[inline]
177 fn eq(&self, other: &Self) -> bool {
178 *self.borrow() == *other.borrow()
179 }
180}
181
182impl<T: ?Sized + Eq> Eq for AliasTrackingRefCell<T> {}
183
184impl<T: ?Sized + Debug> Debug for AliasTrackingRefCell<T> {
185 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
186 match self.try_borrow() {
187 Ok(borrow) => f
188 .debug_struct("AliasTrackingRefCell")
189 .field("value", &borrow)
190 .finish(),
191 Err(_) => f
192 .debug_struct("AliasTrackingRefCell")
193 .field("value", &"<mutably borrowed elsewhere>")
194 .finish(),
195 }
196 }
197}
198
199impl<T: Serialize> Serialize for AliasTrackingRefCell<T> {
200 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
201 where
202 S: Serializer,
203 {
204 let borrow = self.try_borrow().map_err(serde::ser::Error::custom)?;
205 borrow.serialize(serializer)
206 }
207}
208
209impl<'de, T: Deserialize<'de>> Deserialize<'de> for AliasTrackingRefCell<T> {
210 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
211 where
212 D: Deserializer<'de>,
213 {
214 let value = T::deserialize(deserializer)?;
215 Ok(Self::new(value))
216 }
217}
218
219pub struct AliasTrackingRef<'a, T: ?Sized + 'a> {
220 value: NonNull<T>,
224 #[allow(dead_code)]
227 borrow: AtomicRef<'a, ()>,
228}
229
230unsafe impl<'a, T: ?Sized> Sync for AliasTrackingRef<'a, T> where for<'b> &'b T: Sync {}
232unsafe impl<'a, T: ?Sized> Send for AliasTrackingRef<'a, T> where for<'b> &'b T: Send {}
234
235impl<'a, T: ?Sized + Debug + 'a> Debug for AliasTrackingRef<'a, T> {
236 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
237 <T as Debug>::fmt(self, f)
238 }
239}
240
241impl<'a, T: ?Sized> Deref for AliasTrackingRef<'a, T> {
242 type Target = T;
243
244 #[inline]
245 fn deref(&self) -> &T {
246 unsafe { self.value.as_ref() }
248 }
249}
250
251pub struct AliasTrackingRefMut<'a, T: ?Sized + 'a> {
252 value: NonNull<T>,
256 #[allow(dead_code)]
259 borrow: AtomicRefMut<'a, ()>,
260}
261
262unsafe impl<'a, T: ?Sized> Sync for AliasTrackingRefMut<'a, T> where for<'b> &'b T: Sync {}
264unsafe impl<'a, T: ?Sized> Send for AliasTrackingRefMut<'a, T> where for<'b> &'b T: Send {}
266
267impl<'a, T: ?Sized + Debug + 'a> Debug for AliasTrackingRefMut<'a, T> {
268 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
269 <T as Debug>::fmt(self, f)
270 }
271}
272
273impl<'a, T: ?Sized> Deref for AliasTrackingRefMut<'a, T> {
274 type Target = T;
275
276 #[inline]
277 fn deref(&self) -> &T {
278 unsafe { self.value.as_ref() }
280 }
281}
282
283impl<'b, T: ?Sized> DerefMut for AliasTrackingRefMut<'b, T> {
284 #[inline]
285 fn deref_mut(&mut self) -> &mut T {
286 unsafe { self.value.as_mut() }
288 }
289}
290
291pub trait CloneUnsafe {
297 unsafe fn clone_unsafe(&self) -> Self;
298}
299
300impl<T: CloneUnsafe> Clone for AliasTrackingRefCell<T> {
301 fn clone(&self) -> Self {
302 Self {
303 alias_tracker: self.alias_tracker.clone(),
304 value: UnsafeCell::new(unsafe { self.value.get().as_ref().unwrap().clone_unsafe() }),
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use anyhow::Result;
314
315 use super::*;
316 use crate::Tensor;
317 use crate::bridge::ffi::test_make_alias;
318 use crate::bridge::ffi::test_make_tensor;
319
320 #[should_panic]
321 #[test]
322 fn clone_then_mut_borrow() {
323 let t = test_make_tensor();
324 let cell = AliasTrackingRefCell::new(t);
325 let _borrow = cell.borrow();
326
327 let clone = cell.clone();
328 clone.borrow_mut();
330 }
331
332 #[should_panic]
333 #[test]
334 fn alias_mut_borrow() {
335 let t = test_make_tensor();
336 #[allow(clippy::undocumented_unsafe_blocks)]
337 let t_alias = unsafe { test_make_alias(&t) };
338 let cell = AliasTrackingRefCell::new(t);
339 let cell_alias = AliasTrackingRefCell::new_with_alias(t_alias, &cell);
340 let _borrow = cell.borrow();
341
342 cell_alias.borrow_mut();
344 }
345
346 #[test]
347 fn alias_mut_borrow_scoped() {
348 let t = test_make_tensor();
349 #[allow(clippy::undocumented_unsafe_blocks)]
350 let t_alias = unsafe { test_make_alias(&t) };
351 let cell = AliasTrackingRefCell::new(t);
352 let cell_alias = AliasTrackingRefCell::new_with_alias(t_alias, &cell);
353 {
354 let _borrow = cell.borrow();
355 }
356
357 {
359 cell_alias.borrow_mut();
360 }
361 {
363 cell_alias.borrow();
364 }
365 }
366
367 #[test]
368 fn try_borrow() {
369 let t = test_make_tensor();
370 let cell = AliasTrackingRefCell::new(t);
371 {
372 let b1 = cell.try_borrow();
373 assert!(b1.is_ok());
374 let b2 = cell.try_borrow();
375 assert!(b2.is_ok());
376 let b3 = cell.try_borrow();
377 assert!(b3.is_ok());
378
379 let borrow_mut = cell.try_borrow_mut();
380 assert!(borrow_mut.is_err());
381 }
382 let borrow_mut = cell.try_borrow_mut();
383 assert!(borrow_mut.is_ok());
384 let borrow = cell.try_borrow();
385 assert!(borrow.is_err());
386 }
387
388 #[test]
389 fn serialize() -> Result<()> {
390 let c1 = AliasTrackingRefCell::new(test_make_tensor());
391 let buf = bincode::serialize(&c1)?;
392 let c2: AliasTrackingRefCell<Tensor> = bincode::deserialize(&buf)?;
393 assert_eq!(*c1.borrow(), *c2.borrow());
394 Ok(())
395 }
396}