Description
Hello! I've been trying to port a model from pytorch manually to onnx using onnxscript.
I've tried to come with a way of doing this elegantly by creating "custom blocks" with attributes.
However, because of how onnxscript currently works, there are some issues.
This is my current code
def GConv2D(key: str, kernel_size: int, padding: int):
weight = weights[key + ".weight"].numpy()
bias = weights[key + ".bias"].numpy()
@script()
def GConv2D(r: FLOAT[...]):
r = op.Conv(
r,
weight,
bias,
kernel_shape=[kernel_size, kernel_size],
pads=[padding, padding, padding, padding],
)
return r
return GConv2D
def GroupResBlock(key: str, in_dim: int, out_dim: int):
downsample = GConv2D(key + ".downsample", 1, 0) if in_dim != out_dim else Identity()
conv1 = GConv2D(key + ".conv1", 3, 1)
conv2 = GConv2D(key + ".conv2", 3, 1)
@script()
def GroupResBlock(x: FLOAT[...]):
x = conv1(op.Relu(x))
x = conv2(op.Relu(x))
x = downsample(x)
return x
return GroupResBlock
def MaskDecoderBlock(key: str):
up_16_8 = GroupResBlock(key + ".up_16_8.out_conv", 256, 128)
up_8_4 = GroupResBlock(key + ".up_8_4.out_conv", 128, 128)
@script()
def MaskDecoderBlock(x: FLOAT[...]):
x = up_16_8(x)
x = up_8_4(x)
return x
return MaskDecoderBlock
model = MaskDecoderBlock("mask_decoder").to_model_proto()
"downsample" from GroupResBlock is set conditionally. I want it to downsample if in_dim and out_dim are not equal to each other. To avoid having to put this if statement in the model itself, the check is done before so it can be baked into the model as is.
The issue is, up_16_8 gets created, and the function GroupResBlock gets defined as having the downsample block. When up_8_4 gets created, the function GroupResBlock is already defined and it reuses it (with the downsample block, and the wrong weights!)
Is there a way to generate a model proto without functions? As to make it avoid reusing blocks, and make it generate a flat graph instead?