diff --git a/specl/crates/specl-symbolic/src/encoder.rs b/specl/crates/specl-symbolic/src/encoder.rs index cdf245e..06a5fba 100644 --- a/specl/crates/specl-symbolic/src/encoder.rs +++ b/specl/crates/specl-symbolic/src/encoder.rs @@ -568,82 +568,131 @@ impl<'a> EncoderCtx<'a> { base: &CompiledExpr, index: &CompiledExpr, ) -> SymbolicResult { - // Handle nested index: d[k][i] where d is Dict[Range, Seq[T]] - if let CompiledExpr::Index { - base: outer_base, - index: outer_key, - } = base + // Flatten index chain: d[k1][k2][k3] → (base_expr, [k1, k2, k3]) + let mut keys: Vec<&CompiledExpr> = vec![index]; + let mut current: &CompiledExpr = base; + while let CompiledExpr::Index { + base: inner, + index: inner_key, + } = current { - return self.encode_nested_index(outer_base, outer_key, index); + keys.push(inner_key.as_ref()); + current = inner.as_ref(); } + keys.reverse(); - // Handle Local(n)[i] for compound locals (seq within dict) - if let CompiledExpr::Local(local_idx) = base { + // Handle compound Local(n)[keys...]: local is bound to a dict slot + if let CompiledExpr::Local(local_idx) = current { if let Some((var_idx, step, key_z3)) = self .resolve_compound_local(*local_idx) .map(|(v, s, k)| (*v, *s, k.clone())) { - return self.encode_compound_local_index(var_idx, step, &key_z3, index); + return self.encode_compound_local_chain(var_idx, step, &key_z3, &keys); } } - let (specl_var_idx, step) = self.extract_var_step(base)?; + let (specl_var_idx, step) = self.extract_var_step(current)?; let entry = &self.layout.entries[specl_var_idx]; - match &entry.kind { + let z3_vars = &self.step_vars[step][specl_var_idx]; + self.resolve_index_chain(&entry.kind, z3_vars, &keys) + } + + /// Recursively resolve an index chain on a compound VarKind. + fn resolve_index_chain( + &mut self, + kind: &VarKind, + vars: &[Dynamic], + keys: &[&CompiledExpr], + ) -> SymbolicResult { + if keys.is_empty() { + return if vars.len() == 1 { + Ok(vars[0].clone()) + } else { + Err(SymbolicError::Encoding( + "index chain incomplete: compound value requires more indices".into(), + )) + }; + } + + match kind { VarKind::ExplodedDict { key_lo, key_hi, value_kind, } => { + let stride = value_kind.z3_var_count(); let key_lo = *key_lo; let key_hi = *key_hi; - let z3_vars = &self.step_vars[step][specl_var_idx]; - let stride = value_kind.z3_var_count(); + let remaining = &keys[1..]; - if stride == 1 { - // Simple value kind (Bool/Int): one var per key - if let Some(concrete_key) = self.try_concrete_int(index) { + // Last key and scalar value: simple access + if remaining.is_empty() && stride == 1 { + if let Some(concrete_key) = self.try_concrete_int(keys[0]) { let offset = (concrete_key - key_lo) as usize; - if offset < z3_vars.len() { - Ok(z3_vars[offset].clone()) + return if offset < vars.len() { + Ok(vars[offset].clone()) } else { Err(SymbolicError::Encoding(format!( "dict key {} out of range [{}, {}]", concrete_key, key_lo, key_hi ))) - } + }; } else { - let key_z3 = self.encode_int(index)?; - self.build_ite_chain(&key_z3, z3_vars, key_lo) + let key_z3 = self.encode_int(keys[0])?; + return self.build_ite_chain(&key_z3, vars, key_lo); } - } else { - // Compound value kind: d[k] is not a scalar - Err(SymbolicError::Encoding(format!( - "dict variable '{}' has compound values; use nested index (d[k][i]) or len(d[k])", - entry.name - ))) } - } - VarKind::ExplodedSet { lo, .. } => { - let lo = *lo; - let z3_vars = &self.step_vars[step][specl_var_idx]; - if let Some(concrete_key) = self.try_concrete_int(index) { - let offset = (concrete_key - lo) as usize; - if offset < z3_vars.len() { - Ok(z3_vars[offset].clone()) - } else { - Ok(Dynamic::from_ast(&Bool::from_bool(false))) - } + + // Compound value needs more indices + if remaining.is_empty() { + return Err(SymbolicError::Encoding( + "dict has compound values; use nested index (d[k1][k2]) or len(d[k])" + .into(), + )); + } + + // Descend into value_kind with remaining keys + if let Some(concrete_key) = self.try_concrete_int(keys[0]) { + let offset = (concrete_key - key_lo) as usize * stride; + let slot_vars = &vars[offset..offset + stride]; + self.resolve_index_chain(value_kind, slot_vars, remaining) } else { - let key_z3 = self.encode_int(index)?; - self.build_ite_chain(&key_z3, z3_vars, lo) + let key_z3 = self.encode_int(keys[0])?; + let num_keys = (key_hi - key_lo + 1) as usize; + let mut result: Option = None; + for k in (0..num_keys).rev() { + let offset = k * stride; + let slot_vars = &vars[offset..offset + stride]; + let val = self.resolve_index_chain(value_kind, slot_vars, remaining)?; + let k_z3 = Int::from_i64(key_lo + k as i64); + let cond = key_z3.eq(&k_z3); + result = Some(match result { + None => val, + Some(prev) => { + if let (Some(vi), Some(pi)) = (val.as_int(), prev.as_int()) { + Dynamic::from_ast(&cond.ite(&vi, &pi)) + } else if let (Some(vb), Some(pb)) = (val.as_bool(), prev.as_bool()) + { + Dynamic::from_ast(&cond.ite(&vb, &pb)) + } else { + prev + } + } + }); + } + result + .ok_or_else(|| SymbolicError::Encoding("empty dict for index chain".into())) } } VarKind::ExplodedSeq { max_len, .. } => { + if keys.len() != 1 { + return Err(SymbolicError::Encoding( + "cannot index deeper into sequence elements".into(), + )); + } let max_len = *max_len; - let z3_vars = &self.step_vars[step][specl_var_idx]; - let elem_vars = &z3_vars[1..1 + max_len]; - if let Some(concrete_idx) = self.try_concrete_int(index) { + let elem_vars = &vars[1..1 + max_len]; + if let Some(concrete_idx) = self.try_concrete_int(keys[0]) { let idx = concrete_idx as usize; if idx < max_len { Ok(elem_vars[idx].clone()) @@ -654,25 +703,43 @@ impl<'a> EncoderCtx<'a> { ))) } } else { - let idx_z3 = self.encode_int(index)?; + let idx_z3 = self.encode_int(keys[0])?; self.build_ite_chain(&idx_z3, elem_vars, 0) } } - _ => Err(SymbolicError::Encoding(format!( - "index on non-dict/set/seq variable '{}'", - entry.name - ))), + VarKind::ExplodedSet { lo, .. } => { + if keys.len() != 1 { + return Err(SymbolicError::Encoding( + "cannot index deeper into set elements".into(), + )); + } + let lo = *lo; + if let Some(concrete_key) = self.try_concrete_int(keys[0]) { + let offset = (concrete_key - lo) as usize; + if offset < vars.len() { + Ok(vars[offset].clone()) + } else { + Ok(Dynamic::from_ast(&Bool::from_bool(false))) + } + } else { + let key_z3 = self.encode_int(keys[0])?; + self.build_ite_chain(&key_z3, vars, lo) + } + } + VarKind::Bool | VarKind::Int { .. } => { + Err(SymbolicError::Encoding("index on scalar variable".into())) + } } } - /// Nested index: d[k][i] where d is Dict[Range, Seq[T]]. - fn encode_nested_index( + /// Resolve index chain for a compound local (d[outer_key][keys...]). + fn encode_compound_local_chain( &mut self, - dict_base: &CompiledExpr, - dict_key: &CompiledExpr, - seq_index: &CompiledExpr, + var_idx: usize, + step: usize, + outer_key_z3: &Int, + keys: &[&CompiledExpr], ) -> SymbolicResult { - let (var_idx, step) = self.extract_var_step(dict_base)?; let entry = &self.layout.entries[var_idx]; let (key_lo, key_hi, value_kind) = match &entry.kind { VarKind::ExplodedDict { @@ -680,108 +747,104 @@ impl<'a> EncoderCtx<'a> { key_hi, value_kind, } => (*key_lo, *key_hi, value_kind.as_ref()), - _ => { - return Err(SymbolicError::Encoding( - "nested index requires dict base".into(), - )) - } + _ => return Err(SymbolicError::Encoding("compound local on non-dict".into())), }; - let stride = value_kind.z3_var_count(); - let max_len = match value_kind { - VarKind::ExplodedSeq { max_len, .. } => *max_len, - _ => { - return Err(SymbolicError::Encoding( - "nested index: dict value must be Seq".into(), - )) - } - }; - let z3_vars = &self.step_vars[step][var_idx]; - - if let Some(concrete_key) = self.try_concrete_int(dict_key) { - // Concrete dict key: directly access the seq elements for this key - let key_offset = (concrete_key - key_lo) as usize * stride; - let elem_vars = &z3_vars[key_offset + 1..key_offset + 1 + max_len]; - if let Some(concrete_idx) = self.try_concrete_int(seq_index) { - let idx = concrete_idx as usize; - if idx < max_len { - Ok(elem_vars[idx].clone()) - } else { - Err(SymbolicError::Encoding(format!( - "seq index {} out of bounds (max_len {})", - concrete_idx, max_len - ))) - } - } else { - let idx_z3 = self.encode_int(seq_index)?; - self.build_ite_chain(&idx_z3, elem_vars, 0) - } - } else { - // Symbolic dict key: ITE chain over all keys, each resolving the seq element - let key_z3 = self.encode_int(dict_key)?; - let idx_z3 = self.encode_int(seq_index)?; - let num_keys = (key_hi - key_lo + 1) as usize; - // Build the ITE: for each possible key, get the element at seq_index - let mut result: Option = None; - for k in (0..num_keys).rev() { - let key_offset = k * stride; - let elem_vars = &z3_vars[key_offset + 1..key_offset + 1 + max_len]; - // ITE chain for seq_index within this key's elements - let elem_val = self.build_ite_chain(&idx_z3, elem_vars, 0)?; - let k_z3 = Int::from_i64(key_lo + k as i64); - let cond = key_z3.eq(&k_z3); - result = Some(match result { - None => elem_val, - Some(prev) => { - if let (Some(ei), Some(pi)) = (elem_val.as_int(), prev.as_int()) { - Dynamic::from_ast(&cond.ite(&ei, &pi)) - } else if let (Some(eb), Some(pb)) = (elem_val.as_bool(), prev.as_bool()) { - Dynamic::from_ast(&cond.ite(&eb, &pb)) - } else { - prev - } + let num_keys = (key_hi - key_lo + 1) as usize; + let mut result: Option = None; + for k in (0..num_keys).rev() { + let offset = k * stride; + let slot_vars = &z3_vars[offset..offset + stride]; + let val = self.resolve_index_chain(value_kind, slot_vars, keys)?; + let k_z3 = Int::from_i64(key_lo + k as i64); + let cond = outer_key_z3.eq(&k_z3); + result = Some(match result { + None => val, + Some(prev) => { + if let (Some(vi), Some(pi)) = (val.as_int(), prev.as_int()) { + Dynamic::from_ast(&cond.ite(&vi, &pi)) + } else if let (Some(vb), Some(pb)) = (val.as_bool(), prev.as_bool()) { + Dynamic::from_ast(&cond.ite(&vb, &pb)) + } else { + prev } - }); - } - result.ok_or_else(|| SymbolicError::Encoding("empty dict for nested index".into())) + } + }); } + result.ok_or_else(|| SymbolicError::Encoding("empty compound local".into())) } - /// len(d[k]) for Dict[Range, Seq[T]]: return the len var for the seq at key k. - fn encode_nested_len( + /// Compute len of a VarKind from its Z3 vars. + fn len_of_kind(&self, kind: &VarKind, vars: &[Dynamic]) -> SymbolicResult { + match kind { + VarKind::ExplodedSeq { .. } => Ok(vars[0].clone()), + VarKind::ExplodedSet { .. } => { + let one = Int::from_i64(1); + let zero = Int::from_i64(0); + let terms: Vec = vars + .iter() + .map(|v| v.as_bool().unwrap().ite(&one, &zero)) + .collect(); + Ok(Dynamic::from_ast(&Int::add(&terms))) + } + VarKind::ExplodedDict { key_lo, key_hi, .. } => { + Ok(Dynamic::from_ast(&Int::from_i64(key_hi - key_lo + 1))) + } + _ => Err(SymbolicError::Encoding("len on scalar".into())), + } + } + + /// Recursively resolve len(d[k1][k2]...) for nested compounds. + fn resolve_nested_len( &mut self, - dict_base: &CompiledExpr, - dict_key: &CompiledExpr, + kind: &VarKind, + vars: &[Dynamic], + keys: &[&CompiledExpr], ) -> SymbolicResult { - let (var_idx, step) = self.extract_var_step(dict_base)?; - let entry = &self.layout.entries[var_idx]; - let (key_lo, key_hi, value_kind) = match &entry.kind { + if keys.is_empty() { + return self.len_of_kind(kind, vars); + } + + match kind { VarKind::ExplodedDict { key_lo, key_hi, value_kind, - } => (*key_lo, *key_hi, value_kind.as_ref()), - _ => { - return Err(SymbolicError::Encoding( - "nested len requires dict base".into(), - )) - } - }; - - let stride = value_kind.z3_var_count(); - let z3_vars = &self.step_vars[step][var_idx]; + } => { + let stride = value_kind.z3_var_count(); + let key_lo = *key_lo; + let key_hi = *key_hi; + let remaining = &keys[1..]; - if let Some(concrete_key) = self.try_concrete_int(dict_key) { - let key_offset = (concrete_key - key_lo) as usize * stride; - Ok(z3_vars[key_offset].clone()) // len var is at offset 0 within each key's stride - } else { - // Symbolic key: ITE chain over all len vars - let key_z3 = self.encode_int(dict_key)?; - let num_keys = (key_hi - key_lo + 1) as usize; - let len_vars: Vec = - (0..num_keys).map(|k| z3_vars[k * stride].clone()).collect(); - self.build_ite_chain(&key_z3, &len_vars, key_lo) + if let Some(concrete_key) = self.try_concrete_int(keys[0]) { + let offset = (concrete_key - key_lo) as usize * stride; + let slot_vars = &vars[offset..offset + stride]; + if remaining.is_empty() { + self.len_of_kind(value_kind, slot_vars) + } else { + self.resolve_nested_len(value_kind, slot_vars, remaining) + } + } else { + let key_z3 = self.encode_int(keys[0])?; + let num_keys = (key_hi - key_lo + 1) as usize; + let mut len_vals = Vec::new(); + for k in 0..num_keys { + let offset = k * stride; + let slot_vars = &vars[offset..offset + stride]; + let len = if remaining.is_empty() { + self.len_of_kind(value_kind, slot_vars)? + } else { + self.resolve_nested_len(value_kind, slot_vars, remaining)? + }; + len_vals.push(len); + } + self.build_ite_chain(&key_z3, &len_vals, key_lo) + } + } + _ => Err(SymbolicError::Encoding( + "nested len on non-dict kind".into(), + )), } } @@ -838,15 +901,43 @@ impl<'a> EncoderCtx<'a> { )), } } - // len(d[k]) for Dict[Range, Seq[T]] - CompiledExpr::Index { base, index } => self.encode_nested_len(base, index), + // len(d[k]) or len(d[k1][k2]) for nested compounds + CompiledExpr::Index { .. } => { + let mut keys: Vec<&CompiledExpr> = Vec::new(); + let mut cur = inner; + while let CompiledExpr::Index { + base: ib, + index: ik, + } = cur + { + keys.push(ik.as_ref()); + cur = ib.as_ref(); + } + keys.reverse(); + + // Handle compound local len + if let CompiledExpr::Local(local_idx) = cur { + if let Some((var_idx, step, key_z3)) = self + .resolve_compound_local(*local_idx) + .map(|(v, s, k)| (*v, *s, k.clone())) + { + return self + .encode_compound_local_nested_len(var_idx, step, &key_z3, &keys); + } + } + + let (var_idx, step) = self.extract_var_step(cur)?; + let entry = &self.layout.entries[var_idx]; + let z3_vars = &self.step_vars[step][var_idx]; + self.resolve_nested_len(&entry.kind, z3_vars, &keys) + } // len(Local(n)) for compound locals or set locals CompiledExpr::Local(idx) => { if let Some((var_idx, step, key_z3)) = self .resolve_compound_local(*idx) .map(|(v, s, k)| (*v, *s, k.clone())) { - self.encode_compound_local_len(var_idx, step, &key_z3) + self.encode_compound_local_nested_len(var_idx, step, &key_z3, &[]) } else if let Some(members) = self.resolve_set_local(*idx) { Ok(Dynamic::from_ast(&Int::from_i64(members.len() as i64))) } else { @@ -1123,6 +1214,55 @@ impl<'a> EncoderCtx<'a> { "values() requires a simple-valued dict variable".into(), )) } + // x in d[k] where d is Dict[Range, Set[Range]] + CompiledExpr::Index { .. } => { + let mut keys: Vec<&CompiledExpr> = Vec::new(); + let mut cur = set; + while let CompiledExpr::Index { + base: ib, + index: ik, + } = cur + { + keys.push(ik.as_ref()); + cur = ib.as_ref(); + } + keys.reverse(); + + let (var_idx, step) = self.extract_var_step(cur)?; + let entry = &self.layout.entries[var_idx]; + let z3_vars = &self.step_vars[step][var_idx]; + + // Walk the VarKind to find the Set slot + let (set_vars, set_kind) = self.resolve_set_slot(&entry.kind, z3_vars, &keys)?; + + if let VarKind::ExplodedSet { lo, .. } = set_kind { + let lo = *lo; + if let Some(concrete_elem) = self.try_concrete_int(elem) { + let offset = (concrete_elem - lo) as usize; + let result = if offset < set_vars.len() { + set_vars[offset].as_bool().unwrap() + } else { + Bool::from_bool(false) + }; + let final_val = if negate { result.not() } else { result }; + Ok(Dynamic::from_ast(&final_val)) + } else { + let elem_z3 = self.encode_int(elem)?; + let result = self.build_ite_chain(&elem_z3, set_vars, lo)?; + let result_bool = result.as_bool().unwrap(); + let final_val = if negate { + result_bool.not() + } else { + result_bool + }; + Ok(Dynamic::from_ast(&final_val)) + } + } else { + Err(SymbolicError::Encoding( + "'in' on nested index: inner value is not a set".into(), + )) + } + } _ => { // Try as a set expression with known bounds (union, intersect, etc.) if self.is_set_expr(set) { @@ -1214,90 +1354,82 @@ impl<'a> EncoderCtx<'a> { None } - /// Access element `elem_idx` of a compound local (seq within dict). - fn encode_compound_local_index( + /// Resolve nested len for a compound local: len(Local(n)[k1][k2]...). + fn encode_compound_local_nested_len( &mut self, var_idx: usize, step: usize, - key_z3: &Int, - elem_index: &CompiledExpr, + outer_key_z3: &Int, + keys: &[&CompiledExpr], ) -> SymbolicResult { let entry = &self.layout.entries[var_idx]; - if let VarKind::ExplodedDict { - key_lo, - key_hi, - value_kind, - } = &entry.kind - { - let stride = value_kind.z3_var_count(); - if let VarKind::ExplodedSeq { max_len, .. } = value_kind.as_ref() { - let max_len = *max_len; - let idx_z3 = self.encode_int(elem_index)?; - let z3_vars = &self.step_vars[step][var_idx]; - // Build ITE chain: for each key k, for each element index i - // if key == k && idx == i then z3_vars[k*stride + 1 + i] - let key_lo = *key_lo; - let key_hi = *key_hi; - let mut result: Option = None; - for k in (key_lo..=key_hi).rev() { - let k_offset = (k - key_lo) as usize * stride; - for i in (0..max_len).rev() { - let elem_var = z3_vars[k_offset + 1 + i].clone(); - let cond = Bool::and(&[ - &key_z3.eq(&Int::from_i64(k)), - &idx_z3.eq(&Int::from_i64(i as i64)), - ]); - result = Some(match result { - None => elem_var, - Some(prev) => { - if let (Some(ev), Some(pv)) = (elem_var.as_int(), prev.as_int()) { - Dynamic::from_ast(&cond.ite(&ev, &pv)) - } else if let (Some(ev), Some(pv)) = - (elem_var.as_bool(), prev.as_bool()) - { - Dynamic::from_ast(&cond.ite(&ev, &pv)) - } else { - return Err(SymbolicError::Encoding( - "compound local index: type mismatch".into(), - )); - } - } - }); - } - } - return result - .ok_or_else(|| SymbolicError::Encoding("empty compound local".into())); + let (key_lo, key_hi, value_kind) = match &entry.kind { + VarKind::ExplodedDict { + key_lo, + key_hi, + value_kind, + } => (*key_lo, *key_hi, value_kind.as_ref()), + _ => { + return Err(SymbolicError::Encoding( + "compound local len on non-dict".into(), + )) } + }; + let stride = value_kind.z3_var_count(); + let z3_vars = &self.step_vars[step][var_idx]; + let num_keys = (key_hi - key_lo + 1) as usize; + let mut len_vals = Vec::new(); + for k in 0..num_keys { + let offset = k * stride; + let slot_vars = &z3_vars[offset..offset + stride]; + let len = if keys.is_empty() { + self.len_of_kind(value_kind, slot_vars)? + } else { + self.resolve_nested_len(value_kind, slot_vars, keys)? + }; + len_vals.push(len); } - Err(SymbolicError::Encoding( - "compound local index on non-dict-of-seq".into(), - )) + self.build_ite_chain(outer_key_z3, &len_vals, key_lo) } - /// Get the len of a compound local (seq within dict). - fn encode_compound_local_len( - &self, - var_idx: usize, - step: usize, - key_z3: &Int, - ) -> SymbolicResult { - let entry = &self.layout.entries[var_idx]; - if let VarKind::ExplodedDict { - key_lo, - key_hi, - value_kind, - } = &entry.kind - { - let stride = value_kind.z3_var_count(); - let z3_vars = &self.step_vars[step][var_idx]; - let num_keys = (*key_hi - *key_lo + 1) as usize; - let len_vars: Vec = - (0..num_keys).map(|k| z3_vars[k * stride].clone()).collect(); - return self.build_ite_chain(key_z3, &len_vars, *key_lo); + /// Resolve an index chain to find the Set slot vars and kind. + fn resolve_set_slot<'b>( + &mut self, + kind: &'b VarKind, + vars: &'b [Dynamic], + keys: &[&CompiledExpr], + ) -> SymbolicResult<(&'b [Dynamic], &'b VarKind)> { + if keys.is_empty() { + return Ok((vars, kind)); + } + match kind { + VarKind::ExplodedDict { + key_lo, + key_hi: _, + value_kind, + } => { + let stride = value_kind.z3_var_count(); + let key_lo = *key_lo; + let remaining = &keys[1..]; + + if let Some(concrete_key) = self.try_concrete_int(keys[0]) { + let offset = (concrete_key - key_lo) as usize * stride; + let slot_vars = &vars[offset..offset + stride]; + if remaining.is_empty() { + Ok((slot_vars, value_kind.as_ref())) + } else { + self.resolve_set_slot(value_kind, slot_vars, remaining) + } + } else { + Err(SymbolicError::Encoding( + "set membership with symbolic dict key in nested index".into(), + )) + } + } + _ => Err(SymbolicError::Encoding( + "resolve_set_slot: unexpected kind".into(), + )), } - Err(SymbolicError::Encoding( - "compound local len on non-dict".into(), - )) } /// Get head (first element) of a compound local (seq within dict). @@ -1331,9 +1463,25 @@ impl<'a> EncoderCtx<'a> { // === Seq helpers === fn encode_seq_head(&mut self, inner: &CompiledExpr) -> SymbolicResult { - // Handle head(d[k]) for Dict[Range, Seq[T]] - if let CompiledExpr::Index { base, index } = inner { - return self.encode_nested_index(base, index, &CompiledExpr::Int(0)); + // Handle head(d[k]) or head(d[k1][k2]) for nested compounds + if let CompiledExpr::Index { .. } = inner { + let mut keys: Vec<&CompiledExpr> = Vec::new(); + let mut cur = inner; + while let CompiledExpr::Index { + base: ib, + index: ik, + } = cur + { + keys.push(ik.as_ref()); + cur = ib.as_ref(); + } + keys.reverse(); + let zero = CompiledExpr::Int(0); + keys.push(&zero); + let (var_idx, step) = self.extract_var_step(cur)?; + let entry = &self.layout.entries[var_idx]; + let z3_vars = &self.step_vars[step][var_idx]; + return self.resolve_index_chain(&entry.kind, z3_vars, &keys); } // Handle head(Local(n)) for compound local if let CompiledExpr::Local(idx) = inner { @@ -1352,7 +1500,6 @@ impl<'a> EncoderCtx<'a> { entry.name ))); } - // Element 0 is at offset 1 (after len var) Ok(self.step_vars[step][var_idx][1].clone()) } diff --git a/specl/crates/specl-symbolic/src/state_vars.rs b/specl/crates/specl-symbolic/src/state_vars.rs index 990f76e..65cf9eb 100644 --- a/specl/crates/specl-symbolic/src/state_vars.rs +++ b/specl/crates/specl-symbolic/src/state_vars.rs @@ -115,35 +115,55 @@ fn type_to_kind( hi: Some(*hi), }), Type::Fn(key_ty, val_ty) => { - if let Type::Range(lo, hi) = key_ty.as_ref() { - let value_kind = type_to_kind_simple(val_ty, seq_bound)?; - Ok(VarKind::ExplodedDict { - key_lo: *lo, - key_hi: *hi, - value_kind: Box::new(value_kind), - }) + // Reject Dict with Seq keys (e.g., Dict[Seq[Int], V]) + if matches!(key_ty.as_ref(), Type::Seq(_)) { + return Err(crate::SymbolicError::Unsupported( + "Dict with sequence keys requires enumerating all possible sequences; \ + use Dict[Int, V] with an integer key instead" + .into(), + )); + } + let key_range = if let Type::Range(lo, hi) = key_ty.as_ref() { + (*lo, *hi) } else if matches!(key_ty.as_ref(), Type::Int | Type::Nat) { - if let Some((lo, hi)) = infer_dict_range(var_idx, spec, consts) { - let value_kind = type_to_kind_simple(val_ty, seq_bound)?; - Ok(VarKind::ExplodedDict { - key_lo: lo, - key_hi: hi, - value_kind: Box::new(value_kind), - }) - } else { - Err(crate::SymbolicError::Unsupported(format!( + infer_dict_range(var_idx, spec, consts).ok_or_else(|| { + crate::SymbolicError::Unsupported(format!( "Dict with unbounded key type {:?} (cannot infer range from init)", key_ty - ))) - } + )) + })? } else { - Err(crate::SymbolicError::Unsupported(format!( + return Err(crate::SymbolicError::Unsupported(format!( "Dict with non-range key type: {:?}", key_ty - ))) - } + ))); + }; + let init_rhs = find_init_rhs(var_idx, spec); + let init_body = init_rhs.and_then(extract_fn_body); + let value_kind = type_to_kind_value( + val_ty, + var_idx, + spec, + consts, + string_table, + seq_bound, + init_body, + )?; + Ok(VarKind::ExplodedDict { + key_lo: key_range.0, + key_hi: key_range.1, + value_kind: Box::new(value_kind), + }) } Type::Set(elem_ty) => { + if matches!(elem_ty.as_ref(), Type::Seq(_)) { + return Err(crate::SymbolicError::Unsupported( + "Set[Seq[T]] requires exponential encoding (tracks membership of every \ + possible sequence). Workaround: model messages as Dict[Int, Seq[Int]] \ + with a message counter instead of Set[Seq[Int]]" + .into(), + )); + } if let Type::Range(lo, hi) = elem_ty.as_ref() { Ok(VarKind::ExplodedSet { lo: *lo, hi: *hi }) } else if matches!(elem_ty.as_ref(), Type::Int | Type::Nat) { @@ -187,6 +207,124 @@ fn type_to_kind( } } +/// Resolve value types within containers, with full spec context for range inference. +/// The `init_body` parameter is the FnLit body at this nesting level, used to infer +/// inner dict key ranges and set element ranges from init expressions. +fn type_to_kind_value( + ty: &Type, + var_idx: usize, + spec: &CompiledSpec, + consts: &[Value], + string_table: &[String], + seq_bound: usize, + init_body: Option<&CompiledExpr>, +) -> SymbolicResult { + match ty { + Type::Bool => Ok(VarKind::Bool), + Type::Int => Ok(VarKind::Int { lo: None, hi: None }), + Type::String => { + if string_table.is_empty() { + Ok(VarKind::Int { lo: None, hi: None }) + } else { + Ok(VarKind::Int { + lo: Some(0), + hi: Some(string_table.len() as i64 - 1), + }) + } + } + Type::Nat => Ok(VarKind::Int { + lo: Some(0), + hi: None, + }), + Type::Range(lo, hi) => Ok(VarKind::Int { + lo: Some(*lo), + hi: Some(*hi), + }), + Type::Seq(elem_ty) => { + let elem_kind = type_to_kind_simple(elem_ty, seq_bound)?; + Ok(VarKind::ExplodedSeq { + max_len: seq_bound, + elem_kind: Box::new(elem_kind), + }) + } + Type::Fn(key_ty, val_ty) => { + if matches!(key_ty.as_ref(), Type::Seq(_)) { + return Err(crate::SymbolicError::Unsupported( + "Dict with sequence keys requires enumerating all possible sequences".into(), + )); + } + let key_range = if let Type::Range(lo, hi) = key_ty.as_ref() { + (*lo, *hi) + } else if matches!(key_ty.as_ref(), Type::Int | Type::Nat) { + // Infer from init body's FnLit domain + if let Some(body) = init_body { + extract_domain_range(body, consts).ok_or_else(|| { + crate::SymbolicError::Unsupported(format!( + "nested Dict with unbounded key type {:?} \ + (cannot infer range from init body)", + key_ty + )) + })? + } else { + // Fall back to action parameter inference + infer_dict_range(var_idx, spec, consts).ok_or_else(|| { + crate::SymbolicError::Unsupported(format!( + "nested Dict with unbounded key type {:?} (cannot infer range)", + key_ty + )) + })? + } + } else { + return Err(crate::SymbolicError::Unsupported(format!( + "Dict with non-range key type in value context: {:?}", + key_ty + ))); + }; + let inner_init_body = init_body.and_then(extract_fn_body); + let value_kind = type_to_kind_value( + val_ty, + var_idx, + spec, + consts, + string_table, + seq_bound, + inner_init_body, + )?; + Ok(VarKind::ExplodedDict { + key_lo: key_range.0, + key_hi: key_range.1, + value_kind: Box::new(value_kind), + }) + } + Type::Set(elem_ty) => { + if matches!(elem_ty.as_ref(), Type::Seq(_)) { + return Err(crate::SymbolicError::Unsupported( + "Set[Seq[T]] in value context requires exponential encoding".into(), + )); + } + if let Type::Range(lo, hi) = elem_ty.as_ref() { + Ok(VarKind::ExplodedSet { lo: *lo, hi: *hi }) + } else if matches!(elem_ty.as_ref(), Type::Int | Type::Nat) { + // Try action parameter inference (works for inner sets too) + if let Some((lo, hi)) = infer_set_range_from_actions(var_idx, spec, consts) { + Ok(VarKind::ExplodedSet { lo, hi }) + } else { + Err(crate::SymbolicError::Unsupported(format!( + "Set with unbounded element type {:?} (cannot infer range)", + elem_ty + ))) + } + } else { + Err(crate::SymbolicError::Unsupported(format!( + "Set with non-range element type: {:?}", + elem_ty + ))) + } + } + _ => type_to_kind_simple(ty, seq_bound), + } +} + /// Simple type_to_kind without spec/const context (for value types within containers). fn type_to_kind_simple(ty: &Type, seq_bound: usize) -> SymbolicResult { match ty { @@ -224,6 +362,41 @@ fn type_to_kind_simple(ty: &Type, seq_bound: usize) -> SymbolicResult { } } +/// Extract the init RHS expression for a variable (e.g., for `x = FnLit{...}` returns the FnLit). +fn find_init_rhs<'a>(var_idx: usize, spec: &'a CompiledSpec) -> Option<&'a CompiledExpr> { + find_init_rhs_in(var_idx, &spec.init) +} + +fn find_init_rhs_in<'a>(var_idx: usize, expr: &'a CompiledExpr) -> Option<&'a CompiledExpr> { + match expr { + CompiledExpr::Binary { + op: specl_ir::BinOp::And, + left, + right, + } => find_init_rhs_in(var_idx, left).or_else(|| find_init_rhs_in(var_idx, right)), + CompiledExpr::Binary { + op: specl_ir::BinOp::Eq, + left, + right, + } => match left.as_ref() { + CompiledExpr::PrimedVar(idx) | CompiledExpr::Var(idx) if *idx == var_idx => { + Some(right.as_ref()) + } + _ => None, + }, + _ => None, + } +} + +/// Extract the body from a FnLit expression. +fn extract_fn_body(expr: &CompiledExpr) -> Option<&CompiledExpr> { + if let CompiledExpr::FnLit { body, .. } = expr { + Some(body.as_ref()) + } else { + None + } +} + /// Infer dict key range from init expression or action parameters. fn infer_dict_range(var_idx: usize, spec: &CompiledSpec, consts: &[Value]) -> Option<(i64, i64)> { if let Some(range) = find_var_init_range(var_idx, &spec.init, consts) { diff --git a/specl/crates/specl-symbolic/src/trace.rs b/specl/crates/specl-symbolic/src/trace.rs index bd34819..5165404 100644 --- a/specl/crates/specl-symbolic/src/trace.rs +++ b/specl/crates/specl-symbolic/src/trace.rs @@ -257,7 +257,7 @@ fn extract_state( state } -/// Format a compound value (e.g., Seq within a Dict). +/// Format a compound value (e.g., Seq, Dict, Set within a Dict). fn format_compound_value( model: &Model, kind: &VarKind, @@ -289,7 +289,63 @@ fn format_compound_value( } format!("[{}]", elems.join(", ")) } - _ => "?".to_string(), + VarKind::ExplodedDict { + key_lo, + key_hi, + value_kind, + } => { + let inner_stride = value_kind.z3_var_count(); + let mut pairs = Vec::new(); + for (i, k) in (*key_lo..=*key_hi).enumerate() { + if inner_stride == 1 { + let val = model + .eval(&vars[i], true) + .and_then(|v| { + v.as_int() + .and_then(|i| i.as_i64()) + .map(|n| format_int_value(n, value_kind, string_table)) + .or_else(|| { + v.as_bool().and_then(|b| b.as_bool()).map(|b| b.to_string()) + }) + }) + .unwrap_or_else(|| "?".to_string()); + pairs.push(format!("{}: {}", k, val)); + } else { + let offset = i * inner_stride; + let inner_vars = &vars[offset..offset + inner_stride]; + let val_str = + format_compound_value(model, value_kind, inner_vars, string_table); + pairs.push(format!("{}: {}", k, val_str)); + } + } + format!("{{{}}}", pairs.join(", ")) + } + VarKind::ExplodedSet { lo, hi } => { + let mut members = Vec::new(); + for (i, k) in (*lo..=*hi).enumerate() { + let is_member = model + .eval(&vars[i], true) + .and_then(|v| v.as_bool()) + .and_then(|b| b.as_bool()) + .unwrap_or(false); + if is_member { + members.push(k.to_string()); + } + } + format!("{{{}}}", members.join(", ")) + } + VarKind::Bool => model + .eval(&vars[0], true) + .and_then(|v| v.as_bool()) + .and_then(|b| b.as_bool()) + .map(|b| b.to_string()) + .unwrap_or_else(|| "?".to_string()), + VarKind::Int { .. } => model + .eval(&vars[0], true) + .and_then(|v| v.as_int()) + .and_then(|i| i.as_i64()) + .map(|n| format_int_value(n, kind, string_table)) + .unwrap_or_else(|| "?".to_string()), } } diff --git a/specl/crates/specl-symbolic/src/transition.rs b/specl/crates/specl-symbolic/src/transition.rs index 713d557..3594ae9 100644 --- a/specl/crates/specl-symbolic/src/transition.rs +++ b/specl/crates/specl-symbolic/src/transition.rs @@ -155,7 +155,9 @@ fn encode_init_assignment( let k_val = Dynamic::from_ast(&Int::from_i64(k)); enc.locals.push(k_val); - if stride == 1 { + let is_scalar_value = + matches!(value_kind.as_ref(), VarKind::Bool | VarKind::Int { .. }); + if is_scalar_value { let body_z3 = enc.encode(body)?; if let (Some(vi), Some(ri)) = (z3_vars[i].as_int(), body_z3.as_int()) { solver.assert(&vi.eq(&ri)); @@ -165,35 +167,11 @@ fn encode_init_assignment( solver.assert(&vb.eq(&rb)); } } else { - // Compound value: handle SeqLit let key_offset = i * stride; let key_vars = &z3_vars[key_offset..key_offset + stride]; - if let CompiledExpr::SeqLit(elems) = body.as_ref() { - let len_var = key_vars[0].as_int().unwrap(); - solver.assert(&len_var.eq(&Int::from_i64(elems.len() as i64))); - if let VarKind::ExplodedSeq { max_len, .. } = value_kind.as_ref() { - for (ei, elem_expr) in elems.iter().enumerate() { - if ei >= *max_len { - break; - } - let val = enc.encode(elem_expr)?; - let offset = 1 + ei; - if let (Some(vi), Some(ri)) = - (key_vars[offset].as_int(), val.as_int()) - { - solver.assert(&vi.eq(&ri)); - } else if let (Some(vb), Some(rb)) = - (key_vars[offset].as_bool(), val.as_bool()) - { - solver.assert(&vb.eq(&rb)); - } - } - } - } else { - return Err(SymbolicError::Encoding( - "Dict[Range, Seq] init body must be SeqLit".into(), - )); - } + encode_init_compound_body( + solver, &mut enc, body, key_vars, value_kind, consts, + )?; } } Ok(()) @@ -293,6 +271,101 @@ fn encode_init_assignment( } } +/// Recursively encode init for compound value bodies (inner dict, set, seq). +fn encode_init_compound_body( + solver: &Solver, + enc: &mut EncoderCtx, + body: &CompiledExpr, + slot_vars: &[Dynamic], + value_kind: &VarKind, + consts: &[Value], +) -> SymbolicResult<()> { + match value_kind { + VarKind::ExplodedSeq { max_len, .. } => { + if let CompiledExpr::SeqLit(elems) = body { + let len_var = slot_vars[0].as_int().unwrap(); + solver.assert(&len_var.eq(&Int::from_i64(elems.len() as i64))); + for (ei, elem_expr) in elems.iter().enumerate() { + if ei >= *max_len { + break; + } + let val = enc.encode(elem_expr)?; + let offset = 1 + ei; + if let (Some(vi), Some(ri)) = (slot_vars[offset].as_int(), val.as_int()) { + solver.assert(&vi.eq(&ri)); + } else if let (Some(vb), Some(rb)) = + (slot_vars[offset].as_bool(), val.as_bool()) + { + solver.assert(&vb.eq(&rb)); + } + } + Ok(()) + } else { + Err(SymbolicError::Encoding( + "compound init: expected SeqLit for Seq value".into(), + )) + } + } + VarKind::ExplodedDict { + key_lo, + key_hi, + value_kind: inner_vk, + } => { + if let CompiledExpr::FnLit { + domain: _, + body: inner_body, + } = body + { + let inner_stride = inner_vk.z3_var_count(); + for (j_idx, j) in (*key_lo..=*key_hi).enumerate() { + let j_val = Dynamic::from_ast(&Int::from_i64(j)); + enc.locals.push(j_val); + let inner_offset = j_idx * inner_stride; + let inner_vars = &slot_vars[inner_offset..inner_offset + inner_stride]; + if matches!(inner_vk.as_ref(), VarKind::Bool | VarKind::Int { .. }) { + let body_z3 = enc.encode(inner_body)?; + if let (Some(vi), Some(ri)) = (inner_vars[0].as_int(), body_z3.as_int()) { + solver.assert(&vi.eq(&ri)); + } else if let (Some(vb), Some(rb)) = + (inner_vars[0].as_bool(), body_z3.as_bool()) + { + solver.assert(&vb.eq(&rb)); + } + } else { + encode_init_compound_body( + solver, enc, inner_body, inner_vars, inner_vk, consts, + )?; + } + enc.locals.pop(); + } + Ok(()) + } else { + Err(SymbolicError::Encoding(format!( + "compound init: expected FnLit for Dict value, got {:?}", + std::mem::discriminant(body) + ))) + } + } + VarKind::ExplodedSet { lo, hi } => { + let flags = enc.encode_as_set(body, *lo, *hi)?; + for (i, flag) in flags.iter().enumerate() { + let vb = slot_vars[i].as_bool().unwrap(); + solver.assert(&vb.eq(flag)); + } + Ok(()) + } + _ => { + let body_z3 = enc.encode(body)?; + if let (Some(vi), Some(ri)) = (slot_vars[0].as_int(), body_z3.as_int()) { + solver.assert(&vi.eq(&ri)); + } else if let (Some(vb), Some(rb)) = (slot_vars[0].as_bool(), body_z3.as_bool()) { + solver.assert(&vb.eq(&rb)); + } + Ok(()) + } + } +} + /// Encode the transition relation for one step: step → step+1. /// Returns a Bool that is the disjunction of all enabled actions. pub fn encode_transition( @@ -528,7 +601,7 @@ fn encode_primed_assignment( base: _, key, value, - } if stride == 1 => { + } if matches!(value_kind.as_ref(), VarKind::Bool | VarKind::Int { .. }) => { let mut conjuncts = Vec::new(); let key_z3 = enc.encode_int(key)?; let val_z3 = enc.encode(value)?; @@ -554,7 +627,9 @@ fn encode_primed_assignment( Ok(Bool::and(&conjuncts)) } - CompiledExpr::FnLit { domain: _, body } if stride == 1 => { + CompiledExpr::FnLit { domain: _, body } + if matches!(value_kind.as_ref(), VarKind::Bool | VarKind::Int { .. }) => + { let mut conjuncts = Vec::new(); for (i, k) in (key_lo..=key_hi).enumerate() { let k_val = Dynamic::from_ast(&Int::from_i64(k)); @@ -578,7 +653,7 @@ fn encode_primed_assignment( left: _, right, } => { - if stride == 1 { + if matches!(value_kind.as_ref(), VarKind::Bool | VarKind::Int { .. }) { encode_dict_merge(right, enc, next_vars, curr_vars, key_lo, key_hi) } else { encode_dict_merge_compound( @@ -587,43 +662,28 @@ fn encode_primed_assignment( ) } } - // FnLit with compound values (Dict[Range, Seq[T]] init) + // FnLit with compound values CompiledExpr::FnLit { domain: _, body } => { - // body is evaluated per key, expecting a SeqLit let mut conjuncts = Vec::new(); for (i, k) in (key_lo..=key_hi).enumerate() { let k_val = Dynamic::from_ast(&Int::from_i64(k)); enc.locals.push(k_val); let key_offset = i * stride; let key_next = &next_vars[key_offset..key_offset + stride]; - if let CompiledExpr::SeqLit(elems) = body.as_ref() { - let next_len = key_next[0].as_int().unwrap(); - conjuncts.push(next_len.eq(&Int::from_i64(elems.len() as i64))); - if let VarKind::ExplodedSeq { max_len, elem_kind } = value_kind.as_ref() - { - let es = elem_kind.z3_var_count(); - for (ei, elem_expr) in elems.iter().enumerate() { - if ei >= *max_len { - break; - } - let val = enc.encode(elem_expr)?; - let offset = 1 + ei * es; - if let (Some(ni), Some(ri)) = - (key_next[offset].as_int(), val.as_int()) - { - conjuncts.push(ni.eq(&ri)); - } else if let (Some(nb), Some(rb)) = - (key_next[offset].as_bool(), val.as_bool()) - { - conjuncts.push(nb.eq(&rb)); - } + let key_curr = &curr_vars[key_offset..key_offset + stride]; + let updated = + encode_compound_update_for_slot(enc, body, key_curr, value_kind); + match updated { + Ok(vals) => { + for (j, val) in vals.iter().enumerate() { + let c = eq_dynamic(&key_next[j], val)?; + conjuncts.push(c); } } - } else { - enc.locals.pop(); - return Err(SymbolicError::Encoding( - "Dict[Range, Seq] init body must be a SeqLit".into(), - )); + Err(e) => { + enc.locals.pop(); + return Err(e); + } } enc.locals.pop(); } @@ -894,15 +954,6 @@ fn encode_dict_merge_compound( } }; - let max_len = match value_kind { - VarKind::ExplodedSeq { max_len, .. } => *max_len, - _ => { - return Err(SymbolicError::Encoding( - "compound dict merge only supports Seq values".into(), - )); - } - }; - // Encode Z3 key expressions for each update pair let mut encoded_keys: Vec<(Int, &CompiledExpr)> = Vec::new(); for (key_expr, val_expr) in &pairs { @@ -919,26 +970,12 @@ fn encode_dict_merge_compound( let key_curr = &curr_vars[key_offset..key_offset + stride]; let k_z3 = Int::from_i64(k); - // Check if any update key matches this slot - // Build ITE: for each update pair, if key matches → apply operation, else frame - // With multiple update pairs, later pairs shadow earlier ones (last match wins) - // We process pairs in reverse to build nested ITEs correctly. - - // First, compute the "framed" values (copy current to next) - let frame_len = key_curr[0].as_int().unwrap(); - let mut frame_elems: Vec = Vec::new(); - for j in 0..stride { - frame_elems.push(key_curr[j].clone()); - } + let frame_elems: Vec = (0..stride).map(|j| key_curr[j].clone()).collect(); - // For each update pair, compute what the updated values would be - // and build ITE selection let mut result_vars: Vec = frame_elems; for (pair_key, val_expr) in encoded_keys.iter().rev() { let is_match = pair_key.eq(&k_z3); - let updated = - encode_seq_update_for_slot(enc, *val_expr, key_curr, max_len, &frame_len)?; - // ITE per var: if this key matches, use updated; else use previous result + let updated = encode_compound_update_for_slot(enc, val_expr, key_curr, value_kind)?; let mut new_result = Vec::with_capacity(stride); for j in 0..stride { let selected = ite_dynamic(&is_match, &updated[j], &result_vars[j])?; @@ -947,7 +984,6 @@ fn encode_dict_merge_compound( result_vars = new_result; } - // Assert next == result for j in 0..stride { let c = eq_dynamic(&key_next[j], &result_vars[j])?; conjuncts.push(c); @@ -957,90 +993,209 @@ fn encode_dict_merge_compound( Ok(Bool::and(&conjuncts)) } -/// Encode a seq operation (append, tail, literal) for a single dict key slot. -/// Returns the updated Z3 var values (len + elements). -fn encode_seq_update_for_slot( +/// Encode a compound update for a single dict key slot. +/// Handles Seq (append, tail, literal), Dict (inner merge), and Set (union, literal). +fn encode_compound_update_for_slot( enc: &mut EncoderCtx, val_expr: &CompiledExpr, - key_curr: &[Dynamic], - max_len: usize, - curr_len: &Int, + slot_curr: &[Dynamic], + value_kind: &VarKind, ) -> SymbolicResult> { - let stride = key_curr.len(); - match val_expr { - CompiledExpr::Binary { - op: BinOp::Concat, - left: _, - right: concat_right, - } => { - // Append: base ++ [elem] - if let CompiledExpr::SeqLit(elems) = concat_right.as_ref() { - if elems.len() == 1 { - let new_len = Dynamic::from_ast(&Int::add(&[curr_len, &Int::from_i64(1)])); - let appended = enc.encode(&elems[0])?; + let stride = slot_curr.len(); + + // Identity/frame patterns (d | {k: d[k]}) + if matches!( + val_expr, + CompiledExpr::Index { .. } | CompiledExpr::Local(_) + ) { + return Ok(slot_curr.to_vec()); + } + + // If/then/else: recurse on both branches + if let CompiledExpr::If { + cond, + then_branch, + else_branch, + } = val_expr + { + let cond_z3 = enc.encode_bool(cond)?; + let then_vars = encode_compound_update_for_slot(enc, then_branch, slot_curr, value_kind)?; + let else_vars = encode_compound_update_for_slot(enc, else_branch, slot_curr, value_kind)?; + let mut result = Vec::with_capacity(stride); + for i in 0..stride { + result.push(ite_dynamic(&cond_z3, &then_vars[i], &else_vars[i])?); + } + return Ok(result); + } + + match value_kind { + VarKind::ExplodedSeq { max_len, .. } => { + let max_len = *max_len; + let curr_len = slot_curr[0].as_int().unwrap(); + match val_expr { + CompiledExpr::Binary { + op: BinOp::Concat, + left: _, + right: concat_right, + } => { + if let CompiledExpr::SeqLit(elems) = concat_right.as_ref() { + if elems.len() == 1 { + let new_len = + Dynamic::from_ast(&Int::add(&[&curr_len, &Int::from_i64(1)])); + let appended = enc.encode(&elems[0])?; + let mut result = vec![new_len]; + for j in 0..max_len { + let j_z3 = Int::from_i64(j as i64); + let is_append = curr_len.eq(&j_z3); + let updated = + ite_dynamic(&is_append, &appended, &slot_curr[1 + j])?; + result.push(updated); + } + return Ok(result); + } + } + Err(SymbolicError::Encoding( + "dict-of-seq merge: only single-element append supported".into(), + )) + } + CompiledExpr::SeqTail(_) => { + let new_len = Dynamic::from_ast(&Int::sub(&[&curr_len, &Int::from_i64(1)])); let mut result = vec![new_len]; - for j in 0..max_len { - let j_z3 = Int::from_i64(j as i64); - let is_append = curr_len.eq(&j_z3); - let updated = ite_dynamic(&is_append, &appended, &key_curr[1 + j])?; - result.push(updated); + for i in 0..max_len.saturating_sub(1) { + result.push(slot_curr[1 + (i + 1)].clone()); } - return Ok(result); + if max_len > 0 { + result.push(slot_curr[stride - 1].clone()); + } + Ok(result) } + CompiledExpr::SeqLit(elems) => { + let mut result = vec![Dynamic::from_ast(&Int::from_i64(elems.len() as i64))]; + for (i, elem_expr) in elems.iter().enumerate() { + if i >= max_len { + break; + } + result.push(enc.encode(elem_expr)?); + } + while result.len() < stride { + result.push(slot_curr[result.len()].clone()); + } + Ok(result) + } + _ => Err(SymbolicError::Encoding(format!( + "unsupported seq operation in dict merge: {:?}", + std::mem::discriminant(val_expr) + ))), } - Err(SymbolicError::Encoding( - "dict-of-seq merge: only single-element append supported".into(), - )) } - CompiledExpr::SeqTail(_) => { - let new_len = Dynamic::from_ast(&Int::sub(&[curr_len, &Int::from_i64(1)])); - let mut result = vec![new_len]; - for i in 0..max_len.saturating_sub(1) { - result.push(key_curr[1 + (i + 1)].clone()); // shift left - } - // Last element slot: keep current (doesn't matter, beyond new len) - if max_len > 0 { - result.push(key_curr[stride - 1].clone()); + VarKind::ExplodedDict { + key_lo, + key_hi, + value_kind: inner_vk, + } => { + let inner_stride = inner_vk.z3_var_count(); + match val_expr { + // Inner dict merge: val_expr = base | {k: v, ...} + CompiledExpr::Binary { + op: BinOp::Union, + left: _, + right, + } => { + let inner_pairs: Vec<(&CompiledExpr, &CompiledExpr)> = match right.as_ref() { + CompiledExpr::DictLit(ps) => ps.iter().map(|(k, v)| (k, v)).collect(), + _ => { + return Err(SymbolicError::Encoding( + "inner dict merge: expected DictLit on right side".into(), + )) + } + }; + let mut encoded_inner_keys: Vec<(Int, &CompiledExpr)> = Vec::new(); + for (key_expr, inner_val_expr) in &inner_pairs { + let ik_z3 = enc.encode_int(key_expr)?; + encoded_inner_keys.push((ik_z3, inner_val_expr)); + } + + let mut result = Vec::new(); + for j in *key_lo..=*key_hi { + let j_idx = (j - key_lo) as usize; + let j_offset = j_idx * inner_stride; + let j_curr = &slot_curr[j_offset..j_offset + inner_stride]; + let j_z3 = Int::from_i64(j); + + let mut j_result: Vec = j_curr.to_vec(); + for (ik_z3, inner_val) in encoded_inner_keys.iter().rev() { + let is_match = ik_z3.eq(&j_z3); + let updated = + encode_compound_update_for_slot(enc, inner_val, j_curr, inner_vk)?; + let mut new_j = Vec::with_capacity(inner_stride); + for s in 0..inner_stride { + new_j.push(ite_dynamic(&is_match, &updated[s], &j_result[s])?); + } + j_result = new_j; + } + result.extend(j_result); + } + Ok(result) + } + // FnLit (full reassignment of inner dict) + CompiledExpr::FnLit { domain: _, body } => { + let mut result = Vec::new(); + for j in *key_lo..=*key_hi { + let j_val = Dynamic::from_ast(&Int::from_i64(j)); + enc.locals.push(j_val); + let j_idx = (j - key_lo) as usize; + let j_offset = j_idx * inner_stride; + let j_curr = &slot_curr[j_offset..j_offset + inner_stride]; + if matches!(inner_vk.as_ref(), VarKind::Bool | VarKind::Int { .. }) { + let body_z3 = enc.encode(body)?; + result.push(body_z3); + } else { + let updated = + encode_compound_update_for_slot(enc, body, j_curr, inner_vk)?; + result.extend(updated); + } + enc.locals.pop(); + } + Ok(result) + } + _ => Err(SymbolicError::Encoding(format!( + "unsupported inner dict operation: {:?}", + std::mem::discriminant(val_expr) + ))), } - Ok(result) } - CompiledExpr::SeqLit(elems) => { - let mut result = vec![Dynamic::from_ast(&Int::from_i64(elems.len() as i64))]; - for (i, elem_expr) in elems.iter().enumerate() { - if i >= max_len { - break; + VarKind::ExplodedSet { lo, hi } => { + match val_expr { + // Set union: val_expr = base union {elem, ...} + CompiledExpr::Binary { + op: BinOp::Union, + left: _, + right, + } => { + let new_flags = enc.encode_as_set(right, *lo, *hi)?; + let mut result = Vec::new(); + for (i, new_flag) in new_flags.iter().enumerate() { + let curr_flag = slot_curr[i].as_bool().unwrap(); + result.push(Dynamic::from_ast(&Bool::or(&[curr_flag, new_flag.clone()]))); + } + Ok(result) } - result.push(enc.encode(elem_expr)?); - } - // Pad remaining slots with current values - while result.len() < stride { - result.push(key_curr[result.len()].clone()); + // Set literal (direct assignment) + CompiledExpr::SetLit(_) | CompiledExpr::SetComprehension { .. } => { + let flags = enc.encode_as_set(val_expr, *lo, *hi)?; + Ok(flags.into_iter().map(|b| Dynamic::from_ast(&b)).collect()) + } + _ => Err(SymbolicError::Encoding(format!( + "unsupported set operation in dict merge: {:?}", + std::mem::discriminant(val_expr) + ))), } - Ok(result) } - CompiledExpr::If { - cond, - then_branch, - else_branch, - } => { - let cond_z3 = enc.encode_bool(cond)?; - let then_vars = - encode_seq_update_for_slot(enc, then_branch, key_curr, max_len, curr_len)?; - let else_vars = - encode_seq_update_for_slot(enc, else_branch, key_curr, max_len, curr_len)?; - let mut result = Vec::with_capacity(stride); - for i in 0..stride { - result.push(ite_dynamic(&cond_z3, &then_vars[i], &else_vars[i])?); - } - Ok(result) + _ => { + // Scalar value: just encode directly + let val_z3 = enc.encode(val_expr)?; + Ok(vec![val_z3]) } - // Index into same variable (identity/frame): just return current slot values. - // This handles patterns like `d = d | {i: d[i]}` where the value is unchanged. - CompiledExpr::Index { .. } | CompiledExpr::Local(_) => Ok(key_curr.to_vec()), - _ => Err(SymbolicError::Encoding(format!( - "unsupported seq operation in dict merge: {:?}", - std::mem::discriminant(val_expr) - ))), } }