Skip to content

Commit 7bd7353

Browse files
committed
Add BatchNormalization descriptor
1 parent 90102f5 commit 7bd7353

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,8 +734,19 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
734734
attrDescriptors = [],
735735
)
736736

737+
batchNormalizationDesc = OperatorDescriptor(
738+
inputDescriptor = IoDesc(["data_in", "scale", "bias", "mean", "variance"]),
739+
outputDescriptor = IoDesc(["data_out"], optional = ["running_mean", "running_var"]),
740+
attrDescriptors = [
741+
AttrDesc("epsilon", FloatUnpack, default = 1e-5),
742+
AttrDesc("momentum", FloatUnpack, default = 0.9),
743+
AttrDesc("training_mode", BoolUnpack, default = False),
744+
],
745+
)
746+
737747
defaultOperatorDescriptors: Dict[str, OperatorDescriptor] = {
738748
"Add": addDesc,
749+
"BatchNormalization": batchNormalizationDesc,
739750
"CLCA": clcaDesc,
740751
"Concat": concatDesc,
741752
"Conv": convDesc,

0 commit comments

Comments
 (0)