Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 66 additions & 6 deletions src/ethereum/cancun/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
class State:
"""
Contains all information that is preserved between transactions.

Now includes optional state tracking
"""

_main_trie: Trie[Address, Optional[Account]] = field(
Expand All @@ -47,6 +49,8 @@ class State:
]
] = field(default_factory=list)
created_accounts: Set[Address] = field(default_factory=set)

_state_tracker = field(default=None)


@dataclass
Expand All @@ -71,6 +75,8 @@ def close_state(state: State) -> None:
del state._storage_tries
del state._snapshots
del state.created_accounts
if state._state_tracker is not None:
del state._state_tracker


def begin_transaction(
Expand Down Expand Up @@ -186,6 +192,11 @@ def get_account_optional(state: State, address: Address) -> Optional[Account]:
Account at address.
"""
account = trie_get(state._main_trie, address)

if state._state_tracker is not None and state._state_tracker.track_reads:
from .state_tracking import log_state_access, ACCOUNT_READ
log_state_access(state, ACCOUNT_READ, address, value_before=account)

return account


Expand All @@ -205,7 +216,23 @@ def set_account(
account : `Account`
Account to set at address.
"""
# Get old account for tracking
old_account = None
if state._state_tracker is not None and state._state_tracker.track_writes:
old_account = trie_get(state._main_trie, address)

trie_set(state._main_trie, address, account)

# Log write access if tracking enabled
if state._state_tracker is not None and state._state_tracker.track_writes:
from .state_tracking import log_state_access, ACCOUNT_WRITE
log_state_access(
state,
ACCOUNT_WRITE,
address,
value_before=old_account,
value_after=account
)


def destroy_account(state: State, address: Address) -> None:
Expand Down Expand Up @@ -284,12 +311,24 @@ def get_storage(state: State, address: Address, key: Bytes32) -> U256:
"""
trie = state._storage_tries.get(address)
if trie is None:
return U256(0)

value = trie_get(trie, key)

assert isinstance(value, U256)
return value
result = U256(0)
else:
value = trie_get(trie, key)
assert isinstance(value, U256)
result = value

# Log read access if tracking enabled
if state._state_tracker is not None and state._state_tracker.track_reads:
from .state_tracking import log_state_access, STORAGE_READ
log_state_access(
state,
STORAGE_READ,
address,
key=key,
value_before=result
)

return result


def set_storage(
Expand All @@ -312,13 +351,34 @@ def set_storage(
"""
assert trie_get(state._main_trie, address) is not None

# Get old value for tracking
old_value = None
if state._state_tracker is not None and state._state_tracker.track_writes:
trie = state._storage_tries.get(address)
if trie is None:
old_value = U256(0)
else:
old_value = trie_get(trie, key)

trie = state._storage_tries.get(address)
if trie is None:
trie = Trie(secured=True, default=U256(0))
state._storage_tries[address] = trie
trie_set(trie, key, value)
if trie._data == {}:
del state._storage_tries[address]

# Log write access if tracking enabled
if state._state_tracker is not None and state._state_tracker.track_writes:
from .state_tracking import log_state_access, STORAGE_WRITE
log_state_access(
state,
STORAGE_WRITE,
address,
key=key,
value_before=old_value,
value_after=value
)


def storage_root(state: State, address: Address) -> Root:
Expand Down
135 changes: 135 additions & 0 deletions src/ethereum/cancun/state_tracking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple, Union

from ethereum_types.bytes import Bytes32
from ethereum_types.numeric import U256

from .fork_types import Account, Address

# State access types for tracking
ACCOUNT_READ = "account_read"
ACCOUNT_WRITE = "account_write"
STORAGE_READ = "storage_read"
STORAGE_WRITE = "storage_write"


@dataclass
class StateAccess:
"""Record of a single state access for proof generation."""
access_type: str
address: Address
key: Optional[Bytes32] = None
value_before: Optional[Union[Account, U256]] = None
value_after: Optional[Union[Account, U256]] = None


@dataclass
class StateTracker:
"""Tracks state access for merkle proof generation."""
accesses: List[StateAccess] = field(default_factory=list)
main_trie_accessed_keys: Set[Address] = field(default_factory=set)
storage_accessed_keys: Dict[Address, Set[Bytes32]] = field(default_factory=dict)
track_reads: bool = True
track_writes: bool = True


def enable_state_tracking(
state,
track_reads: bool = True,
track_writes: bool = True
) -> None:
"""
Enable state tracking on a State object.

Parameters
----------
state : State
The state to enable tracking on
track_reads : bool
Whether to track read operations
track_writes : bool
Whether to track write operations
"""
state._state_tracker = StateTracker(
track_reads=track_reads,
track_writes=track_writes
)


def disable_state_tracking(state) -> None:
"""
Disable state tracking on a State object.

Parameters
----------
state : State
The state to disable tracking on
"""
state._state_tracker = None


def log_state_access(
state,
access_type: str,
address: Address,
key: Optional[Bytes32] = None,
value_before: Optional[Union[Account, U256]] = None,
value_after: Optional[Union[Account, U256]] = None,
) -> None:
"""
Log a state access if tracking is enabled.

Parameters
----------
state : State
The state (with potential tracker)
access_type : str
Type of access (ACCOUNT_READ, ACCOUNT_WRITE, etc.)
address : Address
Address being accessed
key : Optional[Bytes32]
Storage key (for storage operations)
value_before : Optional[Union[Account, U256]]
Value before the operation
value_after : Optional[Union[Account, U256]]
Value after the operation
"""
if state._state_tracker is None:
return

tracker = state._state_tracker
access = StateAccess(
access_type=access_type,
address=address,
key=key,
value_before=value_before,
value_after=value_after,
)
tracker.accesses.append(access)

if access_type in [ACCOUNT_READ, ACCOUNT_WRITE]:
tracker.main_trie_accessed_keys.add(address)
elif access_type in [STORAGE_READ, STORAGE_WRITE]:
if address not in tracker.storage_accessed_keys:
tracker.storage_accessed_keys[address] = set()
if key is not None:
tracker.storage_accessed_keys[address].add(key)

# Dummy method
def generate_merkle_proof_requests(state) -> Tuple[List[Address], List[Tuple[Address, Bytes32]]]:
"""
Comment on lines +119 to +120
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left this here to show what the next addition may look like -- will delete

Generate lists of proof requests needed for all tracked accesses.

Parameters
----------
state : State
The state containing tracking logs

Returns
-------
account_proofs : List[Address]
List of addresses needing account proofs
storage_proofs : List[Tuple[Address, Bytes32]]
List of (address, storage_key) tuples needing storage proofs
"""
return [], []
Loading