Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 75 additions & 38 deletions crates/cairo-lang-semantic/src/expr/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,21 @@ pub type InferenceResult<T> = Result<T, ErrorSet>;

#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
enum InferenceErrorStatus<'db> {
Pending(InferenceError<'db>),
/// There is a pending error.
Pending(PendingInferenceError<'db>),
/// There was an error but it was already consumed.
Consumed(DiagnosticAdded),
}

/// A pending inference error.
#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)]
struct PendingInferenceError<'db> {
/// The actual error.
err: InferenceError<'db>,
/// The optional location of the error.
stable_ptr: Option<SyntaxStablePtrId<'db>>,
}

/// A mapping of an impl var's trait items to concrete items
#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject, salsa::Update)]
pub struct ImplVarTraitItemMappings<'db> {
Expand Down Expand Up @@ -737,23 +748,18 @@ impl<'db, 'id> Inference<'db, 'id> {
/// Returns whether the inference was successful. If not, the error may be found by
/// `.error_state()`.
pub fn solve(&mut self) -> InferenceResult<()> {
self.solve_ex().map_err(|(err_set, _)| err_set)
}

/// Same as `solve`, but returns the error stable pointer if an error occurred.
fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId<'db>>)> {
let ambiguous = std::mem::take(&mut self.ambiguous).into_iter();
self.pending.extend(ambiguous.map(|(var, _)| var));
while let Some(var) = self.pending.pop_front() {
// First inference error stops inference.
self.solve_single_pending(var).map_err(|err_set| {
(err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied())
self.solve_single_pending(var).inspect_err(|_err_set| {
self.add_error_stable_ptr(InferenceVar::Impl(var));
})?;
}
while let Some(var) = self.negative_pending.pop_front() {
// First inference error stops inference.
self.solve_single_negative_pending(var).map_err(|err_set| {
(err_set, self.stable_ptrs.get(&InferenceVar::NegativeImpl(var)).copied())
self.solve_single_negative_pending(var).inspect_err(|_err_set| {
self.add_error_stable_ptr(InferenceVar::NegativeImpl(var));
})?;
}
Ok(())
Expand Down Expand Up @@ -839,12 +845,9 @@ impl<'db, 'id> Inference<'db, 'id> {

/// Finalizes the inference by inferring uninferred numeric literals as felt252.
/// Returns an error and does not report it.
pub fn finalize_without_reporting(
&mut self,
) -> Result<(), (ErrorSet, Option<SyntaxStablePtrId<'db>>)> {
pub fn finalize_without_reporting(&mut self) -> Result<(), ErrorSet> {
if self.error_status.is_err() {
// TODO(yuval): consider adding error location to the set error.
return Err((ErrorSet, None));
return Err(ErrorSet);
}
let info = self.db.core_info();
let numeric_trait_id = info.numeric_literal_trt;
Expand All @@ -853,7 +856,7 @@ impl<'db, 'id> Inference<'db, 'id> {
// Conform all uninferred numeric literals to felt252.
loop {
let mut changed = false;
self.solve_ex()?;
self.solve()?;
for (var, _) in self.ambiguous.clone() {
let impl_var = self.impl_var(var).clone();
if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id {
Expand All @@ -867,8 +870,8 @@ impl<'db, 'id> Inference<'db, 'id> {
if self.rewrite(ty).no_err() == felt_ty {
continue;
}
self.conform_ty(ty, felt_ty).map_err(|err_set| {
(err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied())
self.conform_ty(ty, felt_ty).inspect_err(|_err_set| {
self.add_error_stable_ptr(InferenceVar::Impl(impl_var.id));
})?;
changed = true;
break;
Expand All @@ -885,7 +888,7 @@ impl<'db, 'id> Inference<'db, 'id> {
let Some((var, err)) = self.first_undetermined_variable() else {
return Ok(());
};
Err((self.set_error(err), self.stable_ptrs.get(&var).copied()))
Err(self.set_error_on_var(err, var))
}

/// Finalizes the inference and report diagnostics if there are any errors.
Expand All @@ -896,12 +899,8 @@ impl<'db, 'id> Inference<'db, 'id> {
diagnostics: &mut SemanticDiagnostics<'db>,
stable_ptr: SyntaxStablePtrId<'db>,
) {
if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() {
let diag = self.report_on_pending_error(
err_set,
diagnostics,
err_stable_ptr.unwrap_or(stable_ptr),
);
if let Err(err_set) = self.finalize_without_reporting() {
let diag = self.report_on_pending_error(err_set, diagnostics, stable_ptr);

let ty_missing = TypeId::missing(self.db, diag);
for var in &self.data.type_vars {
Expand Down Expand Up @@ -985,7 +984,8 @@ impl<'db, 'id> Inference<'db, 'id> {
}
if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var))
{
return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var))));
let inference_var = InferenceVar::Impl(var);
return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
}
self.impl_assignment.insert(var, impl_id);
if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) {
Expand All @@ -999,9 +999,11 @@ impl<'db, 'id> Inference<'db, 'id> {
let ty0 = self.rewrite(ty).no_err();
let ty1 = self.rewrite(impl_ty).no_err();

let error =
InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
self.error_status = Err(InferenceErrorStatus::Pending(error));
let err = InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 };
self.error_status = Err(InferenceErrorStatus::Pending(PendingInferenceError {
err,
stable_ptr: self.stable_ptrs.get(&InferenceVar::Impl(var)).cloned(),
}));
return Err(err_set);
}
}
Expand Down Expand Up @@ -1094,7 +1096,7 @@ impl<'db, 'id> Inference<'db, 'id> {
assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable.");
let inference_var = InferenceVar::Type(var.id);
if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) {
return Err(self.set_error(InferenceError::Cycle(inference_var)));
return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var));
}
// If assigning var to var - making sure assigning to the lower id for proper canonization.
if let TypeLongId::Var(other) = ty.long(self.db)
Expand Down Expand Up @@ -1322,22 +1324,51 @@ impl<'db, 'id> Inference<'db, 'id> {
/// Does nothing if an error is already set.
/// Returns an `ErrorSet` that can be used in reporting the error.
pub fn set_error(&mut self, err: InferenceError<'db>) -> ErrorSet {
self.set_error_ex(err, None)
}

/// Sets an error in the inference state, with an optional location for the diagnostics
/// reporting. Does nothing if an error is already set.
/// Returns an `ErrorSet` that can be used in reporting the error.
pub fn set_error_ex(
&mut self,
err: InferenceError<'db>,
stable_ptr: Option<SyntaxStablePtrId<'db>>,
) -> ErrorSet {
if self.error_status.is_err() {
return ErrorSet;
}
self.error_status = Err(if let InferenceError::Reported(diag_added) = err {
InferenceErrorStatus::Consumed(diag_added)
} else {
InferenceErrorStatus::Pending(err)
InferenceErrorStatus::Pending(PendingInferenceError { err, stable_ptr })
});
ErrorSet
}

/// Sets an error in the inference state, with a var to fetch location for the diagnostics
/// reporting. Does nothing if an error is already set.
/// Returns an `ErrorSet` that can be used in reporting the error.
pub fn set_error_on_var(&mut self, err: InferenceError<'db>, var: InferenceVar) -> ErrorSet {
self.set_error_ex(err, self.stable_ptrs.get(&var).cloned())
}

/// Returns whether an error is set (either pending or consumed).
pub fn is_error_set(&self) -> InferenceResult<()> {
self.error_status.as_ref().copied().map_err(|_| ErrorSet)
}

/// If there is no stable ptr for the pending error, add it by the given var.
fn add_error_stable_ptr(&mut self, var: InferenceVar) {
let var_stable_ptr = self.stable_ptrs.get(&var).copied();
if let Err(InferenceErrorStatus::Pending(PendingInferenceError { err: _, stable_ptr })) =
&mut self.error_status
&& stable_ptr.is_none()
{
*stable_ptr = var_stable_ptr;
}
}

/// Consumes the error but doesn't report it. If there is no error, or the error is consumed,
/// returns None. This should be used with caution. Always prefer to use
/// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer.
Expand All @@ -1347,7 +1378,7 @@ impl<'db, 'id> Inference<'db, 'id> {
&mut self,
err_set: ErrorSet,
) -> Option<InferenceError<'db>> {
self.consume_error_inner(err_set, skip_diagnostic())
Some(self.consume_error_inner(err_set, skip_diagnostic())?.err)
}

/// Consumes the error that is already reported. If there is no error, or the error is consumed,
Expand All @@ -1370,10 +1401,16 @@ impl<'db, 'id> Inference<'db, 'id> {
&mut self,
_err_set: ErrorSet,
diag_added: DiagnosticAdded,
) -> Option<InferenceError<'db>> {
) -> Option<PendingInferenceError<'db>> {
match &mut self.error_status {
Err(InferenceErrorStatus::Pending(error)) => {
let pending_error = std::mem::replace(error, InferenceError::Reported(diag_added));
let pending_error = std::mem::replace(
error,
PendingInferenceError {
err: InferenceError::Reported(diag_added),
stable_ptr: None,
},
);
self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
Some(pending_error)
}
Expand All @@ -1398,16 +1435,16 @@ impl<'db, 'id> Inference<'db, 'id> {
};
match state_error {
InferenceErrorStatus::Consumed(diag_added) => *diag_added,
InferenceErrorStatus::Pending(error) => {
let diag_added = match error {
InferenceErrorStatus::Pending(pending) => {
let diag_added = match &pending.err {
InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => {
// If we have other diagnostics, there is no need to TypeNotInferred.

// Note that `diagnostics` is not empty, so it is safe to return
// 'DiagnosticAdded' here.
skip_diagnostic()
}
diag => diag.report(diagnostics, stable_ptr),
diag => diag.report(diagnostics, pending.stable_ptr.unwrap_or(stable_ptr)),
};
self.error_status = Err(InferenceErrorStatus::Consumed(diag_added));
diag_added
Expand All @@ -1422,7 +1459,7 @@ impl<'db, 'id> Inference<'db, 'id> {
err_set: ErrorSet,
report: impl FnOnce() -> DiagnosticAdded,
) {
if matches!(self.error_status, Err(InferenceErrorStatus::Pending(_))) {
if matches!(self.error_status, Err(InferenceErrorStatus::Pending { .. })) {
self.consume_reported_error(err_set, report());
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/cairo-lang-semantic/src/items/imp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3052,7 +3052,7 @@ fn implicit_impl_impl_semantic_data<'db>(
let impl_lookup_context = resolver.impl_lookup_context();
let resolved_impl = concrete_trait_impl_concrete_trait.and_then(|concrete_trait_id| {
let imp = resolver.inference().new_impl_var(concrete_trait_id, None, impl_lookup_context);
resolver.inference().finalize_without_reporting().map_err(|(err_set, _)| {
resolver.inference().finalize_without_reporting().map_err(|err_set| {
diagnostics.report(
impl_def_id.stable_ptr(db).untyped(),
ImplicitImplNotInferred { trait_impl_id, concrete_trait_id },
Expand Down
6 changes: 3 additions & 3 deletions crates/cairo-lang-semantic/src/items/tests/trait_type
Original file line number Diff line number Diff line change
Expand Up @@ -3185,6 +3185,6 @@ struct S<impl M: MyTrait> {

//! > expected_diagnostics
error: `test::M::InputType` type mismatch: `core::felt252` and `core::integer::u32`.
--> lib.cairo:12:13
fn foo() -> S<M> {
^^^^
--> lib.cairo:13:5
S { x: 3_felt252 }
^
2 changes: 1 addition & 1 deletion crates/cairo-lang-semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ pub fn get_impl_at_context<'db>(
// It's ok to consume the errors without reporting as this is a helper function meant to find an
// impl and return it, but it's ok if the impl can't be found.
let impl_id = inference.new_impl_var(concrete_trait_id, stable_ptr, lookup_context);
if let Err((err_set, _)) = inference.finalize_without_reporting() {
if let Err(err_set) = inference.finalize_without_reporting() {
return Err(inference
.consume_error_without_reporting(err_set)
.expect("Error couldn't be already consumed"));
Expand Down