Skip to content

Generating a model without functions? #1935

Open
@noahcoolboy

Description

@noahcoolboy

Hello! I've been trying to port a model from pytorch manually to onnx using onnxscript.
I've tried to come with a way of doing this elegantly by creating "custom blocks" with attributes.
However, because of how onnxscript currently works, there are some issues.

This is my current code

def GConv2D(key: str, kernel_size: int, padding: int):
    weight = weights[key + ".weight"].numpy()
    bias = weights[key + ".bias"].numpy()

    @script()
    def GConv2D(r: FLOAT[...]):
        r = op.Conv(
            r,
            weight,
            bias,
            kernel_shape=[kernel_size, kernel_size],
            pads=[padding, padding, padding, padding],
        )

        return r

    return GConv2D

def GroupResBlock(key: str, in_dim: int, out_dim: int):
    downsample = GConv2D(key + ".downsample", 1, 0) if in_dim != out_dim else Identity()
    conv1 = GConv2D(key + ".conv1", 3, 1)
    conv2 = GConv2D(key + ".conv2", 3, 1)

    @script()
    def GroupResBlock(x: FLOAT[...]):
        x = conv1(op.Relu(x))
        x = conv2(op.Relu(x))
        x = downsample(x)
        return x
    
    return GroupResBlock

def MaskDecoderBlock(key: str):
    up_16_8 = GroupResBlock(key + ".up_16_8.out_conv", 256, 128)
    up_8_4 = GroupResBlock(key + ".up_8_4.out_conv", 128, 128)

    @script()
    def MaskDecoderBlock(x: FLOAT[...]):
        x = up_16_8(x)
        x = up_8_4(x)
        return x

    return MaskDecoderBlock

model = MaskDecoderBlock("mask_decoder").to_model_proto()

"downsample" from GroupResBlock is set conditionally. I want it to downsample if in_dim and out_dim are not equal to each other. To avoid having to put this if statement in the model itself, the check is done before so it can be baked into the model as is.

The issue is, up_16_8 gets created, and the function GroupResBlock gets defined as having the downsample block. When up_8_4 gets created, the function GroupResBlock is already defined and it reuses it (with the downsample block, and the wrong weights!)

Is there a way to generate a model proto without functions? As to make it avoid reusing blocks, and make it generate a flat graph instead?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions