Skip to content

Commit 3373e10

Browse files
committed
feat(nx): Implement lazy views and symbolic shapes
This commit introduces two major architectural improvements to nx: 1. **Lazy View Operations**: View operations (reshape, permute, expand, etc.) are now lazy - they update metadata instead of copying data. Views are only materialized when operations need to access the underlying data. 2. **Symbolic Shapes**: Replace concrete int arrays with symbolic dimensions that can be bound at runtime. This enables shape-polymorphic kernels and dynamic batching. Key changes: - Add `Symbolic_shape` module supporting Static/Dynamic dimensions - Add `Lazy_view` module tracking sequences of view transformations - Update tensor type to use `Lazy_view.t` instead of `View.t` - Refactor backend interface to accept symbolic shapes - Update Native and Metal backends to handle lazy views - Add `ensure_materializable` pattern for on-demand view materialization Benefits: - Memory efficiency: Avoid unnecessary data copies for view operations - Performance: Enable view fusion and better memory access patterns - Flexibility: Support dynamic shapes for batching and compilation - Foundation: Prepare for advanced JIT optimizations in Rune The changes maintain backward compatibility and NumPy semantics while bringing nx closer to modern tensor frameworks. We once again draw significant inspiration from Tinygrad. WIP WIP WIP
1 parent cd7cbcf commit 3373e10

23 files changed

+2192
-1099
lines changed
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Task: Implement Lazy View Operations and Symbolic Shapes in nx
2+
3+
**IMPORTANT: This checklist must be kept updated throughout implementation**
4+
5+
- [x] Step 1: Add symbolic dimension support to nx core types
6+
- [x] Step 2: Implement Shape_tracker with symbolic shape support
7+
- [x] Step 3: Update backend interface to use Shape_tracker instead of View
8+
- [x] Step 4: Make view operations lazy in frontend
9+
- [ ] Step 5: Update backends to handle lazy views
10+
- [ ] Step 6: Integrate symbolic shapes with Rune effect system
11+
- [ ] Step 7: Update JIT lowering for view realization and symbolic shapes
12+
- [ ] Step 8: Implement view fusion and shape specialization
13+
- [ ] Step 9: Add comprehensive tests
14+
- [ ] Step 10: Update documentation
15+
16+
---
17+
18+
## Objective
19+
20+
Implement lazy view operations and symbolic shapes in nx to:
21+
1. Avoid unnecessary memory allocations and data copies (lazy views)
22+
2. Enable shape-polymorphic kernels and dynamic batching (symbolic shapes)
23+
3. Improve convolution performance through better memory access patterns
24+
25+
## Context
26+
27+
- Current nx has a View system but materializes views eagerly
28+
- All shapes are currently concrete integers - no symbolic support
29+
- Rune's JIT has skeletal symbolic infrastructure (SymVar.t) but it's not integrated
30+
- Tinygrad provides a proven model for both lazy views and symbolic shapes
31+
- Shape_tracker will replace the current View in tensor metadata (addressing the confusion about having both)
32+
33+
## Key Design Decisions
34+
35+
1. **Shape_tracker replaces View**: Instead of having both `view` and `shape_tracker` in tensor metadata, Shape_tracker becomes the single source of truth. It can represent both simple views (one View.t) and complex view chains (multiple View.t).
36+
37+
2. **op_contiguous vs op_realize_view**: We keep `op_contiguous` as the backend operation. The name better matches the semantics - making data contiguous in memory.
38+
39+
3. **Backend-driven realization**: View realization happens in the backend when operations need to access data. The frontend remains agnostic about when realization occurs.
40+
41+
4. **Symbolic shapes throughout**: Replace `int array` with a unified shape type that supports both concrete and symbolic dimensions.
42+
43+
## Implementation Steps
44+
45+
### 1. Add symbolic dimension support (nx/lib/core/)
46+
47+
Create new module for symbolic shapes:
48+
49+
**symbolic.ml/mli**:
50+
```ocaml
51+
type dim =
52+
| Concrete of int
53+
| Symbolic of SymVar.t
54+
55+
and SymVar.t = {
56+
name: string;
57+
min_bound: int;
58+
max_bound: int;
59+
mutable value: int option;
60+
}
61+
62+
type shape = dim array
63+
64+
(* Constructors *)
65+
val concrete : int -> dim
66+
val symbolic : string -> min:int -> max:int -> dim
67+
68+
(* Operations *)
69+
val bind : dim -> int -> unit
70+
val (+) : dim -> dim -> dim
71+
val (*) : dim -> dim -> dim
72+
val (/) : dim -> dim -> dim
73+
val (mod) : dim -> dim -> dim
74+
75+
(* Utilities *)
76+
val to_concrete : shape -> int array option
77+
val is_concrete : shape -> bool
78+
val evaluate : dim -> int option
79+
val substitute : shape -> (string * int) list -> shape
80+
```
81+
82+
### 2. Update View.t for symbolic shapes (nx/lib/core/view.ml)
83+
84+
```ocaml
85+
type t = {
86+
shape : Symbolic.shape;
87+
strides : Symbolic.dim array; (* Computed from shape, not independently symbolic *)
88+
offset : Symbolic.dim; (* Can be symbolic from slicing operations *)
89+
mask : (Symbolic.dim * Symbolic.dim) array option;
90+
layout : layout;
91+
}
92+
```
93+
94+
**Critical design point**: Strides are NOT independently symbolic - they are always computed from shapes using the standard row-major formula. However, when shapes contain symbolic dimensions, the computed strides will also be symbolic expressions.
95+
96+
For example:
97+
- Shape: [batch_size, 128, 64] where batch_size is symbolic
98+
- Strides: [128*64, 64, 1] = [8192, 64, 1] where the first stride is a symbolic expression
99+
100+
Update View functions:
101+
- `strides_for_shape`: Compute strides from shape using accumulation formula
102+
- `create`: Use strides_for_shape to derive strides from shape
103+
- Other functions updated to handle symbolic dimensions
104+
105+
### 3. Implement Shape_tracker (nx/lib/core/shape_tracker.ml)
106+
107+
```ocaml
108+
type t = {
109+
views: View.t list; (* List of view transformations *)
110+
base_shape: Symbolic.shape; (* Original tensor shape *)
111+
}
112+
113+
(* Creation *)
114+
val create : View.t -> t
115+
val from_shape : Symbolic.shape -> t
116+
117+
(* View operations - all lazy *)
118+
val reshape : t -> Symbolic.shape -> t option
119+
val permute : t -> int array -> t
120+
val expand : t -> Symbolic.shape -> t
121+
val pad : t -> (Symbolic.dim * Symbolic.dim) array -> t
122+
val shrink : t -> (Symbolic.dim * Symbolic.dim) array -> t
123+
val flip : t -> bool array -> t
124+
125+
(* Analysis *)
126+
val simplify : t -> t (* Merge compatible views *)
127+
val is_c_contiguous : t -> bool
128+
val real_strides : t -> Symbolic.shape option (* None if not materializable *)
129+
val to_view : t -> View.t (* Compose all views *)
130+
131+
(* Symbolic operations *)
132+
val vars : t -> SymVar.t list
133+
val bind_vars : t -> (string * int) list -> t
134+
val is_concrete : t -> bool
135+
```
136+
137+
### 4. Update backend interface (nx/lib/core/backend_intf.ml)
138+
139+
Replace `view` lens with `shape_tracker`:
140+
```ocaml
141+
val shape_tracker : ('a, 'b) t -> Shape_tracker.t
142+
```
143+
144+
Keep shape-specific operations for Rune's effect system:
145+
- `op_reshape`, `op_expand`, `op_permute`, `op_pad`, `op_shrink`, `op_flip`
146+
- These operations now update Shape_tracker instead of materializing
147+
- They still need to exist as backend ops to raise effects in nx_rune
148+
- In eager mode (CPU/Metal), they just update metadata
149+
- In symbolic mode (Rune), they capture the operation for JIT compilation
150+
151+
Keep `op_contiguous` for forcing materialization.
152+
153+
### 5. Update frontend for lazy views (nx/lib/core/frontend.ml)
154+
155+
All view operations become lazy by updating Shape_tracker:
156+
```ocaml
157+
let reshape t ~shape =
158+
let tracker = Backend.shape_tracker t in
159+
match Shape_tracker.reshape tracker shape with
160+
| Some tracker' -> Backend.update_shape_tracker t tracker'
161+
| None ->
162+
(* Cannot reshape lazily, must materialize first *)
163+
let t' = contiguous t in
164+
let tracker = Backend.shape_tracker t' in
165+
match Shape_tracker.reshape tracker shape with
166+
| Some tracker' -> Backend.update_shape_tracker t' tracker'
167+
| None -> failwith "reshape: incompatible shapes"
168+
```
169+
170+
### 6. Backend implementation updates
171+
172+
**Native backend (nx/lib/native/)**:
173+
- When ops need to access data, check if Shape_tracker is contiguous
174+
- For ops requiring contiguous memory: Call op_contiguous internally
175+
- Only force realization when absolutely necessary (e.g., incompatible strides, device transfer)
176+
- Use symbolic shape evaluation for runtime shape binding
177+
178+
**Metal backend (nx/lib/metal/)**:
179+
- Similar logic, but some views can use buffer views
180+
- Generate kernels that handle strided access patterns
181+
182+
**Key insight**: Realization happens automatically in the backend when needed, not explicitly in the frontend (addressing the NOTE about where realization happens).
183+
184+
### 7. Rune integration (nx_rune.ml)
185+
186+
- Symbolic_tensor carries Shape_tracker with potentially symbolic shapes
187+
- No new effect needed - symbolic shapes are resolved during JIT lowering
188+
- During lowering, symbolic shapes are resolved to concrete values via variable bindings
189+
- JIT can specialize kernels for different shape bindings
190+
- Similar to tinygrad's Variable.bind() mechanism but integrated with Rune's effect system
191+
192+
### 8. JIT updates (rune/lib-jit/)
193+
194+
**ir.ml**:
195+
- Define independent view representation (since rune_jit is separate from nx):
196+
```ocaml
197+
type view = {
198+
shape: int array;
199+
strides: int array;
200+
offset: int;
201+
mask: (int * int) array option;
202+
}
203+
```
204+
- VIEW appears only in lowered IR, not high-level graph
205+
- Symbolic shapes use existing nodes:
206+
- `Define_Var` and `Bind` already exist for symbolic variables
207+
- Shape arithmetic uses regular binary ops (Add, Mul, Div, Mod)
208+
- No special shape-specific nodes needed
209+
210+
**lowerer.ml**:
211+
- Analyze Shape_tracker to determine if views can be fused
212+
- Insert efficient index calculations for strided access
213+
- Track shape specializations for kernel caching
214+
215+
### 9. Example Usage
216+
217+
```ocaml
218+
(* Symbolic shapes remain internal - users use Rune.placeholder *)
219+
let model input_shape =
220+
let x = Rune.placeholder ~shape:[None; None; Some 768] in
221+
(* Internally creates symbolic dims for batch and sequence length *)
222+
223+
(* Lazy view operations - just update Shape_tracker *)
224+
let y = x
225+
|> Nx.transpose ~axis:[1; 0; 2]
226+
|> Nx.flatten ~start_axis:0 ~end_axis:1 (* Reshapes to [-1, 768] *)
227+
in
228+
229+
(* Build computation graph with symbolic shapes *)
230+
let z = Nx.matmul y weight in
231+
z
232+
233+
(* At runtime, JIT specializes for concrete shapes *)
234+
let result = Rune.run model input (* Shape [32; 128; 768] triggers compilation *)
235+
```
236+
237+
## Success Criteria
238+
239+
- [ ] Views are lazy and don't copy data unnecessarily
240+
- [ ] Symbolic shapes enable dynamic batching
241+
- [ ] Convolution performance improved
242+
- [ ] Memory usage reduced for view-heavy code
243+
- [ ] JIT generates efficient specialized kernels
244+
- [ ] All existing tests pass
245+
246+
## Risks & Mitigation
247+
248+
1. **API changes**: Gradual migration with compatibility layer
249+
2. **Symbolic complexity**: Start simple, add features incrementally
250+
3. **Performance regressions**: Benchmark throughout development
251+
4. **Debugging difficulty**: Add shape tracing and view visualization
252+
253+
## Notes
254+
255+
- Shape_tracker is the single source of truth for tensor shape/view information
256+
- Realization is backend-driven, happening only when data access is needed
257+
- Symbolic shapes follow tinygrad's two-stage pattern: define symbolically, bind concretely
258+
- This design maintains nx's clean separation between frontend API and backend implementation

dev/conv2d/nx_conv.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ let winograd_conv2d x w =
500500

501501
(* Reshape to match tinygrad: [6, 6, 1, groups, rcout, cin, 1, 1] for 2D *)
502502
let target_shape = [| 6; 6; 1; groups; rcout; cin; 1; 1 |] in
503-
let gfactors = (Nx_native.op_reshape gfactors_raw target_shape)[@landmark "winograd_weight_reshape"] in
503+
let gfactors = (Nx_native.op_reshape gfactors_raw (Symbolic_shape.of_ints target_shape))[@landmark "winograd_weight_reshape"] in
504504

505505
(* Prepare input tiles - Winograd needs 6x6 tiles with 4x4 output each *)
506506
(* Number of 4x4 output tiles needed *)

nx/lib/core/backend_intf.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ module type S = sig
1919

2020
(* lenses *)
2121

22-
val view : ('a, 'b) t -> View.t
23-
(** Return the logical view metadata of [t]. *)
22+
val view : ('a, 'b) t -> Lazy_view.t
23+
(** Return the view tracker for [t]. *)
2424

2525
val dtype : ('a, 'b) t -> ('a, 'b) Dtype.t
2626
(** Element type of [t]. *)
@@ -129,10 +129,10 @@ module type S = sig
129129

130130
(* Movement Ops - manipulate view metadata *)
131131

132-
val op_expand : ('a, 'b) t -> int array -> ('a, 'b) t
132+
val op_expand : ('a, 'b) t -> Symbolic_shape.t -> ('a, 'b) t
133133
(** Broadcast dimensions of size 1 to a new shape. *)
134134

135-
val op_reshape : ('a, 'b) t -> int array -> ('a, 'b) t
135+
val op_reshape : ('a, 'b) t -> Symbolic_shape.t -> ('a, 'b) t
136136
(** Change the logical shape without moving data. *)
137137

138138
val op_permute : ('a, 'b) t -> int array -> ('a, 'b) t

0 commit comments

Comments
 (0)