diff --git a/src/test/validium/FastWithdrawVault.t.sol b/src/test/validium/FastWithdrawVault.t.sol index 1860cce8..92930e3e 100644 --- a/src/test/validium/FastWithdrawVault.t.sol +++ b/src/test/validium/FastWithdrawVault.t.sol @@ -267,7 +267,8 @@ contract FastWithdrawVaultTest is ValidiumTestBase { address(counterpartGateway), address(messenger), address(template), - address(factory) + address(factory), + address(rollup) ) ) ); diff --git a/src/test/validium/L1ERC20GatewayValidium.t.sol b/src/test/validium/L1ERC20GatewayValidium.t.sol index 240909d3..fd6efaed 100644 --- a/src/test/validium/L1ERC20GatewayValidium.t.sol +++ b/src/test/validium/L1ERC20GatewayValidium.t.sol @@ -14,6 +14,7 @@ import {AddressAliasHelper} from "../../libraries/common/AddressAliasHelper.sol" import {IL1ERC20GatewayValidium} from "../../validium/IL1ERC20GatewayValidium.sol"; import {IL2ERC20GatewayValidium} from "../../validium/IL2ERC20GatewayValidium.sol"; import {L1ERC20GatewayValidium} from "../../validium/L1ERC20GatewayValidium.sol"; +import {ScrollChainValidium} from "../../validium/ScrollChainValidium.sol"; import {TransferReentrantToken} from "../mocks/tokens/TransferReentrantToken.sol"; import {FeeOnTransferToken} from "../mocks/tokens/FeeOnTransferToken.sol"; @@ -128,47 +129,63 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase { _deposit(sender, amount, recipient, gasLimit); } + function testDepositERC20WrongKey( + uint256 amount, + bytes memory recipient, + uint256 gasLimit + ) public { + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); + hevm.expectRevert(ScrollChainValidium.ErrorUnknownEncryptionKey.selector); + gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId + 1); + } + function testDepositReentrantToken(uint256 amount) public { + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); + // should revert, reentrant before transfer reentrantToken.setReentrantCall( address(gateway), 0, abi.encodeWithSignature( - "depositERC20(address,bytes,uint256,uint256)", + "depositERC20(address,bytes,uint256,uint256,uint256)", address(reentrantToken), new bytes(0), amount, - defaultGasLimit + defaultGasLimit, + keyId ), true ); amount = bound(amount, 1, reentrantToken.balanceOf(address(this))); hevm.expectRevert("ReentrancyGuard: reentrant call"); - gateway.depositERC20(address(reentrantToken), new bytes(0), amount, defaultGasLimit); + + gateway.depositERC20(address(reentrantToken), new bytes(0), amount, defaultGasLimit, keyId); // should revert, reentrant after transfer reentrantToken.setReentrantCall( address(gateway), 0, abi.encodeWithSignature( - "depositERC20(address,bytes,uint256,uint256)", + "depositERC20(address,bytes,uint256,uint256,uint256)", address(reentrantToken), new bytes(0), amount, - defaultGasLimit + defaultGasLimit, + keyId ), false ); amount = bound(amount, 1, reentrantToken.balanceOf(address(this))); hevm.expectRevert("ReentrancyGuard: reentrant call"); - gateway.depositERC20(address(reentrantToken), new bytes(0), amount, defaultGasLimit); + gateway.depositERC20(address(reentrantToken), new bytes(0), amount, defaultGasLimit, keyId); } function testFeeOnTransferTokenFailed(uint256 amount) public { feeToken.setFeeRate(1e9); amount = bound(amount, 1, feeToken.balanceOf(address(this))); + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); hevm.expectRevert(L1ERC20GatewayValidium.ErrorAmountIsZero.selector); - gateway.depositERC20(address(feeToken), new bytes(0), amount, defaultGasLimit); + gateway.depositERC20(address(feeToken), new bytes(0), amount, defaultGasLimit, keyId); } function testFeeOnTransferTokenSucceed(uint256 amount, uint256 feeRate) public { @@ -179,7 +196,8 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase { // should succeed, for valid amount uint256 balanceBefore = feeToken.balanceOf(address(gateway)); uint256 fee = (amount * feeRate) / 1e9; - gateway.depositERC20(address(feeToken), new bytes(0), amount, defaultGasLimit); + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); + gateway.depositERC20(address(feeToken), new bytes(0), amount, defaultGasLimit, keyId); uint256 balanceAfter = feeToken.balanceOf(address(gateway)); assertEq(balanceBefore + amount - fee, balanceAfter); } @@ -245,7 +263,8 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase { amount = bound(amount, 1, l1Token.balanceOf(address(this))); // deposit some token to L1ERC20GatewayValidium - gateway.depositERC20(address(l1Token), new bytes(0), amount, defaultGasLimit); + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); + gateway.depositERC20(address(l1Token), new bytes(0), amount, defaultGasLimit, keyId); // do finalize withdraw token bytes memory message = abi.encodeWithSelector( @@ -302,7 +321,8 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase { amount = bound(amount, 1, l1Token.balanceOf(address(this))); // deposit some token to L1ERC20GatewayValidium - gateway.depositERC20(address(l1Token), new bytes(0), amount, defaultGasLimit); + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); + gateway.depositERC20(address(l1Token), new bytes(0), amount, defaultGasLimit, keyId); // do finalize withdraw token bytes memory message = abi.encodeWithSelector( @@ -385,11 +405,12 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase { ); if (amount == 0) { + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); hevm.expectRevert(L1ERC20GatewayValidium.ErrorAmountIsZero.selector); if (from == address(this)) { - gateway.depositERC20(address(l1Token), recipient, amount, gasLimit); + gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId); } else { - gateway.depositERC20(address(l1Token), from, recipient, amount, gasLimit); + gateway.depositERC20(address(l1Token), from, recipient, amount, gasLimit, keyId); } } else { // emit QueueTransaction from L1MessageQueueV2 @@ -412,10 +433,11 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase { uint256 gatewayBalance = l1Token.balanceOf(address(gateway)); uint256 feeVaultBalance = address(feeVault).balance; assertEq(l1Messenger.messageSendTimestamp(keccak256(xDomainCalldata)), 0); + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); if (from == address(this)) { - gateway.depositERC20(address(l1Token), recipient, amount, gasLimit); + gateway.depositERC20(address(l1Token), recipient, amount, gasLimit, keyId); } else { - gateway.depositERC20(address(l1Token), from, recipient, amount, gasLimit); + gateway.depositERC20(address(l1Token), from, recipient, amount, gasLimit, keyId); } assertEq(amount + gatewayBalance, l1Token.balanceOf(address(gateway))); assertEq(feeVaultBalance, address(feeVault).balance); @@ -433,7 +455,8 @@ contract L1ERC20GatewayValidiumTest is ValidiumTestBase { address(counterpartGateway), address(messenger), address(template), - address(factory) + address(factory), + address(rollup) ) ) ); diff --git a/src/test/validium/L1WETHGatewayValidium.t.sol b/src/test/validium/L1WETHGatewayValidium.t.sol index 1a61fe5d..53b9ff5b 100644 --- a/src/test/validium/L1WETHGatewayValidium.t.sol +++ b/src/test/validium/L1WETHGatewayValidium.t.sol @@ -108,13 +108,15 @@ contract L1WETHGatewayValidiumTest is ValidiumTestBase { message ); + (uint256 keyId, ) = rollup.getLatestEncryptionKey(); + if (amount == 0) { hevm.expectRevert(L1ERC20GatewayValidium.ErrorAmountIsZero.selector); - wethGateway.deposit(recipient, amount); + wethGateway.deposit(recipient, amount, keyId); } else { // revert when ErrorInsufficientValue hevm.expectRevert(L1WETHGatewayValidium.ErrorInsufficientValue.selector); - wethGateway.deposit{value: amount - 1}(recipient, amount); + wethGateway.deposit{value: amount - 1}(recipient, amount, keyId); // emit QueueTransaction from L1MessageQueueV2 { @@ -137,7 +139,7 @@ contract L1WETHGatewayValidiumTest is ValidiumTestBase { uint256 gatewayBalance = weth.balanceOf(address(gateway)); uint256 feeVaultBalance = address(feeVault).balance; assertEq(l1Messenger.messageSendTimestamp(keccak256(xDomainCalldata)), 0); - wethGateway.deposit{value: amount}(recipient, amount); + wethGateway.deposit{value: amount}(recipient, amount, keyId); assertEq(ethBalance - amount, address(this).balance); assertEq(amount + gatewayBalance, weth.balanceOf(address(gateway))); assertEq(feeVaultBalance, address(feeVault).balance); @@ -155,7 +157,8 @@ contract L1WETHGatewayValidiumTest is ValidiumTestBase { address(counterpartGateway), address(messenger), address(template), - address(factory) + address(factory), + address(rollup) ) ) ); diff --git a/src/test/validium/ValidiumTestBase.t.sol b/src/test/validium/ValidiumTestBase.t.sol index 3cb89861..a48b88ee 100644 --- a/src/test/validium/ValidiumTestBase.t.sol +++ b/src/test/validium/ValidiumTestBase.t.sol @@ -130,6 +130,8 @@ abstract contract ValidiumTestBase is ScrollTestBase { address(new ScrollChainValidium(_chainId, address(messageQueueV2), address(verifier))) ); rollup.initialize(address(this)); + rollup.grantRole(rollup.KEY_MANAGER_ROLE(), address(this)); + rollup.registerNewEncryptionKey(hex"123456789012345678901234567890123456789012345678901234567890123456"); // Make nonzero block.timestamp hevm.warp(1); diff --git a/src/validium/IL1ERC20GatewayValidium.sol b/src/validium/IL1ERC20GatewayValidium.sol index 314e8c55..343b74b6 100644 --- a/src/validium/IL1ERC20GatewayValidium.sol +++ b/src/validium/IL1ERC20GatewayValidium.sol @@ -61,7 +61,8 @@ interface IL1ERC20GatewayValidium { address _token, bytes memory _to, uint256 _amount, - uint256 _gasLimit + uint256 _gasLimit, + uint256 _keyId ) external payable; /// @notice Deposit some token to a recipient's account on L2. @@ -76,7 +77,8 @@ interface IL1ERC20GatewayValidium { address _realSender, bytes memory _to, uint256 _amount, - uint256 _gasLimit + uint256 _gasLimit, + uint256 _keyId ) external payable; /// @notice Complete ERC20 withdraw from L2 to L1 and send fund to recipient's account in L1. diff --git a/src/validium/IScrollChainValidium.sol b/src/validium/IScrollChainValidium.sol index f7013d18..61467d05 100644 --- a/src/validium/IScrollChainValidium.sol +++ b/src/validium/IScrollChainValidium.sol @@ -24,6 +24,12 @@ interface IScrollChainValidium { /// @param withdrawRoot The merkle root on layer2 after this batch. event FinalizeBatch(uint256 indexed batchIndex, bytes32 indexed batchHash, bytes32 stateRoot, bytes32 withdrawRoot); + /// @notice Emitted when a new encryption key is added. + /// @param keyId The incremental index of the key. + /// @param msgIndex The message queue index at the time of key rotation. + /// @param key The encryption key. + event NewEncryptionKey(uint256 indexed keyId, uint256 msgIndex, bytes key); + /************************* * Public View Functions * *************************/ @@ -50,6 +56,14 @@ interface IScrollChainValidium { /// @return Whether the batch is finalized by batch index. function isBatchFinalized(uint256 batchIndex) external view returns (bool); + /// @return The key-id of the latest encryption key. + /// @return The latest encryption key. + function getLatestEncryptionKey() external view returns (uint256, bytes memory); + + /// @param keyId The incremental index for the encryption key. + /// @return The encryption key with the given key-id. + function getEncryptionKey(uint256 keyId) external view returns (bytes memory); + /***************************** * Public Mutating Functions * *****************************/ diff --git a/src/validium/L1ERC20GatewayValidium.sol b/src/validium/L1ERC20GatewayValidium.sol index 8bdba236..1580e439 100644 --- a/src/validium/L1ERC20GatewayValidium.sol +++ b/src/validium/L1ERC20GatewayValidium.sol @@ -10,6 +10,7 @@ import {SafeERC20Upgradeable} from "@openzeppelin/contracts-upgradeable/token/ER import {IL1ScrollMessenger} from "../L1/IL1ScrollMessenger.sol"; import {IL1ERC20GatewayValidium} from "./IL1ERC20GatewayValidium.sol"; import {IL2ERC20GatewayValidium} from "./IL2ERC20GatewayValidium.sol"; +import {IScrollChainValidium} from "./IScrollChainValidium.sol"; import {ScrollGatewayBase} from "../libraries/gateway/ScrollGatewayBase.sol"; @@ -43,6 +44,9 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium { /// @notice The address of ScrollStandardERC20Factory contract in L2. address public immutable l2TokenFactory; + /// @notice The address of ScrollChainValidium contract in L2. + address public immutable scrollChainValidium; + /************* * Variables * *************/ @@ -67,12 +71,14 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium { address _counterpart, address _messenger, address _l2TokenImplementation, - address _l2TokenFactory + address _l2TokenFactory, + address _scrollChainValidium ) ScrollGatewayBase(_counterpart, address(0), _messenger) { _disableInitializers(); l2TokenImplementation = _l2TokenImplementation; l2TokenFactory = _l2TokenFactory; + scrollChainValidium = _scrollChainValidium; } /// @notice Initialize the storage of L1ERC20GatewayValidium. @@ -102,9 +108,10 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium { address _token, bytes memory _to, uint256 _amount, - uint256 _gasLimit + uint256 _gasLimit, + uint256 _keyId ) external payable override { - _deposit(_token, _msgSender(), _to, _amount, new bytes(0), _gasLimit); + _deposit(_token, _msgSender(), _to, _amount, new bytes(0), _gasLimit, _keyId); } /// @inheritdoc IL1ERC20GatewayValidium @@ -113,9 +120,10 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium { address _realSender, bytes memory _to, uint256 _amount, - uint256 _gasLimit + uint256 _gasLimit, + uint256 _keyId ) external payable override { - _deposit(_token, _realSender, _to, _amount, new bytes(0), _gasLimit); + _deposit(_token, _realSender, _to, _amount, new bytes(0), _gasLimit, _keyId); } /// @inheritdoc IL1ERC20GatewayValidium @@ -192,8 +200,12 @@ contract L1ERC20GatewayValidium is ScrollGatewayBase, IL1ERC20GatewayValidium { bytes memory _to, uint256 _amount, bytes memory _data, - uint256 _gasLimit + uint256 _gasLimit, + uint256 _keyId ) internal virtual nonReentrant { + // Validate the encryption key with the given key-id. + IScrollChainValidium(scrollChainValidium).getEncryptionKey(_keyId); + // 1. Transfer token into this contract. _amount = _transferERC20In(_msgSender(), _token, _amount); if (_amount == 0) revert ErrorAmountIsZero(); diff --git a/src/validium/L1WETHGatewayValidium.sol b/src/validium/L1WETHGatewayValidium.sol index 1aa9be27..579fca94 100644 --- a/src/validium/L1WETHGatewayValidium.sol +++ b/src/validium/L1WETHGatewayValidium.sol @@ -50,7 +50,11 @@ contract L1WETHGatewayValidium { /// @notice Deposit ETH to L2 through the `L1ERC20GatewayValidium` contract. /// @param _to The encrypted address of recipient in L2 to receive the token. - function deposit(bytes memory _to, uint256 _amount) external payable { + function deposit( + bytes memory _to, + uint256 _amount, + uint256 _keyId + ) external payable { if (msg.value < _amount) revert ErrorInsufficientValue(); // WETH deposit is safe. @@ -62,7 +66,8 @@ contract L1WETHGatewayValidium { msg.sender, _to, _amount, - GAS_LIMIT + GAS_LIMIT, + _keyId ); } } diff --git a/src/validium/ScrollChainValidium.sol b/src/validium/ScrollChainValidium.sol index e703beb7..a4880a3f 100644 --- a/src/validium/ScrollChainValidium.sol +++ b/src/validium/ScrollChainValidium.sol @@ -41,6 +41,15 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I /// @dev Thrown when given batch is not committed before. error ErrorBatchNotCommitted(); + /// @dev Error thrown when encryption key length is invalid. + error ErrorInvalidEncryptionKeyLength(); + + /// @dev Error thrown the user attempts to use an encryption key that is unknown. + error ErrorUnknownEncryptionKey(); + + /// @dev Error thrown the user attempts to use an encryption key that is deprecated. + error ErrorDeprecatedEncryptionKey(); + /************* * Constants * *************/ @@ -54,6 +63,9 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I /// @notice The role for prover who can finalize batch. bytes32 public constant PROVER_ROLE = keccak256("PROVER_ROLE"); + /// @notice The role that can rotate encryption keys. + bytes32 public constant KEY_MANAGER_ROLE = keccak256("KEY_MANAGER_ROLE"); + /*********************** * Immutable Variables * ***********************/ @@ -67,6 +79,17 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I /// @notice The address of `MultipleVersionRollupVerifier`. address public immutable verifier; + /*********** + * Structs * + ***********/ + + struct EncryptionKey { + // The on-chain message index when the key was set. + uint256 msgIndex; + // The 33-bytes compressed public key, i.e. encryption key. + bytes key; + } + /********************* * Storage Variables * *********************/ @@ -86,6 +109,9 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I /// @dev Mapping from batch index to corresponding withdraw root in Validium L3. mapping(uint256 => bytes32) public override withdrawRoots; + /// @dev An array of encryption keys. + EncryptionKey[] public encryptionKeys; + /*************** * Constructor * ***************/ @@ -127,6 +153,22 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I return _batchIndex <= lastFinalizedBatchIndex; } + /// @inheritdoc IScrollChainValidium + function getLatestEncryptionKey() external view override returns (uint256, bytes memory) { + uint256 _numKeys = encryptionKeys.length; + if (_numKeys == 0) revert ErrorUnknownEncryptionKey(); + return (_numKeys - 1, encryptionKeys[_numKeys - 1].key); + } + + /// @inheritdoc IScrollChainValidium + function getEncryptionKey(uint256 _keyId) external view override returns (bytes memory) { + uint256 _numKeys = encryptionKeys.length; + if (_numKeys == 0) revert ErrorUnknownEncryptionKey(); + if (_keyId >= _numKeys) revert ErrorUnknownEncryptionKey(); + if (_keyId < _numKeys - 1) revert ErrorDeprecatedEncryptionKey(); + return encryptionKeys[_numKeys - 1].key; + } + /***************************** * Public Mutating Functions * *****************************/ @@ -232,6 +274,17 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I * Restricted Functions * ************************/ + function registerNewEncryptionKey(bytes memory _key) external onlyRole(KEY_MANAGER_ROLE) { + if (_key.length != 33) revert ErrorInvalidEncryptionKeyLength(); + uint256 _keyId = encryptionKeys.length; + + // The message from `nextCrossDomainMessageIndex` will utilise the newly registered encryption key. + uint256 _msgIndex = IL1MessageQueueV2(messageQueueV2).nextCrossDomainMessageIndex(); + encryptionKeys.push(EncryptionKey(_msgIndex, _key)); + + emit NewEncryptionKey(_keyId, _msgIndex, _key); + } + /// @notice Pause the contract /// @param _status The pause status to update. function setPause(bool _status) external onlyRole(DEFAULT_ADMIN_ROLE) { @@ -308,7 +361,11 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I bytes32 postStateRoot = stateRoots[batchIndex]; bytes32 withdrawRoot = withdrawRoots[batchIndex]; - // @todo public inputs TBD + // Get the encryption key at the time of on-chain message queue index. + bytes memory encryptionKey = totalL1MessagesPoppedOverall == 0 + ? _getEncryptionKey(0) + : _getEncryptionKey(totalL1MessagesPoppedOverall - 1); + bytes memory publicInputs = abi.encodePacked( layer2ChainId, messageQueueHash, @@ -317,7 +374,8 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I committedBatches[prevBatchIndex], // _prevBatchHash postStateRoot, batchHash, - withdrawRoot + withdrawRoot, + encryptionKey ); // verify bundle, choose the correct verifier based on the last batch @@ -364,4 +422,22 @@ contract ScrollChainValidium is AccessControlUpgradeable, PausableUpgradeable, I revert ErrorIncorrectBatchHash(); } } + + /// @dev Internal function to get the relevant encryption key that was used to encrypt messages up to the provided message index. + /// @param _msgIndex The on-chain message queue index being finalised. + /// @return The encryption key used at the time of the provided on-chain message queue index. + function _getEncryptionKey(uint256 _msgIndex) internal view returns (bytes memory) { + // Start from the "latest" key and continue fetching keys until we find the key + // that was rotated before the message index we have been provided. + uint256 _numKeys = encryptionKeys.length; + if (_numKeys == 0) revert ErrorUnknownEncryptionKey(); + EncryptionKey memory _encryptionKey = encryptionKeys[--_numKeys]; + + while (_encryptionKey.msgIndex > _msgIndex) { + if (_numKeys == 0) revert ErrorUnknownEncryptionKey(); + _encryptionKey = encryptionKeys[--_numKeys]; + } + + return _encryptionKey.key; + } }