Skip to content
This repository was archived by the owner on Apr 10, 2025. It is now read-only.
This repository was archived by the owner on Apr 10, 2025. It is now read-only.

RASP validator fails for some programs  #11

@langosco

Description

@langosco

Issue #9 introduces a validator to check RASP programs that compile incorrectly.
Here's one case---a RASP program that computes the sum of all inputs up to the current index---in which I think the validator fails (or I've misunderstood how it works):

from tracr.rasp import rasp
from tracr.compiler import validating, compiling


def sum_of_inputs() -> rasp.SOp:
    before = rasp.Select(rasp.indices, rasp.indices, rasp.Comparison.LEQ)
    means = rasp.Aggregate(before, rasp.tokens)  # returns sequence s_i = mean_{j<=i} input_j
    sums = rasp.SequenceMap(lambda x, y: x*y, means, rasp.indices+1)
    return sums


sums = sum_of_inputs()

# The output of the RASP program sums is different that the output of the compiled model:
rasp_output = sums([3, 2, 1, 1])
compiled_model = compiling.compile_rasp_to_model(sums, vocab={1,2,3}, max_seq_len=5, compiler_bos="BOS")
compiled_output = compiled_model.apply(["BOS", 3, 2, 1, 1]).decoded

print(rasp_output)  # output: [3.0, 5.0, 6.0, 7.0]
print(compiled_output)  # output: ['BOS', 3, 4, 3, 4]

# However, it looks like the validator doesn't catch the error:
print(validating.validate(sums, [1, 2, 3]))  # returns an empty list

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions