Skip to content

Trig Functions#40

Open
r-chong wants to merge 5 commits intostagingfrom
extensions/native-trig
Open

Trig Functions#40
r-chong wants to merge 5 commits intostagingfrom
extensions/native-trig

Conversation

@r-chong
Copy link
Copy Markdown
Collaborator

@r-chong r-chong commented Apr 13, 2026

Key changes:

  • Add trig functions (sin,cos,sqrt) natively

Reason for change:

  • sqrt is necessary for normalization in attention
  • sin and cos are used in flow matching transformers

Next Steps:

  • toy trig functions

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds native implementations of common transcendental ops (sin, cos, sqrt) to the Tensor API and autograd engine, with forward + backward support across CPU and CUDA (and WGSL shader entry points).

Changes:

  • Add Tensor.sin(), Tensor.cos(), and Tensor.sqrt() to the TypeScript API, backed by new N-API exports.
  • Implement forward/backward elementwise ops for sin/cos/sqrt in the Rust autograd engine, plus CUDA kernels and WGSL shader entry points.
  • Extend test coverage for both forward correctness and autograd gradients of the new ops.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
test/tensor.test.ts Adds forward-value tests for sin, cos, and sqrt (including clamped-negative sqrt behavior).
test/autograd.test.ts Adds gradient tests for sin, cos, and sqrt.
src/tensor.ts Exposes sin(), cos(), sqrt() on the public Tensor class.
src/native/src/ops/elementwise.rs Implements forward/backward logic for sin/cos/sqrt and wires them into the tape.
src/native/src/lib.rs Exposes sin_op, cos_op, sqrt_op via N-API for JS/TS.
src/native/src/autograd.rs Adds new BackwardOp variants and dispatch cases for sin/cos/sqrt.
src/native/shaders/elementwise.wgsl Adds WGSL compute entry points for sin/cos/sqrt forward + backward.
src/native/kernels/elementwise.cu Adds CUDA kernels for sin/cos/sqrt forward + backward.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@r-chong r-chong requested a review from MankyDanky April 13, 2026 07:03
select() evaluates both branches so the previous form computed
sqrt(x) and 1/0 for x <= 0 before discarding the result. Replace
with an explicit if/else so sqrt(x) is only touched when x > 0.
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds native trig-related elementwise ops to the Tensor API (sin, cos, sqrt) and wires them through the Rust backend with autograd support, plus associated CUDA/WGSL kernels and tests.

Changes:

  • Add Tensor.sin(), Tensor.cos(), and Tensor.sqrt() methods in the TypeScript API.
  • Implement forward + backward ops in Rust autograd, including CUDA kernels and WGSL shader entry points.
  • Add unit tests for forward correctness and autograd gradients for the new ops.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
test/tensor.test.ts Adds forward-value tests for sin/cos/sqrt (including sqrt negative-input clamping expectation).
test/autograd.test.ts Adds gradient tests for sin/cos/sqrt.
src/tensor.ts Exposes sin(), cos(), sqrt() methods that call into native ops.
src/native/src/ops/elementwise.rs Implements sin/cos/sqrt forward + backward and records them on the tape.
src/native/src/lib.rs Exposes new N-API entrypoints sin_op, cos_op, sqrt_op.
src/native/src/autograd.rs Adds BackwardOp variants and dispatch to the new backward functions.
src/native/shaders/elementwise.wgsl Adds WGSL compute entry points for sin/cos/sqrt and their backwards.
src/native/kernels/elementwise.cu Adds CUDA kernels for sin/cos/sqrt and their backwards.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1073 to +1079
// sqrt — sqrt(max(x, 0)); gradient masked to 0 where input <= 0
// =========================================================================

#[cfg(any(feature = "cpu", feature = "webgpu"))]
pub fn sqrt(a: TensorId, store: &mut TensorStore, tape: &mut Tape) -> TensorId {
let data: Vec<f32> = store.to_host(a).iter().map(|x| x.max(0.0).sqrt()).collect();
let shape = store.shape(a).to_vec();
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

intentional design choice, documented inline (sqrt(max(x, 0)) with gradient masked to 0 for x ≤ 0), and its consistent across CPU/CUDA/WGSL

Comment on lines +932 to +946
let shape = store.shape(a).to_vec();
let n = shape_size(&shape);
let out = store.zeros(&shape);
let out_ptr = store.dev_ptr(out);
let a_ptr = store.dev_ptr(a);
let dev = GpuDevice::instance();
let func = dev.get_func("sin_f32");
unsafe {
dev.stream.launch_builder(func)
.arg(&out_ptr)
.arg(&a_ptr)
.arg(&(n as i32))
.launch(launch_cfg(n as u32))
.unwrap();
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid - Fixed in 3110588.

The kernels exist in elementwise.cu and elementwise.wgsl but weren't
in device.rs's compile_and_load / compile_shader symbol lists, so
the first call to get_func would panic on GPU builds.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants