@@ -56,14 +56,68 @@ def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
5656
5757
5858def _prepare_const_values_for_tosa_dtype (
59- values : np .ndarray , tosa_dtype : ts . DType
59+ values : np .ndarray , tosa_arg : TosaArg
6060) -> np .ndarray :
6161 """Normalize constant storage to the expected TOSA serializer dtype."""
62- if tosa_dtype == ts .DType .INT48 and values .dtype != np .int64 :
62+ if tosa_arg . dtype == ts .DType .INT48 and values .dtype != np .int64 :
6363 return values .astype (np .int64 )
6464 return values
6565
6666
67+ def _get_const_shape (values : np .ndarray , tosa_arg : TosaArg ) -> list [int ]:
68+ """Return the TOSA logical shape for a serialized constant."""
69+ if tosa_arg .dtype == ts .DType .FP4E2M1 :
70+ return normalize_symint (tosa_arg .shape )
71+ return normalize_symint (values .shape )
72+
73+
74+ def _is_packed_fp4_const (values : np .ndarray , tosa_arg : TosaArg ) -> bool :
75+ """FP4 elements are pairwise in each byte of a uint8 tensor.
76+
77+ This function checks if the given values and TOSA argument represent a
78+ packed FP4 constant.
79+
80+ """
81+
82+ return (
83+ tosa_arg .dtype == ts .DType .FP4E2M1
84+ and values .dtype == np .uint8
85+ and values .shape [- 1 ] * 2 == tosa_arg .shape [- 1 ]
86+ )
87+
88+
89+ def _add_const (
90+ tosa_graph : Any ,
91+ values : np .ndarray ,
92+ tosa_arg : TosaArg ,
93+ name : str ,
94+ ) -> None :
95+ """Add a constant, preserving packed FP4 storage when required."""
96+ if _is_packed_fp4_const (values , tosa_arg ):
97+ # TOSA FP4 tensors have logical FP4 shape, but constants are stored as
98+ # packed bytes (two values per byte). Add the raw bytes as INT8 first
99+ # then set TOSA dtype and shape correctly on the tensor metadata.
100+ tosa_graph .addConst (
101+ normalize_symint (values .shape ),
102+ ts .DType .INT8 ,
103+ values ,
104+ name = name ,
105+ )
106+ tensor = tosa_graph .currRegion .currBasicBlock .tensors [name ]
107+ tensor .setDtype (ts .DType .FP4E2M1 )
108+ for dim , size in enumerate (normalize_symint (tosa_arg .shape )):
109+ tensor .SetDimSize (dim , size )
110+ return
111+
112+ prepared_values = _prepare_const_values_for_tosa_dtype (values , tosa_arg )
113+ tosa_graph .addConst (
114+ _get_const_shape (prepared_values , tosa_arg ),
115+ tosa_arg .dtype ,
116+ prepared_values ,
117+ name = name ,
118+ )
119+
120+
67121def process_call_function (
68122 node : torch .fx .Node ,
69123 tosa_graph : Any ,
@@ -154,16 +208,7 @@ def process_inputs_to_parameters(
154208 f"{ type (parameter_data ).__name__ } "
155209 )
156210 parameter_values = _tensor_to_numpy (parameter_data )
157- parameter_values = _prepare_const_values_for_tosa_dtype (
158- parameter_values , tosa_arg .dtype
159- )
160-
161- tosa_graph .addConst (
162- normalize_symint (parameter_values .shape ),
163- tosa_arg .dtype ,
164- parameter_values ,
165- name = tosa_arg .name ,
166- )
211+ _add_const (tosa_graph , parameter_values , tosa_arg , name = tosa_arg .name )
167212
168213
169214def process_inputs_to_buffers (
@@ -188,14 +233,7 @@ def process_inputs_to_buffers(
188233 f"{ type (buffer_data ).__name__ } "
189234 )
190235 buffer_values = _tensor_to_numpy (buffer_data )
191- buffer_values = _prepare_const_values_for_tosa_dtype (buffer_values , tosa_arg .dtype )
192-
193- tosa_graph .addConst (
194- normalize_symint (buffer_values .shape ),
195- tosa_arg .dtype ,
196- buffer_values ,
197- name = tosa_arg .name ,
198- )
236+ _add_const (tosa_graph , buffer_values , tosa_arg , name = tosa_arg .name )
199237
200238
201239def process_inputs_to_lifted_tensor_constants (
@@ -217,14 +255,7 @@ def process_inputs_to_lifted_tensor_constants(
217255 f"{ type (tensor ).__name__ } "
218256 )
219257 tensor_values = _tensor_to_numpy (tensor )
220- tensor_values = _prepare_const_values_for_tosa_dtype (tensor_values , tosa_arg .dtype )
221-
222- tosa_graph .addConst (
223- normalize_symint (tensor_values .shape ),
224- tosa_arg .dtype ,
225- tensor_values ,
226- name = tosa_arg .name ,
227- )
258+ _add_const (tosa_graph , tensor_values , tosa_arg , name = tosa_arg .name )
228259
229260
230261def _is_submodule_input (
0 commit comments