Skip to content
Merged
Show file tree
Hide file tree
Changes from 165 commits
Commits
Show all changes
180 commits
Select commit Hold shift + click to select a range
7ee9b05
old benchmarking files added
Sep 22, 2025
dc20c8b
benchmark_runner updated
tmfreiberg Sep 23, 2025
366bb4f
breadth sweep onnx models go in breadth directory
tmfreiberg Sep 23, 2025
d63e9cf
prettifying benchmark commands
tmfreiberg Sep 23, 2025
c4cb632
fixes
tmfreiberg Sep 23, 2025
fe72efd
fix
tmfreiberg Sep 23, 2025
89c8957
eye candy added
tmfreiberg Sep 23, 2025
8704e73
eye candy edits
tmfreiberg Sep 23, 2025
f1a316c
refactor benchmark_runner
tmfreiberg Sep 23, 2025
88675af
refactor gen_and_bench
tmfreiberg Sep 23, 2025
a735a5d
refactor benchmarking_helpers
tmfreiberg Sep 23, 2025
1b27530
CLI jstprove -> jst part one
tmfreiberg Sep 23, 2025
7d7d393
CLI jstprove -> jst part two (documentation)
tmfreiberg Sep 23, 2025
539eae4
capture ecc output
tmfreiberg Sep 23, 2025
6a93c99
capture ecc output
tmfreiberg Sep 23, 2025
a6103ca
capture ecc output
tmfreiberg Sep 23, 2025
85e147b
summary card formatting
tmfreiberg Sep 23, 2025
1582015
summary card formatting
tmfreiberg Sep 23, 2025
4ead1fd
adjusting default depth parameters
tmfreiberg Sep 24, 2025
9aa47e0
adding fixed benchmarking model (lenet)
tmfreiberg Sep 26, 2025
d5a4b47
benchmarking md added; --summarize flag removed from cli
tmfreiberg Sep 26, 2025
5dc5357
added jsonl to gitignore
tmfreiberg Sep 26, 2025
c55d523
jst bench lenet fix
tmfreiberg Sep 29, 2025
a807c90
lenet bench fix 2
tmfreiberg Sep 29, 2025
c962831
lenet bench fix 3
tmfreiberg Sep 29, 2025
067dd0c
jst bench lenet fix continued
tmfreiberg Sep 30, 2025
ad92160
lenet_fixed jsonl -> lenet jsonl
tmfreiberg Sep 30, 2025
988d936
Merge branch 'main' into benchmarking
jsgold-1 Sep 30, 2025
32c00ba
removed legacy lazy gen_and_bench import from cli
Sep 30, 2025
738bdb5
Linting/formatting
jsgold-1 Oct 1, 2025
b18e9f9
fixed text formatting errors arising from linter conformity
Oct 6, 2025
aef5e0a
Add other model choice to benchmarking CLI (#68)
jsgold-1 Oct 7, 2025
4d3cd41
Refactor cli and merge with main
jsgold-1 Oct 8, 2025
5da3877
Slight messaging chang
jsgold-1 Oct 8, 2025
5453012
Refactor bench, move helper functinos, add decorator for optional pat…
HudsonGraeme Oct 9, 2025
5d23561
Avoid top-level import from model registry
HudsonGraeme Oct 9, 2025
dd5d45a
Add constants, update ensure parent dir
HudsonGraeme Oct 9, 2025
00e9252
Minor docs changes
jsgold-1 Oct 9, 2025
2917ba5
Bring in tests from quantizer_tests and linting/formatting
jsgold-1 Oct 10, 2025
aa75b91
Merge branch 'main' into single_layer_tests
jsgold-1 Oct 10, 2025
458df51
Add base fix
jsgold-1 Oct 10, 2025
6859cfe
Minor test improvements
jsgold-1 Oct 11, 2025
cdbc74f
Add quantization to original layer comp checks
jsgold-1 Oct 11, 2025
bfa988a
Refactor layer_tests
jsgold-1 Oct 14, 2025
c2fd8dc
Add e2e tests + scalability, ensuring each layer has an e2e test
jsgold-1 Oct 15, 2025
2796e05
Fix maxpool
jsgold-1 Oct 15, 2025
1705a09
Change absence of --model flag in pytest e2e run
jsgold-1 Oct 16, 2025
c42a694
Add tests
jsgold-1 Oct 17, 2025
7047160
adding max layer
tmfreiberg Oct 24, 2025
f3f241d
singular API fix
tmfreiberg Oct 24, 2025
620768f
as type errors addressed
tmfreiberg Oct 24, 2025
6599c98
more errors
tmfreiberg Oct 24, 2025
16147f2
more errors
tmfreiberg Oct 24, 2025
0472503
more errors
tmfreiberg Oct 24, 2025
1861947
more errors
tmfreiberg Oct 24, 2025
624ad37
more errors
tmfreiberg Oct 24, 2025
e460f31
more errors
tmfreiberg Oct 24, 2025
b309068
more errors
tmfreiberg Oct 24, 2025
73839df
more errors
tmfreiberg Oct 24, 2025
1f7089c
more errors
tmfreiberg Oct 24, 2025
0ab5745
more errors
tmfreiberg Oct 24, 2025
cb2e4ef
more errors
tmfreiberg Oct 24, 2025
cc98084
more errors
tmfreiberg Oct 24, 2025
c00d2fd
Quantization refactoring
jsgold-1 Oct 28, 2025
d58e6d4
Rework w and b loading
jsgold-1 Oct 31, 2025
6cff583
Fix multi inputs, quantization refactor and add Add layer
jsgold-1 Nov 4, 2025
bf6f6b7
Add multi-input layer support
jsgold-1 Nov 6, 2025
0a7c0f4
Finalize support for multi-inputs/outputs and add add with initialize…
jsgold-1 Nov 11, 2025
4a9cd8f
Broadcasting and scalar support for Add
jsgold-1 Nov 11, 2025
e315b8b
Docs update
jsgold-1 Nov 11, 2025
2549278
Docs update
jsgold-1 Nov 11, 2025
1c81c24
starting simple with no quantization for max layer
tmfreiberg Nov 11, 2025
f695e58
forgot imports in max.py
tmfreiberg Nov 11, 2025
3e042cf
Merge with main changes
jsgold-1 Nov 11, 2025
d5b3a6f
Merge changes from main/single_layer_tests
jsgold-1 Nov 11, 2025
becdc36
Merge branch 'quantization_refactor' into single_layer_tests_v2
tmfreiberg Nov 11, 2025
41d9673
fix post test_end_to_end_quantization_accuracy update
tmfreiberg Nov 11, 2025
55e8e28
First review fixes and testing update
jsgold-1 Nov 12, 2025
8889a45
Secondary review changes
jsgold-1 Nov 12, 2025
1bf57d3
missing scale error
tmfreiberg Nov 12, 2025
afcbe3e
whoops
tmfreiberg Nov 12, 2025
990e568
mirror add
tmfreiberg Nov 12, 2025
743feaf
min copy max
tmfreiberg Nov 12, 2025
9dfc1f5
More review updates
jsgold-1 Nov 12, 2025
67001f9
min tweak consistent with max
tmfreiberg Nov 12, 2025
b0aaa05
trying clip layer
tmfreiberg Nov 12, 2025
6ce605e
fix import location
tmfreiberg Nov 12, 2025
83ba2d6
Fix review comments
jsgold-1 Nov 12, 2025
362d021
Minor code review changes
jsgold-1 Nov 12, 2025
9b67e84
Merge quantization_refactor into single_layer_tests_v2
tmfreiberg Nov 13, 2025
ef3a601
max layer rust side
tmfreiberg Nov 13, 2025
0d36064
import MaxLayer in layer_kinds.rs
tmfreiberg Nov 13, 2025
459191d
whoops MaxLayer lives in max no maxpool
tmfreiberg Nov 13, 2025
7fd5cd2
finishing min layer on rust side
tmfreiberg Nov 13, 2025
090cbeb
refactor max, min, maxpool, core_math
tmfreiberg Nov 13, 2025
578c9c5
remove unused imports in maxpool.rs
tmfreiberg Nov 13, 2025
2a2de4d
doc edits. max and min e2e work. clip not yet. about to refactor with…
tmfreiberg Nov 14, 2025
9669787
range check gadget added, refactor constrained_max, constrained_min (…
tmfreiberg Nov 14, 2025
f73104d
clip rust side
tmfreiberg Nov 14, 2025
1cb6340
rem_bits unused variable; prefix underscore
tmfreiberg Nov 14, 2025
9cfa4e1
UtilsError -> CircuitError in signature of range_check_pow2 function
tmfreiberg Nov 14, 2025
5341215
clip_config made proper
tmfreiberg Nov 14, 2025
d9a8aa1
forgot to register Clip in onnx_op_quantizer
tmfreiberg Nov 14, 2025
b0e6ebc
clip config update address scalar tensor versus non scalar shapes issue
tmfreiberg Nov 14, 2025
e2708d9
get_test_specs in clip_config updated
tmfreiberg Nov 14, 2025
0fe4450
type mismatch float/int clip addressed
tmfreiberg Nov 17, 2025
2f0386c
address empty input shape problem in clip_config
tmfreiberg Nov 17, 2025
8a6beb4
rename MaxMinAssertionContext to ShiftRangeContext
tmfreiberg Nov 17, 2025
2ebad68
Merge branch 'main' into single_layer_tests_v2
tmfreiberg Nov 21, 2025
9dbb68c
Cleaned up redundant assignments, fixed docstring, updated tests
tmfreiberg Nov 21, 2025
ba2ad5e
addressing errors
tmfreiberg Nov 21, 2025
043eef6
Code Rabbit nits
tmfreiberg Nov 21, 2025
b7641d5
Address ruff/clippy/pre-commit feedback for clip/max/min
tmfreiberg Nov 21, 2025
36748be
closing delimiters added (linter's fault!)
tmfreiberg Nov 21, 2025
df93713
logup added to core_math
tmfreiberg Nov 25, 2025
e7e8f52
added hints.rs
tmfreiberg Nov 25, 2025
8290b04
added logup code to main_runner
tmfreiberg Nov 25, 2025
a02ad22
remove EMptyHintCaller import in main_runner
tmfreiberg Nov 25, 2025
c16ee61
added CircuitField import to main_runner
tmfreiberg Nov 25, 2025
2414c09
one-shot logup step 1 core_math
tmfreiberg Nov 25, 2025
9e3836b
one-shot logup step 2
tmfreiberg Nov 25, 2025
80b7acc
corrections after step 2
tmfreiberg Nov 25, 2025
26a5acf
extending logup to all range checks (except in quantize)
tmfreiberg Nov 26, 2025
bef0797
closing delimiter in min.rs apply
tmfreiberg Nov 26, 2025
d021bd0
syntax error min.rs
tmfreiberg Nov 26, 2025
31bf84c
using logup for rescaling remainder check etc. in quantization
tmfreiberg Nov 26, 2025
0389ee0
debugging rescale
tmfreiberg Nov 26, 2025
e55c5ac
use LogUp for all range checks in quantization process
tmfreiberg Nov 26, 2025
5acbfe2
moving hints.rs
tmfreiberg Nov 28, 2025
7dc731f
refactor move hints
tmfreiberg Nov 28, 2025
41e07b1
move unconstrained_max
tmfreiberg Nov 28, 2025
945a985
move unconstrained_max
tmfreiberg Nov 28, 2025
65c760c
no layerkind in unconstrained_max, hints shouldn't know about ONNX
tmfreiberg Nov 28, 2025
a48aa43
whoops max_min_clip module imported twice in hints mod
tmfreiberg Nov 28, 2025
a7fb66a
moved unconstrained_min and unconstrained_clip to hints/max_min_clip
tmfreiberg Nov 28, 2025
9dd7307
forgot to import to core_math from hints
tmfreiberg Nov 28, 2025
17f862d
forgot to change layerkind error unconstrained_min
tmfreiberg Nov 28, 2025
7b2d6cc
renaming and moving bit operations
tmfreiberg Nov 28, 2025
8976596
fix imports
tmfreiberg Nov 28, 2025
4af9078
Field trait -> bits
tmfreiberg Nov 28, 2025
86504b6
field::FieldTrait
tmfreiberg Nov 28, 2025
f1d6a0b
fix FieldArith import issue
tmfreiberg Nov 28, 2025
99c2bb4
moving logup/range check functions from core_math to range_check
tmfreiberg Dec 3, 2025
8fd8c6f
fixing imports
tmfreiberg Dec 3, 2025
4df8d18
still fixing imports
tmfreiberg Dec 3, 2025
677eb7f
still fixing imports
tmfreiberg Dec 3, 2025
bd19d31
last remnants of core_math moved to max_min_clip
tmfreiberg Dec 3, 2025
06a7a04
address warning no unconstrained_clip
tmfreiberg Dec 3, 2025
51e9b58
relu layer uses logup
tmfreiberg Dec 3, 2025
3e26a1e
fix errors
tmfreiberg Dec 3, 2025
3fb2b71
gradually improving import organisation and docstrings
tmfreiberg Dec 4, 2025
c40c151
gadget and hint import organising and docstrings
tmfreiberg Dec 4, 2025
989265a
starting to standardize docstrings for gadgets
tmfreiberg Dec 4, 2025
6632adf
more docstring revision
tmfreiberg Dec 4, 2025
279ab3e
docstrings, clippy
tmfreiberg Dec 5, 2025
bf1eae7
clip x_bc deref to variable
tmfreiberg Dec 5, 2025
2e1d83a
docstrings format
tmfreiberg Dec 5, 2025
8082e7b
time trailing whitespace/fix end of files
tmfreiberg Dec 5, 2025
532c680
Clean up lint issues and update poetry.lock
tmfreiberg Dec 5, 2025
e56e7a3
silence clippy on docstring nits
tmfreiberg Dec 5, 2025
d7ea66e
Merge remote-tracking branch 'origin/main' into logup_v2
tmfreiberg Dec 5, 2025
c8ae9cf
remove Cast handler registration
tmfreiberg Dec 5, 2025
a1a216d
put model_quant.onnx under tmp_path in out_path for quantized model i…
tmfreiberg Dec 5, 2025
148872f
updated empty_tensor test case
tmfreiberg Dec 5, 2025
f223b64
revised incomplete error message in bits.rs
tmfreiberg Dec 5, 2025
f4c5c31
fixed error message in unconstrained_min (copy-pasted from unconstrai…
tmfreiberg Dec 5, 2025
89e02ca
Verify bounds checking for layer.inputs access
tmfreiberg Dec 5, 2025
fb012db
forgot to save min.rs
tmfreiberg Dec 5, 2025
b476e19
Fix MaxQuantizer.__init__ to honor BaseOpQuantizer’s new_initializers…
tmfreiberg Dec 5, 2025
04fc017
added tmp_path: Path to signature of test_tiny_conv
tmfreiberg Dec 5, 2025
9a1acde
Specify float32 dtype for ONNX compatibility in initializer overrides
tmfreiberg Dec 5, 2025
fe06eda
address overflow concerns
tmfreiberg Dec 5, 2025
38e11bf
fix layerkind typo
tmfreiberg Dec 5, 2025
8f0f7c4
Remove poetry.lock from tracking
tmfreiberg Dec 5, 2025
b81502c
Apply dtype cast consistently to all initializer overrides.
tmfreiberg Dec 5, 2025
3028d4b
Testing fixes
jsgold-1 Dec 5, 2025
93a8246
Linting final touches
jsgold-1 Dec 10, 2025
61aa418
added initializer test to e2e tests
tmfreiberg Dec 11, 2025
f271075
added e2e tests for broadcasting
tmfreiberg Dec 11, 2025
1a9fafd
corrected off-by-one error
tmfreiberg Dec 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

## Related Issue
<!-- Link to related GitHub issue (e.g. "Fixes #123", "Addresses #456") -->
-
-

## Type of Change
<!-- Delete options that don't apply -->
Expand All @@ -24,8 +24,8 @@

## Deployment Notes
<!-- Special considerations for deployment (migrations, config changes, etc.) -->
-
-

## Additional Comments
<!-- Any other important context for reviewers -->
-
-
1,879 changes: 1,879 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

41 changes: 29 additions & 12 deletions python/core/circuits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any

from numpy import asarray, ndarray
import numpy as np

from python.core.utils.errors import ShapeMismatchError
from python.core.utils.witness_utils import compare_witness_to_io, load_witness

if TYPE_CHECKING:
import numpy as np
import torch

from python.core.circuits.errors import (
Expand Down Expand Up @@ -775,18 +774,18 @@ def _gen_witness_preprocessing(
def reshape_inputs_for_inference(
self: Circuit,
inputs: dict[str],
) -> ndarray | dict[str, ndarray]:
) -> np.ndarray | dict[str, np.ndarray]:
"""
Reshape input tensors to match the model's expected input shape.

Parameters
----------
inputs : dict[str] or ndarray
inputs : dict[str] or np.ndarray
Input tensors or a dictionary of tensors.

Returns
-------
ndarray or dict[str, ndarray]
np.ndarray or dict[str, np.ndarray]
Reshaped input(s) ready for inference.
"""

Expand All @@ -801,15 +800,33 @@ def reshape_inputs_for_inference(
if isinstance(inputs, dict):
if len(inputs) == 1:
only_key = next(iter(inputs))
inputs = asarray(inputs[only_key])
value = np.asarray(inputs[only_key])

# If shape is a dict, extract the shape for this key
if isinstance(shape, dict):
key_shape = shape.get(only_key, None)
if key_shape is None:
raise CircuitConfigurationError(
missing_attributes=[f"input_shape[{only_key!r}]"],
)
shape = key_shape

# From here on, treat it as a regular reshape
inputs = value
else:
return self._reshape_dict_inputs(inputs, shape)

# --- Regular reshape ---
if not isinstance(shape, (list, tuple)):
msg = (
f"Expected list or tuple shape for reshape, got {type(shape).__name__}"
)
raise CircuitInputError(msg)

try:
return asarray(inputs).reshape(shape)
return np.asarray(inputs).reshape(shape)
except Exception as e:
raise ShapeMismatchError(shape, list(asarray(inputs).shape)) from e
raise ShapeMismatchError(shape, list(np.asarray(inputs).shape)) from e

def _reshape_dict_inputs(
self: Circuit,
Expand All @@ -824,7 +841,7 @@ def _reshape_dict_inputs(
)
raise CircuitInputError(msg, parameter="shape", expected="dict")
for key, value in inputs.items():
tensor = asarray(value)
tensor = np.asarray(value)
try:
inputs[key] = tensor.reshape(shape[key])
except Exception as e:
Expand Down Expand Up @@ -867,16 +884,16 @@ def reshape_inputs_for_circuit(
value = inputs[key]

# --- handle unsupported input types BEFORE entering try ---
if not isinstance(value, (ndarray, list, tuple)):
if not isinstance(value, (np.ndarray, list, tuple)):
msg = f"Unsupported input type for key '{key}': {type(value).__name__}"
raise CircuitProcessingError(message=msg)

try:
# Convert to tensor, flatten, and back to list
if isinstance(value, ndarray):
if isinstance(value, np.ndarray):
flattened = value.flatten().tolist()
else:
flattened = asarray(value).flatten().tolist()
flattened = np.asarray(value).flatten().tolist()
except Exception as e:
msg = f"Failed to flatten input '{key}' (type {type(value).__name__})"
raise CircuitProcessingError(message=msg) from e
Expand Down
3 changes: 1 addition & 2 deletions python/core/circuits/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# python/core/utils/exceptions.py
from __future__ import annotations

from python.core.utils.helper_functions import RunType
Expand Down Expand Up @@ -68,7 +67,7 @@ class CircuitInputError(CircuitError):
actual (any): Actual value encountered (optional).
"""

def __init__( # noqa: PLR0913
def __init__(
self: CircuitInputError,
message: str | None = None,
parameter: str | None = None,
Expand Down
6 changes: 3 additions & 3 deletions python/core/model_processing/converters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import numpy as np
Expand All @@ -16,10 +16,10 @@ class ModelType(Enum):

ONNXLayerDict = dict[
str,
Union[int, str, list[str], dict[str, list[int]], Optional[list], Optional[dict]],
int | str | list[str] | dict[str, list[int]] | list | None | dict,
]

CircuitParamsDict = dict[str, Union[int, dict[str, bool]]]
CircuitParamsDict = dict[str, int | dict[str, bool]]


class ModelConverter(ABC):
Expand Down
8 changes: 4 additions & 4 deletions python/core/model_processing/onnx_custom_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import importlib
import pkgutil
import os
from pathlib import Path

# Get the package name of the current module
package_name = __name__

# Dynamically import all .py files in this package directory (except __init__.py)
package_dir = os.path.dirname(__file__)
package_dir = Path(__file__).parent

__all__ = []
__all__: list[str] = []

for _, module_name, is_pkg in pkgutil.iter_modules([package_dir]):
if not is_pkg and (module_name != "custom_helpers"):
importlib.import_module(f"{package_name}.{module_name}")
__all__.append(module_name)
__all__.append(str(module_name)) # noqa: PYI056
4 changes: 2 additions & 2 deletions python/core/model_processing/onnx_quantizer/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class InvalidParamError(QuantizationError):
quantization the quantization process.
"""

def __init__( # noqa: PLR0913
def __init__(
self: QuantizationError,
node_name: str,
op_type: str,
Expand Down Expand Up @@ -151,7 +151,7 @@ class InvalidConfigError(QuantizationError):
def __init__(
self: QuantizationError,
key: str,
value: str | float | bool | None,
value: str | float | bool | None, # noqa: FBT001
expected: str | None = None,
) -> None:
"""Initialize InvalidConfigError with context about the bad config.
Expand Down
34 changes: 34 additions & 0 deletions python/core/model_processing/onnx_quantizer/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,40 @@ def insert_scale_node(


class QuantizerBase:
"""
Shared mixin implementing the generic INT64 quantization pipeline.

IMPORTANT:
QuantizerBase is *not* a standalone quantizer. It must always be
combined with BaseOpQuantizer via multiple inheritance:

class FooQuantizer(BaseOpQuantizer, QuantizeFoo):
...

BaseOpQuantizer supplies required methods and attributes that
QuantizerBase relies on:
- add_scaled_initializer_inputs
- insert_scale_node
- get_scaling
- new_initializers (initializer buffer shared with converter)

If a subclass inherits QuantizerBase without BaseOpQuantizer,
QuantizerBase.quantize() will raise attribute errors at runtime.

This mixin centralizes:
- attribute extraction/merging
- optional initializer scaling (USE_WB + SCALE_PLAN)
- optional rescaling of outputs (USE_SCALING)
- creation of the final quantized NodeProto

The Quantize<Op> mixins should define:
- OP_TYPE
- DOMAIN
- USE_WB (bool)
- USE_SCALING (bool)
- SCALE_PLAN (dict[int,int]) if initializer scaling is enabled
"""

OP_TYPE = None
DOMAIN = "ai.onnx.contrib"
DEFAULT_ATTRS: ClassVar = {}
Expand Down
92 changes: 92 additions & 0 deletions python/core/model_processing/onnx_quantizer/layers/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar

if TYPE_CHECKING:
import onnx

from python.core.model_processing.onnx_quantizer.layers.base import (
BaseOpQuantizer,
QuantizerBase,
ScaleConfig,
)


class QuantizeClip(QuantizerBase):
"""
Quantization traits for ONNX Clip.
Semantics:
- X is already scaled/cast to INT64 at the graph boundary by the converter.
- Clip is elementwise + broadcasting.
- The bound inputs (min, max) should live in the *same* fixed-point scale
as X so that Clip(alpha*x; alpha*a, alpha*b) matches the original Clip(x; a, b).
Implementation:
- Treat inputs 1 and 2 (min, max) like "WB-style" slots: we let the
QuantizerBase machinery rescale / cast those inputs using the same
global scale factor.
- No extra internal scaling input is added (USE_SCALING = False).
"""

OP_TYPE = "Clip"
DOMAIN = "" # standard ONNX domain

# We DO want WB-style handling so that min/max initializers get quantized:
USE_WB = True

# Clip does not introduce its own scale input; it just runs in the
# existing fixed-point scale.
USE_SCALING = False

# Scale-plan for WB-style slots:
# - Input index 1: min
# - Input index 2: max
# Each should be scaled once by the global alpha (same as activations).
SCALE_PLAN: ClassVar = {1: 1, 2: 1}


class ClipQuantizer(BaseOpQuantizer, QuantizeClip):
"""
Quantizer for ONNX Clip.
- Keeps the node op_type as "Clip".
- Ensures that any bound inputs (min, max), whether they are dynamic
inputs or initializers, are converted to the same INT64 fixed-point
representation as A.
"""

def __init__(
self,
new_initializers: dict[str, onnx.TensorProto] | None = None,
) -> None:
# Match Max/Min/Add: we simply share the new_initializers dict
# with the converter so any constants we add are collected.
self.new_initializers = new_initializers

def quantize(
self,
node: onnx.NodeProto,
graph: onnx.GraphProto,
scale_config: ScaleConfig,
initializer_map: dict[str, onnx.TensorProto],
) -> list[onnx.NodeProto]:
# Delegate to the shared QuantizerBase logic, which will:
# - keep X as-is (already scaled/cast by the converter),
# - rescale / cast min/max according to SCALE_PLAN,
# - update initializers as needed.
return QuantizeClip.quantize(self, node, graph, scale_config, initializer_map)

def check_supported(
self,
node: onnx.NodeProto,
initializer_map: dict[str, onnx.TensorProto] | None = None,
) -> None:
"""
Minimal support check for Clip:
- Clip is variadic elementwise with optional min/max as inputs or attrs.
- We accept both forms; if attrs are present, ORT enforces semantics.
- Broadcasting is ONNX-standard; we don't restrict further here.
"""
_ = node, initializer_map
47 changes: 47 additions & 0 deletions python/core/model_processing/onnx_quantizer/layers/max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# python/core/model_processing/onnx_quantizer/layers/max.py
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar

if TYPE_CHECKING:
import onnx

from python.core.model_processing.onnx_quantizer.layers.base import (
BaseOpQuantizer,
QuantizerBase,
ScaleConfig,
)


class QuantizeMax(QuantizerBase):
OP_TYPE = "Max"
DOMAIN = ""
USE_WB = True
USE_SCALING = False
SCALE_PLAN: ClassVar = {1: 1}


class MaxQuantizer(BaseOpQuantizer, QuantizeMax):
def __init__(
self,
new_initializers: dict[str, onnx.TensorProto] | None = None,
) -> None:
self.new_initializers = new_initializers

def quantize(
self,
node: onnx.NodeProto,
graph: onnx.GraphProto,
scale_config: ScaleConfig,
initializer_map: dict[str, onnx.TensorProto],
) -> list[onnx.NodeProto]:
# Delegate to the shared QuantizerBase logic
return QuantizeMax.quantize(self, node, graph, scale_config, initializer_map)

def check_supported(
self,
node: onnx.NodeProto,
initializer_map: dict[str, onnx.TensorProto] | None = None,
) -> None:
# If later we want to enforce/relax broadcasting, add it here.
pass
Loading
Loading