Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow associated types to be ellided from trait constraints #7026

Merged
merged 14 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 124 additions & 23 deletions compiler/noirc_frontend/src/elaborator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ impl<'context> Elaborator<'context> {

self.define_function_metas(&mut items.functions, &mut items.impls, &mut items.trait_impls);

self.collect_traits(&items.traits);
self.collect_traits(&mut items.traits);

// Before we resolve any function symbols we must go through our impls and
// re-collect the methods within into their proper module. This cannot be
Expand Down Expand Up @@ -354,6 +354,7 @@ impl<'context> Elaborator<'context> {
self.current_trait = Some(trait_id);
self.elaborate_functions(unresolved_trait.fns_with_default_impl);
}

self.current_trait = None;

for impls in items.impls.into_values() {
Expand Down Expand Up @@ -475,7 +476,7 @@ impl<'context> Elaborator<'context> {
self.add_existing_variable_to_scope(name, parameter.clone(), warn_if_unused);
}

self.add_trait_constraints_to_scope(&func_meta);
self.add_trait_constraints_to_scope(&func_meta.trait_constraints, func_meta.location.span);

let (hir_func, body_type) = match kind {
FunctionKind::Builtin
Expand All @@ -501,7 +502,7 @@ impl<'context> Elaborator<'context> {
// when multiple impls are available. Instead we default first to choose the Field or u64 impl.
self.check_and_pop_function_context();

self.remove_trait_constraints_from_scope(&func_meta);
self.remove_trait_constraints_from_scope(&func_meta.trait_constraints);

let func_scope_tree = self.scopes.end_function();

Expand Down Expand Up @@ -733,8 +734,13 @@ impl<'context> Elaborator<'context> {
None
}

/// TODO: This is currently only respected for generic free functions
/// there's a bunch of other places where trait constraints can pop up
/// Resolve the given trait constraints and add them to scope as we go.
/// This second step is necessary to resolve subsequent constraints such
/// as `<T as Foo>::Bar: Eq` which may lookup an impl which was assumed
/// by a previous constraint.
///
/// If these constraints are unwanted afterward they should be manually
/// removed from the interner.
fn resolve_trait_constraints(
&mut self,
where_clause: &[UnresolvedTraitConstraint],
Expand All @@ -745,12 +751,92 @@ impl<'context> Elaborator<'context> {
.collect()
}

pub fn resolve_trait_constraint(
/// Expands any traits in a where clause to mention all associated types if they were
/// elided by the user. See `add_missing_named_generics` for more detail.
///
/// Returns all newly created generics to be added to this function/trait/impl.
fn desugar_trait_constraints(
&mut self,
where_clause: &mut [UnresolvedTraitConstraint],
) -> Vec<ResolvedGeneric> {
where_clause
.iter_mut()
.flat_map(|constraint| self.add_missing_named_generics(&mut constraint.trait_bound))
.collect()
}

/// For each associated type that isn't mentioned in a trait bound, this adds
/// the type as an implicit generic to the where clause and returns the newly
/// created generics in a vector to add to the function/trait/impl later.
/// For example, this will turn a function using a trait with 2 associated types:
///
/// `fn foo<T>() where T: Foo { ... }`
///
/// into:
/// `fn foo<T>() where T: Foo<Bar = A, Baz = B> { ... }`
///
/// with a vector of `<A, B>` returned so that the caller can then modify the function to:
/// `fn foo<T, A, B>() where T: Foo<Bar = A, Baz = B> { ... }`
fn add_missing_named_generics(&mut self, bound: &mut TraitBound) -> Vec<ResolvedGeneric> {
let mut added_generics = Vec::new();

let Ok(item) = self.resolve_path_or_error(bound.trait_path.clone()) else {
return Vec::new();
};

let PathResolutionItem::Trait(trait_id) = item else {
return Vec::new();
};

let the_trait = self.get_trait_mut(trait_id);

if the_trait.associated_types.len() > bound.trait_generics.named_args.len() {
for associated_type in &the_trait.associated_types.clone() {
if !bound
.trait_generics
.named_args
.iter()
.any(|(name, _)| name.0.contents == *associated_type.name.as_ref())
{
// This generic isn't contained in the bound's named arguments,
// so add it by creating a fresh type variable.
let new_generic_id = self.interner.next_type_variable_id();
let type_var = TypeVariable::unbound(new_generic_id, Kind::Normal);
jfecher marked this conversation as resolved.
Show resolved Hide resolved

let span = bound.trait_path.span;
let name = associated_type.name.clone();
let typ = Type::NamedGeneric(type_var.clone(), name.clone());
let typ = self.interner.push_quoted_type(typ);
let typ = UnresolvedTypeData::Resolved(typ).with_span(span);
let ident = Ident::new(associated_type.name.as_ref().clone(), span);

bound.trait_generics.named_args.push((ident, typ));
added_generics.push(ResolvedGeneric { name, span, type_var });
}
}
}

added_generics
}

/// Resolves a trait constraint and adds it to scope as an assumed impl.
/// This second step is necessary to resolve subsequent constraints such
/// as `<T as Foo>::Bar: Eq` which may lookup an impl which was assumed
/// by a previous constraint.
fn resolve_trait_constraint(
&mut self,
constraint: &UnresolvedTraitConstraint,
) -> Option<TraitConstraint> {
let typ = self.resolve_type(constraint.typ.clone());
let trait_bound = self.resolve_trait_bound(&constraint.trait_bound)?;

self.add_trait_bound_to_scope(
constraint.trait_bound.trait_path.span,
&typ,
&trait_bound,
trait_bound.trait_id,
);

Some(TraitConstraint { typ, trait_bound })
}

Expand Down Expand Up @@ -800,10 +886,13 @@ impl<'context> Elaborator<'context> {
let has_inline_attribute = has_no_predicates_attribute || should_fold;
let is_pub_allowed = self.pub_allowed(func, in_contract);
self.add_generics(&func.def.generics);
let mut generics = vecmap(&self.generics, |generic| generic.type_var.clone());

let new_generics = self.desugar_trait_constraints(&mut func.def.where_clause);
generics.extend(new_generics.into_iter().map(|generic| generic.type_var));

let mut trait_constraints = self.resolve_trait_constraints(&func.def.where_clause);

let mut generics = vecmap(&self.generics, |generic| generic.type_var.clone());
let mut parameters = Vec::new();
let mut parameter_types = Vec::new();
let mut parameter_idents = Vec::new();
Expand Down Expand Up @@ -874,6 +963,9 @@ impl<'context> Elaborator<'context> {
None
};

// Remove the traits assumed by `resolve_trait_constraints` from scope
self.remove_trait_constraints_from_scope(&trait_constraints);

let meta = FuncMeta {
name: name_ident,
kind: func.kind,
Expand Down Expand Up @@ -1013,10 +1105,10 @@ impl<'context> Elaborator<'context> {
}
}

fn add_trait_constraints_to_scope(&mut self, func_meta: &FuncMeta) {
for constraint in &func_meta.trait_constraints {
fn add_trait_constraints_to_scope(&mut self, constraints: &[TraitConstraint], span: Span) {
for constraint in constraints {
self.add_trait_bound_to_scope(
func_meta,
span,
&constraint.typ,
&constraint.trait_bound,
constraint.trait_bound.trait_id,
Expand All @@ -1030,16 +1122,16 @@ impl<'context> Elaborator<'context> {
let self_type =
self.self_type.clone().expect("Expected a self type if there's a current trait");
self.add_trait_bound_to_scope(
func_meta,
span,
&self_type,
&constraint.trait_bound,
constraint.trait_bound.trait_id,
);
}
}

fn remove_trait_constraints_from_scope(&mut self, func_meta: &FuncMeta) {
for constraint in &func_meta.trait_constraints {
fn remove_trait_constraints_from_scope(&mut self, constraints: &[TraitConstraint]) {
for constraint in constraints {
self.interner
.remove_assumed_trait_implementations_for_trait(constraint.trait_bound.trait_id);
}
Expand All @@ -1052,7 +1144,7 @@ impl<'context> Elaborator<'context> {

fn add_trait_bound_to_scope(
&mut self,
func_meta: &FuncMeta,
span: Span,
object: &Type,
trait_bound: &ResolvedTraitBound,
starting_trait_id: TraitId,
Expand All @@ -1064,7 +1156,6 @@ impl<'context> Elaborator<'context> {
if let Some(the_trait) = self.interner.try_get_trait(trait_id) {
let trait_name = the_trait.name.to_string();
let typ = object.clone();
let span = func_meta.location.span;
self.push_err(TypeCheckError::UnneededTraitConstraint { trait_name, typ, span });
}
}
Expand All @@ -1081,12 +1172,7 @@ impl<'context> Elaborator<'context> {

let parent_trait_bound =
self.instantiate_parent_trait_bound(trait_bound, &parent_trait_bound);
self.add_trait_bound_to_scope(
func_meta,
object,
&parent_trait_bound,
starting_trait_id,
);
self.add_trait_bound_to_scope(span, object, &parent_trait_bound, starting_trait_id);
}
}
}
Expand Down Expand Up @@ -1316,6 +1402,7 @@ impl<'context> Elaborator<'context> {
self.generics = trait_impl.resolved_generics.clone();

let where_clause = self.resolve_trait_constraints(&trait_impl.where_clause);
self.remove_trait_constraints_from_scope(&where_clause);

self.collect_trait_impl_methods(trait_id, trait_impl, &where_clause);

Expand Down Expand Up @@ -1811,6 +1898,17 @@ impl<'context> Elaborator<'context> {
self.add_generics(&trait_impl.generics);
trait_impl.resolved_generics = self.generics.clone();

let new_generics = self.desugar_trait_constraints(&mut trait_impl.where_clause);
for new_generic in new_generics {
trait_impl.resolved_generics.push(new_generic.clone());
self.generics.push(new_generic);
}

// We need to resolve the where clause before any associated types to be
// able to resolve trait as type syntax, eg. `<T as Foo>` in case there
// is a where constraint for `T: Foo`.
let constraints = self.resolve_trait_constraints(&trait_impl.where_clause);

for (_, _, method) in trait_impl.methods.functions.iter_mut() {
// Attach any trait constraints on the impl to the function
method.def.where_clause.append(&mut trait_impl.where_clause.clone());
Expand All @@ -1823,17 +1921,20 @@ impl<'context> Elaborator<'context> {
let impl_id = self.interner.next_trait_impl_id();
self.current_trait_impl = Some(impl_id);

// Fetch trait constraints here
let path_span = trait_impl.trait_path.span;
let (ordered_generics, named_generics) = trait_impl
.trait_id
.map(|trait_id| {
self.resolve_type_args(trait_generics, trait_id, trait_impl.trait_path.span)
// Check for missing generics & associated types for the trait being implemented
self.resolve_trait_args_from_trait_impl(trait_generics, trait_id, path_span)
})
.unwrap_or_default();

trait_impl.resolved_trait_generics = ordered_generics;
self.interner.set_associated_types_for_impl(impl_id, named_generics);

self.remove_trait_constraints_from_scope(&constraints);

let self_type = self.resolve_type(unresolved_type);
self.self_type = Some(self_type.clone());
trait_impl.methods.self_type = Some(self_type);
Expand Down
7 changes: 6 additions & 1 deletion compiler/noirc_frontend/src/elaborator/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use crate::{
use super::Elaborator;

impl<'context> Elaborator<'context> {
pub fn collect_traits(&mut self, traits: &BTreeMap<TraitId, UnresolvedTrait>) {
pub fn collect_traits(&mut self, traits: &mut BTreeMap<TraitId, UnresolvedTrait>) {
for (trait_id, unresolved_trait) in traits {
self.local_module = unresolved_trait.module_id;

Expand All @@ -39,8 +39,13 @@ impl<'context> Elaborator<'context> {
&resolved_generics,
);

let new_generics =
this.desugar_trait_constraints(&mut unresolved_trait.trait_def.where_clause);
this.generics.extend(new_generics);

let where_clause =
this.resolve_trait_constraints(&unresolved_trait.trait_def.where_clause);
this.remove_trait_constraints_from_scope(&where_clause);

// Each associated type in this trait is also an implicit generic
for associated_type in &this.interner.get_trait(*trait_id).associated_types {
Expand Down
Loading
Loading