diff --git a/gas-snapshots/ModularAccount.json b/gas-snapshots/ModularAccount.json index bc4ead06..c364f83e 100644 --- a/gas-snapshots/ModularAccount.json +++ b/gas-snapshots/ModularAccount.json @@ -1,16 +1,16 @@ { "Runtime_AccountCreation": "176053", "Runtime_BatchTransfers": "92180", - "Runtime_Erc20Transfer": "77942", - "Runtime_InstallSessionKey_Case1": "429401", - "Runtime_NativeTransfer": "54261", - "Runtime_UseSessionKey_Case1_Counter": "78463", - "Runtime_UseSessionKey_Case1_Token": "111478", + "Runtime_Erc20Transfer": "77964", + "Runtime_InstallSessionKey_Case1": "429379", + "Runtime_NativeTransfer": "54283", + "Runtime_UseSessionKey_Case1_Counter": "78485", + "Runtime_UseSessionKey_Case1_Token": "111500", "UserOp_BatchTransfers": "178933", - "UserOp_Erc20Transfer": "165901", - "UserOp_InstallSessionKey_Case1": "518400", - "UserOp_NativeTransfer": "142174", - "UserOp_UseSessionKey_Case1_Counter": "175309", - "UserOp_UseSessionKey_Case1_Token": "208402", - "UserOp_deferredValidation": "214001" + "UserOp_Erc20Transfer": "165923", + "UserOp_InstallSessionKey_Case1": "518378", + "UserOp_NativeTransfer": "142196", + "UserOp_UseSessionKey_Case1_Counter": "175331", + "UserOp_UseSessionKey_Case1_Token": "208424", + "UserOp_deferredValidation": "213793" } \ No newline at end of file diff --git a/gas-snapshots/SemiModularAccount.json b/gas-snapshots/SemiModularAccount.json index e04f03f5..90fb68d5 100644 --- a/gas-snapshots/SemiModularAccount.json +++ b/gas-snapshots/SemiModularAccount.json @@ -1,16 +1,16 @@ { "Runtime_AccountCreation": "97770", - "Runtime_BatchTransfers": "88037", - "Runtime_Erc20Transfer": "73844", - "Runtime_InstallSessionKey_Case1": "427632", - "Runtime_NativeTransfer": "50173", - "Runtime_UseSessionKey_Case1_Counter": "78765", - "Runtime_UseSessionKey_Case1_Token": "111780", - "UserOp_BatchTransfers": "174087", - "UserOp_Erc20Transfer": "161123", - "UserOp_InstallSessionKey_Case1": "515729", - "UserOp_NativeTransfer": "137402", - "UserOp_UseSessionKey_Case1_Counter": "175540", - "UserOp_UseSessionKey_Case1_Token": "208657", - "UserOp_deferredValidation": "210031" + "Runtime_BatchTransfers": "88048", + "Runtime_Erc20Transfer": "73877", + "Runtime_InstallSessionKey_Case1": "427621", + "Runtime_NativeTransfer": "50206", + "Runtime_UseSessionKey_Case1_Counter": "78787", + "Runtime_UseSessionKey_Case1_Token": "111802", + "UserOp_BatchTransfers": "174076", + "UserOp_Erc20Transfer": "161134", + "UserOp_InstallSessionKey_Case1": "515696", + "UserOp_NativeTransfer": "137413", + "UserOp_UseSessionKey_Case1_Counter": "175562", + "UserOp_UseSessionKey_Case1_Token": "208679", + "UserOp_deferredValidation": "210190" } \ No newline at end of file diff --git a/gas/modular-account/ModularAccount.gas.t.sol b/gas/modular-account/ModularAccount.gas.t.sol index b781acf3..a2aa577b 100644 --- a/gas/modular-account/ModularAccount.gas.t.sol +++ b/gas/modular-account/ModularAccount.gas.t.sol @@ -269,14 +269,10 @@ contract ModularAccountGasTest is ModularAccountBenchmarkBase("ModularAccount") uint48 deferredInstallDeadline = 0; bytes32 digest = _getDeferredInstallStruct( - account1, userOp.nonce, deferredInstallDeadline, deferredValidationInstallCall + account1, userOp.nonce, signerValidation, deferredInstallDeadline, deferredValidationInstallCall ); - bytes memory deferredValidationSig = _signRawHash( - vm, - owner1Key, - _getModuleReplaySafeHash(address(account1), address(singleSignerValidationModule), digest) - ); + bytes memory deferredValidationSig = _signRawHash(vm, owner1Key, digest); userOp.signature = _encodeDeferredInstallUOSignature( _packDeferredInstallData( diff --git a/gas/modular-account/SemiModularAccount.gas.t.sol b/gas/modular-account/SemiModularAccount.gas.t.sol index 6bac403f..58488216 100644 --- a/gas/modular-account/SemiModularAccount.gas.t.sol +++ b/gas/modular-account/SemiModularAccount.gas.t.sol @@ -263,7 +263,7 @@ contract ModularAccountGasTest is ModularAccountBenchmarkBase("SemiModularAccoun uint48 deferredInstallDeadline = 0; bytes32 digest = _getDeferredInstallStruct( - account1, userOp.nonce, deferredInstallDeadline, deferredValidationInstallCall + account1, userOp.nonce, signerValidation, deferredInstallDeadline, deferredValidationInstallCall ); bytes memory deferredValidationSig = _signRawHash(vm, owner1Key, digest); diff --git a/src/account/ModularAccountBase.sol b/src/account/ModularAccountBase.sol index 0c2f733d..c051f0d5 100644 --- a/src/account/ModularAccountBase.sol +++ b/src/account/ModularAccountBase.sol @@ -91,9 +91,13 @@ abstract contract ModularAccountBase is EITHER } - // keccak256("EIP712Domain(uint256 chainId,address verifyingContract)") + // keccak256("EIP712Domain(uint256 chainId,address verifyingContract,bytes32 salt)") bytes32 internal constant _DOMAIN_SEPARATOR_TYPEHASH = - 0x47e79534a245952e8b16893a336b85a3d9ea9fa8c573f3d803afb92a79469218; + 0x71062c282d40422f744945d587dbf4ecfd4f9cfad1d35d62c944373009d96162; + + // keccak256("ReplaySafeHash(bytes32 hash)") + bytes32 private constant _REPLAY_SAFE_HASH_TYPEHASH = + 0x294a8735843d4afb4f017c76faf3b7731def145ed0025fc9b1d5ce30adf113ff; // keccak256("DeferredAction(uint256 nonce,uint48 deadline,bytes call)") bytes32 internal constant _DEFERRED_ACTION_TYPEHASH = @@ -358,7 +362,7 @@ abstract contract ModularAccountBase is (ValidationLocator locator, bytes calldata signatureRemainder) = ValidationLocatorLib.loadFromSignature(signature); - return _isValidSignature(locator.lookupKey(), hash, signatureRemainder); + return _isValidSignature(locator, hash, signatureRemainder); } /// @inheritdoc IERC165 @@ -396,6 +400,20 @@ abstract contract ModularAccountBase is super.upgradeToAndCall(newImplementation, data); } + /// @notice Returns the replay-safe hash generated from the passed typed data hash for 1271 validation. + /// @param hash The typed data hash to wrap in a replay-safe hash. + /// @return The replay-safe hash, to be used for 1271 signature generation. + /// + /// @dev Generates a replay-safe hash to wrap a standard typed data hash. This prevents replay attacks by + /// enforcing the domain separator, which includes this contract's address, the chainId, and the validation + /// module & entity id. This is only relevant for 1271 validation because UserOp validation relies on the UO + /// hash and the Entrypoint has safeguards. + function replaySafeHash(bytes32 hash, ModuleEntity validationModuleEntity) public view returns (bytes32) { + return MessageHashUtils.toTypedDataHash({ + domainSeparator: _domainSeparator(validationModuleEntity), structHash: _hashStructReplaySafeHash(hash) + }); + } + // INTERNAL FUNCTIONS // Parent function validateUserOp enforces that this call can only be made by the EntryPoint @@ -503,7 +521,12 @@ abstract contract ModularAccountBase is uint48 deadline = uint48(bytes6(encodedData[21:27])); - bytes32 typedDataHash = _computeDeferredActionHash(userOpNonce, deadline, encodedData[27:]); + bytes32 typedDataHash = _computeDeferredActionHash( + userOpNonce, + defActionValidationLocator.lookupKey().moduleEntity(_validationStorage), + deadline, + encodedData[27:] + ); // Check if the outer validation applies to the function call _checkIfValidationAppliesCallData( @@ -734,11 +757,12 @@ abstract contract ModularAccountBase is ExecutionLib.invokeRuntimeCallBufferValidation(callBuffer, runtimeValidationFunction, authorization); } - function _computeDeferredActionHash(uint256 userOpNonce, uint48 deadline, bytes calldata selfCall) - internal - view - returns (bytes32) - { + function _computeDeferredActionHash( + uint256 userOpNonce, + ModuleEntity validationModuleEntity, + uint48 deadline, + bytes calldata selfCall + ) internal view returns (bytes32) { // Note: // - A zero deadline translates to "no deadline" // - The user op nonce also includes the data for: @@ -781,7 +805,8 @@ abstract contract ModularAccountBase is structHash := keccak256(fmp, 0x80) } - bytes32 typedDataHash = MessageHashUtils.toTypedDataHash(_domainSeparator(), structHash); + bytes32 typedDataHash = + MessageHashUtils.toTypedDataHash(_domainSeparator(validationModuleEntity), structHash); return typedDataHash; } @@ -805,15 +830,21 @@ abstract contract ModularAccountBase is } } - function _isValidSignature(ValidationLookupKey validationLookupKey, bytes32 hash, bytes calldata signature) + function _isValidSignature(ValidationLocator validationLocator, bytes32 hash, bytes calldata signature) internal view returns (bytes4) { + ValidationLookupKey validationLookupKey = validationLocator.lookupKey(); ValidationStorage storage _validationStorage = getAccountStorage().validationStorage[validationLookupKey]; HookConfig[] memory preSignatureValidationHooks = MemManagementLib.loadValidationHooks(_validationStorage); + if (!validationLocator.isSkipReplayProtection()) { + ModuleEntity validationModuleEntity = validationLookupKey.moduleEntity(_validationStorage); + hash = replaySafeHash(hash, validationModuleEntity); + } + SigCallBuffer sigCallBuffer; if (!_validationIsNative(validationLookupKey) || preSignatureValidationHooks.length > 0) { sigCallBuffer = ExecutionLib.allocateSigCallBuffer(hash, signature); @@ -1104,7 +1135,7 @@ abstract contract ModularAccountBase is return getAccountStorage().validationStorage[validationFunction].selectors.contains(toSetValue(selector)); } - function _domainSeparator() internal view returns (bytes32) { + function _domainSeparator(ModuleEntity validationModuleEntity) internal view returns (bytes32) { bytes32 result; // Compute the hash without permanently allocating memory @@ -1113,12 +1144,26 @@ abstract contract ModularAccountBase is mstore(fmp, _DOMAIN_SEPARATOR_TYPEHASH) mstore(add(fmp, 0x20), chainid()) mstore(add(fmp, 0x40), address()) - result := keccak256(fmp, 0x60) + mstore(add(fmp, 0x60), validationModuleEntity) + result := keccak256(fmp, 0x80) } return result; } + /// @notice Adds a EIP-712 replay safe hash wrapper to the digest + /// @param hash The hash to wrap in a replay-safe hash + /// @return The replay-safe hash + function _hashStructReplaySafeHash(bytes32 hash) internal pure virtual returns (bytes32) { + bytes32 res; + assembly ("memory-safe") { + mstore(0x00, _REPLAY_SAFE_HASH_TYPEHASH) + mstore(0x20, hash) + res := keccak256(0, 0x40) + } + return res; + } + // A virtual function to detect if a validation function is natively implemented. Used for determining call // buffer allocation. function _validationIsNative(ValidationLookupKey) internal pure virtual returns (bool) { diff --git a/src/account/SemiModularAccountBase.sol b/src/account/SemiModularAccountBase.sol index 5d03cae9..c579a9d7 100644 --- a/src/account/SemiModularAccountBase.sol +++ b/src/account/SemiModularAccountBase.sol @@ -48,10 +48,6 @@ abstract contract SemiModularAccountBase is ModularAccountBase { bool fallbackSignerDisabled; } - // keccak256("ReplaySafeHash(bytes32 hash)") - bytes32 private constant _REPLAY_SAFE_HASH_TYPEHASH = - 0x294a8735843d4afb4f017c76faf3b7731def145ed0025fc9b1d5ce30adf113ff; - // keccak256("ERC6900.SemiModularAccount.Storage") uint256 internal constant _SEMI_MODULAR_ACCOUNT_STORAGE_SLOT = 0x5b9dc9aa943f8fa2653ceceda5e3798f0686455280432166ba472eca0bc17a32; @@ -151,12 +147,6 @@ abstract contract SemiModularAccountBase is ModularAccountBase { if (validationLookupKey.eq(FALLBACK_VALIDATION_LOOKUP_KEY)) { address fallbackSigner = _getFallbackSigner(); - // If called during validateUserOp, this implies that we're doing a deferred validation installation. - // In this case, as the hash is already replay-safe, we don't need to wrap it. - if (msg.sig != this.validateUserOp.selector) { - hash = _replaySafeHash(hash); - } - if (_checkSignature(fallbackSigner, hash, signature)) { return _1271_MAGIC_VALUE; } @@ -233,23 +223,6 @@ abstract contract SemiModularAccountBase is ModularAccountBase { return _storage.fallbackSigner; } - /// @notice Returns the replay-safe hash generated from the passed typed data hash for 1271 validation. - /// @param hash The typed data hash to wrap in a replay-safe hash. - /// @return The replay-safe hash, to be used for 1271 signature generation. - /// - /// @dev Generates a replay-safe hash to wrap a standard typed data hash. This prevents replay attacks by - /// enforcing the domain separator, which includes this contract's address and the chainId. This is only - /// relevant for 1271 validation because UserOp validation relies on the UO hash and the Entrypoint has - /// safeguards. - /// - /// NOTE: Like in signature-based validation modules, the returned hash should be used to generate signatures, - /// but the original hash should be passed to the external-facing function for 1271 validation. - function _replaySafeHash(bytes32 hash) internal view returns (bytes32) { - return MessageHashUtils.toTypedDataHash({ - domainSeparator: _domainSeparator(), structHash: _hashStructReplaySafeHash(hash) - }); - } - function _getSemiModularAccountStorage() internal pure returns (SemiModularAccountStorage storage) { SemiModularAccountStorage storage _storage; assembly ("memory-safe") { @@ -269,19 +242,6 @@ abstract contract SemiModularAccountBase is ModularAccountBase { return validationLookupKey.eq(FALLBACK_VALIDATION_LOOKUP_KEY); } - /// @notice Adds a EIP-712 replay safe hash wrapper to the digest - /// @param hash The hash to wrap in a replay-safe hash - /// @return The replay-safe hash - function _hashStructReplaySafeHash(bytes32 hash) internal pure virtual returns (bytes32) { - bytes32 res; - assembly ("memory-safe") { - mstore(0x00, _REPLAY_SAFE_HASH_TYPEHASH) - mstore(0x20, hash) - res := keccak256(0, 0x40) - } - return res; - } - /// @dev Overrides ModularAccountView. function _isNativeFunction(uint32 selector) internal pure virtual override returns (bool) { return super._isNativeFunction(selector) || selector == uint32(this.updateFallbackSignerData.selector) diff --git a/src/libraries/ValidationLocatorLib.sol b/src/libraries/ValidationLocatorLib.sol index 2c8a05b7..70eb196b 100644 --- a/src/libraries/ValidationLocatorLib.sol +++ b/src/libraries/ValidationLocatorLib.sol @@ -56,6 +56,7 @@ library ValidationLocatorLib { uint8 internal constant _VALIDATION_TYPE_GLOBAL = 1; uint8 internal constant _HAS_DEFERRED_ACTION = 2; uint8 internal constant _IS_DIRECT_CALL_VALIDATION = 4; + uint8 internal constant _IS_SKIP_REPLAY_PROTECTION = 8; function moduleEntity(ValidationLookupKey _lookupKey, ValidationStorage storage validationStorage) internal @@ -198,6 +199,10 @@ library ValidationLocatorLib { return (ValidationLocator.unwrap(locator) & _HAS_DEFERRED_ACTION) != 0; } + function isSkipReplayProtection(ValidationLocator locator) internal pure returns (bool) { + return (ValidationLocator.unwrap(locator) & _IS_SKIP_REPLAY_PROTECTION) != 0; + } + function isDirectCallValidation(ValidationLookupKey _lookupKey) internal pure returns (bool) { return (ValidationLookupKey.unwrap(_lookupKey) & _IS_DIRECT_CALL_VALIDATION) != 0; } @@ -318,6 +323,11 @@ library ValidationLocatorLib { return bytes.concat(abi.encodePacked(options, uint32(validationEntityId)), signature); } + function setSkipReplayProtection(bytes memory signature) internal pure returns (bytes memory result) { + signature[0] = bytes1(uint8(signature[0]) | _IS_SKIP_REPLAY_PROTECTION); + return signature; + } + function packSignatureDirectCall( address directCallValidation, bool _isGlobal, diff --git a/src/modules/validation/SingleSignerValidationModule.sol b/src/modules/validation/SingleSignerValidationModule.sol index 6ba5f933..9f94d5c5 100644 --- a/src/modules/validation/SingleSignerValidationModule.sol +++ b/src/modules/validation/SingleSignerValidationModule.sol @@ -19,7 +19,6 @@ pragma solidity ^0.8.28; import {IModule} from "@erc6900/reference-implementation/interfaces/IModule.sol"; import {IValidationModule} from "@erc6900/reference-implementation/interfaces/IValidationModule.sol"; -import {ReplaySafeWrapper} from "@erc6900/reference-implementation/modules/ReplaySafeWrapper.sol"; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol"; import {ECDSA} from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol"; @@ -41,7 +40,7 @@ import {ModuleBase} from "../ModuleBase.sol"; /// - This validation supports ERC-1271. The signature is valid if it is signed by the owner's private key. /// - This validation supports composition that other validation can relay on entities in this validation to /// validate partially or fully. -contract SingleSignerValidationModule is IValidationModule, ReplaySafeWrapper, ModuleBase { +contract SingleSignerValidationModule is IValidationModule, ModuleBase { uint256 internal constant _SIG_VALIDATION_PASSED = 0; uint256 internal constant _SIG_VALIDATION_FAILED = 1; @@ -124,8 +123,7 @@ contract SingleSignerValidationModule is IValidationModule, ReplaySafeWrapper, M override returns (bytes4) { - bytes32 _replaySafeHash = replaySafeHash(account, digest); - if (_checkSig(signers[entityId][account], _replaySafeHash, signature)) { + if (_checkSig(signers[entityId][account], digest, signature)) { return _1271_MAGIC_VALUE; } return _1271_INVALID; diff --git a/src/modules/validation/WebAuthnValidationModule.sol b/src/modules/validation/WebAuthnValidationModule.sol index 2dfdb502..c911401a 100644 --- a/src/modules/validation/WebAuthnValidationModule.sol +++ b/src/modules/validation/WebAuthnValidationModule.sol @@ -19,7 +19,6 @@ pragma solidity ^0.8.28; import {IModule} from "@erc6900/reference-implementation/interfaces/IModule.sol"; import {IValidationModule} from "@erc6900/reference-implementation/interfaces/IValidationModule.sol"; -import {ReplaySafeWrapper} from "@erc6900/reference-implementation/modules/ReplaySafeWrapper.sol"; import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interfaces/PackedUserOperation.sol"; import {IERC165} from "@openzeppelin/contracts/interfaces/IERC165.sol"; import {WebAuthn} from "webauthn-sol/src/WebAuthn.sol"; @@ -38,7 +37,7 @@ import {ModuleBase} from "../ModuleBase.sol"; /// - This validation supports ERC-1271. The signature is valid if it is signed by the owner's private key. /// - This validation supports composition that other validation can relay on entities in this validation to /// validate partially or fully. -contract WebAuthnValidationModule is IValidationModule, ReplaySafeWrapper, ModuleBase { +contract WebAuthnValidationModule is IValidationModule, ModuleBase { using WebAuthn for WebAuthn.WebAuthnAuth; struct PubKey { @@ -120,8 +119,7 @@ contract WebAuthnValidationModule is IValidationModule, ReplaySafeWrapper, Modul override returns (bytes4) { - bytes32 _replaySafeHash = replaySafeHash(account, digest); - if (_validateSignature(entityId, account, _replaySafeHash, signature)) { + if (_validateSignature(entityId, account, digest, signature)) { return _1271_MAGIC_VALUE; } return _1271_INVALID; diff --git a/test/account/ModularAccount.t.sol b/test/account/ModularAccount.t.sol index a428bd1a..12eb0673 100644 --- a/test/account/ModularAccount.t.sol +++ b/test/account/ModularAccount.t.sol @@ -40,6 +40,7 @@ import {ModularAccountBase} from "../../src/account/ModularAccountBase.sol"; import {SemiModularAccountBytecode} from "../../src/account/SemiModularAccountBytecode.sol"; import {ExecutionInstallDelegate} from "../../src/helpers/ExecutionInstallDelegate.sol"; import {ModuleInstallCommonsLib} from "../../src/libraries/ModuleInstallCommonsLib.sol"; +import {ValidationLocatorLib} from "../../src/libraries/ValidationLocatorLib.sol"; import {SingleSignerValidationModule} from "../../src/modules/validation/SingleSignerValidationModule.sol"; import {Counter} from "../mocks/Counter.sol"; @@ -446,9 +447,14 @@ contract ModularAccountTest is AccountTestBase { function test_isValidSignature() public withSMATest { bytes32 message = keccak256("hello world"); - bytes32 replaySafeHash = _isSMATest - ? _getSMAReplaySafeHash(address(account1), message) - : _getModuleReplaySafeHash(address(account1), address(singleSignerValidationModule), message); + address validationModule = address(0); + if (!_isSMATest) { + validationModule = address(singleSignerValidationModule); + } + + bytes32 replaySafeHash = _getReplaySafeHash( + address(account1), ModuleEntityLib.pack(validationModule, TEST_DEFAULT_VALIDATION_ENTITY_ID), message + ); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, replaySafeHash); @@ -460,6 +466,25 @@ contract ModularAccountTest is AccountTestBase { assertEq(validationResult, bytes4(0x1626ba7e)); } + function test_isValidSignature_withoutReplaySafeHash() public withSMATest { + bytes32 message = keccak256("hello world"); + + address validationModule = address(0); + if (!_isSMATest) { + validationModule = address(singleSignerValidationModule); + } + + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, message); + + bytes memory signature = + _encode1271Signature(_signerValidation, abi.encodePacked(EOA_TYPE_SIGNATURE, r, s, v)); + ValidationLocatorLib.setSkipReplayProtection(signature); + + bytes4 validationResult = IERC1271(address(account1)).isValidSignature(message, signature); + + assertEq(validationResult, bytes4(0x1626ba7e)); + } + // Only need a test case for the negative case, as the positive case is covered by the isValidSignature test function test_signatureValidationFlag_enforce() public withSMATest { // Install a new copy of SingleSignerValidationModule with the signature validation flag set to false @@ -473,9 +498,10 @@ contract ModularAccountTest is AccountTestBase { ); bytes32 message = keccak256("hello world"); - bytes32 replaySafeHash = _isSMATest - ? _getSMAReplaySafeHash(address(account1), message) - : _getModuleReplaySafeHash(address(account1), address(singleSignerValidationModule), message); + + bytes32 replaySafeHash = _getReplaySafeHash( + address(account1), ModuleEntityLib.pack(address(singleSignerValidationModule), newEntityId), message + ); (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, replaySafeHash); diff --git a/test/account/PerHookData.t.sol b/test/account/PerHookData.t.sol index ca15a14e..c296b40c 100644 --- a/test/account/PerHookData.t.sol +++ b/test/account/PerHookData.t.sol @@ -27,6 +27,7 @@ import {PackedUserOperation} from "@eth-infinitism/account-abstraction/interface import {ModularAccountBase} from "../../src/account/ModularAccountBase.sol"; import {ExecutionLib} from "../../src/libraries/ExecutionLib.sol"; +import {ValidationLocatorLib} from "../../src/libraries/ValidationLocatorLib.sol"; import {Counter} from "../mocks/Counter.sol"; import {MockAccessControlHookModule} from "../mocks/modules/MockAccessControlHookModule.sol"; @@ -410,21 +411,17 @@ contract PerHookDataTest is CustomValidationTestBase { bytes memory message = "Hello, world!"; bytes32 messageHash = keccak256(message); - // we use module validation for both cases - bytes32 replaySafeHash = - _getModuleReplaySafeHash(address(account1), address(singleSignerValidationModule), messageHash); - - (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, replaySafeHash); + (uint8 v, bytes32 r, bytes32 s) = vm.sign(owner1Key, messageHash); PreValidationHookData[] memory preValidationHookData = new PreValidationHookData[](1); preValidationHookData[0] = PreValidationHookData({index: 0, validationData: message}); - bytes4 result = account1.isValidSignature( - messageHash, - _encode1271Signature( - _signerValidation, preValidationHookData, abi.encodePacked(EOA_TYPE_SIGNATURE, r, s, v) - ) + bytes memory sig = _encode1271Signature( + _signerValidation, preValidationHookData, abi.encodePacked(EOA_TYPE_SIGNATURE, r, s, v) ); + sig = ValidationLocatorLib.setSkipReplayProtection(sig); + + bytes4 result = account1.isValidSignature(messageHash, sig); assertEq(result, bytes4(0x1626ba7e)); } diff --git a/test/account/SigCallBuffer.t.sol b/test/account/SigCallBuffer.t.sol index 82d5855c..a5163382 100644 --- a/test/account/SigCallBuffer.t.sol +++ b/test/account/SigCallBuffer.t.sol @@ -61,11 +61,13 @@ contract SigCallBufferTest is AccountTestBase { _setUp4ValidationHooks(); + bytes32 replaySafeHash = _getReplaySafeHash(address(account1), _validationFunction, hash); + for (uint256 i = 0; i < 3; i++) { vm.expectCall( address(validationHooks[i]), abi.encodeCall( - IValidationHookModule.preSignatureValidationHook, (uint32(i), beneficiary, hash, "") + IValidationHookModule.preSignatureValidationHook, (uint32(i), beneficiary, replaySafeHash, "") ) ); } @@ -79,7 +81,7 @@ contract SigCallBufferTest is AccountTestBase { address(account1), NEW_VALIDATION_ENTITY_ID, beneficiary, - hash, + replaySafeHash, abi.encodePacked(EOA_TYPE_SIGNATURE) ) ) @@ -107,7 +109,9 @@ contract SigCallBufferTest is AccountTestBase { _setUp4ValidationHooks(); - _expectCalls(fuzzConfig, hash); + bytes32 replaySafeHash = _getReplaySafeHash(address(account1), _validationFunction, hash); + + _expectCalls(fuzzConfig, replaySafeHash); PreValidationHookData[] memory preValidationHookDatasToSend = _generatePreHooksDatasArray(fuzzConfig); @@ -127,7 +131,8 @@ contract SigCallBufferTest is AccountTestBase { function testFuzz_sigCallBuffer(bytes32 hash, FuzzConfig memory fuzzConfig) public withSMATest { _installValidationAndAssocHook(fuzzConfig); - _expectCalls(fuzzConfig, hash); + bytes32 replaySafeHash = _getReplaySafeHash(address(account1), _validationFunction, hash); + _expectCalls(fuzzConfig, replaySafeHash); PreValidationHookData[] memory preValidationHookDatasToSend = _generatePreHooksDatasArray(fuzzConfig); diff --git a/test/modules/WebAuthnValidationModule.t.sol b/test/modules/WebAuthnValidationModule.t.sol index c96924c3..09f23285 100644 --- a/test/modules/WebAuthnValidationModule.t.sol +++ b/test/modules/WebAuthnValidationModule.t.sol @@ -25,6 +25,7 @@ import {Utils, WebAuthnInfo} from "webauthn-sol/test/Utils.sol"; import {ModularAccount} from "../../src/account/ModularAccount.sol"; import {ModularAccountBase} from "../../src/account/ModularAccountBase.sol"; +import {ValidationLocatorLib} from "../../src/libraries/ValidationLocatorLib.sol"; import {WebAuthnValidationModule} from "../../src/modules/validation/WebAuthnValidationModule.sol"; import {AccountTestBase} from "../utils/AccountTestBase.sol"; import {CODELESS_ADDRESS} from "../utils/TestConstants.sol"; @@ -59,7 +60,8 @@ contract WebAuthnValidationModuleTest is AccountTestBase { function test_isValidSignature() external view { bytes32 message = keccak256("message"); - bytes32 challenge = module.replaySafeHash(account, message); + bytes32 challenge = + ModularAccountBase(account).replaySafeHash(message, ModuleEntityLib.pack(address(module), entityId)); assertTrue( ModularAccountBase(account).isValidSignature(message, _get1271SigForChallenge(challenge, 0, 0)) @@ -67,9 +69,18 @@ contract WebAuthnValidationModuleTest is AccountTestBase { ); } + function test_isValidSignature_withoutReplaySafeHash() public view { + bytes32 message = keccak256("message"); + bytes memory signature = _get1271SigForChallenge(message, 0, 0); + ValidationLocatorLib.setSkipReplayProtection(signature); + + assertTrue(ModularAccountBase(account).isValidSignature(message, signature) == 0x1626ba7e); + } + // fuzz message function testFuzz_pass_isValidSignature(bytes32 message) public view { - bytes32 challenge = module.replaySafeHash(account, message); + bytes32 challenge = + ModularAccountBase(account).replaySafeHash(message, ModuleEntityLib.pack(address(module), entityId)); assertTrue( ModularAccountBase(account).isValidSignature(message, _get1271SigForChallenge(challenge, 0, 0)) @@ -79,7 +90,8 @@ contract WebAuthnValidationModuleTest is AccountTestBase { // Fuzz sig function testFuzz_fail_isValidSignature(bytes32 message, uint256 sigR, uint256 sigS) external view { - bytes32 challenge = module.replaySafeHash(account, message); + bytes32 challenge = + ModularAccountBase(account).replaySafeHash(message, ModuleEntityLib.pack(address(module), entityId)); // make sure r, s values isn't the right one by accident. checking 1 should be enough WebAuthnInfo memory webAuthn = Utils.getWebAuthnStruct(challenge); diff --git a/test/utils/AccountTestBase.sol b/test/utils/AccountTestBase.sol index 8dd3b041..2ee12d4f 100644 --- a/test/utils/AccountTestBase.sol +++ b/test/utils/AccountTestBase.sol @@ -380,18 +380,10 @@ abstract contract AccountTestBase is OptimizedTest, ModuleSignatureUtils { bytes memory deferredValidationDatas; { bytes32 digest = _getDeferredInstallStruct( - account, userOpNonce, deferredInstallDeadline, deferredValidationInstallCall + account, userOpNonce, _signerValidation, deferredInstallDeadline, deferredValidationInstallCall ); - bytes32 replaySafeHash; - if (_isSMATest) { - replaySafeHash = digest; - } else { - replaySafeHash = - _getModuleReplaySafeHash(address(account), address(singleSignerValidationModule), digest); - } - - deferredValidationSig = _signRawHash(vm, signingKey, replaySafeHash); + deferredValidationSig = _signRawHash(vm, signingKey, digest); deferredValidationDatas = _packDeferredInstallData( deferredInstallDeadline, defActionValidation, deferredValidationInstallCall diff --git a/test/utils/ModuleSignatureUtils.sol b/test/utils/ModuleSignatureUtils.sol index da1e0a4e..e45b8877 100644 --- a/test/utils/ModuleSignatureUtils.sol +++ b/test/utils/ModuleSignatureUtils.sol @@ -60,9 +60,7 @@ contract ModuleSignatureUtils { bytes32 internal constant _REPLAY_SAFE_HASH_TYPEHASH = keccak256("ReplaySafeHash(bytes32 hash)"); - bytes32 internal constant _ACCOUNT_DOMAIN_SEPARATOR = - keccak256("EIP712Domain(uint256 chainId,address verifyingContract)"); - bytes32 internal constant _MODULE_DOMAIN_SEPARATOR = + bytes32 internal constant _DOMAIN_SEPARATOR = keccak256("EIP712Domain(uint256 chainId,address verifyingContract,bytes32 salt)"); constructor() { @@ -210,21 +208,14 @@ contract ModuleSignatureUtils { return abi.encodePacked(EOA_TYPE_SIGNATURE, r, s, v); } - function _getModuleReplaySafeHash(address account, address validationModule, bytes32 digest) + function _getReplaySafeHash(address account, ModuleEntity validationModuleEntity, bytes32 digest) internal view returns (bytes32) { - bytes32 domainSeparator = - keccak256(abi.encode(_MODULE_DOMAIN_SEPARATOR, block.chainid, validationModule, account)); - - return - MessageHashUtils.toTypedDataHash({domainSeparator: domainSeparator, structHash: _hashStruct(digest)}); - } - - function _getSMAReplaySafeHash(address account, bytes32 digest) internal view returns (bytes32) { return MessageHashUtils.toTypedDataHash({ - domainSeparator: _computeDomainSeparator(account), structHash: _hashStruct(digest) + domainSeparator: _computeDomainSeparator(account, validationModuleEntity), + structHash: _hashStruct(digest) }); } @@ -298,10 +289,11 @@ contract ModuleSignatureUtils { function _getDeferredInstallStruct( ModularAccount account, uint256 userOpNonce, + ModuleEntity validationModuleEntity, uint48 deadline, bytes memory selfCall ) internal view returns (bytes32) { - bytes32 domainSeparator = _computeDomainSeparator(address(account)); + bytes32 domainSeparator = _computeDomainSeparator(address(account), validationModuleEntity); bytes32 selfCallHash = keccak256(selfCall); @@ -311,9 +303,14 @@ contract ModuleSignatureUtils { }); } - // EIP-712 helpers for acount - function _computeDomainSeparator(address account) internal view returns (bytes32) { - return keccak256(abi.encode(_ACCOUNT_DOMAIN_SEPARATOR, block.chainid, account)); + // EIP-712 helpers for account + function _computeDomainSeparator(address account, ModuleEntity validationModuleEntity) + internal + view + returns (bytes32) + { + bytes32 ret = keccak256(abi.encode(_DOMAIN_SEPARATOR, block.chainid, account, validationModuleEntity)); + return ret; } // EIP-712 helpers for acount