Skip to content

Commit 42069e2

Browse files
orizidean-starkware
authored andcommitted
Inferring closures types within higher order functions
1 parent 739ecc2 commit 42069e2

File tree

5 files changed

+154
-23
lines changed

5 files changed

+154
-23
lines changed

corelib/src/test/option_test.cairo

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ fn test_default_for_option() {
209209
#[test]
210210
fn test_option_some_map() {
211211
let maybe_some_string: Option<ByteArray> = Option::Some("Hello, World!");
212-
let maybe_some_len = maybe_some_string.map(|s: ByteArray| s.len());
212+
let maybe_some_len = maybe_some_string.map(|s| s.len());
213213
assert!(maybe_some_len == Option::Some(13));
214214
}
215215

crates/cairo-lang-semantic/src/db.rs

+17-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::diagnostic::SemanticDiagnosticKind;
2626
use crate::expr::inference::{self, ImplVar, ImplVarId};
2727
use crate::items::constant::{ConstCalcInfo, ConstValueId, Constant, ImplConstantId};
2828
use crate::items::function_with_body::FunctionBody;
29-
use crate::items::functions::{ImplicitPrecedence, InlineConfiguration};
29+
use crate::items::functions::{GenericFunctionId, ImplicitPrecedence, InlineConfiguration};
3030
use crate::items::generics::{GenericParam, GenericParamData, GenericParamsData};
3131
use crate::items::imp::{
3232
ImplId, ImplImplId, ImplLookupContext, ImplicitImplImplData, UninferredImpl,
@@ -1431,6 +1431,22 @@ pub trait SemanticGroup:
14311431
#[salsa::invoke(items::functions::concrete_function_signature)]
14321432
fn concrete_function_signature(&self, function_id: FunctionId) -> Maybe<semantic::Signature>;
14331433

1434+
/// Returns a mapping of closure types to their associated parameter types for a concrete
1435+
/// function.
1436+
#[salsa::invoke(items::functions::concrete_function_closure_params)]
1437+
fn concrete_function_closure_params(
1438+
&self,
1439+
function_id: FunctionId,
1440+
) -> Maybe<OrderedHashMap<semantic::TypeId, semantic::TypeId>>;
1441+
1442+
/// Returns a `HashMap` where the key is the closure type, and the value is a
1443+
/// vector of generic parameter types.
1444+
#[salsa::invoke(items::functions::get_closure_params)]
1445+
fn get_closure_params(
1446+
&self,
1447+
generic_function_id: GenericFunctionId,
1448+
) -> Maybe<OrderedHashMap<TypeId, TypeId>>;
1449+
14341450
// Generic type.
14351451
// =============
14361452
/// Returns the generic params of a generic type.

crates/cairo-lang-semantic/src/expr/compute.rs

+59-15
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ use crate::expr::inference::{ImplVarTraitItemMappings, InferenceId};
6666
use crate::items::constant::{ConstValue, resolve_const_expr_and_evaluate, validate_const_expr};
6767
use crate::items::enm::SemanticEnumEx;
6868
use crate::items::feature_kind::extract_item_feature_config;
69-
use crate::items::functions::function_signature_params;
69+
use crate::items::functions::{concrete_function_closure_params, function_signature_params};
7070
use crate::items::imp::{ImplLookupContext, filter_candidate_traits, infer_impl_by_self};
7171
use crate::items::modifiers::compute_mutability;
7272
use crate::items::us::get_use_path_segments;
@@ -424,7 +424,7 @@ pub fn maybe_compute_expr_semantic(
424424
ast::Expr::Indexed(expr) => compute_expr_indexed_semantic(ctx, expr),
425425
ast::Expr::FixedSizeArray(expr) => compute_expr_fixed_size_array_semantic(ctx, expr),
426426
ast::Expr::For(expr) => compute_expr_for_semantic(ctx, expr),
427-
ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr),
427+
ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr, None),
428428
}
429429
}
430430

@@ -882,7 +882,7 @@ fn compute_expr_function_call_semantic(
882882
let mut arg_types = vec![];
883883
for arg_syntax in args_iter {
884884
let stable_ptr = arg_syntax.stable_ptr();
885-
let arg = compute_named_argument_clause(ctx, arg_syntax);
885+
let arg = compute_named_argument_clause(ctx, arg_syntax, None);
886886
if arg.2 != Mutability::Immutable {
887887
return Err(ctx.diagnostics.report(stable_ptr, RefClosureArgument));
888888
}
@@ -930,7 +930,7 @@ fn compute_expr_function_call_semantic(
930930
let named_args: Vec<_> = args_syntax
931931
.elements(syntax_db)
932932
.into_iter()
933-
.map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax))
933+
.map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax, None))
934934
.collect();
935935
if named_args.len() != 1 {
936936
return Err(ctx.diagnostics.report(syntax, WrongNumberOfArguments {
@@ -979,16 +979,21 @@ fn compute_expr_function_call_semantic(
979979
let mut args_iter = args_syntax.elements(syntax_db).into_iter();
980980
// Normal parameters
981981
let mut named_args = vec![];
982-
for _ in function_parameter_types(ctx, function)? {
982+
let closure_params = concrete_function_closure_params(db, function)?;
983+
for ty in function_parameter_types(ctx, function)? {
983984
let Some(arg_syntax) = args_iter.next() else {
984985
continue;
985986
};
986-
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
987+
named_args.push(compute_named_argument_clause(
988+
ctx,
989+
arg_syntax,
990+
closure_params.get(&ty).cloned(),
991+
));
987992
}
988993

989994
// Maybe coupon
990995
if let Some(arg_syntax) = args_iter.next() {
991-
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
996+
named_args.push(compute_named_argument_clause(ctx, arg_syntax, None));
992997
}
993998

994999
expr_function_call(ctx, function, named_args, syntax, syntax.stable_ptr().into())
@@ -1006,6 +1011,7 @@ fn compute_expr_function_call_semantic(
10061011
pub fn compute_named_argument_clause(
10071012
ctx: &mut ComputationContext<'_>,
10081013
arg_syntax: ast::Arg,
1014+
closure_param_types: Option<TypeId>,
10091015
) -> NamedArg {
10101016
let syntax_db = ctx.db.upcast();
10111017

@@ -1018,12 +1024,24 @@ pub fn compute_named_argument_clause(
10181024
let arg_clause = arg_syntax.arg_clause(syntax_db);
10191025
let (expr, arg_name_identifier) = match arg_clause {
10201026
ast::ArgClause::Unnamed(arg_unnamed) => {
1021-
(compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None)
1027+
let arg_expr = arg_unnamed.value(syntax_db);
1028+
if let ast::Expr::Closure(expr_closure) = arg_expr {
1029+
(handle_closure_expr(ctx, &expr_closure, closure_param_types), None)
1030+
} else {
1031+
(compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None)
1032+
}
1033+
}
1034+
ast::ArgClause::Named(arg_named) => {
1035+
let arg_expr = arg_named.value(syntax_db);
1036+
if let ast::Expr::Closure(expr_closure) = arg_expr {
1037+
(handle_closure_expr(ctx, &expr_closure, closure_param_types), None)
1038+
} else {
1039+
(
1040+
compute_expr_semantic(ctx, &arg_named.value(syntax_db)),
1041+
Some(arg_named.name(syntax_db)),
1042+
)
1043+
}
10221044
}
1023-
ast::ArgClause::Named(arg_named) => (
1024-
compute_expr_semantic(ctx, &arg_named.value(syntax_db)),
1025-
Some(arg_named.name(syntax_db)),
1026-
),
10271045
ast::ArgClause::FieldInitShorthand(arg_field_init_shorthand) => {
10281046
let name_expr = arg_field_init_shorthand.name(syntax_db);
10291047
let stable_ptr: ast::ExprPtr = name_expr.stable_ptr().into();
@@ -1038,6 +1056,17 @@ pub fn compute_named_argument_clause(
10381056
NamedArg(expr, arg_name_identifier, mutability)
10391057
}
10401058

1059+
fn handle_closure_expr(
1060+
ctx: &mut ComputationContext<'_>,
1061+
expr_closure: &ast::ExprClosure,
1062+
closure_param_types: Option<TypeId>,
1063+
) -> ExprAndId {
1064+
let expr = compute_expr_closure_semantic(ctx, expr_closure, closure_param_types);
1065+
let expr = wrap_maybe_with_missing(ctx, expr, ast::ExprPtr::from(expr_closure.stable_ptr()));
1066+
let id = ctx.arenas.exprs.alloc(expr.clone());
1067+
ExprAndId { expr, id }
1068+
}
1069+
10411070
pub fn compute_root_expr(
10421071
ctx: &mut ComputationContext<'_>,
10431072
syntax: &ast::ExprBlock,
@@ -1645,6 +1674,7 @@ fn compute_loop_body_semantic(
16451674
fn compute_expr_closure_semantic(
16461675
ctx: &mut ComputationContext<'_>,
16471676
syntax: &ast::ExprClosure,
1677+
param_types: Option<TypeId>,
16481678
) -> Maybe<Expr> {
16491679
ctx.are_closures_in_context = true;
16501680
let syntax_db = ctx.db.upcast();
@@ -1663,6 +1693,14 @@ fn compute_expr_closure_semantic(
16631693
} else {
16641694
vec![]
16651695
};
1696+
let closure_type =
1697+
TypeLongId::Tuple(params.iter().map(|param| param.ty).collect()).intern(new_ctx.db);
1698+
if let Some(param_types) = param_types {
1699+
if let Err(err_set) = new_ctx.resolver.inference().conform_ty(closure_type, param_types)
1700+
{
1701+
new_ctx.resolver.inference().consume_error_without_reporting(err_set);
1702+
}
1703+
}
16661704

16671705
params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| {
16681706
new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam);
@@ -2834,16 +2872,22 @@ fn method_call_expr(
28342872
// Self argument.
28352873
let mut named_args = vec![NamedArg(fixed_lexpr, None, mutability)];
28362874
// Other arguments.
2837-
for _ in function_parameter_types(ctx, function_id)?.skip(1) {
2875+
let closure_params: OrderedHashMap<TypeId, TypeId> =
2876+
concrete_function_closure_params(ctx.db, function_id)?;
2877+
for ty in function_parameter_types(ctx, function_id)?.skip(1) {
28382878
let Some(arg_syntax) = args_iter.next() else {
28392879
break;
28402880
};
2841-
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
2881+
named_args.push(compute_named_argument_clause(
2882+
ctx,
2883+
arg_syntax,
2884+
closure_params.get(&ty).cloned(),
2885+
));
28422886
}
28432887

28442888
// Maybe coupon
28452889
if let Some(arg_syntax) = args_iter.next() {
2846-
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
2890+
named_args.push(compute_named_argument_clause(ctx, arg_syntax, None));
28472891
}
28482892

28492893
expr_function_call(ctx, function_id, named_args, &expr, stable_ptr)

crates/cairo-lang-semantic/src/expr/test_data/closure

+20
Original file line numberDiff line numberDiff line change
@@ -768,3 +768,23 @@ error: Closure parameters cannot be references
768768
--> lib.cairo:2:14
769769
let _ = |ref a| {
770770
^^^^^
771+
772+
//! > ==========================================================================
773+
774+
//! > Passing closures as args with less explicit typing.
775+
776+
//! > test_runner_name
777+
test_function_diagnostics(expect_diagnostics: false)
778+
779+
//! > function
780+
fn foo() -> Option<u32> {
781+
let x: Option<Array<i32>> = Option::Some(array![1, 2, 3]);
782+
x.map(|x| x.len())
783+
}
784+
785+
//! > function_name
786+
foo
787+
788+
//! > module_code
789+
790+
//! > expected_diagnostics

crates/cairo-lang-semantic/src/items/functions.rs

+57-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::fmt::Debug;
21
use std::sync::Arc;
32

43
use cairo_lang_debug::DebugWithDb;
@@ -14,6 +13,7 @@ use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
1413
use cairo_lang_syntax as syntax;
1514
use cairo_lang_syntax::attribute::structured::Attribute;
1615
use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
16+
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
1717
use cairo_lang_utils::{
1818
Intern, LookupIntern, OptionFrom, define_short_id, require, try_extract_matches,
1919
};
@@ -27,16 +27,16 @@ use super::generics::{fmt_generic_args, generic_params_to_args};
2727
use super::imp::{ImplId, ImplLongId};
2828
use super::modifiers;
2929
use super::trt::ConcreteTraitGenericFunctionId;
30-
use crate::corelib::{panic_destruct_trait_fn, unit_ty};
30+
use crate::corelib::{fn_traits, panic_destruct_trait_fn, unit_ty};
3131
use crate::db::SemanticGroup;
3232
use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
3333
use crate::expr::compute::Environment;
3434
use crate::resolve::{Resolver, ResolverData};
3535
use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter};
3636
use crate::types::resolve_type;
3737
use crate::{
38-
ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericParam, SemanticDiagnostic,
39-
TypeId, semantic, semantic_object_for_id,
38+
ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericArgumentId, GenericParam,
39+
SemanticDiagnostic, TypeId, semantic, semantic_object_for_id,
4040
};
4141

4242
/// A generic function of an impl.
@@ -146,8 +146,11 @@ impl GenericFunctionId {
146146
GenericFunctionId::Extern(id) => db.extern_function_declaration_generic_params(id),
147147
GenericFunctionId::Impl(id) => {
148148
let concrete_trait_id = db.impl_concrete_trait(id.impl_id)?;
149-
let id = ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function);
150-
db.concrete_trait_function_generic_params(id)
149+
let concrete_id =
150+
ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function);
151+
let substitution = GenericSubstitution::from_impl(id.impl_id);
152+
let mut rewriter = SubstitutionRewriter { db, substitution: &substitution };
153+
rewriter.rewrite(db.concrete_trait_function_generic_params(concrete_id)?)
151154
}
152155
GenericFunctionId::Trait(id) => db.concrete_trait_function_generic_params(id),
153156
}
@@ -860,6 +863,19 @@ pub fn concrete_function_signature(
860863
SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_signature)
861864
}
862865

866+
/// Query implementation of [crate::db::SemanticGroup::concrete_function_closure_params].
867+
pub fn concrete_function_closure_params(
868+
db: &dyn SemanticGroup,
869+
function_id: FunctionId,
870+
) -> Maybe<OrderedHashMap<semantic::TypeId, semantic::TypeId>> {
871+
let ConcreteFunction { generic_function, generic_args, .. } =
872+
function_id.lookup_intern(db).function;
873+
let generic_params = generic_function.generic_params(db)?;
874+
let generic_closure_params = db.get_closure_params(generic_function)?;
875+
let substitution = GenericSubstitution::new(&generic_params, &generic_args);
876+
SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_closure_params)
877+
}
878+
863879
/// For a given list of AST parameters, returns the list of semantic parameters along with the
864880
/// corresponding environment.
865881
fn update_env_with_ast_params(
@@ -1010,3 +1026,38 @@ impl FromIterator<TypeId> for ImplicitPrecedence {
10101026
Self(Vec::from_iter(iter))
10111027
}
10121028
}
1029+
1030+
/// This function retrieves a mapping of closure types to their associated parameter types.
1031+
/// It analyzes the generic parameters of the current context
1032+
/// to identify any closures and their respective parameter types. It checks
1033+
/// for `Fn`, `FnMut`, or `FnOnce` traits among the generic parameters and
1034+
/// returns a `HashMap` where the key is the closure type, and the value is a
1035+
/// vector of parameter types.
1036+
pub fn get_closure_params(
1037+
db: &dyn SemanticGroup,
1038+
generic_function_id: GenericFunctionId,
1039+
) -> Maybe<OrderedHashMap<TypeId, TypeId>> {
1040+
let mut closure_params_map = OrderedHashMap::default();
1041+
let generic_params = generic_function_id.generic_params(db)?;
1042+
1043+
for param in generic_params {
1044+
if let GenericParam::Impl(generic_param_impl) = param {
1045+
let trait_id = generic_param_impl.concrete_trait?.trait_id(db);
1046+
1047+
if fn_traits(db).contains(&trait_id) {
1048+
if let Ok(concrete_trait) = generic_param_impl.concrete_trait {
1049+
let [
1050+
GenericArgumentId::Type(closure_type),
1051+
GenericArgumentId::Type(params_type),
1052+
] = *concrete_trait.generic_args(db)
1053+
else {
1054+
unreachable!()
1055+
};
1056+
1057+
closure_params_map.insert(closure_type, params_type);
1058+
}
1059+
}
1060+
}
1061+
}
1062+
Ok(closure_params_map)
1063+
}

0 commit comments

Comments
 (0)