Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/guides/8_autocast.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Best Practices
#. **Validate with Real Data**:

- Provide representative input data using the ``calibration_data`` option for more accurate node classification.
- The input names and shapes in ``calibration_data`` should match the ones in the given ONNX model.

#. **Control Reduction Depth**:
- Use ``max_depth_of_reduction`` to limit the depth of reduction operations that can be converted to low precision.
Expand Down
13 changes: 12 additions & 1 deletion modelopt/onnx/autocast/referencerunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def __init__(
"""Initialize with ONNX model path."""
self.model = model
self.input_names = [input.name for input in self.model.graph.input]
self.input_shapes = {
input.name: [s.dim_value for s in input.type.tensor_type.shape.dim]
for input in self.model.graph.input
}
self.providers = self._prepare_ep_list_with_trt_plugin_path(providers, trt_plugins)

def _prepare_ep_list_with_trt_plugin_path(self, providers, trt_plugins):
Expand All @@ -69,12 +73,19 @@ def _load_inputs_from_npz(self, input_data_path):
return [np.load(input_data_path)]

def _validate_inputs(self, data_loader):
"""Validate that input names match the model."""
"""Validate that input names and shapes match the model."""
if isinstance(data_loader, list) and (
isinstance(data_loader[0], (dict, np.lib.npyio.NpzFile))
):
if sorted(self.input_names) != sorted(data_loader[0].keys()):
raise ValueError("Input names from ONNX model do not match provided input names.")
for inp_name, inp_shape in data_loader[0].items():
if self.input_shapes[inp_name] != inp_shape.shape:
raise ValueError(
f"Input shape from '{inp_name}' does not match provided input shape: "
f"{self.input_shapes[inp_name]} vs {list(inp_shape.shape)}. "
f"Please make sure that your calibration data matches the ONNX input shapes."
)
else:
raise ValueError("Invalid input file.")

Expand Down
Loading