@@ -117,7 +117,6 @@ def get_cond_while_submodules_ao(
117117 only the ``while_loop`` cond function is processed explicitly there.
118118
119119 """
120-
121120 if not apply_quantization :
122121 return get_cond_while_submodules (graph_module )
123122
@@ -156,6 +155,7 @@ def get_symmetric_quantization_config(
156155 act_qmax (int): Maximum activation quantization value.
157156 weight_qmin (int): Minimum weight quantization value.
158157 weight_qmax (int): Maximum weight quantization value.
158+ eps (float): Minimum scale value used by observers.
159159
160160 Returns:
161161 QuantizationConfig: Quantization settings for activations, weights, and
@@ -525,14 +525,17 @@ def __init__(
525525
526526 @property
527527 def tosa_spec (self ):
528+ """Return the TOSA specification used by the active quantizer."""
528529 return self .quantizer .tosa_spec
529530
530531 @property
531532 def compile_spec (self ):
533+ """Return the compile specification used by the active quantizer."""
532534 return self .quantizer .compile_spec
533535
534536 @property
535537 def global_config (self ):
538+ """Return the fallback quantization configuration."""
536539 return self .quantizer .global_config
537540
538541 @global_config .setter
@@ -546,6 +549,7 @@ def global_config(self, value: Optional[QuantizationConfig]) -> None:
546549
547550 @property
548551 def io_config (self ):
552+ """Return the input and output quantization configuration."""
549553 if isinstance (self .quantizer , _TOSAQuantizerV1 ):
550554 return self .quantizer .io_config
551555 else :
@@ -564,6 +568,7 @@ def io_config(self, value: Optional[QuantizationConfig]) -> None:
564568
565569 @property
566570 def module_type_config (self ):
571+ """Return quantization configuration overrides by module type."""
567572 if isinstance (self .quantizer , _TOSAQuantizerV1 ):
568573 return self .quantizer .module_type_config
569574 else :
@@ -584,6 +589,7 @@ def module_type_config(
584589
585590 @property
586591 def module_name_config (self ):
592+ """Return quantization configuration overrides by module name."""
587593 if isinstance (self .quantizer , _TOSAQuantizerV1 ):
588594 return getattr (self .quantizer , "module_name_config" , {})
589595 else :
@@ -692,6 +698,7 @@ def set_node_finder(
692698 quantization_config (Optional[QuantizationConfig]): Configuration
693699 describing quantization settings for nodes matched by the provided
694700 NodeFinder. ``None`` indicates no quantization.
701+ node_finder (NodeFinder): Predicate used to select nodes.
695702
696703 """
697704 if self .use_composable_quantizer :
@@ -757,14 +764,18 @@ def annotate(self, model: GraphModule) -> GraphModule:
757764 return self .quantizer .annotate (model )
758765
759766 def validate (self , model : GraphModule ) -> None :
760- """Validate the quantization results. Currently, this includes:
761- - Ensure tensor inputs to each operator live on the same device.
767+ """Validate the quantization results.
768+
769+ Currently, this ensures tensor inputs to each operator live on the same
770+ device.
762771
763772 Args:
764773 model (GraphModule): GraphModule being validated.
774+
765775 Raises:
766776 ValueError: If tensor inputs for any operator span more than one
767777 device.
778+
768779 """
769780 for node in model .graph .nodes :
770781 if node .op != "call_function" :
@@ -809,8 +820,7 @@ def _quantize_with_submodules(
809820 is_qat : bool = False ,
810821 fold_quantize : bool = True ,
811822 ):
812- """Quantizes a GraphModule in a way such that conditional submodules are
813- handled properly.
823+ """Quantize a GraphModule with conditional submodule handling.
814824
815825 Note: torchao's prepare_pt2e and convert_pt2e natively handle
816826 while_loop body_fn submodules, so we only manually process cond
@@ -823,8 +833,8 @@ def _quantize_with_submodules(
823833 model with submodules, at least one sample per code path is
824834 needed.
825835 is_qat (bool): Whether to do quantization aware training or not.
826- fold_quantize (bool): Enables or disables constant folding when quantization
827- is completed.
836+ fold_quantize (bool): Enables or disables constant folding when
837+ quantization is completed.
828838
829839 Returns:
830840 GraphModule: The quantized model.
@@ -949,7 +959,6 @@ def _set_disallow_tfa_for_nodes(self, model: GraphModule) -> None:
949959 quantized models.
950960
951961 """
952-
953962 # First, set all nodes according to global config
954963 for node in model .graph .nodes :
955964 node .meta [DISALLOW_TFA_META_KEY ] = self .global_config is None
@@ -1104,10 +1113,10 @@ def __init__(
11041113
11051114 @property
11061115 def quantizers (self ) -> List [Quantizer ]:
1107- """Returns the configured quantizers in order of precedence, ensuring
1108- the global config and shared_qspec_quantizer are applied last.
1116+ """Return the configured quantizers in order of precedence.
11091117
1110- The returned list is a shallow copy; quantizer instances are shared.
1118+ The returned list is a shallow copy; quantizer instances are shared. The
1119+ global config and shared_qspec_quantizer are applied last.
11111120
11121121 """
11131122 quantizers = self ._quantizers .copy ()
@@ -1119,9 +1128,7 @@ def quantizers(self) -> List[Quantizer]:
11191128
11201129 @quantizers .setter
11211130 def quantizers (self , value : List [Quantizer ]) -> None :
1122- """Override of quantizers setter to allow for dynamic updating of
1123- quantizers without accessing self._quantizers.
1124- """
1131+ """Update quantizers without accessing self._quantizers directly."""
11251132 self ._quantizers = value
11261133
11271134 def annotate (self , model ):
0 commit comments