Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor safe client to use multichain gateway #74

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ape_safe/_cli/safe_mgmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from eth_typing import ChecksumAddress

from ape_safe._cli.click_ext import SafeCliContext, safe_argument, safe_cli_ctx
from ape_safe.client import SafeClient


@click.command(name="list")
Expand Down Expand Up @@ -144,11 +145,10 @@ def all_txns(cli_ctx: SafeCliContext, account, confirmed):
if account in cli_ctx.account_manager.aliases:
account = cli_ctx.account_manager.load(account)

address = cli_ctx.conversion_manager.convert(account, AddressType)

# NOTE: Create a client to support non-local safes.
client = cli_ctx.safes.create_client(address)

address = cli_ctx.conversion_manager.convert(account, AddressType)
chain_id = cli_ctx.provider.chain_id
client = SafeClient(address=address, chain_id=chain_id)

for txn in client.get_transactions(confirmed=confirmed):
if isinstance(txn, ExecutedTxData):
success_str = "success" if txn.is_successful else "revert"
Expand Down
63 changes: 22 additions & 41 deletions ape_safe/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,6 @@ def delete_account(self, alias: str):
"""
self._get_path(alias).unlink(missing_ok=True)

def create_client(self, key: str) -> BaseSafeClient:
if key in self.aliases:
safe = self.load_account(key)
return safe.client

elif key in self.addresses:
account = cast(SafeAccount, self[cast(AddressType, key)])
return account.client

elif key in self.aliases:
return self.load_account(key).client

else:
address = self.conversion_manager.convert(key, AddressType)
if address in self.addresses:
account = cast(SafeAccount, self[cast(AddressType, key)])
return account.client

# Is not locally managed.
return SafeClient(address=address, chain_id=self.chain_manager.provider.chain_id)

def _get_path(self, alias: str) -> Path:
return self.data_folder.joinpath(f"{alias}.json")

Expand Down Expand Up @@ -207,6 +186,10 @@ def alias(self) -> str:
def account_file(self) -> dict:
return json.loads(self.account_file_path.read_text())

@property
def deployed_chain_ids(self) -> list[int]:
return self.account_file.get("deployed_chain_ids", [])

@property
def address(self) -> AddressType:
try:
Expand Down Expand Up @@ -253,34 +236,32 @@ def fallback_handler(self) -> Optional["ContractInstance"]:
self.chain_manager.contracts.instance_at(address) if address != ZERO_ADDRESS else None
)

@cached_property
def client(self) -> BaseSafeClient:
chain_id = self.provider.chain_id
override_url = os.environ.get("SAFE_TRANSACTION_SERVICE_URL")
def get_client(
self, chain_id: Optional[int] = None, override_url: Optional[str] = None
) -> BaseSafeClient:
if chain_id is None:
chain_id = self.provider.chain_id

if override_url is None:
env_override = os.environ.get("SAFE_TRANSACTION_SERVICE_URL")
if env_override:
override_url = env_override

if self.provider.network.is_local:
if chain_id == 0 or (self.provider.network.is_local and self.provider.chain_id == chain_id):
return MockSafeClient(contract=self.contract)

elif chain_id in self.account_file["deployed_chain_ids"]:
return SafeClient(
address=self.address, chain_id=self.provider.chain_id, override_url=override_url
)
return SafeClient(address=self.address, chain_id=chain_id, override_url=override_url)

elif (
@cached_property
def client(self) -> BaseSafeClient:
if (
self.provider.network.name.endswith("-fork")
and isinstance(self.provider.network, ForkedNetworkAPI)
and self.provider.network.upstream_chain_id in self.account_file["deployed_chain_ids"]
and self.provider.network.upstream_chain_id in self.deployed_chain_ids
):
return SafeClient(
address=self.address,
chain_id=self.provider.network.upstream_chain_id,
override_url=override_url,
)

elif self.provider.network.is_dev:
return MockSafeClient(contract=self.contract)
return self.get_client(chain_id=self.provider.network.upstream_chain_id)

return SafeClient(address=self.address, chain_id=self.provider.chain_id)
return self.get_client()

@property
def version(self) -> Version:
Expand Down
34 changes: 8 additions & 26 deletions ape_safe/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,8 @@
ORIGIN = json.dumps(dict(url="https://apeworx.io", name="Ape Safe", ua=APE_SAFE_USER_AGENT))
assert len(ORIGIN) <= 200 # NOTE: Must be less than 200 chars

TRANSACTION_SERVICE_URL = {
# NOTE: If URLs need to be updated, a list of available service URLs can be found at
# https://docs.safe.global/safe-core-api/available-services.
# NOTE: There should be no trailing slashes at the end of the URL.
1: "https://safe-transaction-mainnet.safe.global",
10: "https://safe-transaction-optimism.safe.global",
56: "https://safe-transaction-bsc.safe.global",
100: "https://safe-transaction-gnosis-chain.safe.global",
137: "https://safe-transaction-polygon.safe.global",
250: "https://safe-txservice.fantom.network",
288: "https://safe-transaction.mainnet.boba.network",
8453: "https://safe-transaction-base.safe.global",
42161: "https://safe-transaction-arbitrum.safe.global",
43114: "https://safe-transaction-avalanche.safe.global",
84531: "https://safe-transaction-base-testnet.safe.global",
11155111: "https://safe-transaction-sepolia.safe.global",
81457: "https://transaction.blast-safe.io",
}
# URL for the multichain client gateway
SAFE_CLIENT_GATEWAY_URL = "https://safe-client.safe.global"


class SafeClient(BaseSafeClient):
Expand All @@ -63,20 +47,18 @@ def __init__(
chain_id: Optional[int] = None,
) -> None:
self.address = address
self.chain_id = chain_id

if override_url:
tx_service_url = override_url

base_url = override_url
self.use_client_gateway = False
elif chain_id:
if chain_id not in TRANSACTION_SERVICE_URL:
raise ClientUnsupportedChainError(chain_id)

tx_service_url = TRANSACTION_SERVICE_URL[chain_id]

base_url = SAFE_CLIENT_GATEWAY_URL
self.use_client_gateway = True
else:
raise ValueError("Must provide one of chain_id or override_url.")

super().__init__(tx_service_url)
super().__init__(base_url)

@property
def safe_details(self) -> SafeDetails:
Expand Down
22 changes: 16 additions & 6 deletions ape_safe/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@


class BaseSafeClient(ABC):
def __init__(self, transaction_service_url: str):
self.transaction_service_url = transaction_service_url
def __init__(self, base_url: str):
self.base_url = base_url

"""Abstract methods"""

Expand Down Expand Up @@ -139,13 +139,23 @@ def _http(self):
return urllib3.PoolManager(ca_certs=certifi.where())

def _request(self, method: str, url: str, json: Optional[dict] = None, **kwargs) -> "Response":
api_version = kwargs.pop("api_version", "v1")

# NOTE: paged requests include full url already
if url.startswith(f"{self.transaction_service_url}/api/v1/"):
if url.startswith(f"{self.base_url}/"):
api_url = url
else:
# **WARNING**: The trailing slash in the URL is CRITICAL!
# If you remove it, things will not work as expected.
api_url = f"{self.transaction_service_url}/api/v1/{url}/"
if (
hasattr(self, "use_client_gateway")
and self.use_client_gateway
and hasattr(self, "chain_id")
):
# **WARNING**: The trailing slash in the URL is CRITICAL!
# If you remove it, things will not work as expected.
api_url = f"{self.base_url}/{api_version}/chains/{self.chain_id}/{url}/"
else:
api_url = f"{self.base_url}/api/v1/{url}/"

do_fail = not kwargs.pop("allow_failure", False)

# Use `or 10` to handle when None is explicit.
Expand Down
5 changes: 3 additions & 2 deletions ape_safe/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ class SafeClientException(ApeSafeException):


class ClientUnsupportedChainError(SafeClientException):
def __init__(self, chain_id: int):
super().__init__(f"Unsupported Chain ID '{chain_id}'.")
def __init__(self, chain_id: int, message: Optional[str] = None):
msg = message or f"Unsupported Chain ID '{chain_id}'."
super().__init__(msg)


class ActionNotPerformedError(SafeClientException):
Expand Down
37 changes: 37 additions & 0 deletions tests/functional/test_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,40 @@ def test_safe_account_convert(safe):
convert = safe.conversion_manager.convert
actual = convert(safe, AddressType)
assert actual == safe.address


def test_deployed_chain_ids(safe):
"""
Test the deployed_chain_ids property.
"""
# The deployed_chain_ids should match what's in the account file
assert safe.deployed_chain_ids == safe.account_file.get("deployed_chain_ids", [])


def test_get_client(safe, monkeypatch):
"""
Test getting a client for a specific chain ID.
"""
# Create a test deployed_chain_ids list
test_chain_ids = [1, 10, 100]
monkeypatch.setattr(safe, "deployed_chain_ids", test_chain_ids)

# Should work for a specified chain ID
client = safe.get_client(chain_id=1)
assert client.chain_id == 1
assert hasattr(client, "use_client_gateway")
assert client.use_client_gateway is True

# Should use current chain if no chain ID is provided
current_chain_id = safe.provider.chain_id
client = safe.get_client()
assert client.chain_id == current_chain_id

# Should accept override_url
client = safe.get_client(override_url="https://example.com")
assert not client.use_client_gateway

# Should accept both chain_id and override_url
client = safe.get_client(chain_id=1, override_url="https://example.com")
assert client.chain_id == 1
assert not client.use_client_gateway
Loading