Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): try to infer lambda argument types inside calls #7088

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

asterite
Copy link
Collaborator

@asterite asterite commented Jan 16, 2025

Description

Problem

Fixes #6802

Summary

I've been thinking for days how we could have lambda parameter types be "inferred" from the call they are being passed to.

The first thing that came to my mind is that lambdas are most commonly passed as callbacks after invoking a method on self, where either self or a part of self is given to the lambda (like Option::map, BoundedVec::map, etc.). So the first thing that I tried in this PR is to eagerly unify a method call's function type with the object type. Then, when elaborating a lambda as a method call argument we pass the potential parameter types of a function type that is in that call position:

fn map<U>(&self, f: fn(T) -> U) { ... }
                       ^
                     this

self.map(|x| ...)
          ^
    has to unify with this

Then these types are unified, without erroring because later we'll check the type of the lambda against the argument anyway.

And... that worked! And that already covers a lot of cases.

Then I did the same thing for function calls, except that there's no self, but at least it now works if a callback has a concrete type, like if it's:

fn foo(f: fn(Foo) -> ...

And that worked too! Though I'm not sure there are many uses of that...

BUT: it didn't work in the code Nico shared in Slack, because it's a function call where the first argument given is like a self type, except that it's not a method all:

for_each_in_bounded_vec(notes, |note, _| {
})

So the final thing I did was to eagerly try to unify argument types as we elaborate them against the target function type. And that made that example work!

It won't work if the lambda comes before the argument (which works in Rust) but I think that pattern is uncommon (though we could try to make it work in the future).

Additional Context

I don't know if this is the right way to approach this.

I also don't know if unifying eagerly would cause any issues. One thing that's not done here is using "unify_with_coercions", but given that we don't issue the errors that happens in these eager checks, maybe it's fine (maybe it won't work in cases where an array is automatically converted to a slice, though I guess we could make it work in the future).

One more thing: the changes in the stdlib and programs aren't really necessary, but I wanted to see if the code compiles with those changes... and in many cases the code is simplified a bit.

And finally: the code is not the best as I was just experimenting. We should clean it up.

Documentation

Check one:

  • No documentation needed.
  • Documentation included in this PR.
  • [For Experimental Features] Documentation to be submitted in a separate PR.

PR Checklist

  • I have tested the changes locally.
  • I have formatted the changes with Prettier and/or cargo fmt on default settings.

Copy link
Contributor

github-actions bot commented Jan 16, 2025

Changes to Brillig bytecode sizes

Generated at commit: 7c978dd4efedeee3edce9a53b582e2f3b1df93ef, compared to commit: 48e613ea97afb829012b3c5b1f8181d4f5ccfb7b

🧾 Summary (10% most significant diffs)

Program Brillig opcodes (+/-) %
higher_order_functions_inliner_zero +26 ❌ +3.90%
slices_inliner_min -18 ✅ -0.68%

Full diff report 👇
Program Brillig opcodes (+/-) %
higher_order_functions_inliner_zero 692 (+26) +3.90%
higher_order_functions_inliner_min 1,448 (+4) +0.28%
slices_inliner_min 2,615 (-18) -0.68%

Copy link
Contributor

github-actions bot commented Jan 16, 2025

Changes to number of Brillig opcodes executed

Generated at commit: 7c978dd4efedeee3edce9a53b582e2f3b1df93ef, compared to commit: 48e613ea97afb829012b3c5b1f8181d4f5ccfb7b

🧾 Summary (10% most significant diffs)

Program Brillig opcodes (+/-) %
higher_order_functions_inliner_min -165 ✅ -5.86%
higher_order_functions_inliner_zero -90 ✅ -6.79%

Full diff report 👇
Program Brillig opcodes (+/-) %
slices_inliner_min 4,549 (-75) -1.62%
higher_order_functions_inliner_min 2,651 (-165) -5.86%
higher_order_functions_inliner_zero 1,236 (-90) -6.79%

Copy link
Contributor

github-actions bot commented Jan 16, 2025

Compilation Report

Program Compilation Time
sha256_regression 0.994s
regression_4709 0.788s
ram_blowup_regression 15.900s
private-kernel-tail 1.022s
private-kernel-reset 6.166s
private-kernel-inner 1.950s

Copy link
Contributor

github-actions bot commented Jan 16, 2025

Execution Report

Program Execution Time
sha256_regression 0.051s
regression_4709 0.001s
ram_blowup_regression 0.620s
private-kernel-tail 0.019s
private-kernel-reset 0.310s
private-kernel-inner 0.068s

@asterite
Copy link
Collaborator Author

I didn't expect this to have any repercussion in SSA 😮

Lambdas that involve math operations somehow compile before this change without type annotations:

let myarray: [i32; 3] = [1, 2, 3];

// Compiles fine without saying `n: i32`
assert(myarray.any(|n| n > 2));

I think it's because of the >.

But in the end n ends up being i32 so I don't understand how it could change... I'll investigate it later today. Though I guess if it ends up in more optimizations that's good, though maybe it's suspicious and this introduced a bug... 🤔

Copy link
Contributor

github-actions bot commented Jan 16, 2025

Execution Memory Report

Program Peak Memory %
keccak256 74.63M 0%
workspace 123.86M 0%
regression_4709 315.94M -1%
ram_blowup_regression 512.57M 0%
rollup-root 471.43M -6%
rollup-merge 471.43M -1%
rollup-block-root 471.43M 100%
rollup-block-merge 471.43M -6%
rollup-base-public 471.43M -36%
rollup-base-private 471.43M -21%
private-kernel-tail 180.63M 0%
private-kernel-reset 245.24M 0%
private-kernel-inner 208.64M 0%

Copy link
Contributor

github-actions bot commented Jan 16, 2025

Compilation Memory Report

Program Peak Memory %
keccak256 77.56M 0%
workspace 123.87M 0%
regression_4709 424.07M 0%
ram_blowup_regression 1.46G 0%
rollup-root 471.43M -22%
rollup-merge 471.43M -5%
rollup-block-root-single-tx 471.43M 100%
rollup-block-root-empty 471.43M -4%
rollup-block-root 471.44M 100%
rollup-block-merge 471.43M -22%
rollup-base-public 471.44M 100%
rollup-base-private 471.43M 100%
private-kernel-tail 207.18M 0%
private-kernel-reset 584.20M 0%
private-kernel-inner 294.40M 0%

@asterite
Copy link
Collaborator Author

asterite commented Jan 16, 2025

Also compilation memory and execution memory went up 🤔

Now it's back to normal. Actually the numbers are the same, it's just that maybe the base numbers were outdated...

We could undo commits 79f4076 and 2cffe19 if we wanted, though they are probably good checks to improve performance.

@asterite
Copy link
Collaborator Author

It seems with this PR some functions get inlined right from the beginning, while in master some are not. I'm not sure why... but if this only happens for higher-order functions (maybe uncommon?) and if it leads to an optimization, maybe it's good.

@asterite asterite requested a review from a team January 16, 2025 15:38
Copy link
Contributor

@jfecher jfecher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of unifying afterward we can push the type down and that should be sufficient. This is what Rust does as well since it uses a different algorithm called "bidirectional type inference" which splits up type checking functions into fn infer(T) -> Type which we have, and fn check(T, Type) which we don't have. Certain constructs like literals are always inferred, and others like lambdas are usually checked.

For our purposes though we can just use your existing check arg type function but pass down the expected type instead of the entire function type and argument index.

compiler/noirc_frontend/src/elaborator/expressions.rs Outdated Show resolved Hide resolved
compiler/noirc_frontend/src/elaborator/expressions.rs Outdated Show resolved Hide resolved
compiler/noirc_frontend/src/elaborator/expressions.rs Outdated Show resolved Hide resolved
compiler/noirc_frontend/src/elaborator/expressions.rs Outdated Show resolved Hide resolved
compiler/noirc_frontend/src/elaborator/expressions.rs Outdated Show resolved Hide resolved
@asterite
Copy link
Collaborator Author

Instead of unifying afterward we can push the type down and that should be sufficient

Some weeks ago I tried something similar to what I did here but I can't remember what (it didn't work so I deleted the branch). I thought I was assigning types instead of unifying, and getting errors like "can't find method foo for T" and that's why I tried unifying here. But it's probably the case that I had a bug or something else and that's why it wasn't working. I'll try pushing the type down and using it when the type is unspecified 👍

@asterite
Copy link
Collaborator Author

Hm, now we get a failure on code like this:

struct U60Repr<let N: u32, let NumSegments: u32> {}

impl<let N: u32, let NumSegments: u32> U60Repr<N, NumSegments> {
    fn new<let NumFieldSegments: u32>(_: [Field; N * NumFieldSegments]) -> Self {
        U60Repr {}
    }
}

fn main() {
    let input: [Field; 6] = [0; 6];
    let _: U60Repr<3, 6> = U60Repr::new(input);
}

The error is this:

error: Expected type [Field; ((6 / _) * _)], found type [Field; 6]

let u60repr: U60Repr<3, 6> = unsafe { U60Repr::new(input) };
                                                   -----

Looking at the types that are compares, these are the ones before this PR:

arg = [Field; (6: numeric u32)]
param = [Field; (Numeric(Shared(RefCell { value: Unbound('8, Numeric(u32)) }): u32) * Numeric(Shared(RefCell { value: Unbound('10, Numeric(u32)) }): u32))]

And these are the ones in this PR:

arg = [Field; (6: numeric u32)]
param = [Field; ((6 / _) * Numeric(Shared(RefCell { value: Unbound('10, Numeric(u32)) }): u32))]

It seems the compiler is trying to solve the equation and that's why there's a division. I don't know if this is a bug in the unification code or if it's a bug introduced in this PR.

@asterite
Copy link
Collaborator Author

asterite commented Jan 16, 2025

So what I found is happening is...

When checking arg against param, eventually it'll try the math substitution here:

https://github.com/noir-lang/noir/blob/c172880ae47ec4906cda662801bd4b7866c9586b/compiler/noirc_frontend/src/hir_def/types.rs#L1729C47-L1736

Code:

// Handle cases like `4 = a + b` by trying to solve to `a = 4 - b`
let new_type = InfixExpr(
    Box::new(Constant(*value, kind.clone())),
    inverse,
    rhs.clone(),
);
new_type.try_unify(lhs, bindings)?;
Ok(())

That will eventually fail because of an occurs check in Type::try_bind_to:

// Check if the target id occurs within `this` before binding. Otherwise this could
// cause infinitely recursive types
if this.occurs(target_id) {
    Err(UnificationError)
} else {
    bindings.insert(target_id, (var.clone(), this.kind(), this.clone()));
    Ok(())
}

Maybe it's failing because we are binding these types multiple times now?

Something that caught my attention is the ? followed by Ok(()) in the first snippet:

new_type.try_unify(lhs, bindings)?;
Ok(())

That could just be:

new_type.try_unify(lhs, bindings)

I wonder if the intention of that code was trying to solve the math equation, but always returning Ok(()) if we couldn't unify it, but it ended up accidentally returning error because of the ?.

If in the occurs check I return Ok(()) instead of Err(UnificationError) the issue is also solved and no test fails, though of course I don't know if that's okay to do.

@jfecher Thoughts?

@jfecher
Copy link
Contributor

jfecher commented Jan 16, 2025

@asterite we can't remove the occurs check, that'd break the check for infinite types like a = (a, a). I think the error is stemming from our inability to solve 6 = ? * ? because N and NumSegments are not also being pushed down, since they're only known from the let statement. My guess is before this change on arguments we did ?a = ?b * ?c instead and could just set the value of ?a to equal ?b * ?c in that case. Then when they're constrained further later it'd be all fine.

@asterite
Copy link
Collaborator Author

I see, thanks.

I think the error is stemming from our inability to solve 6 = ? * ?

In the code before this PR, [Field; 6] was checked against [Field; ?a * ?b] and that somehow didn't seem to break. Maybe it's that it solved ?a to be 6 / ?b? I didn't check.

I think now it's checking 6 = (6 / ?a) * ?b, maybe we don't have that case covered. But that equation could be turned into 6 / ?b = 6 / ?a so ?a and ?b should be the same... not sure! I jumped into working on the loop statement, but I'll come back to this after that.

@jfecher
Copy link
Contributor

jfecher commented Jan 16, 2025

// Handle cases like 4 = a + b by trying to solve to a = 4 - b

I don't see anything wrong with that case aside from the fact the last Ok(()) is unnecessary as you mentioned. I don't think it's the cause of the issue here at least.

I think now it's checking 6 = (6 / ?a) * ?b, maybe we don't have that case covered.

This looks incorrect to me since as you mentioned it'd require a = b which shouldn't follow from the original equation of 6 = a * b. So the error probably lies in how that constraint comes about

@TomAFrench TomAFrench linked an issue Jan 16, 2025 that may be closed by this pull request
@asterite
Copy link
Collaborator Author

Coming back to this, I found that it's trying to unify a Constant against an InfixExpr with these values:

# Infix
lhs = Numeric(Shared(RefCell { value: Unbound('10, Numeric(u32)) }): u32)
op = Multiplication
rhs = ((6: numeric u32) / Numeric(Shared(RefCell { value: Unbound('10, Numeric(u32)) }): u32))

# Constant
value = 6

So it's unifying 6 against (6 / x) * x. And effectively the x is the same in both cases.

What ends up happening, because of

// Handle cases like `4 = a + b` by trying to solve to `a = 4 - b`

is trying to unify x against (6 / (6 / x)), and that's where the error happens.

Now I'm trying to see whre that (6 / x) * x is coming from, because I think that could be just 6.

@jfecher
Copy link
Contributor

jfecher commented Jan 16, 2025

The 6 / x * x is definitely a result of trying to solve the equation previously. It's unfortunate we can't just optimize that to 6 anymore like we used to before we considered integer division.

I wonder if we added a version of division that tells the compiler to ignore rounding if that'd fix this. E.g. assuming the original is x = a * b and the translation is x / b = a then we already know x should be divisible by b so we should be able to optimize a theoretical x / b * b later on.

@asterite
Copy link
Collaborator Author

It's unfortunate we can't just optimize that to 6 anymore like we used to before we considered integer division.

Yeah... that's what I tried at first but found that a test about that failed.

I wonder where that (6 / x) * x came from, because in the code there's only N * NumFieldSegments. Like, here there aren't two type variables that are the same that came from the user, it's the compiler duplicating them and putting them in the equation...

I tried gating Type::InfixExpr(..) behind a method to see where's that coming from, but all I can see it's coming from canonicalize which already takes an InfixExpr like that, and I have no idea why I can't find where it all starts 😕

Some commits ago this PR worked, I guess because we didn't type-check as much as we do now (previously we'd only do it if there were lambdas). We could go back to that version, though I guess there would still be code that failed because of this if it was similar to the snippet that fails but also has a lambda argument in it (maybe uncommon). That is, I wonder if this bugs exists regardless of this feature and we could procede with this feature by only applying these changes if there are lambda arguments.

@asterite
Copy link
Collaborator Author

That code is currently erroring on type_check_call because func_type changed after we unified its types with the argument types.

I wanted to try to call type_check_call on a fresh func_type without all the type information we added from arguments... but I didn't know how. I cloned func_type before type-checking the arguments and passed but it didn't work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: 📋 Backlog
Development

Successfully merging this pull request may close these issues.

Methods not found when using lambdas
2 participants