diff --git a/crates/cairo-lang-semantic/src/expr/inference.rs b/crates/cairo-lang-semantic/src/expr/inference.rs index ae28c20d5fe..8ade7266244 100644 --- a/crates/cairo-lang-semantic/src/expr/inference.rs +++ b/crates/cairo-lang-semantic/src/expr/inference.rs @@ -336,10 +336,21 @@ pub type InferenceResult = Result; #[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)] enum InferenceErrorStatus<'db> { - Pending(InferenceError<'db>), + /// There is a pending error. + Pending(PendingInferenceError<'db>), + /// There was an error but it was already consumed. Consumed(DiagnosticAdded), } +/// A pending inference error. +#[derive(Clone, Debug, Eq, Hash, PartialEq, salsa::Update)] +struct PendingInferenceError<'db> { + /// The actual error. + err: InferenceError<'db>, + /// The optional location of the error. + stable_ptr: Option>, +} + /// A mapping of an impl var's trait items to concrete items #[derive(Debug, Default, PartialEq, Eq, Hash, Clone, SemanticObject, salsa::Update)] pub struct ImplVarTraitItemMappings<'db> { @@ -737,23 +748,18 @@ impl<'db, 'id> Inference<'db, 'id> { /// Returns whether the inference was successful. If not, the error may be found by /// `.error_state()`. pub fn solve(&mut self) -> InferenceResult<()> { - self.solve_ex().map_err(|(err_set, _)| err_set) - } - - /// Same as `solve`, but returns the error stable pointer if an error occurred. - fn solve_ex(&mut self) -> Result<(), (ErrorSet, Option>)> { let ambiguous = std::mem::take(&mut self.ambiguous).into_iter(); self.pending.extend(ambiguous.map(|(var, _)| var)); while let Some(var) = self.pending.pop_front() { // First inference error stops inference. - self.solve_single_pending(var).map_err(|err_set| { - (err_set, self.stable_ptrs.get(&InferenceVar::Impl(var)).copied()) + self.solve_single_pending(var).inspect_err(|_err_set| { + self.add_error_stable_ptr(InferenceVar::Impl(var)); })?; } while let Some(var) = self.negative_pending.pop_front() { // First inference error stops inference. - self.solve_single_negative_pending(var).map_err(|err_set| { - (err_set, self.stable_ptrs.get(&InferenceVar::NegativeImpl(var)).copied()) + self.solve_single_negative_pending(var).inspect_err(|_err_set| { + self.add_error_stable_ptr(InferenceVar::NegativeImpl(var)); })?; } Ok(()) @@ -839,12 +845,9 @@ impl<'db, 'id> Inference<'db, 'id> { /// Finalizes the inference by inferring uninferred numeric literals as felt252. /// Returns an error and does not report it. - pub fn finalize_without_reporting( - &mut self, - ) -> Result<(), (ErrorSet, Option>)> { + pub fn finalize_without_reporting(&mut self) -> Result<(), ErrorSet> { if self.error_status.is_err() { - // TODO(yuval): consider adding error location to the set error. - return Err((ErrorSet, None)); + return Err(ErrorSet); } let info = self.db.core_info(); let numeric_trait_id = info.numeric_literal_trt; @@ -853,7 +856,7 @@ impl<'db, 'id> Inference<'db, 'id> { // Conform all uninferred numeric literals to felt252. loop { let mut changed = false; - self.solve_ex()?; + self.solve()?; for (var, _) in self.ambiguous.clone() { let impl_var = self.impl_var(var).clone(); if impl_var.concrete_trait_id.trait_id(self.db) != numeric_trait_id { @@ -867,8 +870,8 @@ impl<'db, 'id> Inference<'db, 'id> { if self.rewrite(ty).no_err() == felt_ty { continue; } - self.conform_ty(ty, felt_ty).map_err(|err_set| { - (err_set, self.stable_ptrs.get(&InferenceVar::Impl(impl_var.id)).copied()) + self.conform_ty(ty, felt_ty).inspect_err(|_err_set| { + self.add_error_stable_ptr(InferenceVar::Impl(impl_var.id)); })?; changed = true; break; @@ -885,7 +888,7 @@ impl<'db, 'id> Inference<'db, 'id> { let Some((var, err)) = self.first_undetermined_variable() else { return Ok(()); }; - Err((self.set_error(err), self.stable_ptrs.get(&var).copied())) + Err(self.set_error_on_var(err, var)) } /// Finalizes the inference and report diagnostics if there are any errors. @@ -896,12 +899,8 @@ impl<'db, 'id> Inference<'db, 'id> { diagnostics: &mut SemanticDiagnostics<'db>, stable_ptr: SyntaxStablePtrId<'db>, ) { - if let Err((err_set, err_stable_ptr)) = self.finalize_without_reporting() { - let diag = self.report_on_pending_error( - err_set, - diagnostics, - err_stable_ptr.unwrap_or(stable_ptr), - ); + if let Err(err_set) = self.finalize_without_reporting() { + let diag = self.report_on_pending_error(err_set, diagnostics, stable_ptr); let ty_missing = TypeId::missing(self.db, diag); for var in &self.data.type_vars { @@ -985,7 +984,8 @@ impl<'db, 'id> Inference<'db, 'id> { } if !impl_id.is_var_free(self.db) && self.impl_contains_var(impl_id, InferenceVar::Impl(var)) { - return Err(self.set_error(InferenceError::Cycle(InferenceVar::Impl(var)))); + let inference_var = InferenceVar::Impl(var); + return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var)); } self.impl_assignment.insert(var, impl_id); if let Some(mappings) = self.impl_vars_trait_item_mappings.remove(&var) { @@ -999,9 +999,11 @@ impl<'db, 'id> Inference<'db, 'id> { let ty0 = self.rewrite(ty).no_err(); let ty1 = self.rewrite(impl_ty).no_err(); - let error = - InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 }; - self.error_status = Err(InferenceErrorStatus::Pending(error)); + let err = InferenceError::ImplTypeMismatch { impl_id, trait_type_id, ty0, ty1 }; + self.error_status = Err(InferenceErrorStatus::Pending(PendingInferenceError { + err, + stable_ptr: self.stable_ptrs.get(&InferenceVar::Impl(var)).cloned(), + })); return Err(err_set); } } @@ -1094,7 +1096,7 @@ impl<'db, 'id> Inference<'db, 'id> { assert!(!self.type_assignment.contains_key(&var.id), "Cannot reassign variable."); let inference_var = InferenceVar::Type(var.id); if !ty.is_var_free(self.db) && self.ty_contains_var(ty, inference_var) { - return Err(self.set_error(InferenceError::Cycle(inference_var))); + return Err(self.set_error_on_var(InferenceError::Cycle(inference_var), inference_var)); } // If assigning var to var - making sure assigning to the lower id for proper canonization. if let TypeLongId::Var(other) = ty.long(self.db) @@ -1322,22 +1324,51 @@ impl<'db, 'id> Inference<'db, 'id> { /// Does nothing if an error is already set. /// Returns an `ErrorSet` that can be used in reporting the error. pub fn set_error(&mut self, err: InferenceError<'db>) -> ErrorSet { + self.set_error_ex(err, None) + } + + /// Sets an error in the inference state, with an optional location for the diagnostics + /// reporting. Does nothing if an error is already set. + /// Returns an `ErrorSet` that can be used in reporting the error. + pub fn set_error_ex( + &mut self, + err: InferenceError<'db>, + stable_ptr: Option>, + ) -> ErrorSet { if self.error_status.is_err() { return ErrorSet; } self.error_status = Err(if let InferenceError::Reported(diag_added) = err { InferenceErrorStatus::Consumed(diag_added) } else { - InferenceErrorStatus::Pending(err) + InferenceErrorStatus::Pending(PendingInferenceError { err, stable_ptr }) }); ErrorSet } + /// Sets an error in the inference state, with a var to fetch location for the diagnostics + /// reporting. Does nothing if an error is already set. + /// Returns an `ErrorSet` that can be used in reporting the error. + pub fn set_error_on_var(&mut self, err: InferenceError<'db>, var: InferenceVar) -> ErrorSet { + self.set_error_ex(err, self.stable_ptrs.get(&var).cloned()) + } + /// Returns whether an error is set (either pending or consumed). pub fn is_error_set(&self) -> InferenceResult<()> { self.error_status.as_ref().copied().map_err(|_| ErrorSet) } + /// If there is no stable ptr for the pending error, add it by the given var. + fn add_error_stable_ptr(&mut self, var: InferenceVar) { + let var_stable_ptr = self.stable_ptrs.get(&var).copied(); + if let Err(InferenceErrorStatus::Pending(PendingInferenceError { err: _, stable_ptr })) = + &mut self.error_status + && stable_ptr.is_none() + { + *stable_ptr = var_stable_ptr; + } + } + /// Consumes the error but doesn't report it. If there is no error, or the error is consumed, /// returns None. This should be used with caution. Always prefer to use /// (1) `report_on_pending_error` if possible, or (2) `consume_reported_error` which is safer. @@ -1347,7 +1378,7 @@ impl<'db, 'id> Inference<'db, 'id> { &mut self, err_set: ErrorSet, ) -> Option> { - self.consume_error_inner(err_set, skip_diagnostic()) + Some(self.consume_error_inner(err_set, skip_diagnostic())?.err) } /// Consumes the error that is already reported. If there is no error, or the error is consumed, @@ -1370,10 +1401,16 @@ impl<'db, 'id> Inference<'db, 'id> { &mut self, _err_set: ErrorSet, diag_added: DiagnosticAdded, - ) -> Option> { + ) -> Option> { match &mut self.error_status { Err(InferenceErrorStatus::Pending(error)) => { - let pending_error = std::mem::replace(error, InferenceError::Reported(diag_added)); + let pending_error = std::mem::replace( + error, + PendingInferenceError { + err: InferenceError::Reported(diag_added), + stable_ptr: None, + }, + ); self.error_status = Err(InferenceErrorStatus::Consumed(diag_added)); Some(pending_error) } @@ -1398,8 +1435,8 @@ impl<'db, 'id> Inference<'db, 'id> { }; match state_error { InferenceErrorStatus::Consumed(diag_added) => *diag_added, - InferenceErrorStatus::Pending(error) => { - let diag_added = match error { + InferenceErrorStatus::Pending(pending) => { + let diag_added = match &pending.err { InferenceError::TypeNotInferred(_) if diagnostics.error_count > 0 => { // If we have other diagnostics, there is no need to TypeNotInferred. @@ -1407,7 +1444,7 @@ impl<'db, 'id> Inference<'db, 'id> { // 'DiagnosticAdded' here. skip_diagnostic() } - diag => diag.report(diagnostics, stable_ptr), + diag => diag.report(diagnostics, pending.stable_ptr.unwrap_or(stable_ptr)), }; self.error_status = Err(InferenceErrorStatus::Consumed(diag_added)); diag_added @@ -1422,7 +1459,7 @@ impl<'db, 'id> Inference<'db, 'id> { err_set: ErrorSet, report: impl FnOnce() -> DiagnosticAdded, ) { - if matches!(self.error_status, Err(InferenceErrorStatus::Pending(_))) { + if matches!(self.error_status, Err(InferenceErrorStatus::Pending { .. })) { self.consume_reported_error(err_set, report()); } } diff --git a/crates/cairo-lang-semantic/src/items/imp.rs b/crates/cairo-lang-semantic/src/items/imp.rs index 587458aaff9..68533a6ac3e 100644 --- a/crates/cairo-lang-semantic/src/items/imp.rs +++ b/crates/cairo-lang-semantic/src/items/imp.rs @@ -3052,7 +3052,7 @@ fn implicit_impl_impl_semantic_data<'db>( let impl_lookup_context = resolver.impl_lookup_context(); let resolved_impl = concrete_trait_impl_concrete_trait.and_then(|concrete_trait_id| { let imp = resolver.inference().new_impl_var(concrete_trait_id, None, impl_lookup_context); - resolver.inference().finalize_without_reporting().map_err(|(err_set, _)| { + resolver.inference().finalize_without_reporting().map_err(|err_set| { diagnostics.report( impl_def_id.stable_ptr(db).untyped(), ImplicitImplNotInferred { trait_impl_id, concrete_trait_id }, diff --git a/crates/cairo-lang-semantic/src/items/tests/trait_type b/crates/cairo-lang-semantic/src/items/tests/trait_type index 12df9f07420..3bb293b705d 100644 --- a/crates/cairo-lang-semantic/src/items/tests/trait_type +++ b/crates/cairo-lang-semantic/src/items/tests/trait_type @@ -3185,6 +3185,6 @@ struct S { //! > expected_diagnostics error: `test::M::InputType` type mismatch: `core::felt252` and `core::integer::u32`. - --> lib.cairo:12:13 -fn foo() -> S { - ^^^^ + --> lib.cairo:13:5 + S { x: 3_felt252 } + ^ diff --git a/crates/cairo-lang-semantic/src/types.rs b/crates/cairo-lang-semantic/src/types.rs index 4aef270f614..d1792df2a23 100644 --- a/crates/cairo-lang-semantic/src/types.rs +++ b/crates/cairo-lang-semantic/src/types.rs @@ -814,7 +814,7 @@ pub fn get_impl_at_context<'db>( // It's ok to consume the errors without reporting as this is a helper function meant to find an // impl and return it, but it's ok if the impl can't be found. let impl_id = inference.new_impl_var(concrete_trait_id, stable_ptr, lookup_context); - if let Err((err_set, _)) = inference.finalize_without_reporting() { + if let Err(err_set) = inference.finalize_without_reporting() { return Err(inference .consume_error_without_reporting(err_set) .expect("Error couldn't be already consumed"));