diff --git a/src/halmos/bitvec.py b/src/halmos/bitvec.py index e1d6ac3b..a5b9e354 100644 --- a/src/halmos/bitvec.py +++ b/src/halmos/bitvec.py @@ -545,8 +545,6 @@ def mul( def div( self, other: BV, *, abstraction: FuncDeclRef | None = None ) -> "HalmosBitVec": - # TODO: div_xy_y - size = self._size assert size == other.size @@ -578,7 +576,6 @@ def div( def sdiv( self, other: BV, *, abstraction: FuncDeclRef | None = None ) -> "HalmosBitVec": - # TODO: sdiv_xy_y size = self._size assert size == other.size diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index 3d2ead17..1c4673d0 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -2102,31 +2102,55 @@ def __init__(self, options: HalmosConfig, fun_info: FunctionInfo) -> None: is_generic = self.options.storage_layout == "generic" self.storage_model = GenericStorage if is_generic else SolidityStorage - def div_xy_y(self, w1: Word, w2: Word) -> Word: - # return the number of bits required to represent the given value. default = 256 - def bitsize(w: Word) -> int: - if ( - w.decl().name() == "concat" - and is_bv_value(w.arg(0)) - and int(str(w.arg(0))) == 0 - ): - return 256 - w.arg(0).size() - return 256 - - w1 = normalize(w1) - - if w1.decl().name() == "bvmul" and w1.num_args() == 2: - x = w1.arg(0) - y = w1.arg(1) - if eq(w2, x) or eq(w2, y): # xy/x or xy/y - size_x = bitsize(x) - size_y = bitsize(y) - if size_x + size_y <= 256: - if eq(w2, x): # xy/x == y - return y - else: # xy/y == x - return x - return None + def div_xy_y(self, w1: Word, w2: Word, signed: bool = False) -> Word: + try: + # return the number of bits required to represent the given value. default = 256 + def bitsize(w: Word) -> int: + if isinstance(w, int): # unwrap if it’s a Python int + return w.bit_length() or 1 + elif isinstance(w, BV): + z3w = w.as_z3() + elif isinstance(w, BitVecRef): # unwrap Z3 bit-vector directly + z3w = w + else: + raise TypeError(f"Unsupported Word type: {type(w)}") + + # Handle zero-extended constants → concat(0, small_bv) + if ( + z3w.decl().name() == "concat" + and is_bv_value(z3w.arg(0)) + and int(str(z3w.arg(0))) == 0 + ): + return 256 - z3w.arg(0).size() + + # Default: use the actual width of the bitvec + return z3w.size() + + w1_z3 = w1.as_z3() if isinstance(w1, Word) else w1 + w2_z3 = w2.as_z3() if isinstance(w2, Word) else w2 + + if "bvmul" in w1_z3.decl().name() and w1_z3.num_args() == 2: + x = w1_z3.arg(0) + y = w1_z3.arg(1) + if eq(w2_z3, x) or eq(w2_z3, y): # xy/x or xy/y + if signed: + # Signed division: safe to simplify if divisor exactly matches factor + if eq(w2_z3, x): + return BV(y, size=w1.size) # wrap back + elif eq(w2_z3, y): + return BV(x, size=w1.size) # wrap back + else: + # Unsigned division: check for overflow + size_x = bitsize(x) + size_y = bitsize(y) + if size_x + size_y <= 256: + if eq(w2_z3, x): # xy/x == y + return BV(y, size=w1.size) # wrap back + else: # xy/y == x + return BV(x, size=w1.size) # wrap back + return None + except Exception: + return None def mk_div(self, ex: Exec, x: Any, y: Any) -> Any: term = f_div(x, y) @@ -2150,7 +2174,11 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word: return w1.mul(w2, abstraction=f_mul[w1.size]) if op == OP_DIV: - # TODO: div_xy_y + # Check: div_xy_y + # Try simplification first: (x * y) / x → y, (x * y) / y → x + simplified_div = self.div_xy_y(w1, w2) + if simplified_div is not None: + return simplified_div term = w1.div(w2, abstraction=f_div) if term.is_symbolic: @@ -2166,6 +2194,11 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word: return term if op == OP_SDIV: + # Try signed simplification first + simplified_sdiv = self.div_xy_y(w1, w2, signed=True) + if simplified_sdiv is not None: + return simplified_sdiv + return w1.sdiv(w2, abstraction=f_sdiv) if op == OP_SMOD: diff --git a/tests/test_sevm.py b/tests/test_sevm.py index ef05c2af..1d43f3f0 100644 --- a/tests/test_sevm.py +++ b/tests/test_sevm.py @@ -5,6 +5,7 @@ BitVecSort, BitVecVal, Concat, + ExprRef, Extract, If, LShR, @@ -504,3 +505,130 @@ def test_valid_jump(sevm, solver, storage): assert len(execs) == 1 # Only one valid path should execute assert execs[0].pc == 4 # PC should move to the stop assert execs[0].current_opcode() == EVM.STOP # Should terminate cleanly + + +def collect_opcodes(expr): + """Recursively collect operator names from a Z3 expression.""" + seen = set() + stack = [expr] + while stack: + e = stack.pop() + if isinstance(e, ExprRef): + seen.add(e.decl().name()) + stack.extend(e.children()) + return seen + + +@pytest.mark.parametrize( + "x_bv, y_bv", + [ + ( + BV(Concat(BitVecVal(0, 8), BitVec("x", 8)), size=256), + BV(Concat(BitVecVal(0, 8), BitVec("y", 8)), size=256), + ), + ( + BV(Concat(BitVecVal(0, 8), BitVec("x", 128)), size=256), + BV(Concat(BitVecVal(0, 8), BitVec("y", 128)), size=256), + ), + ], +) +def test_div_and_simplification(sevm: SEVM, solver, storage, x_bv, y_bv): + """test: raw DIV vs simplified (x*y)/x""" + + # --- Raw DIV --- + ex_div = mk_ex( + bytes.fromhex("04"), # DIV + sevm, + solver, + storage, + caller, + this, + ) + ex_div.st.stack.append(x_bv) + ex_div.st.stack.append(y_bv) + [out_div] = list(sevm.run(ex_div)) + expr_div = out_div.st.stack[-1].as_z3() + assert f_div.name() in collect_opcodes(expr_div) + + # --- Simplified DIV (x*y)/x --- + ex_simplified = mk_ex( + bytes.fromhex("0204"), # 0x02 MUL, 0x04 DIV + sevm, + solver, + storage, + caller, + this, + ) + + # Push raw symbolic operands + # Stack will be: [x, y, x] → MUL → (x * y), DIV → (x * y) / x + ex_simplified.st.stack.append(x_bv) # divisor for DIV + ex_simplified.st.stack.append(y_bv) # operand for MUL + ex_simplified.st.stack.append(x_bv) # operand for MUL + + [out_simplified] = list(sevm.run(ex_simplified)) + expr_simplified = out_simplified.st.stack[-1].as_z3() + + # No "div" should remain + assert f_div.name() not in collect_opcodes(expr_simplified) + + # ensure expression does not depend on x anymore + assert "x" not in str(expr_simplified) + + +@pytest.mark.parametrize( + "x_bv, y_bv", + [ + ( + BV(Concat(BitVecVal(0, 8), BitVec("x", 8)), size=256), + BV(Concat(BitVecVal(0, 8), BitVec("y", 8)), size=256), + ), + ( + BV(Concat(BitVecVal(0, 8), BitVec("x", 128)), size=256), + BV(Concat(BitVecVal(0, 8), BitVec("y", 128)), size=256), + ), + (x, y), + ], +) +def test_signed_div_and_simplification(sevm: SEVM, solver, storage, x_bv, y_bv): + """test: raw SDIV vs simplified (x*y)/x""" + + # --- Raw SDIV --- + ex_sdiv = mk_ex( + bytes.fromhex("05"), # 0x05 SDIV + sevm, + solver, + storage, + caller, + this, + ) + ex_sdiv.st.stack.append(x_bv) + ex_sdiv.st.stack.append(y_bv) + [out_sdiv] = list(sevm.run(ex_sdiv)) + expr_sdiv = out_sdiv.st.stack[-1].as_z3() + assert f_sdiv.name() in collect_opcodes(expr_sdiv) + + # --- Simplified SDIV (x*y)/x --- + ex_simplified = mk_ex( + bytes.fromhex("0205"), # 0x02 MUL, 0x05 SDIV + sevm, + solver, + storage, + caller, + this, + ) + + # Push raw symbolic operands + # Stack will be: [x, y, x] → MUL → (x * y), SDIV → (x * y) / x + ex_simplified.st.stack.append(x_bv) # divisor for SDIV + ex_simplified.st.stack.append(y_bv) # operand for MUL + ex_simplified.st.stack.append(x_bv) # operand for MUL + + [out_simplified] = list(sevm.run(ex_simplified)) + expr_simplified = out_simplified.st.stack[-1].as_z3() + + # No "sdiv" should remain + assert f_sdiv.name() not in collect_opcodes(expr_simplified) + + # ensure expression does not depend on x anymore + assert "x" not in str(expr_simplified)