Skip to content

Commit 5202b16

Browse files
committed
WIP
1 parent 028b3d5 commit 5202b16

File tree

5 files changed

+151
-114
lines changed

5 files changed

+151
-114
lines changed

nx/bench/bench_conv.ml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
open Nx
2+
3+
(* Test data specification *)
4+
let test_specs =
5+
[
6+
(* ("tiny_4x4", [| 1; 1; 4; 4 |], [| 1; 1; 3; 3 |]); *)
7+
(* ("small_8x8", [| 1; 1; 8; 8 |], [| 1; 1; 3; 3 |]); *)
8+
("medium_16x16", [| 1; 4; 16; 16 |], [| 8; 4; 3; 3 |]);
9+
(* Skip large tests for now - they're too slow and might cause memory issues *)
10+
(* ("channels_32x32", [| 1; 8; 32; 32 |], [| 16; 8; 3; 3 |]); *)
11+
(* ("kernel_5x5", [| 1; 4; 16; 16 |], [| 8; 4; 5; 5 |]); *)
12+
(* ("batch_16x16", [| 4; 4; 16; 16 |], [| 8; 4; 3; 3 |]); *)
13+
]
14+
15+
(* Create all test data upfront and keep references *)
16+
let test_data =
17+
List.map
18+
(fun (name, x_shape, k_shape) ->
19+
let x = ones float32 x_shape in
20+
let k = ones float32 k_shape in
21+
(name, x, k))
22+
test_specs
23+
24+
(* Benchmark original implementation *)
25+
let bench_original () =
26+
List.map
27+
(fun (name, x, k) ->
28+
Ubench.create ("" ^ name) (fun () ->
29+
Nx.convolve2d ~padding_mode:`Valid x k |> ignore))
30+
test_data
31+
32+
let () =
33+
Printf.printf "Convolution Benchmarks\n";
34+
Printf.printf "=====================\n\n";
35+
36+
let tests = bench_original () in
37+
38+
Printf.printf "Running %d benchmarks...\n" (List.length tests);
39+
flush stdout;
40+
41+
let results = Ubench.run ~warmup:1 ~trials:3 ~min_time:0.01 tests in
42+
Ubench.print_report results

nx/bench/bench_nx.ml

Lines changed: 96 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,117 +1,108 @@
1-
let string_of_dtype : type a b. (a, b) Nx.dtype -> string = function
2-
| Float32 -> "float32"
3-
| Float64 -> "float64"
4-
| _ -> "other" (* Only float32 and float64 are used here *)
1+
open Nx
52

6-
(* Helper for binary operations: takes two arrays *)
7-
let binary_op_bench : type a b.
8-
((a, b) Nx.t -> (a, b) Nx.t -> (a, b) Nx.t) ->
9-
int ->
10-
(a, b) Nx.dtype ->
11-
unit ->
12-
unit =
13-
fun op size dtype ->
14-
let shape = [| size; size |] in
3+
(* Helper to create test arrays *)
4+
let make_array dtype shape =
155
(* TODO: Fix Nx.rand and use it instead of ones *)
16-
let a = Nx.astype dtype (Nx.ones Nx.float32 shape) in
17-
let b = Nx.astype dtype (Nx.ones Nx.float32 shape) in
18-
fun () -> op a b |> ignore
6+
Nx.astype dtype (Nx.ones float32 shape)
197

20-
(* Helper for unary operations: takes one array *)
21-
let unary_op_bench : type a b.
22-
((a, b) Nx.t -> (a, b) Nx.t) -> int -> (a, b) Nx.dtype -> unit -> unit =
23-
fun op size dtype ->
24-
let shape = [| size; size |] in
25-
(* TODO: Fix Nx.rand and use it instead of ones *)
26-
let a = Nx.astype dtype (Nx.ones Nx.float32 shape) in
27-
fun () -> op a |> ignore
8+
(* Benchmark functions *)
9+
let bench_add : type a b. int -> (a, b) dtype -> unit -> unit =
10+
fun size dtype () ->
11+
let a = make_array dtype [| size; size |] in
12+
let b = make_array dtype [| size; size |] in
13+
Nx.add a b |> ignore
2814

29-
(* Helper for reduction operations: reduces array to scalar/smaller array *)
30-
let reduction_op_bench : type a b c d.
31-
((a, b) Nx.t -> (c, d) Nx.t) -> int -> (a, b) Nx.dtype -> unit -> unit =
32-
fun op size dtype ->
33-
let shape = [| size; size |] in
34-
(* TODO: Fix Nx.rand and use it instead of ones *)
35-
let a = Nx.astype dtype (Nx.ones Nx.float32 shape) in
36-
fun () -> op a |> ignore
15+
let bench_mul : type a b. int -> (a, b) dtype -> unit -> unit =
16+
fun size dtype () ->
17+
let a = make_array dtype [| size; size |] in
18+
let b = make_array dtype [| size; size |] in
19+
Nx.mul a b |> ignore
3720

38-
(* Helper for matrix operations like matmul *)
39-
let matmul_bench : type a b. int -> (a, b) Nx.dtype -> unit -> unit =
40-
fun size dtype ->
41-
(* TODO: Fix Nx.rand and use it instead of ones *)
42-
let a = Nx.astype dtype (Nx.ones Nx.float32 [| size; size |]) in
43-
let b = Nx.astype dtype (Nx.ones Nx.float32 [| size; size |]) in
44-
fun () -> Nx.matmul a b |> ignore
21+
let bench_square : type a b. int -> (a, b) dtype -> unit -> unit =
22+
fun size dtype () ->
23+
let a = make_array dtype [| size; size |] in
24+
Nx.square a |> ignore
4525

46-
(* List of operations to benchmark *)
47-
let operations : type a b.
48-
int -> (a, b) Nx.dtype -> (string * (unit -> unit)) list =
49-
fun size dtype ->
50-
List.concat
51-
[
52-
(* Binary operations *)
53-
[
54-
("Addition", binary_op_bench Nx.add size dtype);
55-
("Multiplication", binary_op_bench Nx.mul size dtype);
56-
(* Unary operations *)
57-
("Square", unary_op_bench Nx.square size dtype);
58-
];
59-
(* Matrix operations - skip for large sizes *)
60-
(if size <= 100 then [ ("MatMul", matmul_bench size dtype) ] else []);
61-
(* Reductions *)
62-
[ ("Sum", reduction_op_bench Nx.sum size dtype) ];
63-
]
26+
let bench_sqrt : type b. int -> (float, b) dtype -> unit -> unit =
27+
fun size dtype () ->
28+
let a = make_array dtype [| size; size |] in
29+
Nx.sqrt a |> ignore
6430

65-
let float_operations : type b.
66-
int -> (float, b) Nx.dtype -> (string * (unit -> unit)) list =
67-
fun size dtype ->
68-
[
69-
(* Float-specific unary operations *)
70-
("Sqrt", unary_op_bench Nx.sqrt size dtype);
71-
("Exp", unary_op_bench Nx.exp size dtype);
72-
]
31+
let bench_exp : type b. int -> (float, b) dtype -> unit -> unit =
32+
fun size dtype () ->
33+
let a = make_array dtype [| size; size |] in
34+
Nx.exp a |> ignore
7335

74-
(* Generate benchmark tests for all combinations *)
75-
let tests ~sizes =
76-
let tests_on_dtype (type a b) (dtype : (a, b) Nx.dtype) =
77-
List.concat_map
78-
(fun size ->
79-
let ops = operations size dtype in
80-
List.map
81-
(fun (op_name, bench_fun) ->
82-
let name =
83-
Printf.sprintf "%s on %dx%d %s" op_name size size
84-
(string_of_dtype dtype)
85-
in
86-
Ubench.create name bench_fun)
87-
ops)
88-
sizes
36+
let bench_sum : type a b. int -> (a, b) dtype -> unit -> unit =
37+
fun size dtype () ->
38+
let a = make_array dtype [| size; size |] in
39+
Nx.sum a |> ignore
40+
41+
let bench_matmul : type a b. int -> (a, b) dtype -> unit -> unit =
42+
fun size dtype () ->
43+
let a = make_array dtype [| size; size |] in
44+
let b = make_array dtype [| size; size |] in
45+
Nx.matmul a b |> ignore
46+
47+
let bench_conv2d : type a b. int -> int -> (a, b) dtype -> unit -> unit =
48+
fun size kernel_size dtype () ->
49+
let input = make_array dtype [| 1; 3; size; size |] in
50+
let kernel = make_array dtype [| 16; 3; kernel_size; kernel_size |] in
51+
Nx.convolve2d ~padding_mode:`Same input kernel |> ignore
52+
53+
(* Generate benchmarks *)
54+
let make_benchmarks () =
55+
let sizes = [50; 100] in (* Reduced for faster runs *)
56+
let dtype_name : type a b. (a, b) dtype -> string = function
57+
| Float32 -> "f32"
58+
| Float64 -> "f64"
59+
| _ -> "other"
8960
in
90-
let tests_float_on_dtype (type b) (dtype : (float, b) Nx.dtype) =
91-
List.concat_map
92-
(fun size ->
93-
let ops = float_operations size dtype in
94-
List.map
95-
(fun (op_name, bench_fun) ->
96-
let name =
97-
Printf.sprintf "%s on %dx%d %s" op_name size size
98-
(string_of_dtype dtype)
99-
in
100-
Ubench.create name bench_fun)
101-
ops)
102-
sizes
61+
62+
let bench_for_dtype : type a b. (a, b) dtype -> _ =
63+
fun dtype ->
64+
List.concat_map (fun size ->
65+
let name s = Printf.sprintf "%s %dx%d %s" s size size (dtype_name dtype) in
66+
List.concat [
67+
(* Basic operations *)
68+
[ Ubench.create (name "add") (bench_add size dtype);
69+
Ubench.create (name "mul") (bench_mul size dtype);
70+
Ubench.create (name "square") (bench_square size dtype);
71+
Ubench.create (name "sum") (bench_sum size dtype);
72+
];
73+
74+
(* Float-specific operations *)
75+
(match dtype with
76+
| Float32 ->
77+
[ Ubench.create (name "sqrt") (bench_sqrt size Float32);
78+
Ubench.create (name "exp") (bench_exp size Float32); ]
79+
| Float64 ->
80+
[ Ubench.create (name "sqrt") (bench_sqrt size Float64);
81+
Ubench.create (name "exp") (bench_exp size Float64); ]
82+
| _ -> []);
83+
84+
(* Matrix operations - skip large sizes *)
85+
(if size < 100 then
86+
[ Ubench.create (name "matmul") (bench_matmul size dtype) ]
87+
else []);
88+
89+
(* Convolution - skip large sizes *)
90+
(if size < 100 then
91+
[ Ubench.create (name "conv2d-3x3") (bench_conv2d size 3 dtype);
92+
Ubench.create (name "conv2d-5x5") (bench_conv2d size 5 dtype); ]
93+
else []);
94+
]
95+
) sizes
10396
in
104-
List.concat
105-
[
106-
tests_on_dtype Float32;
107-
tests_on_dtype Float64;
108-
tests_float_on_dtype Float32;
109-
tests_float_on_dtype Float64;
110-
]
97+
98+
List.concat [
99+
bench_for_dtype Float32;
100+
bench_for_dtype Float64;
101+
]
111102

112-
(* Run the benchmarks *)
103+
(* Run benchmarks *)
113104
let () =
114-
print_endline "# Nx Benchmarks";
115-
let tests = tests ~sizes:[ 50; 100; 500 ] in
116-
let results = Ubench.run ~warmup:1 ~trials:3 ~min_time:0.01 tests in
117-
Ubench.print_report results
105+
print_endline "# Nx Benchmarks\n";
106+
let benchmarks = make_benchmarks () in
107+
let results = Ubench.run ~warmup:1 ~trials:2 ~min_time:0.001 benchmarks in
108+
Ubench.print_report results

nx/bench/dune

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@
22
(name bench_nx)
33
(modules bench_nx)
44
(libraries nx ubench))
5+
6+
7+
(executable
8+
(name bench_conv)
9+
(modules bench_conv)
10+
(libraries nx ubench))

nx/lib/native/internal.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ type ('a, 'b) t = {
1414
(* Helper to map logical indices through a chain of view transformations *)
1515
(* This is needed when views can't be composed into a single view *)
1616
let iterate_view_indices shape indices f =
17-
(* Helper to iterate through all indices of a tensor *)
1817
let ndim = Array.length shape in
1918
if ndim = 0 then f indices
2019
else
2120
let rec iter_dim d =
22-
if d = ndim then f (Array.copy indices)
21+
if d = ndim then
22+
f indices
2323
else
2424
for i = 0 to shape.(d) - 1 do
2525
indices.(d) <- i;

nx/lib/native/nx_native.ml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,15 @@ let op_assign target_t source_t = Internal.blit source_t target_t
7777

7878
(* Helper for binary operations that ensures inputs are materializable first *)
7979
let binary_op op_func a b =
80-
let a' = ensure_materializable a in
81-
let b' = ensure_materializable b in
82-
let ctx = a'.context in
83-
let out_shape = Internal.shape a' in
84-
let out_size = Internal.numel a' in
85-
let out_dtype = a'.dtype in
80+
let ctx = a.context in
81+
let out_shape = Internal.shape a in
82+
let out_size = Internal.numel a in
83+
let out_dtype = a.dtype in
8684
let out_tensor =
8785
op_buffer ctx out_dtype out_size |> fun t ->
8886
with_view t (Lazy_view.create (Symbolic_shape.of_ints out_shape))
8987
in
90-
op_func ctx a' b' out_tensor;
88+
op_func ctx a b out_tensor;
9189
out_tensor
9290

9391
(* Helper for binary comparison operations *)

0 commit comments

Comments
 (0)