Skip to content

Commit 4022149

Browse files
committedOct 13, 2024·
Merge branch 'main' into cycle-detection
2 parents 0ea8a6f + 46efcd5 commit 4022149

File tree

9 files changed

+100
-10
lines changed

9 files changed

+100
-10
lines changed
 

‎src/buffer.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::CPU;
1212
use crate::{
1313
flag::AllocFlag, shape::Shape, Alloc, Base, ClearBuf, CloneBuf, Device, DevicelessAble, HasId,
1414
IsShapeIndep, OnDropBuffer, OnNewBuffer, PtrType, Read, ReplaceBuf, ShallowCopy, Unit,
15-
WrappedData, WriteBuf, ZeroGrad,
15+
WrappedCopy, WrappedData, WriteBuf, ZeroGrad,
1616
};
1717

1818
pub use self::num::Num;
@@ -479,11 +479,14 @@ impl<'a, T: Unit, D: Device, S: Shape> Buffer<'a, T, D, S> {
479479
pub fn to_dims<O: Shape>(self) -> Buffer<'a, T, D, O>
480480
where
481481
D: crate::ToDim<T, S, O>,
482-
D::Data<T, S>: ShallowCopy,
482+
D::Data<T, S>: WrappedCopy<Base = D::Base<T, S>>,
483+
D::Base<T, S>: ShallowCopy,
483484
{
485+
let base = unsafe { (*self).shallow() };
486+
let data = self.data.wrapped_copy(base);
484487
let buf = ManuallyDrop::new(self);
485488

486-
let mut data = buf.device().to_dim(unsafe { buf.data.shallow() });
489+
let mut data = buf.device().to_dim(data);
487490
unsafe { data.set_flag(AllocFlag::None) };
488491

489492
Buffer {

‎src/devices/cpu/cpu_ptr.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use core::{
77

88
use std::alloc::handle_alloc_error;
99

10-
use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy};
10+
use crate::{flag::AllocFlag, HasId, HostPtr, Id, PtrType, ShallowCopy, WrappedCopy};
1111

1212
/// The pointer used for `CPU` [`Buffer`](crate::Buffer)s
1313
#[derive(Debug)]
@@ -229,6 +229,15 @@ impl<T> PtrType for CPUPtr<T> {
229229
}
230230
}
231231

232+
impl<T> WrappedCopy for CPUPtr<T> {
233+
type Base = Self;
234+
235+
#[inline]
236+
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
237+
to_wrap
238+
}
239+
}
240+
232241
impl<T> ShallowCopy for CPUPtr<T> {
233242
#[inline]
234243
unsafe fn shallow(&self) -> Self {

‎src/devices/cuda/cuda_ptr.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::api::{cu_read, cufree, cumalloc, CudaResult};
2-
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy};
2+
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, WrappedCopy};
33
use core::marker::PhantomData;
44

55
/// The pointer used for `CUDA` [`Buffer`](crate::Buffer)s
@@ -76,6 +76,15 @@ impl<T> Drop for CUDAPtr<T> {
7676
}
7777
}
7878

79+
impl<T> WrappedCopy for CUDAPtr<T> {
80+
type Base = Self;
81+
82+
#[inline]
83+
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
84+
to_wrap
85+
}
86+
}
87+
7988
impl<T> ShallowCopy for CUDAPtr<T> {
8089
#[inline]
8190
unsafe fn shallow(&self) -> Self {

‎src/devices/opencl/cl_ptr.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::HostPtr;
88

99
use min_cl::api::release_mem_object;
1010

11-
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy};
11+
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, WrappedCopy};
1212

1313
/// The pointer used for `OpenCL` [`Buffer`](crate::Buffer)s
1414
#[derive(Debug, PartialEq, Eq)]
@@ -59,6 +59,15 @@ impl<T> CLPtr<T> {
5959
}
6060
}
6161

62+
impl<T> WrappedCopy for CLPtr<T> {
63+
type Base = Self;
64+
65+
#[inline]
66+
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
67+
to_wrap
68+
}
69+
}
70+
6271
impl<T> ShallowCopy for CLPtr<T> {
6372
#[inline]
6473
unsafe fn shallow(&self) -> Self {

‎src/devices/stack_array.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use core::ops::{Deref, DerefMut};
22

3-
use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy};
3+
use crate::{shape::Shape, HasId, HostPtr, PtrType, ShallowCopy, WrappedCopy};
44

55
/// A possibly multi-dimensional array allocated on the stack.
66
/// It uses `S:`[`Shape`] to get the type of the array.
@@ -137,6 +137,16 @@ impl<S: Shape, T> HostPtr<T> for StackArray<S, T> {
137137
}
138138
}
139139

140+
141+
impl<S: Shape, T> WrappedCopy for StackArray<S, T> {
142+
type Base = Self;
143+
144+
#[inline]
145+
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
146+
to_wrap
147+
}
148+
}
149+
140150
impl<S: Shape, T> ShallowCopy for StackArray<S, T>
141151
where
142152
S::ARR<T>: Copy,

‎src/devices/vulkan/vk_array.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use core::{
99
};
1010
use std::rc::Rc;
1111

12-
use crate::{flag::AllocFlag, HasId, HostPtr, PtrType, ShallowCopy};
12+
use crate::{flag::AllocFlag, HasId, HostPtr, PtrType, ShallowCopy, WrappedCopy};
1313

1414
use super::{context::Context, submit_and_wait};
1515

@@ -228,6 +228,15 @@ impl<T> VkArray<T> {
228228
}
229229
}
230230

231+
impl<T> WrappedCopy for VkArray<T> {
232+
type Base = Self;
233+
234+
#[inline]
235+
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
236+
to_wrap
237+
}
238+
}
239+
231240
impl<T> ShallowCopy for VkArray<T> {
232241
#[inline]
233242
unsafe fn shallow(&self) -> Self {

‎src/lib.rs

+5
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ pub trait Unit {} // useful for Sync and Send or 'static
180180

181181
impl<T> Unit for T {}
182182

183+
pub trait WrappedCopy {
184+
type Base;
185+
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self;
186+
}
187+
183188
/// Used to shallow-copy a pointer. Use is discouraged.
184189
pub trait ShallowCopy {
185190
/// # Safety

‎src/modules/autograd/wrapper.rs

+17-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use core::marker::PhantomData;
22

3-
use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, WrappedData};
3+
use crate::{flag::AllocFlag, Autograd, HasId, PtrType, ShallowCopy, WrappedCopy, WrappedData};
44

55
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
66
pub struct ReqGradWrapper<Data, T> {
@@ -74,6 +74,22 @@ impl<Data: PtrType, T> PtrType for ReqGradWrapper<Data, T> {
7474
}
7575
}
7676

77+
impl<Data, T> WrappedCopy for ReqGradWrapper<Data, T>
78+
where
79+
Data: WrappedCopy<Base = T>,
80+
{
81+
type Base = T;
82+
83+
#[inline]
84+
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
85+
Self {
86+
requires_grad: self.requires_grad,
87+
data: self.data.wrapped_copy(to_wrap),
88+
_pd: PhantomData,
89+
}
90+
}
91+
}
92+
7793
impl<Data, T> ShallowCopy for ReqGradWrapper<Data, T>
7894
where
7995
Data: ShallowCopy,

‎src/modules/lazy/wrapper.rs

+21-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ use core::{
66
ops::{Deref, DerefMut},
77
};
88

9-
use crate::{flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, WrappedData};
9+
use crate::{
10+
flag::AllocFlag, HasId, HostPtr, Lazy, PtrType, ShallowCopy, WrappedCopy, WrappedData,
11+
};
1012

1113
#[derive(Debug, Default)]
1214
pub struct LazyWrapper<Data, T> {
@@ -102,6 +104,24 @@ impl<T, Data: HostPtr<T>> HostPtr<T> for LazyWrapper<Data, T> {
102104
}
103105
}
104106

107+
impl<Data, T> WrappedCopy for LazyWrapper<Data, T>
108+
where
109+
Data: WrappedCopy<Base = T>,
110+
{
111+
type Base = T;
112+
113+
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
114+
LazyWrapper {
115+
maybe_data: match &self.maybe_data {
116+
MaybeData::Data(data) => MaybeData::Data(data.wrapped_copy(to_wrap)),
117+
MaybeData::Id(id) => MaybeData::Id(*id),
118+
MaybeData::None => unimplemented!(),
119+
},
120+
_pd: PhantomData,
121+
}
122+
}
123+
}
124+
105125
impl<Data: ShallowCopy, T> ShallowCopy for LazyWrapper<Data, T> {
106126
#[inline]
107127
unsafe fn shallow(&self) -> Self {

0 commit comments

Comments
 (0)
Please sign in to comment.