torch_sys/
cell.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::cell::UnsafeCell;
10use std::fmt::Debug;
11use std::ops::Deref;
12use std::ops::DerefMut;
13use std::ptr::NonNull;
14use std::sync::Arc;
15
16use atomic_refcell::AtomicRef;
17use atomic_refcell::AtomicRefCell;
18use atomic_refcell::AtomicRefMut;
19use atomic_refcell::BorrowError;
20use atomic_refcell::BorrowMutError;
21use serde::Deserialize;
22use serde::Deserializer;
23use serde::Serialize;
24use serde::Serializer;
25
26/// A container that dynamically checks borrow rules in an aliasing-aware fashion.
27///
28/// `AliasTrackingRefCell`s can alias one another, and a mutable borrow of one
29/// `AliasTrackingRefCell` will mutably borrow all its aliases as well. That means that
30/// trying to borrow an alias at that point will panic.
31///
32/// The API for `AliasTrackingRefCell` is very similar to [`RefCell`](std::cell::RefCell),
33/// with some modifications to account for our explicit management of aliasing.
34///
35/// # Example
36/// ```
37/// // type TensorCell = AliasTrackingRefCell<Tensor>;
38/// use torch_sys::TensorCell;
39/// use torch_sys::test_make_alias;
40/// use torch_sys::test_make_tensor;
41/// let my_tensor = test_make_tensor();
42/// let my_tensor_alias = unsafe { test_make_alias(&my_tensor) };
43/// let cell = TensorCell::new(my_tensor);
44/// let aliased_cell = TensorCell::new_with_alias(my_tensor_alias, &cell);
45///
46/// // Can immutably borrow as many times as you want.
47/// // You can use the output like a `&Tensor`.
48/// let tensor_ref = cell.borrow();
49/// let second_borrow = cell.borrow();
50/// let tensor_alias_ref = aliased_cell.borrow();
51///
52/// // But this would panic! Since we already have immutable borrows active.
53/// // let oops = cell.borrow_mut()
54/// ```
55///
56/// # Implementation notes
57///
58/// The idea is to hold onto a shared-ownership value (`alias_tracker`) that
59/// correctly models the aliasing relationships of the underlying tensor
60/// storage. So, two `AliasTrackingRefCell`s that contain aliased data
61/// will share a reference to the same `alias_tracker`.
62///
63/// We then use an [`AtomicRefCell`] over that value to dynamically enforce
64/// borrow rules for the aliases set collectively (and thread-safely).
65///
66/// The rest of the implementation closely copies the standard library's
67/// implementation of `RefCell`.
68pub struct AliasTrackingRefCell<T: ?Sized> {
69    alias_tracker: Arc<AtomicRefCell<()>>,
70    value: UnsafeCell<T>,
71}
72
73// SAFETY: `AliasTrackingRefCell<T> is a cell of T and acts like a reference.
74unsafe impl<T: ?Sized + Send> Send for AliasTrackingRefCell<T> {}
75// SAFETY: `AliasTrackingRefCell<T> is a cell of T and acts like a reference.
76unsafe impl<T: ?Sized + Send + Sync> Sync for AliasTrackingRefCell<T> {}
77
78impl<T> AliasTrackingRefCell<T> {
79    /// Creates a new `AliasTrackingRefCell` that owns the given `T`. This
80    /// `AliasTrackingRefCell` will not alias anything.
81    #[inline]
82    pub fn new(value: T) -> Self {
83        Self {
84            value: UnsafeCell::new(value),
85            alias_tracker: Arc::new(AtomicRefCell::new(())),
86        }
87    }
88
89    /// Creates a new `AliasTrackingRefCell` that aliases `alias`.
90    #[inline]
91    pub fn new_with_alias(value: T, alias: &Self) -> Self {
92        Self {
93            value: UnsafeCell::new(value),
94            alias_tracker: alias.alias_tracker.clone(),
95        }
96    }
97}
98
99impl<T: ?Sized> AliasTrackingRefCell<T> {
100    /// Immutably borrows the given `T` and all its aliases. The borrow lasts until
101    /// `AliasTrackingRef` is dropped.
102    ///
103    /// # Panics
104    /// Will panic if the given `T` or any of its aliases are mutably borrowed.
105    /// For a non-panicking version, see [`try_borrow`](AliasTrackingRefCell::try_borrow).
106    pub fn borrow(&self) -> AliasTrackingRef<T> {
107        match self.try_borrow() {
108            Ok(borrow) => borrow,
109            Err(e) => panic!("borrow failed: {:?}", e),
110        }
111    }
112
113    /// Immutably borrows the given `T` and all its aliases, returning an error if
114    /// it or any of its aliases are currently mutably borrowed. The borrow
115    /// lasts until the `AliasTrackingRef` is dropped.
116    ///
117    /// This is a non-panicking version of [`borrow`](AliasTrackingRefCell::borrow).
118    pub fn try_borrow(&self) -> Result<AliasTrackingRef<T>, BorrowError> {
119        Ok(AliasTrackingRef {
120            borrow: self.alias_tracker.try_borrow()?,
121            // SAFETY: The alias_tracker borrow guarantees that there is only
122            // immutable access to the given `T`.
123            value: unsafe { NonNull::new_unchecked(self.value.get()) },
124        })
125    }
126
127    /// Mutably borrows the tensor and all its aliases. The borrow lasts until
128    /// `AliasTrackingRefMut` is dropped.
129    ///
130    /// # Panics
131    /// Will panic if the `Tensor` or any of its aliases are borrowed. For a
132    /// non-panicking version, see
133    /// [`try_borrow_mut`](TensorCell::try_borrow_mut).
134    pub fn borrow_mut(&self) -> AliasTrackingRefMut<T> {
135        match self.try_borrow_mut() {
136            Ok(borrow) => borrow,
137            Err(e) => panic!("borrow_mut failed: {:?}", e),
138        }
139    }
140
141    /// Mutably borrows the tensor and all its aliases, returning an error if
142    /// the tensor is currently borrowed. The borrow lasts until `TensorRefMut`
143    /// is dropped.
144    ///
145    /// This is a non-panicking version of [`borrow_mut`](TensorCell::borrow_mut).
146    pub fn try_borrow_mut(&self) -> Result<AliasTrackingRefMut<T>, BorrowMutError> {
147        Ok(AliasTrackingRefMut {
148            borrow: self.alias_tracker.try_borrow_mut()?,
149            // SAFETY: The alias_tracker mutable borrow guarantees unique access.
150            value: unsafe { NonNull::new_unchecked(self.value.get()) },
151        })
152    }
153
154    /// Returns true if this `TensorCell` aliases `other`.
155    pub fn aliases(&self, other: &Self) -> bool {
156        Arc::ptr_eq(&self.alias_tracker, &other.alias_tracker)
157    }
158
159    /// Returns a pointer to the alias tracker. Useful for de-duping borrows.
160    /// Visibility limited to crate as it exposes TensorCell internal
161    /// representation.
162    pub(crate) fn alias_ptr(&self) -> *const AtomicRefCell<()> {
163        Arc::as_ptr(&self.alias_tracker)
164    }
165
166    /// Returns a reference to the tensor, without checking for borrows.
167    ///
168    /// SAFETY: The caller must ensure that it holds a borrow on this tensor.
169    pub unsafe fn get_unchecked(&self) -> &T {
170        // SAFETY: see above
171        unsafe { self.value.get().as_ref().unwrap() }
172    }
173}
174
175impl<T: ?Sized + PartialEq> PartialEq for AliasTrackingRefCell<T> {
176    #[inline]
177    fn eq(&self, other: &Self) -> bool {
178        *self.borrow() == *other.borrow()
179    }
180}
181
182impl<T: ?Sized + Eq> Eq for AliasTrackingRefCell<T> {}
183
184impl<T: ?Sized + Debug> Debug for AliasTrackingRefCell<T> {
185    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
186        match self.try_borrow() {
187            Ok(borrow) => f
188                .debug_struct("AliasTrackingRefCell")
189                .field("value", &borrow)
190                .finish(),
191            Err(_) => f
192                .debug_struct("AliasTrackingRefCell")
193                .field("value", &"<mutably borrowed elsewhere>")
194                .finish(),
195        }
196    }
197}
198
199impl<T: Serialize> Serialize for AliasTrackingRefCell<T> {
200    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
201    where
202        S: Serializer,
203    {
204        let borrow = self.try_borrow().map_err(serde::ser::Error::custom)?;
205        borrow.serialize(serializer)
206    }
207}
208
209impl<'de, T: Deserialize<'de>> Deserialize<'de> for AliasTrackingRefCell<T> {
210    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
211    where
212        D: Deserializer<'de>,
213    {
214        let value = T::deserialize(deserializer)?;
215        Ok(Self::new(value))
216    }
217}
218
219pub struct AliasTrackingRef<'a, T: ?Sized + 'a> {
220    // NB: we use a pointer instead of `&'a T` to avoid `noalias` violations,
221    // because a `Ref` argument doesn't hold immutability for the entire 'a
222    // lifetime, only until it drops.
223    value: NonNull<T>,
224    // This is not used, but holding the borrow is what guards the tensor
225    // value.
226    #[allow(dead_code)]
227    borrow: AtomicRef<'a, ()>,
228}
229
230// SAFETY: `AliasTrackingRef<'_, T> acts as a reference.
231unsafe impl<'a, T: ?Sized> Sync for AliasTrackingRef<'a, T> where for<'b> &'b T: Sync {}
232// SAFETY: `AliasTrackingRef<'_, T> acts as a reference.
233unsafe impl<'a, T: ?Sized> Send for AliasTrackingRef<'a, T> where for<'b> &'b T: Send {}
234
235impl<'a, T: ?Sized + Debug + 'a> Debug for AliasTrackingRef<'a, T> {
236    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
237        <T as Debug>::fmt(self, f)
238    }
239}
240
241impl<'a, T: ?Sized> Deref for AliasTrackingRef<'a, T> {
242    type Target = T;
243
244    #[inline]
245    fn deref(&self) -> &T {
246        // SAFETY: We hold shared borrow of the value.
247        unsafe { self.value.as_ref() }
248    }
249}
250
251pub struct AliasTrackingRefMut<'a, T: ?Sized + 'a> {
252    // NB: we use a pointer instead of `&'a T` to avoid `noalias` violations,
253    // because a `Ref` argument doesn't hold immutability for the entire 'a
254    // lifetime, only until it drops.
255    value: NonNull<T>,
256    // This is not used, but holding the borrow is what guards the tensor
257    // value.
258    #[allow(dead_code)]
259    borrow: AtomicRefMut<'a, ()>,
260}
261
262// SAFETY: `AliasTrackingRefMut<'_, T> acts as a mutable reference.
263unsafe impl<'a, T: ?Sized> Sync for AliasTrackingRefMut<'a, T> where for<'b> &'b T: Sync {}
264// SAFETY: `AliasTrackingRefMut<'_, T> acts as a mutable reference.
265unsafe impl<'a, T: ?Sized> Send for AliasTrackingRefMut<'a, T> where for<'b> &'b T: Send {}
266
267impl<'a, T: ?Sized + Debug + 'a> Debug for AliasTrackingRefMut<'a, T> {
268    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
269        <T as Debug>::fmt(self, f)
270    }
271}
272
273impl<'a, T: ?Sized> Deref for AliasTrackingRefMut<'a, T> {
274    type Target = T;
275
276    #[inline]
277    fn deref(&self) -> &T {
278        // SAFETY: We hold an exclusive borrow of the value.
279        unsafe { self.value.as_ref() }
280    }
281}
282
283impl<'b, T: ?Sized> DerefMut for AliasTrackingRefMut<'b, T> {
284    #[inline]
285    fn deref_mut(&mut self) -> &mut T {
286        // SAFETY: We hold an exclusive borrow of the value.
287        unsafe { self.value.as_mut() }
288    }
289}
290
291/// `CloneUnsafe` is a trait that allows us to have the `AliasTrackingRefCell`
292/// implement `Clone` for that type. The `clone_unsafe` method is unsafe because
293/// it does not create an independent copy of the underlying type but instead
294/// the returned value will be tracked like any other alias, and borrow-checking
295/// rules will be enforced across both cells.
296pub trait CloneUnsafe {
297    unsafe fn clone_unsafe(&self) -> Self;
298}
299
300impl<T: CloneUnsafe> Clone for AliasTrackingRefCell<T> {
301    fn clone(&self) -> Self {
302        Self {
303            alias_tracker: self.alias_tracker.clone(),
304            // SAFETY: We use the alias tracker to ensure that we are handling the underlying
305            // value safely.
306            value: UnsafeCell::new(unsafe { self.value.get().as_ref().unwrap().clone_unsafe() }),
307        }
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use anyhow::Result;
314
315    use super::*;
316    use crate::Tensor;
317    use crate::bridge::ffi::test_make_alias;
318    use crate::bridge::ffi::test_make_tensor;
319
320    #[should_panic]
321    #[test]
322    fn clone_then_mut_borrow() {
323        let t = test_make_tensor();
324        let cell = AliasTrackingRefCell::new(t);
325        let _borrow = cell.borrow();
326
327        let clone = cell.clone();
328        // uh oh!
329        clone.borrow_mut();
330    }
331
332    #[should_panic]
333    #[test]
334    fn alias_mut_borrow() {
335        let t = test_make_tensor();
336        #[allow(clippy::undocumented_unsafe_blocks)]
337        let t_alias = unsafe { test_make_alias(&t) };
338        let cell = AliasTrackingRefCell::new(t);
339        let cell_alias = AliasTrackingRefCell::new_with_alias(t_alias, &cell);
340        let _borrow = cell.borrow();
341
342        // uh oh!
343        cell_alias.borrow_mut();
344    }
345
346    #[test]
347    fn alias_mut_borrow_scoped() {
348        let t = test_make_tensor();
349        #[allow(clippy::undocumented_unsafe_blocks)]
350        let t_alias = unsafe { test_make_alias(&t) };
351        let cell = AliasTrackingRefCell::new(t);
352        let cell_alias = AliasTrackingRefCell::new_with_alias(t_alias, &cell);
353        {
354            let _borrow = cell.borrow();
355        }
356
357        // This is fine, the previous borrow went away.
358        {
359            cell_alias.borrow_mut();
360        }
361        // Same.
362        {
363            cell_alias.borrow();
364        }
365    }
366
367    #[test]
368    fn try_borrow() {
369        let t = test_make_tensor();
370        let cell = AliasTrackingRefCell::new(t);
371        {
372            let b1 = cell.try_borrow();
373            assert!(b1.is_ok());
374            let b2 = cell.try_borrow();
375            assert!(b2.is_ok());
376            let b3 = cell.try_borrow();
377            assert!(b3.is_ok());
378
379            let borrow_mut = cell.try_borrow_mut();
380            assert!(borrow_mut.is_err());
381        }
382        let borrow_mut = cell.try_borrow_mut();
383        assert!(borrow_mut.is_ok());
384        let borrow = cell.try_borrow();
385        assert!(borrow.is_err());
386    }
387
388    #[test]
389    fn serialize() -> Result<()> {
390        let c1 = AliasTrackingRefCell::new(test_make_tensor());
391        let buf = bincode::serialize(&c1)?;
392        let c2: AliasTrackingRefCell<Tensor> = bincode::deserialize(&buf)?;
393        assert_eq!(*c1.borrow(), *c2.borrow());
394        Ok(())
395    }
396}