torch_sys/
borrow.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::any::Any;
10use std::collections::HashMap;
11use std::collections::hash_map::Entry;
12
13use atomic_refcell::AtomicRefCell;
14use derive_more::Display;
15use thiserror::Error;
16
17use crate::RValue;
18use crate::cell::AliasTrackingRef;
19use crate::cell::AliasTrackingRefMut;
20
21/// Errors that can occur while calling an operator.
22#[derive(Error, Debug)]
23#[non_exhaustive]
24pub enum BorrowError {
25    #[error("cannot borrow with type: {0}")]
26    InvalidBorrow(BorrowType),
27}
28
29/// Abstracts over the different types of borrows we can have.
30#[derive(Debug)]
31pub enum Borrow<'a> {
32    // Dead code because we never access these, just hold onto them as guards.
33    #[allow(dead_code)]
34    Shared(AliasTrackingRef<'a, dyn Any>),
35    #[allow(dead_code)]
36    Mutable(AliasTrackingRefMut<'a, dyn Any>),
37}
38
39#[derive(Debug, Display, Clone, Copy, PartialEq, Eq)]
40pub enum BorrowType {
41    Shared,
42    Mutable,
43}
44
45/// A helper that batches multiple borrows for a single borrower, deduping them
46/// so we don't accidentally borrow the same alias twice.
47#[derive(Debug)]
48pub struct MultiBorrow<'a> {
49    cells: Vec<(&'a crate::cell::AliasTrackingRefCell<dyn Any>, BorrowType)>,
50}
51
52impl<'a> MultiBorrow<'a> {
53    pub fn new() -> Self {
54        Self { cells: Vec::new() }
55    }
56
57    pub fn borrow(&self) -> Result<Vec<Borrow>, BorrowError> {
58        // Dedupe borrows so that we don't accidentally borrow the same alias twice.
59        let mut alias_ptrs: HashMap<
60            *const AtomicRefCell<()>,
61            (&crate::cell::AliasTrackingRefCell<dyn Any>, BorrowType),
62        > = HashMap::new();
63        for (cell, borrow_type) in &self.cells {
64            let alias_ptr = cell.alias_ptr();
65
66            match alias_ptrs.entry(alias_ptr) {
67                Entry::Vacant(entry) => {
68                    entry.insert((cell, *borrow_type));
69                }
70                Entry::Occupied(mut entry) => match (entry.get(), borrow_type) {
71                    // Upgrade a shared borrow to a mutable borrow.
72                    ((_, BorrowType::Shared), BorrowType::Mutable) => {
73                        entry.insert((cell, BorrowType::Mutable));
74                    }
75                    // Otherwise just leave the existing entry as it is.
76                    _ => (),
77                },
78            }
79        }
80
81        let mut ret = Vec::new();
82        for (_, (cell, borrow_type)) in alias_ptrs {
83            match borrow_type {
84                BorrowType::Mutable => ret.push(Borrow::Mutable(
85                    cell.try_borrow_mut()
86                        .map_err(|_| BorrowError::InvalidBorrow(borrow_type))?,
87                )),
88                BorrowType::Shared => ret.push(Borrow::Shared(
89                    cell.try_borrow()
90                        .map_err(|_| BorrowError::InvalidBorrow(borrow_type))?,
91                )),
92            }
93        }
94        Ok(ret)
95    }
96
97    pub fn add(&mut self, arg: &'a RValue, borrow_type: BorrowType) {
98        match arg {
99            RValue::Tensor(cell) => {
100                self.cells.push((cell, borrow_type));
101            }
102            RValue::TensorList(cells) => {
103                for cell in cells {
104                    // If this is a write to a tensor list, just borrow every
105                    // tensor mutably.
106                    self.cells.push((cell, borrow_type));
107                }
108            }
109            RValue::Opaque(val) => self.cells.push((val, borrow_type)),
110            _ => (),
111        };
112    }
113}