Skip to content

Commit ffd06b7

Browse files
committed
Add CowMut
1 parent 99d36a0 commit ffd06b7

File tree

11 files changed

+135
-72
lines changed

11 files changed

+135
-72
lines changed

Cargo.toml

+2-7
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,8 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
5151
# min-cl = { version = "0.3.0", optional=true }
5252

5353
[features]
54-
default = ["cpu", "opencl", "blas", "static-api", "stack", "macro", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph", "serde"]
55-
56-
# default = ["cpu"]
57-
# default = ["no-std"]
58-
# default = ["opencl"]
59-
# default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"]
60-
54+
# default = ["cpu", "opencl", "blas", "static-api", "stack", "macro", "untyped", "autograd", "autograd", "cached", "lazy", "fork", "graph"]
55+
default = ["cpu"]
6156

6257
std = []
6358

src/buffer.rs

+38-28
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ use crate::cpu::{CPUPtr, CPU};
1010
use crate::CPU;
1111

1212
use crate::{
13-
flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, Device, DevicelessAble, HasId,
14-
IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, Unit,
15-
WrappedData, WriteBuf, ZeroGrad,
13+
flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, CowMut, Device, DevicelessAble, HasId, IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, Unit, WrappedData, WriteBuf, ZeroGrad
1614
};
1715

1816
pub use self::num::Num;
@@ -42,7 +40,7 @@ mod num;
4240
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4341
pub struct Buffer<'a, T: Unit = f32, D: Device = CPU<Base>, S: Shape = ()> {
4442
/// the type of pointer
45-
pub(crate) data: D::Data<T, S>,
43+
pub(crate) data: CowMut<'a, D::Data<T, S>>,
4644
/// A reference to the corresponding device. Mainly used for operations without a device parameter.
4745
#[cfg_attr(feature = "serde", serde(skip))]
4846
pub(crate) device: Option<&'a D>,
@@ -81,7 +79,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
8179
where
8280
D: OnNewBuffer<'a, T, D, S>,
8381
{
84-
let data = device.base_to_data(base);
82+
let data = CowMut::Owned(device.base_to_data(base));
8583
let buf = Buffer {
8684
data,
8785
device: Some(device),
@@ -265,7 +263,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
265263
D: DevicelessAble<'b, T, S>,
266264
{
267265
Buffer {
268-
data: device.base_to_data(device.alloc(len, AllocFlag::None).unwrap()),
266+
data: CowMut::Owned(device.base_to_data(device.alloc(len, AllocFlag::None).unwrap())),
269267
device: None,
270268
}
271269
}
@@ -275,6 +273,12 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
275273
where
276274
D::Data<T, S>: Default,
277275
{
276+
277+
if !self.data.is_owned() {
278+
// TODO: return None
279+
unimplemented!()
280+
}
281+
278282
if let Some(device) = self.device {
279283
if self.data.flag() != AllocFlag::None {
280284
device.on_drop_buffer(device, &self)
@@ -283,9 +287,11 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
283287

284288
let mut val = ManuallyDrop::new(self);
285289

286-
let data = core::mem::take(&mut val.data);
290+
let CowMut::Owned(owned) = core::mem::take(&mut val.data) else {
291+
unimplemented!()
292+
};
287293

288-
Buffer { data, device: None }
294+
Buffer { data: CowMut::Owned(owned), device: None }
289295
}
290296

291297
/// Returns the device of the `Buffer`.
@@ -393,10 +399,11 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
393399
where
394400
<D as Device>::Data<T, S>: ShallowCopy,
395401
{
396-
Buffer {
397-
data: self.data.shallow(),
398-
device: self.device,
399-
}
402+
todo!()
403+
// Buffer {
404+
// data: self.data.shallow(),
405+
// device: self.device,
406+
// }
400407
}
401408

402409
/// Sets all elements in `Buffer` to the default value.
@@ -479,15 +486,16 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
479486
D: crate::ToDim<T, S, O>,
480487
D::Data<T, S>: ShallowCopy,
481488
{
482-
let buf = ManuallyDrop::new(self);
489+
todo!()
490+
// let buf = ManuallyDrop::new(self);
483491

484-
let mut data = buf.device().to_dim(unsafe { buf.data.shallow() });
485-
unsafe { data.set_flag(AllocFlag::None) };
492+
// let mut data = buf.device().to_dim(unsafe { buf.data.shallow() });
493+
// unsafe { data.set_flag(AllocFlag::None) };
486494

487-
Buffer {
488-
data,
489-
device: buf.device,
490-
}
495+
// Buffer {
496+
// data,
497+
// device: buf.device,
498+
// }
491499
}
492500
}
493501

@@ -550,10 +558,11 @@ impl<'a, T: Unit, S: Shape> Buffer<'a, T, CPU<Base>, S> {
550558
/// The `Buffer` does not manage deallocation of the allocated memory.
551559
#[inline]
552560
pub unsafe fn from_raw_host(ptr: *mut T, len: usize) -> Buffer<'a, T, CPU<Base>, S> {
553-
Buffer {
554-
data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper),
555-
device: None,
556-
}
561+
todo!()
562+
// Buffer {
563+
// data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper),
564+
// device: None,
565+
// }
557566
}
558567
}
559568

@@ -571,10 +580,11 @@ impl<'a, Mods: OnDropBuffer, T: Unit, S: Shape> Buffer<'a, T, CPU<Mods>, S> {
571580
ptr: *mut T,
572581
len: usize,
573582
) -> Buffer<'a, T, CPU<Mods>, S> {
574-
Buffer {
575-
data: device.wrap_in_base(CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper)),
576-
device: Some(device),
577-
}
583+
todo!()
584+
// Buffer {
585+
// data: device.wrap_in_base(CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper)),
586+
// device: Some(device),
587+
// }
578588
}
579589
}
580590

@@ -633,7 +643,7 @@ where
633643
{
634644
fn default() -> Self {
635645
Self {
636-
data: D::Data::<T, S>::default(),
646+
data: Default::default(),
637647
device: None,
638648
}
639649
}

src/buffer/num.rs

+5-6
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ use core::{
44
};
55

66
use crate::{
7-
flag::AllocFlag, Alloc, Buffer, CloneBuf, Device, HasId, OnDropBuffer, PtrType, ShallowCopy,
8-
Unit, WrappedData,
7+
flag::AllocFlag, Alloc, Buffer, CloneBuf, CowMut, Device, HasId, OnDropBuffer, PtrType, ShallowCopy, Unit, WrappedData
98
};
109

1110
#[derive(Debug, Default)]
@@ -131,9 +130,9 @@ impl<'a, T: Unit + Clone> CloneBuf<'a, T> for () {
131130
#[inline]
132131
fn clone_buf(&self, buf: &Buffer<'a, T, Self>) -> Buffer<'a, T, Self> {
133132
Buffer {
134-
data: Num {
133+
data: CowMut::Owned(Num {
135134
num: buf.data.num.clone(),
136-
},
135+
}),
137136
device: buf.device,
138137
}
139138
}
@@ -143,7 +142,7 @@ impl<T: crate::number::Number> From<T> for Buffer<'_, T, ()> {
143142
#[inline]
144143
fn from(ptr: T) -> Self {
145144
Buffer {
146-
data: Num { num: ptr },
145+
data: CowMut::Owned(Num { num: ptr }),
147146
device: None,
148147
}
149148
}
@@ -158,7 +157,7 @@ impl<'a, T: Unit> Buffer<'a, T, ()> {
158157
T: Unit + Copy,
159158
{
160159
Buffer {
161-
data: Num { num: self.data.num },
160+
data: CowMut::Owned(Num { num: self.data.num }),
162161
device: self.device,
163162
}
164163
}

src/cow_mut.rs

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
use core::{
2+
default,
3+
ops::{Deref, DerefMut},
4+
};
5+
6+
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
7+
#[derive(Debug)]
8+
pub enum CowMut<'a, T> {
9+
BorrowedMut(&'a mut T),
10+
Owned(T),
11+
}
12+
13+
impl<T: Default> Default for CowMut<'_, T> {
14+
fn default() -> Self {
15+
CowMut::Owned(T::default())
16+
}
17+
}
18+
19+
impl<'a, T> CowMut<'a, T> {
20+
pub fn is_owned(&self) -> bool {
21+
match self {
22+
CowMut::BorrowedMut(_) => false,
23+
CowMut::Owned(_) => true
24+
}
25+
}
26+
27+
pub fn into_owned(self) -> T
28+
where
29+
T: Clone
30+
{
31+
match self {
32+
CowMut::BorrowedMut(b) => b.clone(),
33+
CowMut::Owned(o) => o,
34+
}
35+
}
36+
}
37+
38+
impl<'a, T> Deref for CowMut<'a, T> {
39+
type Target = T;
40+
41+
#[inline]
42+
fn deref(&self) -> &Self::Target {
43+
match self {
44+
CowMut::BorrowedMut(b) => b,
45+
CowMut::Owned(o) => o,
46+
}
47+
}
48+
}
49+
50+
impl<'a, T> DerefMut for CowMut<'a, T> {
51+
#[inline]
52+
fn deref_mut(&mut self) -> &mut Self::Target {
53+
match self {
54+
CowMut::BorrowedMut(b) => b,
55+
CowMut::Owned(o) => o,
56+
}
57+
}
58+
}

src/features.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
use core::{cell::RefMut, fmt::Debug, ops::RangeBounds};
66

77
use crate::{
8-
op_hint::OpHint,
9-
range::{AsRange, CursorRange},
10-
AnyOp, HasId, Parents, Shape, UniqueId, Unit, ZeroGrad, CPU,
8+
op_hint::OpHint, range::{AsRange, CursorRange}, AnyOp, CowMut, HasId, Parents, Shape, UniqueId, Unit, ZeroGrad, CPU
119
};
1210

1311
#[cfg(feature = "cached")]
@@ -32,7 +30,7 @@ pub trait Retrieve<D, T: Unit, S: Shape = ()>: OnDropBuffer {
3230
device: &D,
3331
len: usize,
3432
parents: impl Parents<NUM_PARENTS>,
35-
) -> crate::Result<Self::Wrap<T, D::Base<T, S>>>
33+
) -> crate::Result<CowMut<Self::Wrap<T, D::Base<T, S>>>>
3634
where
3735
S: Shape,
3836
D: Device + Alloc<T>;
@@ -80,7 +78,7 @@ pub trait Cursor {
8078
}
8179

8280
#[inline]
83-
fn cached(&self, cb: impl Fn())
81+
fn cached(&self, cb: impl Fn())
8482
where
8583
Self: Sized,
8684
{

src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ pub mod flag;
101101
mod any_op;
102102
#[cfg(feature = "std")]
103103
mod boxed_shallow_copy;
104+
mod cow_mut;
104105
pub mod hooks;
105106
mod id;
106107
mod layer_management;
@@ -116,6 +117,7 @@ mod wrapper;
116117

117118
pub use any_op::*;
118119
pub use cache::*;
120+
pub use cow_mut::*;
119121
pub use features::*;
120122
pub use hooks::*;
121123
pub use id::*;

src/modules/base.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
use crate::{
2-
flag::AllocFlag, AddGradFn, AddOperation, Alloc, CachedBuffers, Cursor, Device, ExecNow, HasId,
3-
HashLocation, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, ReplaceBuf, Retrieve,
4-
SetOpHint, Setup, Shape, Unit, WrappedData,
2+
flag::AllocFlag, AddGradFn, AddOperation, Alloc, CachedBuffers, CowMut, Cursor, Device, ExecNow, HasId, HashLocation, Module, OnDropBuffer, OnNewBuffer, Parents, PtrType, ReplaceBuf, Retrieve, SetOpHint, Setup, Shape, Unit, WrappedData
53
};
64

75
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
@@ -85,11 +83,11 @@ impl<D, T: Unit, S: Shape> Retrieve<D, T, S> for Base {
8583
device: &D,
8684
len: usize,
8785
_parents: impl Parents<NUM_PARENTS>,
88-
) -> crate::Result<Self::Wrap<T, D::Base<T, S>>>
86+
) -> crate::Result<CowMut<Self::Wrap<T, D::Base<T, S>>>>
8987
where
9088
D: Alloc<T>,
9189
{
92-
device.alloc(len, AllocFlag::None)
90+
device.alloc(len, AllocFlag::None).map(|x| CowMut::Owned(x))
9391
}
9492
}
9593

src/modules/lazy/wrapper.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ impl<Data: HasId, T> HasId for LazyWrapper<Data, T> {
4242
match self.maybe_data {
4343
MaybeData::Data(ref data) => data.id(),
4444
MaybeData::Id(id) => id,
45-
MaybeData::None => unimplemented!()
45+
MaybeData::None => unimplemented!(),
4646
}
4747
}
4848
}
@@ -53,13 +53,14 @@ impl<Data: PtrType, T> PtrType for LazyWrapper<Data, T> {
5353
match self.maybe_data {
5454
MaybeData::Data(ref data) => data.size(),
5555
MaybeData::Id(id) => id.len,
56-
MaybeData::None => unimplemented!()
56+
MaybeData::None => unimplemented!(),
5757
}
5858
}
5959

6060
#[inline]
6161
fn flag(&self) -> AllocFlag {
62-
self.maybe_data.data()
62+
self.maybe_data
63+
.data()
6364
.map(|data| data.flag())
6465
.unwrap_or(AllocFlag::Lazy)
6566
}

src/modules/lazy/wrapper/maybe_data.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ impl<Data> MaybeData<Data> {
1717
MaybeData::None => None,
1818
}
1919
}
20-
20+
2121
#[inline]
2222
pub fn data_mut(&mut self) -> Option<&mut Data> {
2323
match self {
2424
MaybeData::Data(data) => Some(data),
2525
MaybeData::Id(_id) => None,
26-
MaybeData::None => None
26+
MaybeData::None => None,
2727
}
2828
}
29-
29+
3030
#[inline]
3131
pub fn id(&self) -> Option<&Id> {
3232
match self {
@@ -35,7 +35,7 @@ impl<Data> MaybeData<Data> {
3535
MaybeData::None => None,
3636
}
3737
}
38-
38+
3939
#[inline]
4040
pub fn id_mut(&mut self) -> Option<&mut Id> {
4141
match self {

0 commit comments

Comments
 (0)