Skip to content

Commit ff3fc1c

Browse files
authored
Merge pull request #37 from mni-ml/staging
Refactor/tests
2 parents c2367db + 0cad6ab commit ff3fc1c

File tree

13 files changed

+949
-862
lines changed

13 files changed

+949
-862
lines changed

.github/workflows/build.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,4 +270,4 @@ jobs:
270270
npx tsc
271271
272272
- name: Run tests
273-
run: node test/ops.test.js
273+
run: npm test

.github/workflows/publish.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ jobs:
249249
npx tsc
250250
251251
- name: Run tests
252-
run: node test/ops.test.js
252+
run: npm test
253253

254254
publish:
255255
needs: [test, build-cuda, build-cuda-windows]

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,5 @@ coverage/
4141

4242
# Temporary files
4343
tmp/
44-
temp/
44+
temp/
45+
pnpm-lock.yaml

package.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"gpu"
2222
],
2323
"engines": {
24-
"node": ">=20.0.0"
24+
"node": ">=22.18.0"
2525
},
2626
"exports": {
2727
".": {
@@ -58,6 +58,6 @@
5858
"build:native:cuda": "cd src/native && cargo build --release --no-default-features --features cuda && cp target/release/libmni_framework_native.so mni-framework-native.linux-x64-gnu.node",
5959
"build:native:webgpu": "cd src/native && cargo build --release --no-default-features --features webgpu && cp target/release/libmni_framework_native.dylib mni-framework-native.darwin-arm64.node",
6060
"build:all": "npm run build:native && npm run build",
61-
"test": "node test/ops.test.js"
61+
"test": "node --no-warnings --loader ./test/loader.js test/run.ts"
6262
}
6363
}

test/autograd.test.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import {
2+
Tensor,
3+
} from '../dist/index.js';
4+
import { assert, assertClose, section } from './helpers.js';
5+
6+
// ============================================================
7+
// Autograd / backward
8+
// ============================================================
9+
10+
section('Autograd / backward');
11+
12+
// simple gradient: d/dx (x^2) at x=3 should be 6
13+
const xGrad = Tensor.fromFloat32(new Float32Array([3]), [1]).setRequiresGrad(true);
14+
const xSq = xGrad.pow(2);
15+
xSq.backward();
16+
const gradX = xGrad.grad;
17+
assert(gradX !== null, 'gradient exists after backward');
18+
assertClose(gradX!.toFloat32()[0], 6, 1e-3, 'd/dx(x^2) at x=3 = 6');
19+
20+
// gradient through mul
21+
const paramA = Tensor.fromFloat32(new Float32Array([1, 2, 3, 4]), [2, 2]).setRequiresGrad(true);
22+
const paramB = Tensor.fromFloat32(new Float32Array([5, 6, 7, 8]), [2, 2]).setRequiresGrad(true);
23+
const c = paramA.mul(paramB).sum(0).sum(0);
24+
c.backward();
25+
assert(paramA.grad !== null, 'paramA gradient exists');
26+
assert(paramB.grad !== null, 'paramB gradient exists');
27+
28+
const gradAData = paramA.grad!.toFloat32();
29+
assertClose(gradAData[0], 5, 1e-3, 'grad_a[0] = b[0]');
30+
assertClose(gradAData[1], 6, 1e-3, 'grad_a[1] = b[1]');
31+
32+
// gradient through add
33+
const addX = Tensor.fromFloat32(new Float32Array([2, 3]), [2]).setRequiresGrad(true);
34+
const addY = Tensor.fromFloat32(new Float32Array([4, 5]), [2]).setRequiresGrad(true);
35+
const addSum = addX.add(addY).sum();
36+
addSum.backward();
37+
assert(addX.grad !== null && addY.grad !== null, 'add gradients exist');
38+
assertClose(addX.grad!.toFloat32()[0], 1, 1e-3, 'd/dx(x+y).sum() = 1');
39+
assertClose(addY.grad!.toFloat32()[0], 1, 1e-3, 'd/dy(x+y).sum() = 1');
40+
41+
// gradient through matmul
42+
const mmX = Tensor.fromFloat32(new Float32Array([1, 0, 0, 1]), [2, 2]).setRequiresGrad(true);
43+
const mmY = Tensor.fromFloat32(new Float32Array([3, 4, 5, 6]), [2, 2]).setRequiresGrad(true);
44+
const mmOut = mmX.matmul(mmY).sum();
45+
mmOut.backward();
46+
assert(mmX.grad !== null, 'matmul grad exists');
47+

test/helpers.ts

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
const green = '\x1b[32m';
2+
const red = '\x1b[31m';
3+
const yellow = '\x1b[33m';
4+
const bold = '\x1b[1m';
5+
const dim = '\x1b[2m';
6+
const reset = '\x1b[0m';
7+
const greenBg = '\x1b[30m\x1b[42m';
8+
const redBg = '\x1b[37m\x1b[41m';
9+
10+
let totalPassed = 0;
11+
let totalFailed = 0;
12+
let totalSkipped = 0;
13+
let suitePassed = 0;
14+
let suiteFailed = 0;
15+
let suiteSkipped = 0;
16+
let suitesRun = 0;
17+
let suitesFailed = 0;
18+
const failures: string[] = [];
19+
const startTime = Date.now();
20+
21+
export function assert(cond: boolean, msg: string): void {
22+
if (!cond) {
23+
suiteFailed++;
24+
totalFailed++;
25+
failures.push(msg);
26+
console.log(` ${red}\u2716${reset} ${msg}`);
27+
} else {
28+
suitePassed++;
29+
totalPassed++;
30+
console.log(` ${green}\u2713${reset} ${dim}${msg}${reset}`);
31+
}
32+
}
33+
34+
export function assertClose(a: number, b: number, tol: number = 1e-4, msg: string = ''): void {
35+
if (Math.abs(a - b) > tol) {
36+
suiteFailed++;
37+
totalFailed++;
38+
const detail = `${msg}: ${a} != ${b} (tol=${tol})`;
39+
failures.push(detail);
40+
console.log(` ${red}\u2716${reset} ${detail}`);
41+
} else {
42+
suitePassed++;
43+
totalPassed++;
44+
console.log(` ${green}\u2713${reset} ${dim}${msg}${reset}`);
45+
}
46+
}
47+
48+
export function skip(msg: string): void {
49+
suiteSkipped++;
50+
totalSkipped++;
51+
console.log(` ${yellow}\u25CB${reset} ${dim}skipped: ${msg}${reset}`);
52+
}
53+
54+
export function section(name: string): void {
55+
console.log(` ${name}`);
56+
}
57+
58+
export function startSuite(file: string): void {
59+
suitesRun++;
60+
suitePassed = 0;
61+
suiteFailed = 0;
62+
suiteSkipped = 0;
63+
console.log('');
64+
}
65+
66+
export function endSuite(file: string): void {
67+
const badge = suiteFailed > 0
68+
? `${redBg} FAIL ${reset}`
69+
: `${greenBg} PASS ${reset}`;
70+
console.log(`${badge} ${file}`);
71+
if (suiteFailed > 0) suitesFailed++;
72+
}
73+
74+
export function summarize(): void {
75+
const elapsed = ((Date.now() - startTime) / 1000).toFixed(3);
76+
const total = totalPassed + totalFailed + totalSkipped;
77+
78+
console.log('');
79+
80+
// Suites line
81+
const suitesPassedCount = suitesRun - suitesFailed;
82+
if (suitesFailed > 0) {
83+
console.log(`${bold}Test Suites:${reset} ${red}${suitesFailed} failed${reset}, ${green}${suitesPassedCount} passed${reset}, ${suitesRun} total`);
84+
} else {
85+
console.log(`${bold}Test Suites:${reset} ${green}${suitesPassedCount} passed${reset}, ${suitesRun} total`);
86+
}
87+
88+
// Tests line
89+
const parts: string[] = [];
90+
if (totalFailed > 0) parts.push(`${red}${totalFailed} failed${reset}`);
91+
if (totalSkipped > 0) parts.push(`${yellow}${totalSkipped} skipped${reset}`);
92+
parts.push(`${green}${totalPassed} passed${reset}`);
93+
console.log(`${bold}Tests:${reset} ${parts.join(', ')}, ${total} total`);
94+
95+
// Time
96+
console.log(`${bold}Time:${reset} ${elapsed} s`);
97+
98+
if (totalFailed > 0) {
99+
console.log(`\n${red}Failures:${reset}`);
100+
for (const f of failures) console.log(` ${red}\u2716${reset} ${f}`);
101+
process.exit(1);
102+
}
103+
}

test/loader.js

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import { readFileSync } from 'fs';
2+
import { fileURLToPath } from 'url';
3+
import { dirname, resolve as pathResolve } from 'path';
4+
5+
export async function resolve(specifier, context, nextResolve) {
6+
// Remap relative .js imports to sibling .ts files when the .js file is absent.
7+
if (specifier.endsWith('.js') && (specifier.startsWith('./') || specifier.startsWith('../'))) {
8+
const parentPath = context.parentURL ? fileURLToPath(context.parentURL) : process.cwd();
9+
const parentDir = context.parentURL ? dirname(parentPath) : parentPath;
10+
const resolved = pathResolve(parentDir, specifier);
11+
12+
// If .js doesn't exist, try .ts
13+
try {
14+
readFileSync(resolved);
15+
} catch {
16+
const tsPath = resolved.replace(/\.js$/, '.ts');
17+
try {
18+
readFileSync(tsPath);
19+
return nextResolve(specifier.replace(/\.js$/, '.ts'), context);
20+
} catch {
21+
// Fall through to default resolution
22+
}
23+
}
24+
}
25+
return nextResolve(specifier, context);
26+
}

test/module.test.ts

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import {
2+
Tensor,
3+
Module, Parameter,
4+
Linear, ReLU,
5+
Adam, SGD, GradScaler,
6+
mseLoss,
7+
} from '../dist/index.js';
8+
import { assert, assertClose, section } from './helpers.js';
9+
10+
// ============================================================
11+
// Module system
12+
// ============================================================
13+
14+
section('Module system');
15+
16+
class TestNet extends Module {
17+
l1: any;
18+
l2: any;
19+
relu: any;
20+
constructor() {
21+
super();
22+
this.l1 = new Linear(3, 4);
23+
this.l2 = new Linear(4, 2);
24+
this.relu = new ReLU();
25+
}
26+
forward(x: any) {
27+
return this.l2.forward(this.relu.forward(this.l1.forward(x)));
28+
}
29+
}
30+
31+
const net = new TestNet();
32+
33+
const params = net.parameters();
34+
assert(params.length === 4, 'TestNet has 4 parameters (2 weights + 2 biases)');
35+
36+
const named = net.namedParameters();
37+
assert(named.length === 4, 'namedParameters count');
38+
const names = named.map(([n]: [string, any]) => n);
39+
assert(names.some((n: string) => n.includes('l1')), 'namedParameters includes l1');
40+
assert(names.some((n: string) => n.includes('l2')), 'namedParameters includes l2');
41+
42+
const kids = net.children();
43+
assert(kids.length === 3, 'TestNet has 3 children (l1, l2, relu)');
44+
45+
const allMods = net.modules();
46+
assert(allMods.length >= 4, 'modules() includes self + children');
47+
48+
net.eval();
49+
assert(net.training === false, 'eval sets training=false');
50+
net.train();
51+
assert(net.training === true, 'train sets training=true');
52+
53+
const netInput = Tensor.rand([2, 3]);
54+
const netOut = net.forward(netInput);
55+
assert(netOut.shape[0] === 2 && netOut.shape[1] === 2, 'TestNet output shape');
56+
57+
// ============================================================
58+
// Optimizers
59+
// ============================================================
60+
61+
section('Optimizers');
62+
63+
// SGD
64+
const sgdParam = Tensor.fromFloat32(new Float32Array([5, 5, 5, 5]), [2, 2]).setRequiresGrad(true);
65+
const sgdParamObj = new Parameter(sgdParam);
66+
const sgdTarget = Tensor.zeros([2, 2]);
67+
const sgdLoss = sgdParamObj.value.sub(sgdTarget).pow(2).mean();
68+
sgdLoss.backward();
69+
const sgd = new SGD([sgdParamObj], 0.1);
70+
const sgdBefore = sgdParamObj.value.toFloat32()[0];
71+
sgd.step();
72+
const sgdAfter = sgdParamObj.value.toFloat32()[0];
73+
assert(sgdAfter < sgdBefore, 'SGD step reduces parameter toward target');
74+
sgd.zeroGrad();
75+
76+
// Adam
77+
const adamParam = Tensor.fromFloat32(new Float32Array([5, 5, 5, 5]), [2, 2]).setRequiresGrad(true);
78+
const adamParamObj = new Parameter(adamParam);
79+
const adamTarget = Tensor.zeros([2, 2]);
80+
const adamLoss = adamParamObj.value.sub(adamTarget).pow(2).mean();
81+
adamLoss.backward();
82+
const adam = new Adam([adamParamObj], { lr: 0.01 });
83+
adam.step();
84+
const adamAfter = adamParamObj.value.toFloat32()[0];
85+
assert(adamAfter !== 5, 'Adam step changes parameter');
86+
adam.zeroGrad();
87+
88+
// Adam returns grad norm
89+
const adamParam2 = Tensor.fromFloat32(new Float32Array([3, 3]), [2]).setRequiresGrad(true);
90+
const adamParamObj2 = new Parameter(adamParam2);
91+
const adamLoss2 = adamParamObj2.value.pow(2).sum();
92+
adamLoss2.backward();
93+
const adam2 = new Adam([adamParamObj2], { lr: 0.01 });
94+
const gradNorm = adam2.step();
95+
assert(typeof gradNorm === 'number', 'Adam.step() returns grad norm');
96+
97+
// GradScaler
98+
const scaler = new GradScaler({ initScale: 1024 });
99+
assert(scaler.getScale() === 1024, 'GradScaler initial scale');
100+
101+
const gsLossInput = Tensor.fromFloat32(new Float32Array([2, 3]), [2]);
102+
const scaledLoss = scaler.scaleLoss(gsLossInput);
103+
const scaledData = scaledLoss.toFloat32();
104+
assertClose(scaledData[0], 2 * 1024, 1e-1, 'scaleLoss scales by initScale');
105+
assertClose(scaledData[1], 3 * 1024, 1e-1, 'scaleLoss scales second element');
106+
107+
// ============================================================
108+
// End-to-end training loop
109+
// ============================================================
110+
111+
section('End-to-end training');
112+
113+
const trainX = Tensor.fromFloat32(new Float32Array([0, 1, 2, 3, 4, 5]), [6, 1]);
114+
const trainY = Tensor.fromFloat32(new Float32Array([1, 3, 5, 7, 9, 11]), [6, 1]);
115+
const regNet = new Linear(1, 1);
116+
const regOptim = new Adam(regNet.parameters(), { lr: 0.05 });
117+
118+
let earlyLoss: number | null = null;
119+
for (let i = 0; i < 200; i++) {
120+
regOptim.zeroGrad();
121+
const pred = regNet.forward(trainX);
122+
const loss = mseLoss(pred, trainY);
123+
if (i === 10) earlyLoss = loss.toFloat32()[0];
124+
loss.backward();
125+
regOptim.step();
126+
}
127+
const finalPred = regNet.forward(trainX);
128+
const finalLoss = mseLoss(finalPred, trainY).toFloat32()[0];
129+
assert(finalLoss < earlyLoss!, 'training reduces loss');
130+
assert(finalLoss < 1.0, 'training converges to low loss');
131+

test/native.test.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import {
2+
Tensor, native,
3+
flashAttention, residualLayerNorm, biasGelu,
4+
} from '../dist/index.js';
5+
import { assert, skip, section } from './helpers.js';
6+
7+
// ============================================================
8+
// FlashAttention / ResidualLayerNorm / BiasGelu (GPU/CUDA only)
9+
// ============================================================
10+
11+
section('FlashAttention / ResidualLayerNorm / BiasGelu');
12+
13+
if (typeof native.flashAttention === 'function') {
14+
const nHeads = 2, seqLen = 4, headDim = 8;
15+
const qAtt = Tensor.rand([1, nHeads, seqLen, headDim]);
16+
const kAtt = Tensor.rand([1, nHeads, seqLen, headDim]);
17+
const vAtt = Tensor.rand([1, nHeads, seqLen, headDim]);
18+
const scale = 1.0 / Math.sqrt(headDim);
19+
const attOut = flashAttention(qAtt, kAtt, vAtt, scale, true);
20+
assert(attOut.shape[0] === 1, 'attention batch');
21+
assert(attOut.shape[1] === nHeads, 'attention heads');
22+
assert(attOut.shape[2] === seqLen, 'attention seq_len');
23+
assert(attOut.shape[3] === headDim, 'attention head_dim');
24+
} else {
25+
skip('flashAttention not available in CPU build');
26+
}
27+
28+
if (typeof native.residualLayernorm === 'function') {
29+
const rlnX = Tensor.rand([2, 4]);
30+
const rlnResidual = Tensor.rand([2, 4]);
31+
const rlnGamma = Tensor.ones([4]).setRequiresGrad(true);
32+
const rlnBeta = Tensor.zeros([4]).setRequiresGrad(true);
33+
const rlnOut = residualLayerNorm(rlnX, rlnResidual, rlnGamma, rlnBeta);
34+
assert(rlnOut.shape[0] === 2 && rlnOut.shape[1] === 4, 'residualLayerNorm shape');
35+
} else {
36+
skip('residualLayerNorm not available in CPU build');
37+
}
38+
39+
if (typeof native.biasGelu === 'function') {
40+
const bgX = Tensor.rand([2, 4]);
41+
const bgBias = Tensor.rand([4]);
42+
const bgOut = biasGelu(bgX, bgBias);
43+
assert(bgOut.shape[0] === 2 && bgOut.shape[1] === 4, 'biasGelu shape');
44+
} else {
45+
skip('biasGelu not available in CPU build');
46+
}
47+

0 commit comments

Comments
 (0)