1
- let string_of_dtype : type a b. (a, b) Nx.dtype -> string = function
2
- | Float32 -> " float32"
3
- | Float64 -> " float64"
4
- | _ -> " other" (* Only float32 and float64 are used here *)
1
+ open Nx
5
2
6
- (* Helper for binary operations: takes two arrays *)
7
- let binary_op_bench : type a b .
8
- ((a , b ) Nx. t -> (a , b ) Nx. t -> (a , b ) Nx. t ) ->
9
- int ->
10
- (a , b ) Nx. dtype ->
11
- unit ->
12
- unit =
13
- fun op size dtype ->
14
- let shape = [| size; size |] in
3
+ (* Helper to create test arrays *)
4
+ let make_array dtype shape =
15
5
(* TODO: Fix Nx.rand and use it instead of ones *)
16
- let a = Nx. astype dtype (Nx. ones Nx. float32 shape) in
17
- let b = Nx. astype dtype (Nx. ones Nx. float32 shape) in
18
- fun () -> op a b |> ignore
6
+ Nx. astype dtype (Nx. ones float32 shape)
19
7
20
- (* Helper for unary operations: takes one array *)
21
- let unary_op_bench : type a b .
22
- ((a , b ) Nx. t -> (a , b ) Nx. t ) -> int -> (a , b ) Nx. dtype -> unit -> unit =
23
- fun op size dtype ->
24
- let shape = [| size; size |] in
25
- (* TODO: Fix Nx.rand and use it instead of ones *)
26
- let a = Nx. astype dtype (Nx. ones Nx. float32 shape) in
27
- fun () -> op a |> ignore
8
+ (* Benchmark functions *)
9
+ let bench_add : type a b. int -> (a, b) dtype -> unit -> unit =
10
+ fun size dtype () ->
11
+ let a = make_array dtype [| size; size |] in
12
+ let b = make_array dtype [| size; size |] in
13
+ Nx. add a b |> ignore
28
14
29
- (* Helper for reduction operations: reduces array to scalar/smaller array *)
30
- let reduction_op_bench : type a b c d .
31
- ((a , b ) Nx. t -> (c , d ) Nx. t ) -> int -> (a , b ) Nx. dtype -> unit -> unit =
32
- fun op size dtype ->
33
- let shape = [| size; size |] in
34
- (* TODO: Fix Nx.rand and use it instead of ones *)
35
- let a = Nx. astype dtype (Nx. ones Nx. float32 shape) in
36
- fun () -> op a |> ignore
15
+ let bench_mul : type a b. int -> (a, b) dtype -> unit -> unit =
16
+ fun size dtype () ->
17
+ let a = make_array dtype [| size; size |] in
18
+ let b = make_array dtype [| size; size |] in
19
+ Nx. mul a b |> ignore
37
20
38
- (* Helper for matrix operations like matmul *)
39
- let matmul_bench : type a b. int -> (a, b) Nx.dtype -> unit -> unit =
40
- fun size dtype ->
41
- (* TODO: Fix Nx.rand and use it instead of ones *)
42
- let a = Nx. astype dtype (Nx. ones Nx. float32 [| size; size |]) in
43
- let b = Nx. astype dtype (Nx. ones Nx. float32 [| size; size |]) in
44
- fun () -> Nx. matmul a b |> ignore
21
+ let bench_square : type a b. int -> (a, b) dtype -> unit -> unit =
22
+ fun size dtype () ->
23
+ let a = make_array dtype [| size; size |] in
24
+ Nx. square a |> ignore
45
25
46
- (* List of operations to benchmark *)
47
- let operations : type a b .
48
- int -> (a , b ) Nx. dtype -> (string * (unit -> unit )) list =
49
- fun size dtype ->
50
- List. concat
51
- [
52
- (* Binary operations *)
53
- [
54
- (" Addition" , binary_op_bench Nx. add size dtype);
55
- (" Multiplication" , binary_op_bench Nx. mul size dtype);
56
- (* Unary operations *)
57
- (" Square" , unary_op_bench Nx. square size dtype);
58
- ];
59
- (* Matrix operations - skip for large sizes *)
60
- (if size < = 100 then [ (" MatMul" , matmul_bench size dtype) ] else [] );
61
- (* Reductions *)
62
- [ (" Sum" , reduction_op_bench Nx. sum size dtype) ];
63
- ]
26
+ let bench_sqrt : type b. int -> (float, b) dtype -> unit -> unit =
27
+ fun size dtype () ->
28
+ let a = make_array dtype [| size; size |] in
29
+ Nx. sqrt a |> ignore
64
30
65
- let float_operations : type b .
66
- int -> (float , b ) Nx. dtype -> (string * (unit -> unit )) list =
67
- fun size dtype ->
68
- [
69
- (* Float-specific unary operations *)
70
- (" Sqrt" , unary_op_bench Nx. sqrt size dtype);
71
- (" Exp" , unary_op_bench Nx. exp size dtype);
72
- ]
31
+ let bench_exp : type b. int -> (float, b) dtype -> unit -> unit =
32
+ fun size dtype () ->
33
+ let a = make_array dtype [| size; size |] in
34
+ Nx. exp a |> ignore
73
35
74
- (* Generate benchmark tests for all combinations *)
75
- let tests ~sizes =
76
- let tests_on_dtype (type a b ) (dtype : (a, b) Nx.dtype ) =
77
- List. concat_map
78
- (fun size ->
79
- let ops = operations size dtype in
80
- List. map
81
- (fun (op_name , bench_fun ) ->
82
- let name =
83
- Printf. sprintf " %s on %dx%d %s" op_name size size
84
- (string_of_dtype dtype)
85
- in
86
- Ubench. create name bench_fun)
87
- ops)
88
- sizes
36
+ let bench_sum : type a b. int -> (a, b) dtype -> unit -> unit =
37
+ fun size dtype () ->
38
+ let a = make_array dtype [| size; size |] in
39
+ Nx. sum a |> ignore
40
+
41
+ let bench_matmul : type a b. int -> (a, b) dtype -> unit -> unit =
42
+ fun size dtype () ->
43
+ let a = make_array dtype [| size; size |] in
44
+ let b = make_array dtype [| size; size |] in
45
+ Nx. matmul a b |> ignore
46
+
47
+ let bench_conv2d : type a b. int -> int -> (a, b) dtype -> unit -> unit =
48
+ fun size kernel_size dtype () ->
49
+ let input = make_array dtype [| 1 ; 3 ; size; size |] in
50
+ let kernel = make_array dtype [| 16 ; 3 ; kernel_size; kernel_size |] in
51
+ Nx. convolve2d ~padding_mode: `Same input kernel |> ignore
52
+
53
+ (* Generate benchmarks *)
54
+ let make_benchmarks () =
55
+ let sizes = [50 ; 100 ] in (* Reduced for faster runs *)
56
+ let dtype_name : type a b. (a, b) dtype -> string = function
57
+ | Float32 -> " f32"
58
+ | Float64 -> " f64"
59
+ | _ -> " other"
89
60
in
90
- let tests_float_on_dtype (type b ) (dtype : (float, b) Nx.dtype ) =
91
- List. concat_map
92
- (fun size ->
93
- let ops = float_operations size dtype in
94
- List. map
95
- (fun (op_name , bench_fun ) ->
96
- let name =
97
- Printf. sprintf " %s on %dx%d %s" op_name size size
98
- (string_of_dtype dtype)
99
- in
100
- Ubench. create name bench_fun)
101
- ops)
102
- sizes
61
+
62
+ let bench_for_dtype : type a b. (a, b) dtype -> _ =
63
+ fun dtype ->
64
+ List. concat_map (fun size ->
65
+ let name s = Printf. sprintf " %s %dx%d %s" s size size (dtype_name dtype) in
66
+ List. concat [
67
+ (* Basic operations *)
68
+ [ Ubench. create (name " add" ) (bench_add size dtype);
69
+ Ubench. create (name " mul" ) (bench_mul size dtype);
70
+ Ubench. create (name " square" ) (bench_square size dtype);
71
+ Ubench. create (name " sum" ) (bench_sum size dtype);
72
+ ];
73
+
74
+ (* Float-specific operations *)
75
+ (match dtype with
76
+ | Float32 ->
77
+ [ Ubench. create (name " sqrt" ) (bench_sqrt size Float32 );
78
+ Ubench. create (name " exp" ) (bench_exp size Float32 ); ]
79
+ | Float64 ->
80
+ [ Ubench. create (name " sqrt" ) (bench_sqrt size Float64 );
81
+ Ubench. create (name " exp" ) (bench_exp size Float64 ); ]
82
+ | _ -> [] );
83
+
84
+ (* Matrix operations - skip large sizes *)
85
+ (if size < 100 then
86
+ [ Ubench. create (name " matmul" ) (bench_matmul size dtype) ]
87
+ else [] );
88
+
89
+ (* Convolution - skip large sizes *)
90
+ (if size < 100 then
91
+ [ Ubench. create (name " conv2d-3x3" ) (bench_conv2d size 3 dtype);
92
+ Ubench. create (name " conv2d-5x5" ) (bench_conv2d size 5 dtype); ]
93
+ else [] );
94
+ ]
95
+ ) sizes
103
96
in
104
- List. concat
105
- [
106
- tests_on_dtype Float32 ;
107
- tests_on_dtype Float64 ;
108
- tests_float_on_dtype Float32 ;
109
- tests_float_on_dtype Float64 ;
110
- ]
97
+
98
+ List. concat [
99
+ bench_for_dtype Float32 ;
100
+ bench_for_dtype Float64 ;
101
+ ]
111
102
112
- (* Run the benchmarks *)
103
+ (* Run benchmarks *)
113
104
let () =
114
- print_endline " # Nx Benchmarks" ;
115
- let tests = tests ~sizes: [ 50 ; 100 ; 500 ] in
116
- let results = Ubench. run ~warmup: 1 ~trials: 3 ~min_time: 0.01 tests in
117
- Ubench. print_report results
105
+ print_endline " # Nx Benchmarks\n " ;
106
+ let benchmarks = make_benchmarks () in
107
+ let results = Ubench. run ~warmup: 1 ~trials: 2 ~min_time: 0.001 benchmarks in
108
+ Ubench. print_report results
0 commit comments