Skip to content

Commit

Permalink
Remove OnDropBuffer 2 (ocl)
Browse files Browse the repository at this point in the history
  • Loading branch information
elftausend committed Nov 24, 2024
1 parent 292143b commit 0ea7a75
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 70 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ min-cl = { git = "https://github.com/elftausend/min-cl", optional = true }
[features]
# default = ["cpu", "blas", "static-api", "macro", "cached", "autograd", "stack", "opencl", "fork", "graph", "untyped"]

default = ["cpu", "cached", "autograd", "static-api", "blas", "macro", "fork"]
default = ["cpu", "cached", "autograd", "static-api", "blas", "macro", "fork", "opencl"]
# default = ["no-std"]
# default = ["opencl"]
# default = ["untyped", "cpu", "lazy", "graph", "autograd", "fork", "serde", "json", "half", "cached", "static-api", "stack", "opencl", "nnapi"]
Expand Down
4 changes: 2 additions & 2 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ impl<'a, Mods: WrappedData, T: Unit, S: Shape> Buffer<'a, T, CPU<Mods>, S> {
}

#[cfg(feature = "opencl")]
impl<'a, Mods: OnDropBuffer, T: Unit, S: Shape> Buffer<'a, T, crate::OpenCL<Mods>, S> {
impl<'a, Mods: WrappedData, T: Unit, S: Shape> Buffer<'a, T, crate::OpenCL<Mods>, S> {
/// Returns the OpenCL pointer of the `Buffer`.
#[inline]
pub fn cl_ptr(&self) -> *mut core::ffi::c_void {
Expand All @@ -561,7 +561,7 @@ impl<'a, Mods: OnDropBuffer, T: Unit, S: Shape> Buffer<'a, T, crate::OpenCL<Mods
}

#[cfg(feature = "cuda")]
impl<'a, Mods: OnDropBuffer, T: Unit> Buffer<'a, T, crate::CUDA<Mods>> {
impl<'a, Mods: WrappedData, T: Unit> Buffer<'a, T, crate::CUDA<Mods>> {
/// Returns a non null CUDA pointer
#[inline]
pub fn cu_ptr(&self) -> u64 {
Expand Down
7 changes: 7 additions & 0 deletions src/cache/owned_cache/fast_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ pub struct FastCache {
pub nodes: LockedMap<UniqueId, Box<dyn Any>, BuildHasherDefault<NoHasher>>,
}

impl FastCache {
#[inline]
pub fn new() -> Self {
FastCache::default()
}
}

impl Cache<Box<dyn Any>> for FastCache {
#[inline]
fn get_mut(&self, id: UniqueId, _len: usize) -> State<RefMut<Box<dyn Any>>> {
Expand Down
37 changes: 21 additions & 16 deletions src/devices/opencl/cl_device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{flag::AllocFlag, opencl::KernelLaunch};
use crate::{impl_device_traits, Shape, Unit};
use crate::{
pass_down_use_gpu_or_cpu, Alloc, Base, Buffer, Cached, CachedCPU, CloneBuf, Device,
IsShapeIndep, Module, OnDropBuffer, OnNewBuffer, Setup, WrappedData, CPU,
IsShapeIndep, Module, OnNewBuffer, Setup, WrappedData, CPU,
};

use core::ops::{Deref, DerefMut};
Expand Down Expand Up @@ -165,8 +165,8 @@ impl<Mods> OpenCL<Mods> {
}
}

impl<Mods: OnDropBuffer> Device for OpenCL<Mods> {
type Data<T: Unit, S: Shape> = Self::Wrap<T, Self::Base<T, S>>;
impl<Mods: WrappedData> Device for OpenCL<Mods> {
type Data<'a, T: Unit, S: Shape> = Self::Wrap<'a, T, Self::Base<T, S>>;
type Base<U: Unit, S: Shape> = CLPtr<U>;
type Error = ();

Expand All @@ -175,34 +175,39 @@ impl<Mods: OnDropBuffer> Device for OpenCL<Mods> {
// OpenCL::<Base>::new(chosen_cl_idx())
}
#[inline(always)]
fn base_to_data<T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<T, S> {
fn default_base_to_data<'a, T: Unit, S: Shape>(&'a self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
self.wrap_in_base(base)
}

#[inline(always)]
fn wrap_to_data<T: Unit, S: Shape>(
fn default_base_to_data_unbound<'a, T: Unit, S: Shape>(&self, base: Self::Base<T, S>) -> Self::Data<'a, T, S> {
self.wrap_in_base_unbound(base)
}

#[inline(always)]
fn wrap_to_data<'a, T: Unit, S: Shape>(
&self,
wrap: Self::Wrap<T, Self::Base<T, S>>,
) -> Self::Data<T, S> {
wrap: Self::Wrap<'a, T, Self::Base<T, S>>,
) -> Self::Data<'a, T, S> {
wrap
}

#[inline(always)]
fn data_as_wrap<T: Unit, S: Shape>(
data: &Self::Data<T, S>,
) -> &Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap<'a, 'b, T: Unit, S: Shape>(
data: &'b Self::Data<'a, T, S>,
) -> &'b Self::Wrap<'a, T, Self::Base<T, S>> {
data
}

#[inline(always)]
fn data_as_wrap_mut<T: Unit, S: Shape>(
data: &mut Self::Data<T, S>,
) -> &mut Self::Wrap<T, Self::Base<T, S>> {
fn data_as_wrap_mut<'a, 'b, T: Unit, S: Shape>(
data: &'b mut Self::Data<'a, T, S>,
) -> &'b mut Self::Wrap<'a, T, Self::Base<T, S>> {
data
}
}

unsafe impl<Mods: OnDropBuffer> IsShapeIndep for OpenCL<Mods> {}
unsafe impl<Mods: WrappedData> IsShapeIndep for OpenCL<Mods> {}

impl<Mods> Debug for OpenCL<Mods> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand All @@ -222,7 +227,7 @@ impl<Mods> Debug for OpenCL<Mods> {
}
}

impl<Mods: OnDropBuffer, T: Unit> Alloc<T> for OpenCL<Mods> {
impl<Mods: WrappedData, T: Unit> Alloc<T> for OpenCL<Mods> {
fn alloc<S: Shape>(&self, mut len: usize, flag: AllocFlag) -> crate::Result<CLPtr<T>> {
if S::LEN > len {
len = S::LEN
Expand Down Expand Up @@ -270,7 +275,7 @@ impl<Mods: OnDropBuffer, T: Unit> Alloc<T> for OpenCL<Mods> {
}
}

impl<'a, T: Unit, Mods: OnDropBuffer + OnNewBuffer<'a, T, Self, ()>> CloneBuf<'a, T>
impl<'a, T: Unit, Mods: WrappedData + OnNewBuffer<'a, T, Self, ()>> CloneBuf<'a, T>
for OpenCL<Mods>
{
fn clone_buf(&'a self, buf: &Buffer<'a, T, Self>) -> Buffer<'a, T, Self> {
Expand Down
11 changes: 1 addition & 10 deletions src/devices/opencl/cl_ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::HostPtr;

use min_cl::api::release_mem_object;

use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, Unit, WrappedCopy};
use crate::{flag::AllocFlag, HasId, Id, PtrType, ShallowCopy, Unit};

/// The pointer used for `OpenCL` [`Buffer`](crate::Buffer)s
#[derive(Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -59,15 +59,6 @@ impl<T> CLPtr<T> {
}
}

impl<T> WrappedCopy for CLPtr<T> {
type Base = Self;

#[inline]
fn wrapped_copy(&self, to_wrap: Self::Base) -> Self {
to_wrap
}
}

impl<T> ShallowCopy for CLPtr<T> {
#[inline]
unsafe fn shallow(&self) -> Self {
Expand Down
4 changes: 2 additions & 2 deletions src/devices/opencl/fusing.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{OnDropBuffer, OpenCL, UnaryFusing};
use crate::{WrappedData, OpenCL, UnaryFusing};

impl<Mods: OnDropBuffer> UnaryFusing for OpenCL<Mods> {
impl<Mods: WrappedData> UnaryFusing for OpenCL<Mods> {
#[cfg(feature = "lazy")]
#[cfg(feature = "graph")]
#[inline]
Expand Down
6 changes: 3 additions & 3 deletions src/devices/opencl/kernel_enqueue.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{number::Number, Buffer, OnDropBuffer, OpenCL, Shape, Unit};
use crate::{number::Number, Buffer, OpenCL, Shape, Unit, WrappedData};
use min_cl::{
api::{set_kernel_arg, OCLErrorKind},
CLDevice,
Expand Down Expand Up @@ -77,14 +77,14 @@ pub trait AsClCvoidPtr {
}
}

impl<'a, Mods: OnDropBuffer, T: Unit, S: Shape> AsClCvoidPtr for &Buffer<'a, T, OpenCL<Mods>, S> {
impl<'a, Mods: WrappedData, T: Unit, S: Shape> AsClCvoidPtr for &Buffer<'a, T, OpenCL<Mods>, S> {
#[inline]
fn as_cvoid_ptr(&self) -> *const c_void {
self.base().ptr
}
}

impl<'a, Mods: OnDropBuffer, T: Unit, S: Shape> AsClCvoidPtr for Buffer<'a, T, OpenCL<Mods>, S> {
impl<'a, Mods: WrappedData, T: Unit, S: Shape> AsClCvoidPtr for Buffer<'a, T, OpenCL<Mods>, S> {
#[inline]
fn as_cvoid_ptr(&self) -> *const c_void {
self.base().ptr
Expand Down
25 changes: 11 additions & 14 deletions src/devices/opencl/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@ use min_cl::{
};

use crate::{
bounds_to_range, cpu_stack_ops::clear_slice, location, op_hint::unary, pass_down_add_operation,
pass_down_exec_now, prelude::Number, AddOperation, ApplyFunction, Buffer, CDatatype, ClearBuf,
CopySlice, OnDropBuffer, OpenCL, Read, Resolve, Retrieve, Retriever, SetOpHint, Shape,
ToCLSource, ToMarker, TwoWay, UnaryGrad, Unit, UseGpuOrCpu, WriteBuf, ZeroGrad,
bounds_to_range, cpu_stack_ops::clear_slice, location, op_hint::unary, pass_down_add_operation, pass_down_exec_now, prelude::Number, AddOperation, ApplyFunction, Buffer, CDatatype, ClearBuf, CopySlice, OpenCL, Read, Resolve, Retrieve, Retriever, SetOpHint, Shape, ToCLSource, ToMarker, TwoWay, UnaryGrad, Unit, UseGpuOrCpu, WrappedData, WriteBuf, ZeroGrad
};

use super::{enqueue_kernel, CLPtr};
Expand All @@ -24,7 +21,7 @@ use super::{enqueue_kernel, CLPtr};
pass_down_add_operation!(OpenCL);
pass_down_exec_now!(OpenCL);

impl<Mods: OnDropBuffer + UseGpuOrCpu, T: CDatatype + Default> ClearBuf<T> for OpenCL<Mods> {
impl<Mods: WrappedData + UseGpuOrCpu, T: CDatatype + Default> ClearBuf<T> for OpenCL<Mods> {
#[inline]
fn clear(&self, buf: &mut Buffer<T, OpenCL<Mods>>) {
/*crate::fork!(
Expand All @@ -49,7 +46,7 @@ impl<Mods: OnDropBuffer + UseGpuOrCpu, T: CDatatype + Default> ClearBuf<T> for O
}
}

impl<Mods: OnDropBuffer, T: CDatatype> ZeroGrad<T> for OpenCL<Mods> {
impl<Mods: WrappedData, T: CDatatype> ZeroGrad<T> for OpenCL<Mods> {
#[inline]
fn zero_grad<S: Shape>(&self, data: &mut Self::Base<T, S>) {
try_cl_clear(self, data).unwrap()
Expand Down Expand Up @@ -90,7 +87,7 @@ pub fn try_cl_clear<T: CDatatype>(device: &CLDevice, lhs: &mut CLPtr<T>) -> crat
Ok(())
}

impl<T: Unit, S: Shape, Mods: OnDropBuffer> WriteBuf<T, S> for OpenCL<Mods> {
impl<T: Unit, S: Shape, Mods: WrappedData> WriteBuf<T, S> for OpenCL<Mods> {
#[inline]
fn write(&self, buf: &mut Buffer<T, Self, S>, data: &[T]) {
let event = unsafe { self.device.enqueue_write_buffer(buf.cl_ptr(), data, false) }.unwrap();
Expand Down Expand Up @@ -170,7 +167,7 @@ impl<T: Unit> CopySlice<T> for OpenCL {

impl<Mods, T, S> Read<T, S> for OpenCL<Mods>
where
Mods: OnDropBuffer + 'static,
Mods: WrappedData + 'static,
T: Unit + Clone + Default,
S: Shape,
{
Expand Down Expand Up @@ -221,18 +218,18 @@ fn try_read_cl_buf_to_vec<T: Clone + Default>(
Ok(read)
}

impl<T, S, Mods> ApplyFunction<T, S> for OpenCL<Mods>
impl<'a, T, S, Mods> ApplyFunction<'a, T, S> for OpenCL<Mods>
where
T: CDatatype + Number,
S: Shape,
Mods: AddOperation + Retrieve<Self, T, S> + UseGpuOrCpu + SetOpHint<T> + 'static,
Mods: AddOperation + Retrieve<'a, Self, T, S> + UseGpuOrCpu + SetOpHint<T> + 'static,
{
#[inline]
fn apply_fn<F>(
&self,
&'a self,
buf: &Buffer<T, Self, S>,
f: impl Fn(Resolve<T>) -> F + Copy + 'static,
) -> Buffer<T, Self, S>
) -> Buffer<'a, T, Self, S>
where
F: TwoWay<T>,
{
Expand Down Expand Up @@ -301,7 +298,7 @@ where
Ok(())
}

impl<T, S, Mods: OnDropBuffer + AddOperation + 'static> UnaryGrad<T, S> for OpenCL<Mods>
impl<T, S, Mods: WrappedData + AddOperation + 'static> UnaryGrad<T, S> for OpenCL<Mods>
where
T: CDatatype + Number,
S: Shape,
Expand All @@ -325,7 +322,7 @@ where

/// A failable OpenCL version of [`add_unary_grad`](UnaryGrad::add_unary_grad).
/// Writes the unary gradient (with chainrule) to the lhs_grad [`Buffer`].
pub fn try_cl_add_unary_grad<T, S, F, Mods: OnDropBuffer>(
pub fn try_cl_add_unary_grad<T, S, F, Mods: WrappedData>(
device: &OpenCL<Mods>,
lhs: &Buffer<T, OpenCL<Mods>, S>,
lhs_grad: &mut Buffer<T, OpenCL<Mods>, S>,
Expand Down
36 changes: 19 additions & 17 deletions src/devices/opencl/unified.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
use core::{any::Any, hash::BuildHasherDefault};
use std::{collections::HashMap, ffi::c_void, sync::Arc};

use crate::{AllocFlag, DeviceError, Unit};
use crate::{AllocFlag, Cache, DeviceError, Unit};

use super::CLPtr;
use crate::{
Base, Buffer, CachedCPU, CachedModule, Cursor, Device, OnDropBuffer, OpenCL, Shape,
Base, Buffer, CachedCPU, CachedModule, Cursor, Device, WrappedData, OpenCL, Shape,
UnifiedMemChain, UniqueId, CPU,
};
use min_cl::api::{create_buffer, MemFlags};

impl<Mods: UnifiedMemChain<Self> + OnDropBuffer> UnifiedMemChain<Self> for OpenCL<Mods> {
impl<Mods: UnifiedMemChain<Self> + WrappedData> UnifiedMemChain<Self> for OpenCL<Mods> {
#[inline]
fn construct_unified_buf_from_cpu_buf<'a, T: Unit + 'static, S: Shape>(
&self,
Expand All @@ -22,11 +22,12 @@ impl<Mods: UnifiedMemChain<Self> + OnDropBuffer> UnifiedMemChain<Self> for OpenC
}
}

impl<Mods, OclMods, SimpleMods> UnifiedMemChain<OpenCL<OclMods>>
for CachedModule<Mods, OpenCL<SimpleMods>>
impl<Mods, CacheType, OclMods, SimpleMods> UnifiedMemChain<OpenCL<OclMods>>
for CachedModule<Mods, OpenCL<SimpleMods>, CacheType>
where
OclMods: Cursor + OnDropBuffer,
SimpleMods: OnDropBuffer,
CacheType: Cache<Box<dyn Any>>,
OclMods: Cursor + WrappedData,
SimpleMods: WrappedData,
{
#[inline]
fn construct_unified_buf_from_cpu_buf<'a, T: Unit + 'static, S: Shape>(
Expand All @@ -37,7 +38,7 @@ where
construct_buffer(
device,
no_drop_buf,
&mut self.cache.borrow_mut().nodes,
&self.cache
device.cursor() as UniqueId,
)
}
Expand Down Expand Up @@ -65,8 +66,8 @@ pub unsafe fn to_cached_unified<OclMods, CpuMods, T, S>(
id: crate::UniqueId,
) -> crate::Result<*mut c_void>
where
OclMods: OnDropBuffer,
CpuMods: OnDropBuffer,
OclMods: WrappedData,
CpuMods: WrappedData,
T: Unit + 'static,
S: Shape,
{
Expand Down Expand Up @@ -117,15 +118,16 @@ where
/// Ok(())
/// }
/// ```
pub fn construct_buffer<'a, OclMods, CpuMods, T, S>(
pub fn construct_buffer<'a, OclMods, CpuMods, T, S, CacheType>(
device: &'a OpenCL<OclMods>,
no_drop: Buffer<'a, T, CPU<CpuMods>, S>,
cache: &mut HashMap<crate::UniqueId, Arc<dyn Any>, BuildHasherDefault<crate::NoHasher>>,
cache: &CacheType,
id: crate::UniqueId,
) -> crate::Result<Buffer<'a, T, OpenCL<OclMods>, S>>
where
OclMods: Cursor + OnDropBuffer,
CpuMods: OnDropBuffer,
OclMods: Cursor + WrappedData,
CacheType: Cache<Box<dyn Any>>,
CpuMods: WrappedData,
T: Unit + 'static,
S: Shape,
{
Expand All @@ -138,11 +140,11 @@ where
unsafe { device.bump_cursor() };

// if buffer was already converted, return the cache entry.
if let Some(rawcl) = cache.get(&id) {
if let Some(rawcl) = cache.get(id, no_drop.len) {
let rawcl = rawcl
.downcast_ref::<<OpenCL<OclMods> as Device>::Base<T, S>>()
.unwrap();
let data = device.base_to_data::<T, S>(CLPtr {
let data = device.default_base_to_data::<T, S>(CLPtr {
ptr: rawcl.ptr,
host_ptr: rawcl.host_ptr,
len: no_drop.len(),
Expand All @@ -156,7 +158,7 @@ where
let (host_ptr, len) = (no_drop.base().ptr, no_drop.len());
let ptr = unsafe { to_cached_unified(device, no_drop, cache, id)? };

let data = device.base_to_data::<T, S>(CLPtr {
let data = device.default_base_to_data::<T, S>(CLPtr {
ptr,
host_ptr,
len,
Expand Down
Loading

0 comments on commit 0ea7a75

Please sign in to comment.