@@ -159,6 +159,61 @@ def test_mxfp_conv2d_quantize_supports_fp4_weights() -> None:
159159 )
160160
161161
162+ def test_mxfp_conv2d_preserves_bfloat16_output_dtype () -> None :
163+ model = Conv2dModule ().eval ().to (torch .bfloat16 )
164+ to_mxfp (
165+ model ,
166+ MXFPOpConfig (weight_dtype = torch .float8_e4m3fn ),
167+ )
168+
169+ output = model (torch .randn (1 , IN_CHANNELS , 8 , 8 , dtype = torch .bfloat16 ))
170+
171+ assert isinstance (model .conv , MXFPConv2dOp )
172+ assert model .conv .output_dtype == torch .bfloat16
173+ assert output .dtype == torch .bfloat16
174+
175+
176+ def test_mxfp_conv2d_op_output_dtype_constructor_arg () -> None :
177+ model = Conv2dModule ().eval ()
178+ config = MXFPOpConfig (weight_dtype = torch .float8_e4m3fn )
179+ to_mxfp (
180+ model ,
181+ config ,
182+ )
183+ assert isinstance (model .conv , MXFPConv2dOp )
184+
185+ fp32_conv = MXFPConv2dOp (
186+ model .conv .weight_qdata ,
187+ model .conv .weight_scale ,
188+ model .conv .bias ,
189+ model .conv .stride ,
190+ model .conv .padding ,
191+ model .conv .dilation ,
192+ model .conv .groups ,
193+ config .weight_dtype ,
194+ config .block_size ,
195+ )
196+ bf16_conv = MXFPConv2dOp (
197+ model .conv .weight_qdata ,
198+ model .conv .weight_scale ,
199+ model .conv .bias ,
200+ model .conv .stride ,
201+ model .conv .padding ,
202+ model .conv .dilation ,
203+ model .conv .groups ,
204+ config .weight_dtype ,
205+ config .block_size ,
206+ output_dtype = torch .bfloat16 ,
207+ )
208+
209+ test_input = torch .randn (1 , IN_CHANNELS , 8 , 8 )
210+
211+ assert fp32_conv .output_dtype == torch .float32
212+ assert fp32_conv (test_input ).dtype == torch .float32
213+ assert bf16_conv .output_dtype == torch .bfloat16
214+ assert bf16_conv (test_input ).dtype == torch .bfloat16
215+
216+
162217def _test_mxfp_conv2d_export_preserves_custom_op (config : MXFPOpConfig ) -> None :
163218 model = Conv2dModule ().eval ()
164219 to_mxfp (model , config )
@@ -198,6 +253,33 @@ def test_mxfp6_e3m2_conv2d_export_preserves_custom_op() -> None:
198253 )
199254
200255
256+ def test_mxfp_conv2d_export_preserves_inferred_bfloat16_output_dtype () -> None :
257+ model = Conv2dModule ().eval ().to (torch .bfloat16 )
258+ to_mxfp (
259+ model ,
260+ MXFPOpConfig (weight_dtype = torch .float8_e4m3fn ),
261+ )
262+
263+ exported = export (
264+ model ,
265+ (torch .randn (1 , IN_CHANNELS , 8 , 8 , dtype = torch .bfloat16 ),),
266+ strict = False ,
267+ )
268+
269+ cast_nodes = [
270+ node
271+ for node in exported .graph_module .graph .nodes
272+ if node .op == "call_function" and node .target == torch .ops .aten .to .dtype
273+ ]
274+
275+ assert len (cast_nodes ) == 1
276+ assert cast_nodes [0 ].args [1 ] == torch .bfloat16
277+ assert cast_nodes [0 ].meta ["val" ].dtype == torch .bfloat16
278+ cast_input = cast_nodes [0 ].args [0 ]
279+ assert isinstance (cast_input , torch .fx .Node )
280+ assert cast_input .target == torch .ops .tosa_mxfp .conv2d .default
281+
282+
201283def test_mxfp_conv2d_cpu_impl_matches_ref () -> None :
202284 ref_model = Conv2dModule ().eval ()
203285 test_model = Conv2dModule ().eval ()
0 commit comments