Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
64a2747
Add OperatorDescriptor
lukamac Sep 24, 2025
c5a0c71
Add OperatorDescriptor.py
lukamac Sep 24, 2025
e31ea13
Add operatorDescriptors to NetworkDeployers
lukamac Sep 24, 2025
550b559
Fix extract padding pass
lukamac Sep 24, 2025
ab9fdfe
Fix isoftmax parser
lukamac Sep 24, 2025
a410763
Fix iRMSNorm and iNoNorm parsers
lukamac Sep 24, 2025
f6027fb
Fix ReduceMean type signature
lukamac Sep 24, 2025
475b337
Fix itamax and itapartialmax parsers
lukamac Sep 24, 2025
c6c3109
Fix attr comparison to compare with tuple in neureka
lukamac Sep 24, 2025
cd2270c
Fix keepdims type in fuse mhsa pass
lukamac Sep 24, 2025
2e62e84
Fix old _unpack_const to pass Python literals
lukamac Sep 25, 2025
587d6de
Add RequantizedConv desc
lukamac Sep 25, 2025
0ccd3b8
Fix DW parser
lukamac Sep 28, 2025
c2f2bb2
Fix pulp 1D conv
lukamac Sep 28, 2025
0b60329
Sort operator descriptors alphabetically
lukamac Sep 28, 2025
a19f98a
Add DequantDescriptor
lukamac Sep 28, 2025
4af6552
Add Div, IntegerDiv, RQIntegerDiv
lukamac Sep 28, 2025
2e2e3df
Add DebugPrint, LayerNormalization, iLayerNorm
lukamac Sep 28, 2025
9ac9a62
Add RequantizedOperatorDescriptor
lukamac Sep 28, 2025
e01fdb0
Add flatten and gather
lukamac Sep 28, 2025
1db3ae7
Add Squeeze and Unsqueeze
lukamac Sep 28, 2025
fd30dc7
Add Mul
lukamac Sep 28, 2025
a3309ed
Add MatMul, RQMatMul, MatMulInteger
lukamac Sep 28, 2025
c758fcc
Add Gemm and RQGemm
lukamac Sep 28, 2025
7e951d8
Add RequantizedGemm
lukamac Sep 28, 2025
1ab763e
Fix transA and transB being treated like ints
lukamac Sep 29, 2025
1ec6cde
Add LinearAttention
lukamac Sep 28, 2025
565cd95
Add CLCA
lukamac Sep 28, 2025
26cf648
Add IntegerMean
lukamac Sep 28, 2025
8b00f48
Add MHSA
lukamac Sep 28, 2025
6ecf95d
Add Relu, Reshape, RequantShift
lukamac Sep 28, 2025
9a577a3
Add RequantizedAdd
lukamac Sep 28, 2025
8ae808a
Add RequantizediHardswish
lukamac Sep 28, 2025
5eece92
Add iGELU
lukamac Sep 28, 2025
7598303
Add SoftmaxCrossEntropyLoss(Grad)
lukamac Sep 28, 2025
72c8d21
Add Memcopy for dma tests
lukamac Sep 28, 2025
bff8668
Remove some trailing white space in CHANGELOG.md
lukamac Oct 27, 2025
5ac4e31
Add try canonicalization exceptions
lukamac Oct 27, 2025
2f871d4
Make IntegerDataTypes a tuple
lukamac Oct 27, 2025
31577c3
Fix reshape bindings (which are used for squeeze/unsqeeze too) to typ…
lukamac Oct 27, 2025
90102f5
Canonicalize (un)squeeze operations as pre-opset-13, i.e., put axes i…
lukamac Oct 27, 2025
7bd7353
Add BatchNormalization descriptor
lukamac Oct 27, 2025
16bc463
Add ConvTranspose descriptor
lukamac Oct 27, 2025
d865898
Relax opset check on squeeze operations to a warning
lukamac Oct 27, 2025
cd62a69
Replace prints with logging
lukamac Oct 27, 2025
91bdeb7
Add missing itertools import
lukamac Oct 27, 2025
238d3af
Initialize optional value with None
lukamac Oct 27, 2025
a4198b4
Fix typo
lukamac Oct 27, 2025
e8f1721
Explicit exception coverage
lukamac Oct 27, 2025
f180f85
Rename attrToTensor to attrToInputTensor and add inputTensorToAttr
lukamac Oct 27, 2025
bc75e85
Use inputTensorToAttr in squeeze canonicalization
lukamac Oct 27, 2025
6976c52
Remove duplicate attribute
lukamac Oct 27, 2025
97da07c
Refactor MatMulTileConstraint
lukamac Oct 27, 2025
0c64a3e
Remove duplicate attributes and check that the value is positive
lukamac Oct 27, 2025
fe3bce6
Rename ref in GemmTemplate and check for batching
lukamac Oct 29, 2025
bb88795
Rename ref in FloatGemmTemplate
lukamac Oct 29, 2025
d078a3c
Add min, max for single-item numpy numbers
lukamac Oct 28, 2025
8e81f94
Make SignPropTypeChecker an abstract class and refactor
lukamac Oct 29, 2025
b6ed382
DeeployTypes small refactors
lukamac Sep 30, 2025
77f3339
Move condition checking for PULPMatrixVecParser and PULPTallGemmParse…
lukamac Oct 29, 2025
bc1f46f
Add node name and op to comment
lukamac Oct 29, 2025
6314665
Fix wrong formatting of integer arrays and refactor test io generation
lukamac Oct 1, 2025
38de238
Move iNoNorm from Generic to Snitch since it's only used there
lukamac Oct 29, 2025
c5a5cfd
Fix wrong path to generated sources
lukamac Oct 1, 2025
264cf2a
Refactor merge_conv_rq_fun
lukamac Oct 2, 2025
a3bdc51
Fix flatten values before generating the array
lukamac Oct 17, 2025
a3b86c1
Fix MaxPool parseNode
lukamac Oct 29, 2025
5f15b11
Remove MemcpyTypeChecker and do the usual type listing
lukamac Oct 17, 2025
558bfcc
Fix signprop checks
lukamac Oct 29, 2025
4868f21
Add Gemm helper function to get the matrix dimensions
lukamac Oct 14, 2025
d27ae03
Fix PULP Requantized Convolution tile constraints to properly handle …
lukamac Oct 29, 2025
1d953bd
DeeployTypes.py changes
lukamac Sep 25, 2025
8c1b4e8
Add NodeTemplate.py
lukamac Oct 28, 2025
c241525
Target Template changes
lukamac Oct 29, 2025
444064d
SignPropDeployer changes
lukamac Oct 28, 2025
344dff9
Remove ioBinding and add parse to the methods that the
lukamac Oct 14, 2025
21b0b9c
Engine deployer for some reason is not overriding the _mapNode function
lukamac Oct 16, 2025
9b4cbdb
Neureka parser changes due to canonicalization
lukamac Oct 28, 2025
43bd7e5
Reduce the L2 memory size so that the test still fails
lukamac Oct 17, 2025
b832d69
Rename Mul to MulScalar to reflect better what the kernel implements
lukamac Oct 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci-deeploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ jobs:
run: |
cd DeeployTest
python testMVP.py -t Tests/CCT/CCT_1_16_16_8 -p Siracusa --defaultMemLevel=L2 --l1=64000 --l2=75000 --memAllocStrategy=MiniMalloc
python testMVP.py -t Tests/CCT/CCT_1_16_16_8 -p Siracusa --defaultMemLevel=L2 --l1=64000 --l2=60000 --memAllocStrategy=MiniMalloc --shouldFail
python testMVP.py -t Tests/CCT/CCT_1_16_16_8 -p Siracusa --defaultMemLevel=L2 --l1=64000 --l2=50000 --memAllocStrategy=MiniMalloc --shouldFail
python testMVP.py -t Tests/CCT/CCT_1_16_16_8 -p Siracusa --defaultMemLevel=L2 --l1=64000 --l2=90000 --memAllocStrategy=TetrisRandom
python testMVP.py -t Tests/CCT/CCT_1_16_16_8 -p Siracusa --defaultMemLevel=L2 --l1=64000 --l2=75000 --memAllocStrategy=TetrisRandom --shouldFail
python testMVP.py -t Tests/CCT/CCT_1_16_16_8 -p Siracusa --defaultMemLevel=L2 --l1=64000 --l2=69000 --memAllocStrategy=TetrisRandom --shouldFail

deeploy-state-serialization:
needs: select-env
Expand Down
8 changes: 4 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ This release containing major architectural changes, new platform support, enhan


### Added
- BatchNorm kernel
- ConvTranspose kernel
- MaxPool1D kernel
- BatchNorm kernel
- ConvTranspose kernel
- MaxPool1D kernel
- Template for 1D Convolution
- Support for float32 data type in the previous kernels
- Float binding for Pad1D kernel
Expand Down Expand Up @@ -318,7 +318,7 @@ This release containing major architectural changes, new platform support, enhan

### Changed
- FloatConvTemplate file
- Platform.py file
- Platform.py file
- Bump the CMake version to 3.24 as required for the chimera-sdk
- Bump GVSoC's version and add chimera simulation target
- Rename the generic source util to utils to avoid name collision with chimera-sdk
Expand Down
8 changes: 8 additions & 0 deletions Deeploy/AbstractDataTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,20 @@ def checkValue(cls, value: Union[int, Iterable[int], np.ndarray], ctxt: Optional

if isinstance(value, int):
_max, _min = (value, value)
elif isinstance(value, np.number):
value = value.item()
if isinstance(value, float):
assert value.is_integer(), f"Floating-point value {value} is not an integer."
value = int(value)
_max, _min = (value, value)
elif isinstance(value, np.ndarray):
_max = value.max()
_min = value.min()
elif isinstance(value, Iterable):
_max = max(value)
_min = min(value)
else:
raise ValueError(f"Unsupported value of type {type(value)} with value {value}")

if _max > cls.typeMax:
return False
Expand Down
10 changes: 5 additions & 5 deletions Deeploy/CommonExtensions/DataTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ class float64_t(FloatImmediate):

SignedIntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (int8_t, int16_t, int32_t, int64_t)
UnsignedIntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (uint8_t, uint16_t, uint32_t, uint64_t)
IntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (sorted((
*SignedIntegerDataTypes,
*UnsignedIntegerDataTypes,
),
key = lambda _type: _type.typeWidth))
IntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = tuple(
sorted((
*SignedIntegerDataTypes,
*UnsignedIntegerDataTypes,
), key = lambda _type: _type.typeWidth))
FloatDataTypes: Tuple[Type[FloatImmediate], ...] = (bfloat16_t, float16_t, float32_t, float64_t)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import onnx_graphsurgeon as gs

from Deeploy.DeeployTypes import CodeGenVerbosity, NetworkContext, NetworkDeployer, ONNXLayer, _NoVerbosity
from Deeploy.DeeployTypes import CodeGenVerbosity, NetworkDeployer, ONNXLayer, _NoVerbosity


class NetworkDeployerWrapper(NetworkDeployer):
Expand Down Expand Up @@ -48,8 +48,8 @@ def prepared(self):
"""

# SignPropDeployer augment
def _createIOBindings(self, ctxt: NetworkContext, graph: gs.Graph):
return self._innerObject._createIOBindings(ctxt, graph)
def parse(self, default_channels_first: bool = True) -> bool:
return self._innerObject.parse(default_channels_first)

# MemoryAwareDeployer, TilerAwareDeployer, and PULPDeployer augments
def bind(self) -> bool:
Expand Down
56 changes: 42 additions & 14 deletions Deeploy/CommonExtensions/NetworkDeployers/SignPropDeployer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import onnx_graphsurgeon as gs

from Deeploy.AbstractDataTypes import Pointer
from Deeploy.DeeployTypes import DeploymentPlatform, NetworkDeployer, TopologyOptimizer
from Deeploy.AbstractDataTypes import IntegerImmediate, Pointer
from Deeploy.CommonExtensions.TypeCheckers.SignPropTypeChecker import SignPropTypeChecker
from Deeploy.DeeployTypes import ConstantBuffer, DeploymentPlatform, NetworkDeployer, OperatorDescriptor, \
TopologyOptimizer, VariableBuffer
from Deeploy.Logging import DEFAULT_LOGGER as log


Expand All @@ -18,12 +20,13 @@ def __init__(self,
deploymentPlatform: DeploymentPlatform,
inputTypes: Dict[str, Type[Pointer]],
loweringOptimizer: TopologyOptimizer,
operatorDescriptors: Dict[str, OperatorDescriptor],
scheduler: Callable = lambda x: x,
name: str = 'DeeployNetwork',
default_channels_first: bool = True,
deeployStateDir: str = "DeeployState",
inputOffsets: Dict[str, int] = {}):
super().__init__(graph, deploymentPlatform, inputTypes, loweringOptimizer, scheduler, name,
super().__init__(graph, deploymentPlatform, inputTypes, loweringOptimizer, operatorDescriptors, scheduler, name,
default_channels_first, deeployStateDir)

if inputOffsets == {}:
Expand All @@ -32,17 +35,6 @@ def __init__(self,

self.inputOffsets = inputOffsets

def _createIOBindings(self, ctxt, graph):
ctxt = super()._createIOBindings(ctxt, graph)
for node in graph.inputs:
data_name = node.name
nb = ctxt.lookup(data_name)
data_type = self.inputTypes[data_name]
nb._signed = (self.inputOffsets[data_name] == 0)
nb.nLevels = (2**data_type.referencedType.typeWidth)

return ctxt

def _printInputOutputSummary(self):
log.info('Input:')
for buf in self.inputs():
Expand All @@ -55,3 +47,39 @@ def _printInputOutputSummary(self):
log.info(
f" - '{buf.name}': Type: {buf._type.referencedType.typeName}, nLevels: {buf.nLevels}, Signed: {buf._signed}"
)

def parse(self, default_channels_first: bool = True) -> bool:
parsable = super().parse(default_channels_first)
if not parsable:
return False

# Annotate global buffers
for obj in self.ctxt.globalObjects.values():
assert isinstance(obj, VariableBuffer)
refTy = obj._type.referencedType
if isinstance(obj, ConstantBuffer):
assert refTy.checkPromotion(obj.values), f"Can't cast {obj} to {refTy}"
if issubclass(refTy, IntegerImmediate):
obj.nLevels = obj.values.max() - obj.values.min()
obj._signed = refTy.typeMin < 0
elif obj.name in self.inputOffsets:
obj._signed = (self.inputOffsets[obj.name] == 0)
obj.nLevels = (2**refTy.typeWidth)

# Annotate rest
for layer in self.layerBinding.values():
node = layer.node
opRepr = layer.mapper.parser.operatorRepresentation
typeChecker = layer.mapper.binder.typeChecker
outTy = self.ctxt.lookup(node.outputs[0].name)._type.referencedType
if issubclass(outTy, IntegerImmediate) and isinstance(typeChecker, SignPropTypeChecker):
inputs = [self.ctxt.lookup(t.name) for t in node.inputs]
outputNLevels = typeChecker._inferNumLevels(inputs, opRepr)
outputSigned = typeChecker._inferSignedness(inputs, opRepr)

outputs = [self.ctxt.lookup(t.name) for t in node.outputs]
for buffer, nLevels, signed in zip(outputs, outputNLevels, outputSigned):
buffer.nLevels = nLevels
buffer._signed = signed

return True
84 changes: 84 additions & 0 deletions Deeploy/CommonExtensions/NodeTemplate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-FileCopyrightText: 2021 ETH Zurich and University of Bologna
#
# SPDX-License-Identifier: Apache-2.0

from typing import List, Sequence, Tuple

import numpy as np
import onnx_graphsurgeon as gs

from Deeploy.DeeployTypes import NodeTemplate


class ElementwiseTemplate(NodeTemplate):

def alignShapes(self, node: gs.Node) -> Tuple[List[Sequence[int]], List[Sequence[int]]]:
assert len(node.outputs) == 1, f"Expected only one output. Received {len(node.outputs)}"
shape = tuple(np.broadcast_shapes(*[t.shape for t in node.inputs]))
return [shape] * len(node.inputs), [shape]


class ElementwiseScalarTemplate(NodeTemplate):

def alignShapes(self, node: gs.Node) -> Tuple[List[Sequence[int]], List[Sequence[int]]]:
assert len(node.inputs) == 2, f"Expected only two inputs. Received {len(node.inputs)}"
assert len(node.outputs) == 1, f"Expected only one output. Received {len(node.outputs)}"
shape = tuple(node.inputs[0].shape)
return [shape, (1,)], [shape]


class RequantShiftTemplate(NodeTemplate):

def alignShapes(self, node: gs.Node) -> Tuple[List[Sequence[int]], List[Sequence[int]]]:
inShapes, outShapes = [t.shape for t in node.inputs], [t.shape for t in node.outputs]
batch, ch = inShapes[0][:2]
# TODO: Copied from old computeShape. Should probably be investigated
inShapes[1] = (batch, ch, *inShapes[1][1:])
inShapes[2] = (batch, ch, *inShapes[2][1:])
return inShapes, outShapes


class ConvTemplate(NodeTemplate):

@staticmethod
def minPerChannelTensorShape(node: gs.Node, channels: int) -> Tuple[int, ...]:
spatialDims = len(node.attrs["kernel_shape"])
if node.attrs["channels_first"]:
return (channels,) + (1,) * (spatialDims)
else:
return (channels,)

def alignShapes(self, node: gs.Node) -> Tuple[List[Sequence[int]], List[Sequence[int]]]:
inShapes, outShapes = [t.shape for t in node.inputs], [t.shape for t in node.outputs]
if len(node.inputs) == 3:
minBiasShape = self.minPerChannelTensorShape(node, inShapes[1][0])
inShapes[2] = minBiasShape
return inShapes, outShapes


class RequantizedConvTemplate(ConvTemplate):

def alignShapes(self, node: gs.Node) -> Tuple[List[Sequence[int]], List[Sequence[int]]]:
inShapes, outShapes = [t.shape for t in node.inputs[:2]], [t.shape for t in node.outputs]
minRqsShape = self.minPerChannelTensorShape(node, inShapes[1][0])
rqsShapes = [minRqsShape] * len(node.inputs[2:])
return inShapes + rqsShapes, outShapes


class GemmTemplate(NodeTemplate):

def alignShapes(self, node: gs.Node) -> Tuple[List[Sequence[int]], List[Sequence[int]]]:
biasShape = node.outputs[0].shape[-2:]
return [node.inputs[0].shape, node.inputs[1].shape, biasShape], [node.outputs[0].shape]


class RequantizedGemmTemplate(NodeTemplate):

def alignShapes(self, node: gs.Node) -> Tuple[List[Sequence[int]], List[Sequence[int]]]:
inShapes, outShapes = [t.shape for t in node.inputs[:2]], [t.shape for t in node.outputs]
if node.attrs["transB"]:
N = inShapes[1][-2]
else:
N = inShapes[1][-1]
rqsShapes = [(N,)] * len(node.inputs[2:])
return inShapes + rqsShapes, outShapes
59 changes: 29 additions & 30 deletions Deeploy/CommonExtensions/TypeCheckers/SignPropTypeChecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional
from abc import ABC, abstractmethod
from typing import List

import onnx_graphsurgeon as gs

Expand All @@ -11,27 +12,30 @@
from Deeploy.Logging import DEFAULT_LOGGER as log


class SignPropTypeChecker(NodeTypeChecker):
class SignPropTypeChecker(NodeTypeChecker, ABC):

@abstractmethod
def _inferNumLevels(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> Optional[List[int]]:
return None
operatorRepresentation: OperatorRepresentation) -> List[int]:
pass

@abstractmethod
def _inferSignedness(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> Optional[List[int]]:
return None
operatorRepresentation: OperatorRepresentation) -> List[bool]:
pass

def typeInferGlobalCtxt(self, ctxt: NetworkContext, node: gs.Node) -> NetworkContext:
ctxt = super().typeInferGlobalCtxt(ctxt, node)

for inputNode, _type in zip(node.inputs, self.input_types):
if isinstance(ctxt.lookup(inputNode.name), ConstantBuffer):
reference = ctxt.lookup(inputNode.name)
if not _type.referencedType.checkPromotion(reference.values):
raise Exception(f"Can't cast {reference} to {_type}!")

reference.nLevels = reference.values.max() - reference.values.min()
reference._signed = _type.referencedType.typeMin < 0
for tensor, _type in zip(node.inputs, self.input_types):
buffer = ctxt.lookup(tensor.name)
if isinstance(buffer, ConstantBuffer):
refTy = _type.referencedType
assert issubclass(refTy, IntegerImmediate)
if not refTy.checkPromotion(buffer.values):
raise ValueError(f"Can't cast {buffer} to {refTy}!")
buffer.nLevels = buffer.values.max() - buffer.values.min()
buffer._signed = refTy.typeMin < 0

return ctxt

Expand All @@ -42,21 +46,16 @@ def typeInferOutput(self, ctxt: NetworkContext, node: gs.Node,
inputs = [ctxt.lookup(inputNode.name) for inputNode in node.inputs]
outputs = [ctxt.lookup(outputNode.name) for outputNode in node.outputs]

signProp = all([hasattr(_input, "_signed") and hasattr(_input, "nLevels") for _input in inputs])

if signProp:
nLevels = self._inferNumLevels(inputs, operatorRepresentation)
signedness = self._inferSignedness(inputs, operatorRepresentation)

if nLevels is None or signedness is None:
return ctxt
for obj, nLevel, sign in zip(outputs, nLevels, signedness):
obj.nLevels = nLevel
obj._signed = sign

if issubclass(obj._type.referencedType, IntegerImmediate) and not obj._type.fitsNumLevels(nLevel):
log.warning(
f"{obj.name} has {nLevel} levels, but {obj._type.referencedType.typeName} only supports {obj._type.referencedType.nLevels} levels."
)
nLevels = self._inferNumLevels(inputs, operatorRepresentation)
signedness = self._inferSignedness(inputs, operatorRepresentation)

for obj, nLevels, sign in zip(outputs, nLevels, signedness):
assert isinstance(obj, VariableBuffer)
obj.nLevels = nLevels
obj._signed = sign
refTy = obj._type.referencedType
if issubclass(refTy, IntegerImmediate) and not refTy.fitsNumLevels(nLevels):
log.warning(
f"{obj.name} has {nLevels} levels, but {refTy.typeName} only supports {refTy.nLevels} levels.")

return ctxt
Loading
Loading