Skip to content
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7b9a762
Rename ref in GemmTemplate and check for batching
lukamac Oct 29, 2025
49cbd2e
Rename ref in FloatGemmTemplate
lukamac Oct 29, 2025
023cedf
Add min, max for single-item numpy numbers
lukamac Oct 28, 2025
e8b5a1a
Make SignPropTypeChecker an abstract class and refactor
lukamac Oct 29, 2025
151bdda
DeeployTypes small refactors
lukamac Sep 30, 2025
6eff719
Move condition checking for PULPMatrixVecParser and PULPTallGemmParse…
lukamac Oct 29, 2025
6a5c21d
Add node name and op to comment
lukamac Oct 29, 2025
e9a4580
Fix wrong formatting of integer arrays and refactor test io generation
lukamac Oct 1, 2025
9435eca
Move iNoNorm from Generic to Snitch since it's only used there
lukamac Oct 29, 2025
84497c9
Fix wrong path to generated sources
lukamac Oct 1, 2025
74c4a14
Refactor merge_conv_rq_fun
lukamac Oct 2, 2025
82e4da0
Fix flatten values before generating the array
lukamac Oct 17, 2025
0034fa2
Fix MaxPool parseNode
lukamac Oct 29, 2025
9317241
Remove MemcpyTypeChecker and do the usual type listing
lukamac Oct 17, 2025
c3117cd
Fix signprop checks
lukamac Oct 29, 2025
629611e
Add Gemm helper function to get the matrix dimensions
lukamac Oct 14, 2025
50646aa
Fix PULP Requantized Convolution tile constraints to properly handle …
lukamac Oct 29, 2025
abb2e08
Make IntegerDataTypes a tuple
lukamac Oct 27, 2025
505f96d
Add missing newline to error
lukamac Oct 29, 2025
ca02e85
Update changelog
lukamac Oct 30, 2025
9d7c970
Fix missing super parseNode call
lukamac Nov 1, 2025
791c2e6
Add div checks
lukamac Nov 1, 2025
d3b052d
Improve C tensor naming
Xeratec Oct 30, 2025
02b6b65
Prevent empty node names
Xeratec Oct 30, 2025
64d198a
Implement CodeRabit Feedback
Xeratec Nov 3, 2025
2b88967
Update changelog
Xeratec Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
## Unreleased (Planned Release Target: v0.2.1)

### List of Pull Requests
- Refactors and fixes [#131](https://github.com/pulp-platform/Deeploy/pull/131)
- Disallow shape inference [#128](https://github.com/pulp-platform/Deeploy/pull/128)
- Remove memory-aware node bindings [#123](https://github.com/pulp-platform/Deeploy/pull/123)
- Fix missing const's layout transformation and refactor NCHWtoNHWC passes [#122](https://github.com/pulp-platform/Deeploy/pull/122)
Expand Down Expand Up @@ -91,6 +92,8 @@ This file contains the changelog for the Deeploy project. The changelog is divid
- Removed Wmem variants of bindings and tile constraints from Neureka
- Disabled ICCT_ITA_8 MemPool test because it was using a lowering that created shapeless tensors
- Added missing shape annotation to the testTypeInferenceDifferentTypes
- ref naming scheme in Gemm and FloatGemm templates from ${data_out}_${<tensor>} to ${nodeName}_${<tensor>}
- move iNoNorm from Generic to Snitch since it uses a Snitch kernel

### Fixed
- Prevent node duplication for graphs generated via GraphSurgeon
Expand All @@ -105,6 +108,8 @@ This file contains the changelog for the Deeploy project. The changelog is divid
- Missing layout transformation of the const's (bias, mul, add, shift in Conv/RequantizedConv)
- Keep mul/add rank of requantized Neureka tile constraints
- Fix bias hoisting in generic GEMM with no bias
- formatting of test_input/output integer values
- pulp rqs tile constraints now properly target the last dimension of rqs params

### Removed
- Delete outdated and unused `.gitlab-ci.yml` file
Expand Down Expand Up @@ -180,9 +185,9 @@ This release containing major architectural changes, new platform support, enhan


### Added
- BatchNorm kernel
- ConvTranspose kernel
- MaxPool1D kernel
- BatchNorm kernel
- ConvTranspose kernel
- MaxPool1D kernel
- Template for 1D Convolution
- Support for float32 data type in the previous kernels
- Float binding for Pad1D kernel
Expand Down Expand Up @@ -321,7 +326,7 @@ This release containing major architectural changes, new platform support, enhan

### Changed
- FloatConvTemplate file
- Platform.py file
- Platform.py file
- Bump the CMake version to 3.24 as required for the chimera-sdk
- Bump GVSoC's version and add chimera simulation target
- Rename the generic source util to utils to avoid name collision with chimera-sdk
Expand Down
8 changes: 8 additions & 0 deletions Deeploy/AbstractDataTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,20 @@ def checkValue(cls, value: Union[int, Iterable[int], np.ndarray], ctxt: Optional

if isinstance(value, int):
_max, _min = (value, value)
elif isinstance(value, np.number):
value = value.item()
if isinstance(value, float):
assert value.is_integer(), f"Floating-point value {value} is not an integer."
value = int(value)
_max, _min = (value, value)
elif isinstance(value, np.ndarray):
_max = value.max()
_min = value.min()
elif isinstance(value, Iterable):
_max = max(value)
_min = min(value)
else:
raise ValueError(f"Unsupported value of type {type(value)} with value {value}")

if _max > cls.typeMax:
return False
Expand Down
10 changes: 5 additions & 5 deletions Deeploy/CommonExtensions/DataTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ class float64_t(FloatImmediate):

SignedIntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (int8_t, int16_t, int32_t, int64_t)
UnsignedIntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (uint8_t, uint16_t, uint32_t, uint64_t)
IntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = (sorted((
*SignedIntegerDataTypes,
*UnsignedIntegerDataTypes,
),
key = lambda _type: _type.typeWidth))
IntegerDataTypes: Tuple[Type[IntegerImmediate], ...] = tuple(
sorted((
*SignedIntegerDataTypes,
*UnsignedIntegerDataTypes,
), key = lambda _type: _type.typeWidth))
FloatDataTypes: Tuple[Type[FloatImmediate], ...] = (bfloat16_t, float16_t, float32_t, float64_t)


Expand Down
57 changes: 27 additions & 30 deletions Deeploy/CommonExtensions/TypeCheckers/SignPropTypeChecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import List, Optional
from abc import ABC, abstractmethod
from typing import List

import onnx_graphsurgeon as gs

Expand All @@ -11,27 +12,28 @@
from Deeploy.Logging import DEFAULT_LOGGER as log


class SignPropTypeChecker(NodeTypeChecker):
class SignPropTypeChecker(NodeTypeChecker, ABC):

@abstractmethod
def _inferNumLevels(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> Optional[List[int]]:
return None
operatorRepresentation: OperatorRepresentation) -> List[int]:
pass

@abstractmethod
def _inferSignedness(self, inputs: List[VariableBuffer],
operatorRepresentation: OperatorRepresentation) -> Optional[List[int]]:
return None
operatorRepresentation: OperatorRepresentation) -> List[bool]:
pass

def typeInferGlobalCtxt(self, ctxt: NetworkContext, node: gs.Node) -> NetworkContext:
ctxt = super().typeInferGlobalCtxt(ctxt, node)

for inputNode, _type in zip(node.inputs, self.input_types):
if isinstance(ctxt.lookup(inputNode.name), ConstantBuffer):
reference = ctxt.lookup(inputNode.name)
if not _type.referencedType.checkPromotion(reference.values):
raise Exception(f"Can't cast {reference} to {_type}!")

reference.nLevels = reference.values.max() - reference.values.min()
reference._signed = _type.referencedType.typeMin < 0
for tensor, ty in zip(node.inputs, self.input_types):
buffer = ctxt.lookup(tensor.name)
if isinstance(buffer, ConstantBuffer):
refTy = ty.referencedType
assert refTy.checkPromotion(buffer.values), f"Can't cast {buffer} to {ty}!"
buffer.nLevels = buffer.values.max() - buffer.values.min()
buffer._signed = refTy.typeMin < 0

return ctxt

Expand All @@ -42,21 +44,16 @@ def typeInferOutput(self, ctxt: NetworkContext, node: gs.Node,
inputs = [ctxt.lookup(inputNode.name) for inputNode in node.inputs]
outputs = [ctxt.lookup(outputNode.name) for outputNode in node.outputs]

signProp = all([hasattr(_input, "_signed") and hasattr(_input, "nLevels") for _input in inputs])

if signProp:
nLevels = self._inferNumLevels(inputs, operatorRepresentation)
signedness = self._inferSignedness(inputs, operatorRepresentation)

if nLevels is None or signedness is None:
return ctxt
for obj, nLevel, sign in zip(outputs, nLevels, signedness):
obj.nLevels = nLevel
obj._signed = sign

if issubclass(obj._type.referencedType, IntegerImmediate) and not obj._type.fitsNumLevels(nLevel):
log.warning(
f"{obj.name} has {nLevel} levels, but {obj._type.referencedType.typeName} only supports {obj._type.referencedType.nLevels} levels."
)
nLevels = self._inferNumLevels(inputs, operatorRepresentation)
signedness = self._inferSignedness(inputs, operatorRepresentation)

for obj, nLevels, sign in zip(outputs, nLevels, signedness):
assert isinstance(obj, VariableBuffer)
obj.nLevels = nLevels
obj._signed = sign
refTy = obj._type.referencedType
if issubclass(refTy, IntegerImmediate) and not refTy.fitsNumLevels(nLevels):
log.warning(
f"{obj.name} has {nLevels} levels, but {refTy.typeName} only supports {refTy.nLevels} levels.")

return ctxt
27 changes: 13 additions & 14 deletions Deeploy/DeeployTypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def __init__(self, name: str = '', shape = [1], aliases: Optional[List[str]] = N
self._live: bool = False #: bool: DO NOT OVERRIDE - this variable is true if a previous Memory allocation pass has allocated the buffer, and false if this buffer has been deallocated or has not been allocated yet.
self._deploy: bool = True #: bool: MAY OVERRIDE - this variable is a global switch to deactivate the buffer for all purposes without deleting it outright.

self._signed = None
self.nLevels = None
self._signed: bool = None
self.nLevels: int = None

self.is_input: bool = False
self.is_output: bool = False
Expand Down Expand Up @@ -1009,9 +1009,10 @@ def annotateType(self, name: str, _type: Type[Pointer]):
VariableBuffer with

"""
obj = self.lookup(name)
obj._type = _type
obj._instance = _type(name, ctxt = self)
buffer = self.lookup(name)
assert isinstance(buffer, VariableBuffer)
buffer._type = _type
buffer._instance = _type(name, ctxt = self)

def copy(self) -> NetworkContext:
"""Return a shallow copy of this NetworkContext
Expand Down Expand Up @@ -1312,14 +1313,12 @@ def typeCheckNodeInputs(self, ctxt: NetworkContext, node: gs.Node) -> bool:
return retCheck

def typeInferGlobalCtxt(self, ctxt: NetworkContext, node: gs.Node) -> NetworkContext:
for inputNode, _type in zip(node.inputs, self.input_types):
if isinstance(ctxt.lookup(inputNode.name), ConstantBuffer):
reference = ctxt.lookup(inputNode.name)
if not _type.referencedType.checkPromotion(reference.values):
raise Exception(f"Can't cast {reference} to {_type}!")

ctxt.annotateType(inputNode.name, _type)

for tensor, ty in zip(node.inputs, self.input_types):
buffer = ctxt.lookup(tensor.name)
if isinstance(buffer, ConstantBuffer):
if not ty.referencedType.checkPromotion(buffer.values):
raise Exception(f"Can't cast {buffer} to {ty}!")
ctxt.annotateType(tensor.name, ty)
return ctxt

def annotateDict(self, ctxt: NetworkContext, node: gs.Node, operatorRepresentation: OperatorRepresentation):
Expand Down Expand Up @@ -2695,7 +2694,7 @@ def parse(self, default_channels_first: bool = True) -> bool:
f" - Deepest layer available mappers: {[type(x.parser).__name__ for x in deepestLayer.maps]}")
log.error("=" * 80)
raise RuntimeError(
f'Did not find adequate mapping for graph! Explored until layer {deepestLayer.__class__.__name__} of node {deepestNodeName}'
f'Did not find adequate mapping for graph! Explored until layer {deepestLayer.__class__.__name__} of node {deepestNodeName}\n'
f'Candidates: {[type(x.parser).__name__ for x in deepestLayer.maps]}. Exhausted backtracking.')

previousLayer = scheduledLayerList[idx - 1]
Expand Down
19 changes: 1 addition & 18 deletions Deeploy/Targets/Generic/Layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from Deeploy.DeeployTypes import NodeMapper, ONNXLayer, OperatorRepresentation, Shape
from Deeploy.DeeployTypes import NodeMapper, ONNXLayer, Shape


class ConcatLayer(ONNXLayer):
Expand Down Expand Up @@ -64,23 +64,6 @@ def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)


class iNoNormLayer(ONNXLayer):

def __init__(self, maps: List[NodeMapper]):
super().__init__(maps)

def computeOps(self):
return self.mapper.parser.operatorRepresentation['size'] * 4 # 2 mul, 1 add, 1 right shift

def computeShapes(self, inputShapes: Shape, outputShapes: Shape, operatorRepresentation: OperatorRepresentation,
channels_first: bool) -> Tuple[Shape]:

# JUNGVI: Broadcast the weights and bias to have as many dimensions as the inputs
inputShapes[1] = [1] * (len(inputShapes[0]) - len(inputShapes[1])) + list(inputShapes[1])
inputShapes[2] = inputShapes[1]
return (inputShapes, outputShapes)


class RQSiGELULayer(GELULayer):

def __init__(self, maps: List[NodeMapper]):
Expand Down
109 changes: 41 additions & 68 deletions Deeploy/Targets/Generic/Parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,25 @@ def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:
ret = super().parseNode(node)
wellFormed = False
if ret:
pads = self.operatorRepresentation['pads']
kernel_shape = self.operatorRepresentation['kernel_shape']
strides = self.operatorRepresentation['strides']
# 1D: pads should be length 2, kernel_shape length 1, strides length 1
if len(pads) == 2 and len(kernel_shape) == 1 and len(strides) == 1:
wellFormed = True
self.operatorRepresentation['padding_y'] = int(pads[0])
self.operatorRepresentation['padding_y_right'] = int(pads[1])
self.operatorRepresentation['stride_y'] = int(strides[0])
self.operatorRepresentation['dim_kernel_y'] = int(kernel_shape[0])
return wellFormed
if not super().parseNode(node):
return False

pads = self.operatorRepresentation['pads']
kernel_shape = self.operatorRepresentation['kernel_shape']
strides = self.operatorRepresentation['strides']

if not all([
len(pads) == 2,
len(kernel_shape) == 1,
len(strides) == 1,
]):
return False

self.operatorRepresentation['padding_y'] = pads[0]
self.operatorRepresentation['padding_y_right'] = pads[1]
self.operatorRepresentation['stride_y'] = strides[0]
self.operatorRepresentation['dim_kernel_y'] = kernel_shape[0]
return True

def parseNodeCtxt(self, ctxt, node, channels_first = True):
newCtxt, ret = super().parseNodeCtxt(ctxt, node, channels_first)
Expand Down Expand Up @@ -269,28 +274,31 @@ def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:
if not super().parseNode(node):
return False

ret = super().parseNode(node)
wellFormed = False
if ret:
pads = self.operatorRepresentation['pads']
kernel_shape = self.operatorRepresentation['kernel_shape']
strides = self.operatorRepresentation['strides']
if len(pads) == 4 and len(kernel_shape) == 2 and len(strides) == 2:
wellFormed = True
pads = self.operatorRepresentation['pads']
kernel_shape = self.operatorRepresentation['kernel_shape']
strides = self.operatorRepresentation['strides']

self.operatorRepresentation['padding_x'] = int(self.operatorRepresentation['pads'][0])
self.operatorRepresentation['padding_y'] = int(self.operatorRepresentation['pads'][1])
self.operatorRepresentation['padding_x_left'] = int(self.operatorRepresentation['pads'][0])
self.operatorRepresentation['padding_y_top'] = int(self.operatorRepresentation['pads'][1])
self.operatorRepresentation['padding_x_right'] = int(self.operatorRepresentation['pads'][2])
self.operatorRepresentation['padding_y_bottom'] = int(self.operatorRepresentation['pads'][3])
self.operatorRepresentation['stride_x'] = int(self.operatorRepresentation['strides'][0])
self.operatorRepresentation['stride_y'] = int(self.operatorRepresentation['strides'][1])
self.operatorRepresentation['dim_kernel_x'] = int(self.operatorRepresentation['kernel_shape'][0])
self.operatorRepresentation['dim_kernel_y'] = int(self.operatorRepresentation['kernel_shape'][1])
if not all([
len(pads) == 4,
len(kernel_shape) == 2,
len(strides) == 2,
]):
return False

return wellFormed
self.operatorRepresentation['padding_x'] = pads[0]
self.operatorRepresentation['padding_y'] = pads[1]
self.operatorRepresentation['padding_x_left'] = pads[0]
self.operatorRepresentation['padding_y_top'] = pads[1]
self.operatorRepresentation['padding_x_right'] = pads[2]
self.operatorRepresentation['padding_y_bottom'] = pads[3]
self.operatorRepresentation['stride_x'] = strides[0]
self.operatorRepresentation['stride_y'] = strides[1]
self.operatorRepresentation['dim_kernel_x'] = kernel_shape[0]
self.operatorRepresentation['dim_kernel_y'] = kernel_shape[1]
return True

def parseNodeCtxt(self,
ctxt: NetworkContext,
Expand Down Expand Up @@ -837,41 +845,6 @@ def parseNodeCtxt(self,
return ctxt, True


class iNoNormParser(NodeParser):

def __init__(self):
super().__init__()

def parseNode(self, node: gs.Node) -> bool:

ret = all(['D' in node.attrs, 'mul' in node.attrs, 'n_levels' in node.attrs])

if ret:
self.operatorRepresentation['D'] = node.attrs['D']
self.operatorRepresentation['log2D'] = int(np.log2(node.attrs['D'].values).tolist()[0])
self.operatorRepresentation['mul'] = int(node.attrs['mul'].values.tolist()[0])
self.operatorRepresentation['n_levels'] = node.attrs['n_levels']

return ret

def parseNodeCtxt(self,
ctxt: NetworkContext,
node: gs.Node,
channels_first: bool = True) -> Tuple[NetworkContext, bool]:

data_in = ctxt.lookup(node.inputs[0].name)
weights = ctxt.lookup(node.inputs[1].name)
bias = ctxt.lookup(node.inputs[2].name)
data_out = ctxt.lookup(node.outputs[0].name)
self.operatorRepresentation['data_in'] = data_in.name
self.operatorRepresentation['weights'] = weights.name
self.operatorRepresentation['bias'] = bias.name
self.operatorRepresentation['data_out'] = data_out.name
self.operatorRepresentation['size'] = np.prod(data_in.shape)

return ctxt, True


class RQSiHardswishParser(iHardswishParser):

def __init__(self):
Expand Down
Loading
Loading