@@ -462,26 +462,29 @@ def _replace_inf(self, graph_module: GraphModule) -> GraphModule:
462462 node .args = tuple (arg_list )
463463 elif node .op == "get_attr" :
464464 constant_tensor = attrgetter (node .target )(graph_module )
465- if (
465+ if not (
466466 torch .is_tensor (constant_tensor )
467467 and constant_tensor .is_floating_point ()
468+ and torch .isinf (constant_tensor ).any ()
468469 ):
469- # Anything smaller than float16.min, which covers float32.min and float(-inf)
470- min_value = torch .finfo (torch .float16 ).min
471- # Anything larger than float16.max, which covers float32.max and float(inf)
472- max_value = torch .finfo (torch .float16 ).max
473-
474- quant_min , quant_max = float ("inf" ), float ("-inf" )
475- for source_node in node .users :
476- if quant_range := self ._get_quant_range (source_node ):
477- quant_min = min (quant_min , - quant_range )
478- quant_max = max (quant_max , quant_range )
479-
480- if quant_min != float ("inf" ) and quant_max != float ("-inf" ):
481- # Inplace update
482- with torch .no_grad ():
483- constant_tensor [constant_tensor <= min_value ] = quant_min
484- constant_tensor [constant_tensor >= max_value ] = quant_max
470+ continue
471+
472+ # Anything smaller than float16.min, which covers float32.min and float(-inf)
473+ min_value = torch .finfo (torch .float16 ).min
474+ # Anything larger than float16.max, which covers float32.max and float(inf)
475+ max_value = torch .finfo (torch .float16 ).max
476+
477+ quant_min , quant_max = float ("inf" ), float ("-inf" )
478+ for source_node in node .users :
479+ if quant_range := self ._get_quant_range (source_node ):
480+ quant_min = min (quant_min , - quant_range )
481+ quant_max = max (quant_max , quant_range )
482+
483+ if quant_min != float ("inf" ) and quant_max != float ("-inf" ):
484+ # Inplace update
485+ with torch .no_grad ():
486+ constant_tensor [constant_tensor <= min_value ] = quant_min
487+ constant_tensor [constant_tensor >= max_value ] = quant_max
485488
486489 graph_module .recompile ()
487490
0 commit comments