Skip to content

Commit 82433c7

Browse files
committed
Testing fix
1 parent 1621fa2 commit 82433c7

1 file changed

Lines changed: 20 additions & 17 deletions

File tree

backends/qualcomm/quantizer/quantizer.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -462,26 +462,29 @@ def _replace_inf(self, graph_module: GraphModule) -> GraphModule:
462462
node.args = tuple(arg_list)
463463
elif node.op == "get_attr":
464464
constant_tensor = attrgetter(node.target)(graph_module)
465-
if (
465+
if not (
466466
torch.is_tensor(constant_tensor)
467467
and constant_tensor.is_floating_point()
468+
and torch.isinf(constant_tensor).any()
468469
):
469-
# Anything smaller than float16.min, which covers float32.min and float(-inf)
470-
min_value = torch.finfo(torch.float16).min
471-
# Anything larger than float16.max, which covers float32.max and float(inf)
472-
max_value = torch.finfo(torch.float16).max
473-
474-
quant_min, quant_max = float("inf"), float("-inf")
475-
for source_node in node.users:
476-
if quant_range := self._get_quant_range(source_node):
477-
quant_min = min(quant_min, -quant_range)
478-
quant_max = max(quant_max, quant_range)
479-
480-
if quant_min != float("inf") and quant_max != float("-inf"):
481-
# Inplace update
482-
with torch.no_grad():
483-
constant_tensor[constant_tensor <= min_value] = quant_min
484-
constant_tensor[constant_tensor >= max_value] = quant_max
470+
continue
471+
472+
# Anything smaller than float16.min, which covers float32.min and float(-inf)
473+
min_value = torch.finfo(torch.float16).min
474+
# Anything larger than float16.max, which covers float32.max and float(inf)
475+
max_value = torch.finfo(torch.float16).max
476+
477+
quant_min, quant_max = float("inf"), float("-inf")
478+
for source_node in node.users:
479+
if quant_range := self._get_quant_range(source_node):
480+
quant_min = min(quant_min, -quant_range)
481+
quant_max = max(quant_max, quant_range)
482+
483+
if quant_min != float("inf") and quant_max != float("-inf"):
484+
# Inplace update
485+
with torch.no_grad():
486+
constant_tensor[constant_tensor <= min_value] = quant_min
487+
constant_tensor[constant_tensor >= max_value] = quant_max
485488

486489
graph_module.recompile()
487490

0 commit comments

Comments
 (0)