diff --git a/contracts/extension/interface/plugin/IEntrypointOverrideable.sol b/contracts/extension/interface/plugin/IEntrypointOverrideable.sol new file mode 100644 index 000000000..2830a64c6 --- /dev/null +++ b/contracts/extension/interface/plugin/IEntrypointOverrideable.sol @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.0; + +interface IEntrypointOverrideable { + struct ExtensionMap { + bytes4 selector; + address extension; + } + + function overrideExtensionForFunction(bytes4 _selector, address _extension) external; + + function getAllOverriden() external view returns (ExtensionMap[] memory functionExtensionPairs); +} diff --git a/contracts/extension/plugin/EntrypointOverrideable.sol b/contracts/extension/plugin/EntrypointOverrideable.sol new file mode 100644 index 000000000..f6f29f34d --- /dev/null +++ b/contracts/extension/plugin/EntrypointOverrideable.sol @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.0; + +import "../interface/plugin/IMap.sol"; +import "../interface/plugin/IEntrypointOverrideable.sol"; +import "../../extension/Multicall.sol"; +import "../../openzeppelin-presets/utils/EnumerableSet.sol"; + +library EntrypointOverrideableStorage { + bytes32 public constant ENTRYPOINT_OVERRIDEABLE_STORAGE_POSITION = keccak256("entrypoint.overrideable.storage"); + + struct Data { + EnumerableSet.Bytes32Set functions; + mapping(bytes4 => address) extensionOverride; + } + + function entrypointStorage() internal pure returns (Data storage entrypointData) { + bytes32 position = ENTRYPOINT_OVERRIDEABLE_STORAGE_POSITION; + assembly { + entrypointData.slot := position + } + } +} + +abstract contract EntrypointOverrideable is Multicall, IEntrypointOverrideable { + using EnumerableSet for EnumerableSet.Bytes32Set; + + /*/////////////////////////////////////////////////////////////// + State variables + //////////////////////////////////////////////////////////////*/ + + address public immutable functionMap; + + /*/////////////////////////////////////////////////////////////// + Constructor + initializer logic + //////////////////////////////////////////////////////////////*/ + + constructor(address _functionMap) { + functionMap = _functionMap; + } + + /*/////////////////////////////////////////////////////////////// + Generic contract logic + //////////////////////////////////////////////////////////////*/ + + fallback() external payable virtual { + address extension = _getExtensionOverride(msg.sig); + if (extension == address(0)) { + extension = IMap(functionMap).getExtensionForFunction(msg.sig); + } + + _delegate(extension); + } + + receive() external payable {} + + function _delegate(address implementation) internal virtual { + assembly { + // Copy msg.data. We take full control of memory in this inline assembly + // block because it will not return to Solidity code. We overwrite the + // Solidity scratch pad at memory position 0. + calldatacopy(0, 0, calldatasize()) + + // Call the implementation. + // out and outsize are 0 because we don't know the size yet. + let result := delegatecall(gas(), implementation, 0, calldatasize(), 0, 0) + + // Copy the returned data. + returndatacopy(0, 0, returndatasize()) + + switch result + // delegatecall returns 0 on error. + case 0 { + revert(0, returndatasize()) + } + default { + return(0, returndatasize()) + } + } + } + + /*/////////////////////////////////////////////////////////////// + External functions + //////////////////////////////////////////////////////////////*/ + + function overrideExtensionForFunction(bytes4 _selector, address _extension) external { + require(_canOverrideExtensions(), "Entrypoint: cannot override extensions."); + + EntrypointOverrideableStorage.Data storage data = EntrypointOverrideableStorage.entrypointStorage(); + data.extensionOverride[_selector] = _extension; + + if (_extension != address(0)) { + data.functions.add(bytes32(_selector)); + } else { + data.functions.remove(bytes32(_selector)); + } + } + + function getAllOverriden() external view returns (ExtensionMap[] memory functionExtensionPairs) { + EntrypointOverrideableStorage.Data storage data = EntrypointOverrideableStorage.entrypointStorage(); + uint256 len = data.functions.length(); + functionExtensionPairs = new ExtensionMap[](len); + + for (uint256 i = 0; i < len; i += 1) { + bytes4 selector = bytes4(data.functions.at(i)); + functionExtensionPairs[i] = ExtensionMap(selector, data.extensionOverride[selector]); + } + } + + /*/////////////////////////////////////////////////////////////// + Internal functions + //////////////////////////////////////////////////////////////*/ + + function _getExtensionOverride(bytes4 _selector) internal view returns (address) { + EntrypointOverrideableStorage.Data storage data = EntrypointOverrideableStorage.entrypointStorage(); + return data.extensionOverride[_selector]; + } + + function _canOverrideExtensions() internal view virtual returns (bool); +} diff --git a/docs/EntrypointOverrideable.md b/docs/EntrypointOverrideable.md new file mode 100644 index 000000000..ecb331b5b --- /dev/null +++ b/docs/EntrypointOverrideable.md @@ -0,0 +1,88 @@ +# EntrypointOverrideable + + + + + + + + + +## Methods + +### functionMap + +```solidity +function functionMap() external view returns (address) +``` + + + + + + +#### Returns + +| Name | Type | Description | +|---|---|---| +| _0 | address | undefined | + +### getAllOverriden + +```solidity +function getAllOverriden() external view returns (struct IEntrypointOverrideable.ExtensionMap[] functionExtensionPairs) +``` + + + + + + +#### Returns + +| Name | Type | Description | +|---|---|---| +| functionExtensionPairs | IEntrypointOverrideable.ExtensionMap[] | undefined | + +### multicall + +```solidity +function multicall(bytes[] data) external nonpayable returns (bytes[] results) +``` + +Receives and executes a batch of function calls on this contract. + +*Receives and executes a batch of function calls on this contract.* + +#### Parameters + +| Name | Type | Description | +|---|---|---| +| data | bytes[] | The bytes data that makes up the batch of function calls to execute. | + +#### Returns + +| Name | Type | Description | +|---|---|---| +| results | bytes[] | The bytes data that makes up the result of the batch of function calls executed. | + +### overrideExtensionForFunction + +```solidity +function overrideExtensionForFunction(bytes4 _selector, address _extension) external nonpayable +``` + + + + + +#### Parameters + +| Name | Type | Description | +|---|---|---| +| _selector | bytes4 | undefined | +| _extension | address | undefined | + + + + diff --git a/docs/EntrypointOverrideableStorage.md b/docs/EntrypointOverrideableStorage.md new file mode 100644 index 000000000..0d1a573dc --- /dev/null +++ b/docs/EntrypointOverrideableStorage.md @@ -0,0 +1,32 @@ +# EntrypointOverrideableStorage + + + + + + + + + +## Methods + +### ENTRYPOINT_OVERRIDEABLE_STORAGE_POSITION + +```solidity +function ENTRYPOINT_OVERRIDEABLE_STORAGE_POSITION() external view returns (bytes32) +``` + + + + + + +#### Returns + +| Name | Type | Description | +|---|---|---| +| _0 | bytes32 | undefined | + + + + diff --git a/docs/IEntrypointOverrideable.md b/docs/IEntrypointOverrideable.md new file mode 100644 index 000000000..09231a96e --- /dev/null +++ b/docs/IEntrypointOverrideable.md @@ -0,0 +1,49 @@ +# IEntrypointOverrideable + + + + + + + + + +## Methods + +### getAllOverriden + +```solidity +function getAllOverriden() external view returns (struct IEntrypointOverrideable.ExtensionMap[] functionExtensionPairs) +``` + + + + + + +#### Returns + +| Name | Type | Description | +|---|---|---| +| functionExtensionPairs | IEntrypointOverrideable.ExtensionMap[] | undefined | + +### overrideExtensionForFunction + +```solidity +function overrideExtensionForFunction(bytes4 _selector, address _extension) external nonpayable +``` + + + + + +#### Parameters + +| Name | Type | Description | +|---|---|---| +| _selector | bytes4 | undefined | +| _extension | address | undefined | + + + + diff --git a/src/test/plugin/EntrypointOverrideable.t.sol b/src/test/plugin/EntrypointOverrideable.t.sol new file mode 100644 index 000000000..11a6cd95a --- /dev/null +++ b/src/test/plugin/EntrypointOverrideable.t.sol @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +pragma solidity ^0.8.0; + +import "contracts/extension/plugin/Map.sol"; +import "contracts/extension/plugin/EntrypointOverrideable.sol"; +import { BaseTest } from "../utils/BaseTest.sol"; + +contract Entrypoint is EntrypointOverrideable { + constructor(address _functionMap) EntrypointOverrideable(_functionMap) {} + + function _canOverrideExtensions() internal pure override returns (bool) { + return true; + } +} + +library CounterStorage { + bytes32 public constant COUNTER_STORAGE_POSITION = keccak256("counter.storage"); + + struct Data { + uint256 number; + } + + function counterStorage() internal pure returns (Data storage counterData) { + bytes32 position = COUNTER_STORAGE_POSITION; + assembly { + counterData.slot := position + } + } +} + +contract Counter { + function number() external view returns (uint256) { + CounterStorage.Data storage data = CounterStorage.counterStorage(); + return data.number; + } + + function setNumber(uint256 _newNum) external { + CounterStorage.Data storage data = CounterStorage.counterStorage(); + data.number = _newNum; + } + + function doubleNumber() external { + CounterStorage.Data storage data = CounterStorage.counterStorage(); + data.number *= 4; // Buggy! + } +} + +contract CounterAlternate { + function doubleNumber() external { + CounterStorage.Data storage data = CounterStorage.counterStorage(); + data.number *= 2; // Fixed! + } +} + +contract EntrypointOverrideableTest is BaseTest { + address internal map; + address internal entrypoint; + + address internal counter; + address internal counterAlternate; + + function setUp() public override { + super.setUp(); + + counter = address(new Counter()); + counterAlternate = address(new CounterAlternate()); + + IMap.ExtensionMap[] memory extensionMaps = new IMap.ExtensionMap[](3); + extensionMaps[0] = IMap.ExtensionMap(Counter.number.selector, counter); + extensionMaps[1] = IMap.ExtensionMap(Counter.setNumber.selector, counter); + extensionMaps[2] = IMap.ExtensionMap(Counter.doubleNumber.selector, counter); + + map = address(new Map(extensionMaps)); + entrypoint = address(new Entrypoint(map)); + } + + function test_state_overrideExtensionForFunction() external { + // Set number. + uint256 num = 5; + Counter(entrypoint).setNumber(num); + assertEq(Counter(entrypoint).number(), num); + + // Double number. Bug: it quadruples the number. + Counter(entrypoint).doubleNumber(); + assertEq(Counter(entrypoint).number(), num * 4); + + // Reset number. + Counter(entrypoint).setNumber(num); + assertEq(Counter(entrypoint).number(), num); + + // Fix the extension for `doubleNumber`. + Entrypoint(payable(entrypoint)).overrideExtensionForFunction(Counter.doubleNumber.selector, counterAlternate); + + // Double number. Fixed: it doubles the number. + Counter(entrypoint).doubleNumber(); + assertEq(Counter(entrypoint).number(), num * 2); + + // Get and check all overriden extensions. + IEntrypointOverrideable.ExtensionMap[] memory extensionMapsStored = Entrypoint(payable(entrypoint)) + .getAllOverriden(); + assertEq(extensionMapsStored.length, 1); + assertEq(extensionMapsStored[0].extension, counterAlternate); + assertEq(extensionMapsStored[0].selector, Counter.doubleNumber.selector); + } +}