1use std::any::Any;
10use std::collections::HashMap;
11use std::collections::hash_map::Entry;
12
13use atomic_refcell::AtomicRefCell;
14use derive_more::Display;
15use thiserror::Error;
16
17use crate::RValue;
18use crate::cell::AliasTrackingRef;
19use crate::cell::AliasTrackingRefMut;
20
21#[derive(Error, Debug)]
23#[non_exhaustive]
24pub enum BorrowError {
25 #[error("cannot borrow with type: {0}")]
26 InvalidBorrow(BorrowType),
27}
28
29#[derive(Debug)]
31pub enum Borrow<'a> {
32 #[allow(dead_code)]
34 Shared(AliasTrackingRef<'a, dyn Any>),
35 #[allow(dead_code)]
36 Mutable(AliasTrackingRefMut<'a, dyn Any>),
37}
38
39#[derive(Debug, Display, Clone, Copy, PartialEq, Eq)]
40pub enum BorrowType {
41 Shared,
42 Mutable,
43}
44
45#[derive(Debug)]
48pub struct MultiBorrow<'a> {
49 cells: Vec<(&'a crate::cell::AliasTrackingRefCell<dyn Any>, BorrowType)>,
50}
51
52impl<'a> MultiBorrow<'a> {
53 pub fn new() -> Self {
54 Self { cells: Vec::new() }
55 }
56
57 pub fn borrow(&self) -> Result<Vec<Borrow>, BorrowError> {
58 let mut alias_ptrs: HashMap<
60 *const AtomicRefCell<()>,
61 (&crate::cell::AliasTrackingRefCell<dyn Any>, BorrowType),
62 > = HashMap::new();
63 for (cell, borrow_type) in &self.cells {
64 let alias_ptr = cell.alias_ptr();
65
66 match alias_ptrs.entry(alias_ptr) {
67 Entry::Vacant(entry) => {
68 entry.insert((cell, *borrow_type));
69 }
70 Entry::Occupied(mut entry) => match (entry.get(), borrow_type) {
71 ((_, BorrowType::Shared), BorrowType::Mutable) => {
73 entry.insert((cell, BorrowType::Mutable));
74 }
75 _ => (),
77 },
78 }
79 }
80
81 let mut ret = Vec::new();
82 for (_, (cell, borrow_type)) in alias_ptrs {
83 match borrow_type {
84 BorrowType::Mutable => ret.push(Borrow::Mutable(
85 cell.try_borrow_mut()
86 .map_err(|_| BorrowError::InvalidBorrow(borrow_type))?,
87 )),
88 BorrowType::Shared => ret.push(Borrow::Shared(
89 cell.try_borrow()
90 .map_err(|_| BorrowError::InvalidBorrow(borrow_type))?,
91 )),
92 }
93 }
94 Ok(ret)
95 }
96
97 pub fn add(&mut self, arg: &'a RValue, borrow_type: BorrowType) {
98 match arg {
99 RValue::Tensor(cell) => {
100 self.cells.push((cell, borrow_type));
101 }
102 RValue::TensorList(cells) => {
103 for cell in cells {
104 self.cells.push((cell, borrow_type));
107 }
108 }
109 RValue::Opaque(val) => self.cells.push((val, borrow_type)),
110 _ => (),
111 };
112 }
113}