Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experimental): try to infer lambda argument types inside calls #7088

Merged
merged 16 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 99 additions & 27 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<'context> Elaborator<'context> {
ExpressionKind::If(if_) => self.elaborate_if(*if_),
ExpressionKind::Variable(variable) => return self.elaborate_variable(variable),
ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None),
ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr),
ExpressionKind::Quote(quote) => self.elaborate_quote(quote, expr.span),
ExpressionKind::Comptime(comptime, _) => {
Expand Down Expand Up @@ -387,17 +387,28 @@ impl<'context> Elaborator<'context> {

fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) {
let (func, func_type) = self.elaborate_expression(*call.func);
let func_arg_types =
if let Type::Function(args, _, _, _) = &func_type { Some(args) } else { None };

let mut arguments = Vec::with_capacity(call.arguments.len());
let args = vecmap(call.arguments, |arg| {
let args = vecmap(call.arguments.into_iter().enumerate(), |(arg_index, arg)| {
let span = arg.span;
let expected_type = func_arg_types.and_then(|args| args.get(arg_index));

let (arg, typ) = if call.is_macro_call {
self.elaborate_in_comptime_context(|this| this.elaborate_expression(arg))
self.elaborate_in_comptime_context(|this| {
this.elaborate_expression_with_type(arg, expected_type)
})
} else {
self.elaborate_expression(arg)
self.elaborate_expression_with_type(arg, expected_type)
};

// Try to unify this argument type against the function's argument type
// so that a potential lambda following this argument can have more concrete types.
if let Some(expected_type) = expected_type {
let _ = expected_type.unify(&typ);
}

arguments.push(arg);
(typ, arg, span)
});
Expand Down Expand Up @@ -458,24 +469,55 @@ impl<'context> Elaborator<'context> {
None
};

let call_span = Span::from(object_span.start()..method_name_span.end());
let location = Location::new(call_span, self.file);

let (function_id, function_name) = method_ref.clone().into_function_id_and_name(
object_type.clone(),
generics.clone(),
location,
self.interner,
);

let func_type =
self.type_check_variable(function_name.clone(), function_id, generics.clone());
self.interner.push_expr_type(function_id, func_type.clone());

let func_arg_types =
if let Type::Function(args, _, _, _) = &func_type { Some(args) } else { None };

// Try to unify the object type with the first argument of the function.
// The reason to do this is that many methods that take a lambda will yield `self` or part of `self`
// as a parameter. By unifying `self` with the first argument we'll potentially get more
// concrete types in the arguments that are function types, which will later be passed as
// lambda parameter hints.
if let Some(first_arg_type) = func_arg_types.and_then(|args| args.first()) {
let _ = first_arg_type.unify(&object_type);
}

// These arguments will be given to the desugared function call.
// Compared to the method arguments, they also contain the object.
let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1);
let mut arguments = Vec::with_capacity(method_call.arguments.len());

function_args.push((object_type.clone(), object, object_span));

for arg in method_call.arguments {
for (arg_index, arg) in method_call.arguments.into_iter().enumerate() {
let span = arg.span;
let (arg, typ) = self.elaborate_expression(arg);
let expected_type = func_arg_types.and_then(|args| args.get(arg_index + 1));
let (arg, typ) = self.elaborate_expression_with_type(arg, expected_type);

// Try to unify this argument type against the function's argument type
// so that a potential lambda following this argument can have more concrete types.
if let Some(expected_type) = expected_type {
let _ = expected_type.unify(&typ);
}

arguments.push(arg);
function_args.push((typ, arg, span));
}

let call_span = Span::from(object_span.start()..method_name_span.end());
let location = Location::new(call_span, self.file);
let method = method_call.method_name;
let turbofish_generics = generics.clone();
let is_macro_call = method_call.is_macro_call;
let method_call =
HirMethodCallExpression { method, object, arguments, location, generics };
Expand All @@ -485,18 +527,9 @@ impl<'context> Elaborator<'context> {
// Desugar the method call into a normal, resolved function call
// so that the backend doesn't need to worry about methods
// TODO: update object_type here?
let ((function_id, function_name), function_call) = method_call.into_function_call(
method_ref,
object_type,
is_macro_call,
location,
self.interner,
);

let func_type =
self.type_check_variable(function_name, function_id, turbofish_generics);

self.interner.push_expr_type(function_id, func_type.clone());
let function_call =
method_call.into_function_call(function_id, is_macro_call, location);

self.interner
.add_function_reference(func_id, Location::new(method_name_span, self.file));
Expand All @@ -520,6 +553,26 @@ impl<'context> Elaborator<'context> {
}
}

/// Elaborates an expression knowing that it has to match a given type.
fn elaborate_expression_with_type(
&mut self,
arg: Expression,
typ: Option<&Type>,
) -> (ExprId, Type) {
let ExpressionKind::Lambda(lambda) = arg.kind else {
return self.elaborate_expression(arg);
};

let span = arg.span;
let type_hint =
if let Some(Type::Function(func_args, _, _, _)) = typ { Some(func_args) } else { None };
let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint);
let id = self.interner.push_expr(hir_expr);
self.interner.push_expr_location(id, span, self.file);
self.interner.push_expr_type(id, typ.clone());
(id, typ)
}

fn check_method_call_visibility(&mut self, func_id: FuncId, object_type: &Type, name: &Ident) {
if !method_call_is_visible(
object_type,
Expand Down Expand Up @@ -846,19 +899,38 @@ impl<'context> Elaborator<'context> {
(HirExpression::Tuple(element_ids), Type::Tuple(element_types))
}

fn elaborate_lambda(&mut self, lambda: Lambda) -> (HirExpression, Type) {
/// For elaborating a lambda we might get `parameters_type_hints`. These come from a potential
/// call that has this lambda as the argument.
/// The parameter type hints will be the types of the function type corresponding to the lambda argument.
fn elaborate_lambda(
&mut self,
lambda: Lambda,
parameters_type_hints: Option<&Vec<Type>>,
) -> (HirExpression, Type) {
self.push_scope();
let scope_index = self.scopes.current_scope_index();

self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index });

let mut arg_types = Vec::with_capacity(lambda.parameters.len());
let parameters = vecmap(lambda.parameters, |(pattern, typ)| {
let parameter = DefinitionKind::Local(None);
let typ = self.resolve_inferred_type(typ);
arg_types.push(typ.clone());
(self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ)
});
let parameters =
vecmap(lambda.parameters.into_iter().enumerate(), |(index, (pattern, typ))| {
let parameter = DefinitionKind::Local(None);
let typ = if let UnresolvedTypeData::Unspecified = typ.typ {
if let Some(parameter_type_hint) =
parameters_type_hints.and_then(|hints| hints.get(index))
{
parameter_type_hint.clone()
} else {
self.interner.next_type_variable_with_kind(Kind::Any)
}
} else {
self.resolve_type(typ)
};

arg_types.push(typ.clone());
(self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ)
});

let return_type = self.resolve_inferred_type(lambda.return_type);
let body_span = lambda.body.span;
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ impl<'context> Elaborator<'context> {
Type::MutableReference(typ) => {
self.mark_type_as_used(typ);
}
Type::InfixExpr(left, _op, right) => {
Type::InfixExpr(left, _op, right, _) => {
self.mark_type_as_used(left);
self.mark_type_as_used(right);
}
Expand Down Expand Up @@ -1688,7 +1688,7 @@ impl<'context> Elaborator<'context> {
Type::MutableReference(typ) | Type::Array(_, typ) | Type::Slice(typ) => {
self.check_type_is_not_more_private_then_item(name, visibility, typ, span);
}
Type::InfixExpr(left, _op, right) => {
Type::InfixExpr(left, _op, right, _) => {
self.check_type_is_not_more_private_then_item(name, visibility, left, span);
self.check_type_is_not_more_private_then_item(name, visibility, right, span);
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ impl<'context> Elaborator<'context> {
}
}
(lhs, rhs) => {
let infix = Type::InfixExpr(Box::new(lhs), op, Box::new(rhs));
let infix = Type::infix_expr(Box::new(lhs), op, Box::new(rhs));
Type::CheckedCast { from: Box::new(infix.clone()), to: Box::new(infix) }
.canonicalize()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ impl Type {
Type::Constant(..) => panic!("Type::Constant where a type was expected: {self:?}"),
Type::Quoted(quoted_type) => UnresolvedTypeData::Quoted(*quoted_type),
Type::Error => UnresolvedTypeData::Error,
Type::InfixExpr(lhs, op, rhs) => {
Type::InfixExpr(lhs, op, rhs, _) => {
let lhs = Box::new(lhs.to_type_expression());
let rhs = Box::new(rhs.to_type_expression());
let span = Span::default();
Expand Down
37 changes: 20 additions & 17 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,24 +225,15 @@ impl HirMethodReference {
}
}
}
}

impl HirMethodCallExpression {
/// Converts a method call into a function call
///
/// Returns ((func_var_id, func_var), call_expr)
pub fn into_function_call(
mut self,
method: HirMethodReference,
pub fn into_function_id_and_name(
self,
object_type: Type,
is_macro_call: bool,
generics: Option<Vec<Type>>,
location: Location,
interner: &mut NodeInterner,
) -> ((ExprId, HirIdent), HirCallExpression) {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);

let (id, impl_kind) = match method {
) -> (ExprId, HirIdent) {
let (id, impl_kind) = match self {
HirMethodReference::FuncId(func_id) => {
(interner.function_definition_id(func_id), ImplKind::NotATraitMethod)
}
Expand All @@ -261,10 +252,22 @@ impl HirMethodCallExpression {
}
};
let func_var = HirIdent { location, id, impl_kind };
let func = interner.push_expr(HirExpression::Ident(func_var.clone(), self.generics));
let func = interner.push_expr(HirExpression::Ident(func_var.clone(), generics));
interner.push_expr_location(func, location.span, location.file);
let expr = HirCallExpression { func, arguments, location, is_macro_call };
((func, func_var), expr)
(func, func_var)
}
}

impl HirMethodCallExpression {
pub fn into_function_call(
mut self,
func: ExprId,
is_macro_call: bool,
location: Location,
) -> HirCallExpression {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);
HirCallExpression { func, arguments, location, is_macro_call }
}
}

Expand Down
Loading
Loading