@@ -10,9 +10,7 @@ use crate::cpu::{CPUPtr, CPU};
10
10
use crate :: CPU ;
11
11
12
12
use crate :: {
13
- flag:: AllocFlag , shape:: Shape , Alloc , Base , ClearBuf , CloneBuf , Device , DevicelessAble , HasId ,
14
- IsShapeIndep , OnDropBuffer , OnNewBuffer , PtrType , Read , ReplaceBuf , ShallowCopy , Unit ,
15
- WrappedData , WriteBuf , ZeroGrad ,
13
+ flag:: AllocFlag , shape:: Shape , Alloc , Base , ClearBuf , CloneBuf , CowMut , Device , DevicelessAble , HasId , IsShapeIndep , OnDropBuffer , OnNewBuffer , PtrType , Read , ReplaceBuf , ShallowCopy , Unit , WrappedData , WriteBuf , ZeroGrad
16
14
} ;
17
15
18
16
pub use self :: num:: Num ;
@@ -42,7 +40,7 @@ mod num;
42
40
#[ cfg_attr( feature = "serde" , derive( serde:: Serialize , serde:: Deserialize ) ) ]
43
41
pub struct Buffer < ' a , T : Unit = f32 , D : Device = CPU < Base > , S : Shape = ( ) > {
44
42
/// the type of pointer
45
- pub ( crate ) data : D :: Data < T , S > ,
43
+ pub ( crate ) data : CowMut < ' a , D :: Data < T , S > > ,
46
44
/// A reference to the corresponding device. Mainly used for operations without a device parameter.
47
45
#[ cfg_attr( feature = "serde" , serde( skip) ) ]
48
46
pub ( crate ) device : Option < & ' a D > ,
@@ -81,7 +79,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
81
79
where
82
80
D : OnNewBuffer < ' a , T , D , S > ,
83
81
{
84
- let data = device. base_to_data ( base) ;
82
+ let data = CowMut :: Owned ( device. base_to_data ( base) ) ;
85
83
let buf = Buffer {
86
84
data,
87
85
device : Some ( device) ,
@@ -265,7 +263,7 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
265
263
D : DevicelessAble < ' b , T , S > ,
266
264
{
267
265
Buffer {
268
- data : device. base_to_data ( device. alloc ( len, AllocFlag :: None ) . unwrap ( ) ) ,
266
+ data : CowMut :: Owned ( device. base_to_data ( device. alloc ( len, AllocFlag :: None ) . unwrap ( ) ) ) ,
269
267
device : None ,
270
268
}
271
269
}
@@ -275,6 +273,12 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
275
273
where
276
274
D :: Data < T , S > : Default ,
277
275
{
276
+
277
+ if !self . data . is_owned ( ) {
278
+ // TODO: return None
279
+ unimplemented ! ( )
280
+ }
281
+
278
282
if let Some ( device) = self . device {
279
283
if self . data . flag ( ) != AllocFlag :: None {
280
284
device. on_drop_buffer ( device, & self )
@@ -283,9 +287,11 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
283
287
284
288
let mut val = ManuallyDrop :: new ( self ) ;
285
289
286
- let data = core:: mem:: take ( & mut val. data ) ;
290
+ let CowMut :: Owned ( owned) = core:: mem:: take ( & mut val. data ) else {
291
+ unimplemented ! ( )
292
+ } ;
287
293
288
- Buffer { data, device : None }
294
+ Buffer { data : CowMut :: Owned ( owned ) , device : None }
289
295
}
290
296
291
297
/// Returns the device of the `Buffer`.
@@ -393,10 +399,11 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
393
399
where
394
400
<D as Device >:: Data < T , S > : ShallowCopy ,
395
401
{
396
- Buffer {
397
- data : self . data . shallow ( ) ,
398
- device : self . device ,
399
- }
402
+ todo ! ( )
403
+ // Buffer {
404
+ // data: self.data.shallow(),
405
+ // device: self.device,
406
+ // }
400
407
}
401
408
402
409
/// Sets all elements in `Buffer` to the default value.
@@ -479,15 +486,16 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
479
486
D : crate :: ToDim < T , S , O > ,
480
487
D :: Data < T , S > : ShallowCopy ,
481
488
{
482
- let buf = ManuallyDrop :: new ( self ) ;
489
+ todo ! ( )
490
+ // let buf = ManuallyDrop::new(self);
483
491
484
- let mut data = buf. device ( ) . to_dim ( unsafe { buf. data . shallow ( ) } ) ;
485
- unsafe { data. set_flag ( AllocFlag :: None ) } ;
492
+ // let mut data = buf.device().to_dim(unsafe { buf.data.shallow() });
493
+ // unsafe { data.set_flag(AllocFlag::None) };
486
494
487
- Buffer {
488
- data,
489
- device : buf. device ,
490
- }
495
+ // Buffer {
496
+ // data,
497
+ // device: buf.device,
498
+ // }
491
499
}
492
500
}
493
501
@@ -550,10 +558,11 @@ impl<'a, T: Unit, S: Shape> Buffer<'a, T, CPU<Base>, S> {
550
558
/// The `Buffer` does not manage deallocation of the allocated memory.
551
559
#[ inline]
552
560
pub unsafe fn from_raw_host ( ptr : * mut T , len : usize ) -> Buffer < ' a , T , CPU < Base > , S > {
553
- Buffer {
554
- data : CPUPtr :: from_ptr ( ptr, len, AllocFlag :: Wrapper ) ,
555
- device : None ,
556
- }
561
+ todo ! ( )
562
+ // Buffer {
563
+ // data: CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper),
564
+ // device: None,
565
+ // }
557
566
}
558
567
}
559
568
@@ -571,10 +580,11 @@ impl<'a, Mods: OnDropBuffer, T: Unit, S: Shape> Buffer<'a, T, CPU<Mods>, S> {
571
580
ptr : * mut T ,
572
581
len : usize ,
573
582
) -> Buffer < ' a , T , CPU < Mods > , S > {
574
- Buffer {
575
- data : device. wrap_in_base ( CPUPtr :: from_ptr ( ptr, len, AllocFlag :: Wrapper ) ) ,
576
- device : Some ( device) ,
577
- }
583
+ todo ! ( )
584
+ // Buffer {
585
+ // data: device.wrap_in_base(CPUPtr::from_ptr(ptr, len, AllocFlag::Wrapper)),
586
+ // device: Some(device),
587
+ // }
578
588
}
579
589
}
580
590
@@ -633,7 +643,7 @@ where
633
643
{
634
644
fn default ( ) -> Self {
635
645
Self {
636
- data : D :: Data :: < T , S > :: default ( ) ,
646
+ data : Default :: default ( ) ,
637
647
device : None ,
638
648
}
639
649
}
0 commit comments