Skip to content

Commit

Permalink
feat(experimental): try to infer lambda argument types inside calls (#…
Browse files Browse the repository at this point in the history
…7088)

Co-authored-by: jfecher <[email protected]>
  • Loading branch information
asterite and jfecher authored Jan 17, 2025
1 parent ed12ad7 commit a3b823c
Show file tree
Hide file tree
Showing 19 changed files with 364 additions and 193 deletions.
126 changes: 99 additions & 27 deletions compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<'context> Elaborator<'context> {
ExpressionKind::If(if_) => self.elaborate_if(*if_),
ExpressionKind::Variable(variable) => return self.elaborate_variable(variable),
ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda),
ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None),
ExpressionKind::Parenthesized(expr) => return self.elaborate_expression(*expr),
ExpressionKind::Quote(quote) => self.elaborate_quote(quote, expr.span),
ExpressionKind::Comptime(comptime, _) => {
Expand Down Expand Up @@ -387,17 +387,28 @@ impl<'context> Elaborator<'context> {

fn elaborate_call(&mut self, call: CallExpression, span: Span) -> (HirExpression, Type) {
let (func, func_type) = self.elaborate_expression(*call.func);
let func_arg_types =
if let Type::Function(args, _, _, _) = &func_type { Some(args) } else { None };

let mut arguments = Vec::with_capacity(call.arguments.len());
let args = vecmap(call.arguments, |arg| {
let args = vecmap(call.arguments.into_iter().enumerate(), |(arg_index, arg)| {
let span = arg.span;
let expected_type = func_arg_types.and_then(|args| args.get(arg_index));

let (arg, typ) = if call.is_macro_call {
self.elaborate_in_comptime_context(|this| this.elaborate_expression(arg))
self.elaborate_in_comptime_context(|this| {
this.elaborate_expression_with_type(arg, expected_type)
})
} else {
self.elaborate_expression(arg)
self.elaborate_expression_with_type(arg, expected_type)
};

// Try to unify this argument type against the function's argument type
// so that a potential lambda following this argument can have more concrete types.
if let Some(expected_type) = expected_type {
let _ = expected_type.unify(&typ);
}

arguments.push(arg);
(typ, arg, span)
});
Expand Down Expand Up @@ -458,24 +469,55 @@ impl<'context> Elaborator<'context> {
None
};

let call_span = Span::from(object_span.start()..method_name_span.end());
let location = Location::new(call_span, self.file);

let (function_id, function_name) = method_ref.clone().into_function_id_and_name(
object_type.clone(),
generics.clone(),
location,
self.interner,
);

let func_type =
self.type_check_variable(function_name.clone(), function_id, generics.clone());
self.interner.push_expr_type(function_id, func_type.clone());

let func_arg_types =
if let Type::Function(args, _, _, _) = &func_type { Some(args) } else { None };

// Try to unify the object type with the first argument of the function.
// The reason to do this is that many methods that take a lambda will yield `self` or part of `self`
// as a parameter. By unifying `self` with the first argument we'll potentially get more
// concrete types in the arguments that are function types, which will later be passed as
// lambda parameter hints.
if let Some(first_arg_type) = func_arg_types.and_then(|args| args.first()) {
let _ = first_arg_type.unify(&object_type);
}

// These arguments will be given to the desugared function call.
// Compared to the method arguments, they also contain the object.
let mut function_args = Vec::with_capacity(method_call.arguments.len() + 1);
let mut arguments = Vec::with_capacity(method_call.arguments.len());

function_args.push((object_type.clone(), object, object_span));

for arg in method_call.arguments {
for (arg_index, arg) in method_call.arguments.into_iter().enumerate() {
let span = arg.span;
let (arg, typ) = self.elaborate_expression(arg);
let expected_type = func_arg_types.and_then(|args| args.get(arg_index + 1));
let (arg, typ) = self.elaborate_expression_with_type(arg, expected_type);

// Try to unify this argument type against the function's argument type
// so that a potential lambda following this argument can have more concrete types.
if let Some(expected_type) = expected_type {
let _ = expected_type.unify(&typ);
}

arguments.push(arg);
function_args.push((typ, arg, span));
}

let call_span = Span::from(object_span.start()..method_name_span.end());
let location = Location::new(call_span, self.file);
let method = method_call.method_name;
let turbofish_generics = generics.clone();
let is_macro_call = method_call.is_macro_call;
let method_call =
HirMethodCallExpression { method, object, arguments, location, generics };
Expand All @@ -485,18 +527,9 @@ impl<'context> Elaborator<'context> {
// Desugar the method call into a normal, resolved function call
// so that the backend doesn't need to worry about methods
// TODO: update object_type here?
let ((function_id, function_name), function_call) = method_call.into_function_call(
method_ref,
object_type,
is_macro_call,
location,
self.interner,
);

let func_type =
self.type_check_variable(function_name, function_id, turbofish_generics);

self.interner.push_expr_type(function_id, func_type.clone());
let function_call =
method_call.into_function_call(function_id, is_macro_call, location);

self.interner
.add_function_reference(func_id, Location::new(method_name_span, self.file));
Expand All @@ -520,6 +553,26 @@ impl<'context> Elaborator<'context> {
}
}

/// Elaborates an expression knowing that it has to match a given type.
fn elaborate_expression_with_type(
&mut self,
arg: Expression,
typ: Option<&Type>,
) -> (ExprId, Type) {
let ExpressionKind::Lambda(lambda) = arg.kind else {
return self.elaborate_expression(arg);
};

let span = arg.span;
let type_hint =
if let Some(Type::Function(func_args, _, _, _)) = typ { Some(func_args) } else { None };
let (hir_expr, typ) = self.elaborate_lambda(*lambda, type_hint);
let id = self.interner.push_expr(hir_expr);
self.interner.push_expr_location(id, span, self.file);
self.interner.push_expr_type(id, typ.clone());
(id, typ)
}

fn check_method_call_visibility(&mut self, func_id: FuncId, object_type: &Type, name: &Ident) {
if !method_call_is_visible(
object_type,
Expand Down Expand Up @@ -846,19 +899,38 @@ impl<'context> Elaborator<'context> {
(HirExpression::Tuple(element_ids), Type::Tuple(element_types))
}

fn elaborate_lambda(&mut self, lambda: Lambda) -> (HirExpression, Type) {
/// For elaborating a lambda we might get `parameters_type_hints`. These come from a potential
/// call that has this lambda as the argument.
/// The parameter type hints will be the types of the function type corresponding to the lambda argument.
fn elaborate_lambda(
&mut self,
lambda: Lambda,
parameters_type_hints: Option<&Vec<Type>>,
) -> (HirExpression, Type) {
self.push_scope();
let scope_index = self.scopes.current_scope_index();

self.lambda_stack.push(LambdaContext { captures: Vec::new(), scope_index });

let mut arg_types = Vec::with_capacity(lambda.parameters.len());
let parameters = vecmap(lambda.parameters, |(pattern, typ)| {
let parameter = DefinitionKind::Local(None);
let typ = self.resolve_inferred_type(typ);
arg_types.push(typ.clone());
(self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ)
});
let parameters =
vecmap(lambda.parameters.into_iter().enumerate(), |(index, (pattern, typ))| {
let parameter = DefinitionKind::Local(None);
let typ = if let UnresolvedTypeData::Unspecified = typ.typ {
if let Some(parameter_type_hint) =
parameters_type_hints.and_then(|hints| hints.get(index))
{
parameter_type_hint.clone()
} else {
self.interner.next_type_variable_with_kind(Kind::Any)
}
} else {
self.resolve_type(typ)
};

arg_types.push(typ.clone());
(self.elaborate_pattern(pattern, typ.clone(), parameter, true), typ)
});

let return_type = self.resolve_inferred_type(lambda.return_type);
let body_span = lambda.body.span;
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ impl<'context> Elaborator<'context> {
Type::MutableReference(typ) => {
self.mark_type_as_used(typ);
}
Type::InfixExpr(left, _op, right) => {
Type::InfixExpr(left, _op, right, _) => {
self.mark_type_as_used(left);
self.mark_type_as_used(right);
}
Expand Down Expand Up @@ -1688,7 +1688,7 @@ impl<'context> Elaborator<'context> {
Type::MutableReference(typ) | Type::Array(_, typ) | Type::Slice(typ) => {
self.check_type_is_not_more_private_then_item(name, visibility, typ, span);
}
Type::InfixExpr(left, _op, right) => {
Type::InfixExpr(left, _op, right, _) => {
self.check_type_is_not_more_private_then_item(name, visibility, left, span);
self.check_type_is_not_more_private_then_item(name, visibility, right, span);
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ impl<'context> Elaborator<'context> {
}
}
(lhs, rhs) => {
let infix = Type::InfixExpr(Box::new(lhs), op, Box::new(rhs));
let infix = Type::infix_expr(Box::new(lhs), op, Box::new(rhs));
Type::CheckedCast { from: Box::new(infix.clone()), to: Box::new(infix) }
.canonicalize()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ impl Type {
Type::Constant(..) => panic!("Type::Constant where a type was expected: {self:?}"),
Type::Quoted(quoted_type) => UnresolvedTypeData::Quoted(*quoted_type),
Type::Error => UnresolvedTypeData::Error,
Type::InfixExpr(lhs, op, rhs) => {
Type::InfixExpr(lhs, op, rhs, _) => {
let lhs = Box::new(lhs.to_type_expression());
let rhs = Box::new(rhs.to_type_expression());
let span = Span::default();
Expand Down
37 changes: 20 additions & 17 deletions compiler/noirc_frontend/src/hir_def/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,24 +225,15 @@ impl HirMethodReference {
}
}
}
}

impl HirMethodCallExpression {
/// Converts a method call into a function call
///
/// Returns ((func_var_id, func_var), call_expr)
pub fn into_function_call(
mut self,
method: HirMethodReference,
pub fn into_function_id_and_name(
self,
object_type: Type,
is_macro_call: bool,
generics: Option<Vec<Type>>,
location: Location,
interner: &mut NodeInterner,
) -> ((ExprId, HirIdent), HirCallExpression) {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);

let (id, impl_kind) = match method {
) -> (ExprId, HirIdent) {
let (id, impl_kind) = match self {
HirMethodReference::FuncId(func_id) => {
(interner.function_definition_id(func_id), ImplKind::NotATraitMethod)
}
Expand All @@ -261,10 +252,22 @@ impl HirMethodCallExpression {
}
};
let func_var = HirIdent { location, id, impl_kind };
let func = interner.push_expr(HirExpression::Ident(func_var.clone(), self.generics));
let func = interner.push_expr(HirExpression::Ident(func_var.clone(), generics));
interner.push_expr_location(func, location.span, location.file);
let expr = HirCallExpression { func, arguments, location, is_macro_call };
((func, func_var), expr)
(func, func_var)
}
}

impl HirMethodCallExpression {
pub fn into_function_call(
mut self,
func: ExprId,
is_macro_call: bool,
location: Location,
) -> HirCallExpression {
let mut arguments = vec![self.object];
arguments.append(&mut self.arguments);
HirCallExpression { func, arguments, location, is_macro_call }
}
}

Expand Down
Loading

0 comments on commit a3b823c

Please sign in to comment.