Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/native/kernels/elementwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,42 @@ void log_f32(float* out, const float* a, int n) {
if (i < n) out[i] = logf(a[i]);
}

extern "C" __global__
void sin_f32(float* out, const float* a, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = sinf(a[i]);
}

extern "C" __global__
void sin_backward_f32(float* dx, const float* dy, const float* x, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) dx[i] = dy[i] * cosf(x[i]);
}

extern "C" __global__
void cos_f32(float* out, const float* a, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = cosf(a[i]);
}

extern "C" __global__
void cos_backward_f32(float* dx, const float* dy, const float* x, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) dx[i] = -dy[i] * sinf(x[i]);
}

extern "C" __global__
void sqrt_f32(float* out, const float* a, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) out[i] = sqrtf(fmaxf(a[i], 0.0f));
}

extern "C" __global__
void sqrt_backward_f32(float* dx, const float* dy, const float* x, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) dx[i] = x[i] > 0.0f ? dy[i] * 0.5f / sqrtf(x[i]) : 0.0f;
}

extern "C" __global__
void add_bias_f32(float* out, const float* a, const float* bias, int total, int bias_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down
43 changes: 43 additions & 0 deletions src/native/shaders/elementwise.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,49 @@ fn log_f32(@builtin(global_invocation_id) gid: vec3u) {
if (i < params.n) { out[i] = log(a[i]); }
}

@compute @workgroup_size(256)
fn sin_f32(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i < params.n) { out[i] = sin(a[i]); }
}

@compute @workgroup_size(256)
fn sin_backward_f32(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i < params.n) { out[i] = a[i] * cos(b[i]); } // a=grad, b=input
}

@compute @workgroup_size(256)
fn cos_f32(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i < params.n) { out[i] = cos(a[i]); }
}

@compute @workgroup_size(256)
fn cos_backward_f32(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i < params.n) { out[i] = -a[i] * sin(b[i]); } // a=grad, b=input
}

@compute @workgroup_size(256)
fn sqrt_f32(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i < params.n) { out[i] = sqrt(max(a[i], 0.0)); }
}

@compute @workgroup_size(256)
fn sqrt_backward_f32(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
if (i < params.n) {
let x = b[i]; // a=grad, b=input
if (x > 0.0) {
out[i] = a[i] * 0.5 / sqrt(x);
} else {
out[i] = 0.0;
}
}
}

@compute @workgroup_size(256)
fn div_f32(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
Expand Down
6 changes: 6 additions & 0 deletions src/native/src/autograd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub enum BackwardOp {
MulScalar,
Exp,
Log,
Sin,
Cos,
Sqrt,
MatMul,
Gelu,
Relu,
Expand Down Expand Up @@ -192,6 +195,9 @@ fn dispatch_backward(
BackwardOp::MulScalar => ops::elementwise::mul_scalar_backward(grad_id, &entry.saved, store),
BackwardOp::Exp => ops::elementwise::exp_backward(grad_id, &entry.saved, store),
BackwardOp::Log => ops::elementwise::log_backward(grad_id, &entry.saved, store),
BackwardOp::Sin => ops::elementwise::sin_backward(grad_id, &entry.saved, store),
BackwardOp::Cos => ops::elementwise::cos_backward(grad_id, &entry.saved, store),
BackwardOp::Sqrt => ops::elementwise::sqrt_backward(grad_id, &entry.saved, store),
BackwardOp::MatMul => ops::matmul::matmul_backward(grad_id, &entry.saved, store),
BackwardOp::Gelu => ops::activation::gelu_backward(grad_id, &entry.saved, store),
BackwardOp::Relu => ops::activation::relu_backward(grad_id, &entry.saved, store),
Expand Down
21 changes: 21 additions & 0 deletions src/native/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,27 @@ pub fn log_op(a: u32) -> u32 {
ops::elementwise::log(a as TensorId, store, tape) as u32
}

#[napi]
pub fn sin_op(a: u32) -> u32 {
let mut e = engine().lock();
let Engine { store, tape, .. } = &mut *e;
ops::elementwise::sin(a as TensorId, store, tape) as u32
}

#[napi]
pub fn cos_op(a: u32) -> u32 {
let mut e = engine().lock();
let Engine { store, tape, .. } = &mut *e;
ops::elementwise::cos(a as TensorId, store, tape) as u32
}

#[napi]
pub fn sqrt_op(a: u32) -> u32 {
let mut e = engine().lock();
let Engine { store, tape, .. } = &mut *e;
ops::elementwise::sqrt(a as TensorId, store, tape) as u32
}

#[napi]
pub fn sum_op(a: u32, dim: i64) -> u32 {
let mut e = engine().lock();
Expand Down
239 changes: 239 additions & 0 deletions src/native/src/ops/elementwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,245 @@ pub fn log_backward(grad: TensorId, saved: &SavedContext, store: &mut TensorStor
} else { vec![None] }
}

// =========================================================================
// sin
// =========================================================================

#[cfg(any(feature = "cpu", feature = "webgpu"))]
pub fn sin(a: TensorId, store: &mut TensorStore, tape: &mut Tape) -> TensorId {
let data: Vec<f32> = store.to_host(a).iter().map(|x| x.sin()).collect();
let shape = store.shape(a).to_vec();
let out = store.from_vec(data, &shape);
tape.record(TapeEntry {
op: BackwardOp::Sin, output_id: out, input_ids: smallvec![a],
saved: SavedContext::Tensor(a),
});
out
}

#[cfg(feature = "cuda")]
pub fn sin(a: TensorId, store: &mut TensorStore, tape: &mut Tape) -> TensorId {
let shape = store.shape(a).to_vec();
let n = shape_size(&shape);
let out = store.zeros(&shape);
let out_ptr = store.dev_ptr(out);
let a_ptr = store.dev_ptr(a);
let dev = GpuDevice::instance();
let func = dev.get_func("sin_f32");
unsafe {
dev.stream.launch_builder(func)
.arg(&out_ptr)
.arg(&a_ptr)
.arg(&(n as i32))
.launch(launch_cfg(n as u32))
.unwrap();
}
Comment on lines +932 to +946
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid - Fixed in 3110588.

tape.record(TapeEntry {
op: BackwardOp::Sin, output_id: out, input_ids: smallvec![a],
saved: SavedContext::Tensor(a),
});
out
}

// =========================================================================
// sin_backward
// =========================================================================

#[cfg(any(feature = "cpu", feature = "webgpu"))]
pub fn sin_backward(grad: TensorId, saved: &SavedContext, store: &mut TensorStore) -> Vec<Option<TensorId>> {
if let SavedContext::Tensor(inp) = saved {
let inp_data = store.to_host(*inp);
let grad_data = store.to_host(grad);
let data: Vec<f32> = grad_data.iter().zip(&inp_data).map(|(g, x)| g * x.cos()).collect();
let shape = store.shape(grad).to_vec();
vec![Some(store.from_vec(data, &shape))]
} else { vec![None] }
}

#[cfg(feature = "cuda")]
pub fn sin_backward(grad: TensorId, saved: &SavedContext, store: &mut TensorStore) -> Vec<Option<TensorId>> {
if let SavedContext::Tensor(inp) = saved {
let shape = store.shape(grad).to_vec();
let n = shape_size(&shape);
let result = store.zeros(&shape);
let result_ptr = store.dev_ptr(result);
let grad_ptr = store.dev_ptr(grad);
let inp_ptr = store.dev_ptr(*inp);
let dev = GpuDevice::instance();
let func = dev.get_func("sin_backward_f32");
unsafe {
dev.stream.launch_builder(func)
.arg(&result_ptr)
.arg(&grad_ptr)
.arg(&inp_ptr)
.arg(&(n as i32))
.launch(launch_cfg(n as u32))
.unwrap();
}
vec![Some(result)]
} else { vec![None] }
}

// =========================================================================
// cos
// =========================================================================

#[cfg(any(feature = "cpu", feature = "webgpu"))]
pub fn cos(a: TensorId, store: &mut TensorStore, tape: &mut Tape) -> TensorId {
let data: Vec<f32> = store.to_host(a).iter().map(|x| x.cos()).collect();
let shape = store.shape(a).to_vec();
let out = store.from_vec(data, &shape);
tape.record(TapeEntry {
op: BackwardOp::Cos, output_id: out, input_ids: smallvec![a],
saved: SavedContext::Tensor(a),
});
out
}

#[cfg(feature = "cuda")]
pub fn cos(a: TensorId, store: &mut TensorStore, tape: &mut Tape) -> TensorId {
let shape = store.shape(a).to_vec();
let n = shape_size(&shape);
let out = store.zeros(&shape);
let out_ptr = store.dev_ptr(out);
let a_ptr = store.dev_ptr(a);
let dev = GpuDevice::instance();
let func = dev.get_func("cos_f32");
unsafe {
dev.stream.launch_builder(func)
.arg(&out_ptr)
.arg(&a_ptr)
.arg(&(n as i32))
.launch(launch_cfg(n as u32))
.unwrap();
}
tape.record(TapeEntry {
op: BackwardOp::Cos, output_id: out, input_ids: smallvec![a],
saved: SavedContext::Tensor(a),
});
out
}

// =========================================================================
// cos_backward
// =========================================================================

#[cfg(any(feature = "cpu", feature = "webgpu"))]
pub fn cos_backward(grad: TensorId, saved: &SavedContext, store: &mut TensorStore) -> Vec<Option<TensorId>> {
if let SavedContext::Tensor(inp) = saved {
let inp_data = store.to_host(*inp);
let grad_data = store.to_host(grad);
let data: Vec<f32> = grad_data.iter().zip(&inp_data).map(|(g, x)| -g * x.sin()).collect();
let shape = store.shape(grad).to_vec();
vec![Some(store.from_vec(data, &shape))]
} else { vec![None] }
}

#[cfg(feature = "cuda")]
pub fn cos_backward(grad: TensorId, saved: &SavedContext, store: &mut TensorStore) -> Vec<Option<TensorId>> {
if let SavedContext::Tensor(inp) = saved {
let shape = store.shape(grad).to_vec();
let n = shape_size(&shape);
let result = store.zeros(&shape);
let result_ptr = store.dev_ptr(result);
let grad_ptr = store.dev_ptr(grad);
let inp_ptr = store.dev_ptr(*inp);
let dev = GpuDevice::instance();
let func = dev.get_func("cos_backward_f32");
unsafe {
dev.stream.launch_builder(func)
.arg(&result_ptr)
.arg(&grad_ptr)
.arg(&inp_ptr)
.arg(&(n as i32))
.launch(launch_cfg(n as u32))
.unwrap();
}
vec![Some(result)]
} else { vec![None] }
}

// =========================================================================
// sqrt — sqrt(max(x, 0)); gradient masked to 0 where input <= 0
// =========================================================================

#[cfg(any(feature = "cpu", feature = "webgpu"))]
pub fn sqrt(a: TensorId, store: &mut TensorStore, tape: &mut Tape) -> TensorId {
let data: Vec<f32> = store.to_host(a).iter().map(|x| x.max(0.0).sqrt()).collect();
let shape = store.shape(a).to_vec();
Comment on lines +1073 to +1079
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intentional design choice, documented inline (sqrt(max(x, 0)) with gradient masked to 0 for x ≤ 0), and its consistent across CPU/CUDA/WGSL

let out = store.from_vec(data, &shape);
tape.record(TapeEntry {
op: BackwardOp::Sqrt, output_id: out, input_ids: smallvec![a],
saved: SavedContext::Tensor(a),
});
out
}

#[cfg(feature = "cuda")]
pub fn sqrt(a: TensorId, store: &mut TensorStore, tape: &mut Tape) -> TensorId {
let shape = store.shape(a).to_vec();
let n = shape_size(&shape);
let out = store.zeros(&shape);
let out_ptr = store.dev_ptr(out);
let a_ptr = store.dev_ptr(a);
let dev = GpuDevice::instance();
let func = dev.get_func("sqrt_f32");
unsafe {
dev.stream.launch_builder(func)
.arg(&out_ptr)
.arg(&a_ptr)
.arg(&(n as i32))
.launch(launch_cfg(n as u32))
.unwrap();
}
tape.record(TapeEntry {
op: BackwardOp::Sqrt, output_id: out, input_ids: smallvec![a],
saved: SavedContext::Tensor(a),
});
out
}

// =========================================================================
// sqrt_backward
// =========================================================================

#[cfg(any(feature = "cpu", feature = "webgpu"))]
pub fn sqrt_backward(grad: TensorId, saved: &SavedContext, store: &mut TensorStore) -> Vec<Option<TensorId>> {
if let SavedContext::Tensor(inp) = saved {
let inp_data = store.to_host(*inp);
let grad_data = store.to_host(grad);
let data: Vec<f32> = grad_data.iter().zip(&inp_data)
.map(|(g, x)| if *x > 0.0 { g * 0.5 / x.sqrt() } else { 0.0 })
.collect();
let shape = store.shape(grad).to_vec();
vec![Some(store.from_vec(data, &shape))]
} else { vec![None] }
}

#[cfg(feature = "cuda")]
pub fn sqrt_backward(grad: TensorId, saved: &SavedContext, store: &mut TensorStore) -> Vec<Option<TensorId>> {
if let SavedContext::Tensor(inp) = saved {
let shape = store.shape(grad).to_vec();
let n = shape_size(&shape);
let result = store.zeros(&shape);
let result_ptr = store.dev_ptr(result);
let grad_ptr = store.dev_ptr(grad);
let inp_ptr = store.dev_ptr(*inp);
let dev = GpuDevice::instance();
let func = dev.get_func("sqrt_backward_f32");
unsafe {
dev.stream.launch_builder(func)
.arg(&result_ptr)
.arg(&grad_ptr)
.arg(&inp_ptr)
.arg(&(n as i32))
.launch(launch_cfg(n as u32))
.unwrap();
}
vec![Some(result)]
} else { vec![None] }
}

// =========================================================================
// div
// =========================================================================
Expand Down
Loading
Loading