Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds native implementations of common transcendental ops (sin, cos, sqrt) to the Tensor API and autograd engine, with forward + backward support across CPU and CUDA (and WGSL shader entry points).
Changes:
- Add
Tensor.sin(),Tensor.cos(), andTensor.sqrt()to the TypeScript API, backed by new N-API exports. - Implement forward/backward elementwise ops for sin/cos/sqrt in the Rust autograd engine, plus CUDA kernels and WGSL shader entry points.
- Extend test coverage for both forward correctness and autograd gradients of the new ops.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| test/tensor.test.ts | Adds forward-value tests for sin, cos, and sqrt (including clamped-negative sqrt behavior). |
| test/autograd.test.ts | Adds gradient tests for sin, cos, and sqrt. |
| src/tensor.ts | Exposes sin(), cos(), sqrt() on the public Tensor class. |
| src/native/src/ops/elementwise.rs | Implements forward/backward logic for sin/cos/sqrt and wires them into the tape. |
| src/native/src/lib.rs | Exposes sin_op, cos_op, sqrt_op via N-API for JS/TS. |
| src/native/src/autograd.rs | Adds new BackwardOp variants and dispatch cases for sin/cos/sqrt. |
| src/native/shaders/elementwise.wgsl | Adds WGSL compute entry points for sin/cos/sqrt forward + backward. |
| src/native/kernels/elementwise.cu | Adds CUDA kernels for sin/cos/sqrt forward + backward. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
select() evaluates both branches so the previous form computed sqrt(x) and 1/0 for x <= 0 before discarding the result. Replace with an explicit if/else so sqrt(x) is only touched when x > 0.
There was a problem hiding this comment.
Pull request overview
This PR adds native trig-related elementwise ops to the Tensor API (sin, cos, sqrt) and wires them through the Rust backend with autograd support, plus associated CUDA/WGSL kernels and tests.
Changes:
- Add
Tensor.sin(),Tensor.cos(), andTensor.sqrt()methods in the TypeScript API. - Implement forward + backward ops in Rust autograd, including CUDA kernels and WGSL shader entry points.
- Add unit tests for forward correctness and autograd gradients for the new ops.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| test/tensor.test.ts | Adds forward-value tests for sin/cos/sqrt (including sqrt negative-input clamping expectation). |
| test/autograd.test.ts | Adds gradient tests for sin/cos/sqrt. |
| src/tensor.ts | Exposes sin(), cos(), sqrt() methods that call into native ops. |
| src/native/src/ops/elementwise.rs | Implements sin/cos/sqrt forward + backward and records them on the tape. |
| src/native/src/lib.rs | Exposes new N-API entrypoints sin_op, cos_op, sqrt_op. |
| src/native/src/autograd.rs | Adds BackwardOp variants and dispatch to the new backward functions. |
| src/native/shaders/elementwise.wgsl | Adds WGSL compute entry points for sin/cos/sqrt and their backwards. |
| src/native/kernels/elementwise.cu | Adds CUDA kernels for sin/cos/sqrt and their backwards. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // sqrt — sqrt(max(x, 0)); gradient masked to 0 where input <= 0 | ||
| // ========================================================================= | ||
|
|
||
| #[cfg(any(feature = "cpu", feature = "webgpu"))] | ||
| pub fn sqrt(a: TensorId, store: &mut TensorStore, tape: &mut Tape) -> TensorId { | ||
| let data: Vec<f32> = store.to_host(a).iter().map(|x| x.max(0.0).sqrt()).collect(); | ||
| let shape = store.shape(a).to_vec(); |
There was a problem hiding this comment.
intentional design choice, documented inline (sqrt(max(x, 0)) with gradient masked to 0 for x ≤ 0), and its consistent across CPU/CUDA/WGSL
| let shape = store.shape(a).to_vec(); | ||
| let n = shape_size(&shape); | ||
| let out = store.zeros(&shape); | ||
| let out_ptr = store.dev_ptr(out); | ||
| let a_ptr = store.dev_ptr(a); | ||
| let dev = GpuDevice::instance(); | ||
| let func = dev.get_func("sin_f32"); | ||
| unsafe { | ||
| dev.stream.launch_builder(func) | ||
| .arg(&out_ptr) | ||
| .arg(&a_ptr) | ||
| .arg(&(n as i32)) | ||
| .launch(launch_cfg(n as u32)) | ||
| .unwrap(); | ||
| } |
The kernels exist in elementwise.cu and elementwise.wgsl but weren't in device.rs's compile_and_load / compile_shader symbol lists, so the first call to get_func would panic on GPU builds.
Key changes:
Reason for change:
Next Steps: