diff --git a/backends/xnnpack/test/models/llama2_et_example.py b/backends/xnnpack/test/models/llama2_et_example.py index 378f9dd3d48..23bc9e9aa49 100644 --- a/backends/xnnpack/test/models/llama2_et_example.py +++ b/backends/xnnpack/test/models/llama2_et_example.py @@ -17,9 +17,11 @@ def setUp(self): torch._dynamo.reset() def test_f32(self): + torch.manual_seed(0) self._test() def test_f16(self): + torch.manual_seed(0) self._test(torch.float16) # TODO - dynamic shape @@ -31,7 +33,22 @@ def _test(self, dtype: torch.dtype = torch.float): ], f"Only fp32 and fp16 are supported, but got dtype: {dtype}" llama2 = Llama2Model() - model = llama2.get_eager_model().to(dtype) + model = llama2.get_eager_model() + # The example uses a dummy small model with random weights for demo + # purposes only. Default torch init (e.g. nn.Embedding ~ N(0, 1)) + # combined with the model dim produces intermediate activations that + # overflow in fp16 (max ~65504), yielding nan/-inf and making the + # output comparison flaky. Re-init parameters AND float buffers (RoPE + # tables, causal mask, etc.) to a small bounded range so activations + # stay representable; this still exercises the export + lowering + # pipeline. + with torch.no_grad(): + for p in model.parameters(): + p.uniform_(-0.02, 0.02) + for b in model.buffers(): + if b.is_floating_point(): + b.uniform_(-0.02, 0.02) + model = model.to(dtype) # Only convert fp32 inputs to dtype example_inputs = tuple( diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index dc92a9542a9..d45f20e786d 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -547,6 +547,7 @@ def get_qnode_checks(quant_node_checks, dialect): # ) def test_qd8_f32_per_channel_shared_dq_chain(self): + torch.manual_seed(42) for use_bias in (False, True): module = SharedDQChain( input_size=13, @@ -561,6 +562,7 @@ def test_qd8_f32_per_channel_shared_dq_chain(self): is_per_channel=True, linear_count=2, uses_bias=use_bias, + atol=1.5e-1, # TODO(T212995726): Investigate right atol for rand[n] inputs ) def _test_qd8_per_channel_linear(self, dtype: torch.dtype = torch.float):