Skip to content
3 changes: 0 additions & 3 deletions src/halmos/bitvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,6 @@ def mul(
def div(
self, other: BV, *, abstraction: FuncDeclRef | None = None
) -> "HalmosBitVec":
# TODO: div_xy_y

size = self._size
assert size == other.size

Expand Down Expand Up @@ -578,7 +576,6 @@ def div(
def sdiv(
self, other: BV, *, abstraction: FuncDeclRef | None = None
) -> "HalmosBitVec":
# TODO: sdiv_xy_y
size = self._size
assert size == other.size

Expand Down
85 changes: 59 additions & 26 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2102,31 +2102,55 @@ def __init__(self, options: HalmosConfig, fun_info: FunctionInfo) -> None:
is_generic = self.options.storage_layout == "generic"
self.storage_model = GenericStorage if is_generic else SolidityStorage

def div_xy_y(self, w1: Word, w2: Word) -> Word:
# return the number of bits required to represent the given value. default = 256
def bitsize(w: Word) -> int:
if (
w.decl().name() == "concat"
and is_bv_value(w.arg(0))
and int(str(w.arg(0))) == 0
):
return 256 - w.arg(0).size()
return 256

w1 = normalize(w1)

if w1.decl().name() == "bvmul" and w1.num_args() == 2:
x = w1.arg(0)
y = w1.arg(1)
if eq(w2, x) or eq(w2, y): # xy/x or xy/y
size_x = bitsize(x)
size_y = bitsize(y)
if size_x + size_y <= 256:
if eq(w2, x): # xy/x == y
return y
else: # xy/y == x
return x
return None
def div_xy_y(self, w1: Word, w2: Word, signed: bool = False) -> Word:
try:
# return the number of bits required to represent the given value. default = 256
def bitsize(w: Word) -> int:
if isinstance(w, int): # unwrap if it’s a Python int
return w.bit_length() or 1
elif isinstance(w, BV):
z3w = w.as_z3()
elif isinstance(w, BitVecRef): # unwrap Z3 bit-vector directly
z3w = w
else:
raise TypeError(f"Unsupported Word type: {type(w)}")

# Handle zero-extended constants → concat(0, small_bv)
if (
z3w.decl().name() == "concat"
and is_bv_value(z3w.arg(0))
and int(str(z3w.arg(0))) == 0
):
return 256 - z3w.arg(0).size()

# Default: use the actual width of the bitvec
return z3w.size()

w1_z3 = w1.as_z3() if isinstance(w1, Word) else w1
w2_z3 = w2.as_z3() if isinstance(w2, Word) else w2

if "bvmul" in w1_z3.decl().name() and w1_z3.num_args() == 2:
x = w1_z3.arg(0)
y = w1_z3.arg(1)
if eq(w2_z3, x) or eq(w2_z3, y): # xy/x or xy/y
if signed:
# Signed division: safe to simplify if divisor exactly matches factor
if eq(w2_z3, x):
return BV(y, size=w1.size) # wrap back
elif eq(w2_z3, y):
return BV(x, size=w1.size) # wrap back
else:
# Unsigned division: check for overflow
size_x = bitsize(x)
size_y = bitsize(y)
if size_x + size_y <= 256:
if eq(w2_z3, x): # xy/x == y
return BV(y, size=w1.size) # wrap back
else: # xy/y == x
return BV(x, size=w1.size) # wrap back
return None
except Exception:
return None

def mk_div(self, ex: Exec, x: Any, y: Any) -> Any:
term = f_div(x, y)
Expand All @@ -2150,7 +2174,11 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word:
return w1.mul(w2, abstraction=f_mul[w1.size])

if op == OP_DIV:
# TODO: div_xy_y
# Check: div_xy_y
# Try simplification first: (x * y) / x → y, (x * y) / y → x
simplified_div = self.div_xy_y(w1, w2)
if simplified_div is not None:
return simplified_div

term = w1.div(w2, abstraction=f_div)
if term.is_symbolic:
Expand All @@ -2166,6 +2194,11 @@ def arith(self, ex: Exec, op: int, w1: Word, w2: Word) -> Word:
return term

if op == OP_SDIV:
# Try signed simplification first
simplified_sdiv = self.div_xy_y(w1, w2, signed=True)
if simplified_sdiv is not None:
return simplified_sdiv

return w1.sdiv(w2, abstraction=f_sdiv)

if op == OP_SMOD:
Expand Down
128 changes: 128 additions & 0 deletions tests/test_sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
BitVecSort,
BitVecVal,
Concat,
ExprRef,
Extract,
If,
LShR,
Expand Down Expand Up @@ -504,3 +505,130 @@ def test_valid_jump(sevm, solver, storage):
assert len(execs) == 1 # Only one valid path should execute
assert execs[0].pc == 4 # PC should move to the stop
assert execs[0].current_opcode() == EVM.STOP # Should terminate cleanly


def collect_opcodes(expr):
"""Recursively collect operator names from a Z3 expression."""
seen = set()
stack = [expr]
while stack:
e = stack.pop()
if isinstance(e, ExprRef):
seen.add(e.decl().name())
stack.extend(e.children())
return seen


@pytest.mark.parametrize(
"x_bv, y_bv",
[
(
BV(Concat(BitVecVal(0, 8), BitVec("x", 8)), size=256),
BV(Concat(BitVecVal(0, 8), BitVec("y", 8)), size=256),
),
(
BV(Concat(BitVecVal(0, 8), BitVec("x", 128)), size=256),
BV(Concat(BitVecVal(0, 8), BitVec("y", 128)), size=256),
),
],
)
def test_div_and_simplification(sevm: SEVM, solver, storage, x_bv, y_bv):
"""test: raw DIV vs simplified (x*y)/x"""

# --- Raw DIV ---
ex_div = mk_ex(
bytes.fromhex("04"), # DIV
sevm,
solver,
storage,
caller,
this,
)
ex_div.st.stack.append(x_bv)
ex_div.st.stack.append(y_bv)
[out_div] = list(sevm.run(ex_div))
expr_div = out_div.st.stack[-1].as_z3()
assert f_div.name() in collect_opcodes(expr_div)

# --- Simplified DIV (x*y)/x ---
ex_simplified = mk_ex(
bytes.fromhex("0204"), # 0x02 MUL, 0x04 DIV
sevm,
solver,
storage,
caller,
this,
)

# Push raw symbolic operands
# Stack will be: [x, y, x] → MUL → (x * y), SDIV → (x * y) / x
ex_simplified.st.stack.append(x_bv) # divisor for SDIV
ex_simplified.st.stack.append(y_bv) # operand for MUL
ex_simplified.st.stack.append(x_bv) # operand for MUL

[out_simplified] = list(sevm.run(ex_simplified))
expr_simplified = out_simplified.st.stack[-1].as_z3()

# No "div" should remain
assert f_div.name() not in collect_opcodes(expr_simplified)

# ensure expression does not depend on x anymore
assert "x" not in str(expr_simplified)


@pytest.mark.parametrize(
"x_bv, y_bv",
[
(
BV(Concat(BitVecVal(0, 8), BitVec("x", 8)), size=256),
BV(Concat(BitVecVal(0, 8), BitVec("y", 8)), size=256),
),
(
BV(Concat(BitVecVal(0, 8), BitVec("x", 128)), size=256),
BV(Concat(BitVecVal(0, 8), BitVec("y", 128)), size=256),
),
(x, y),
],
)
def test_signed_div_and_simplification(sevm: SEVM, solver, storage, x_bv, y_bv):
"""test: raw SDIV vs simplified (x*y)/x"""

# --- Raw SDIV ---
ex_sdiv = mk_ex(
bytes.fromhex("05"), # 0x05 SDIV
sevm,
solver,
storage,
caller,
this,
)
ex_sdiv.st.stack.append(x_bv)
ex_sdiv.st.stack.append(y_bv)
[out_sdiv] = list(sevm.run(ex_sdiv))
expr_sdiv = out_sdiv.st.stack[-1].as_z3()
assert f_sdiv.name() in collect_opcodes(expr_sdiv)

# --- Simplified SDIV (x*y)/x ---
ex_simplified = mk_ex(
bytes.fromhex("0205"), # 0x02 MUL, 0x05 SDIV
sevm,
solver,
storage,
caller,
this,
)

# Push raw symbolic operands
# Stack will be: [x, y, x] → MUL → (x * y), DIV → (x * y) / x
ex_simplified.st.stack.append(x_bv) # divisor for DIV
ex_simplified.st.stack.append(y_bv) # operand for MUL
ex_simplified.st.stack.append(x_bv) # operand for MUL

[out_simplified] = list(sevm.run(ex_simplified))
expr_simplified = out_simplified.st.stack[-1].as_z3()

# No "div" should remain
assert f_div.name() not in collect_opcodes(expr_simplified)

# ensure expression does not depend on x anymore
assert "x" not in str(expr_simplified)