diff --git a/src/IWorldID.sol b/src/IWorldID.sol index 1c4f985..95220c7 100644 --- a/src/IWorldID.sol +++ b/src/IWorldID.sol @@ -4,5 +4,23 @@ pragma solidity ^0.8.0; interface IWorldID { - function verifyIdentity(address user) external view returns (bool); + /// @notice Verifies a WorldID zero knowledge proof. + /// @dev Note that a double-signaling check is not included here, and should be carried by the + /// caller. + /// @dev It is highly recommended that the implementation is restricted to `view` if possible. + /// + /// @param root The root of the Merkle tree + /// @param signalHash A keccak256 hash of the Semaphore signal + /// @param nullifierHash The nullifier hash + /// @param externalNullifierHash A keccak256 hash of the external nullifier + /// @param proof The zero-knowledge proof + /// + /// @custom:reverts string If the `proof` is invalid. + function verifyProof( + uint256 root, + uint256 signalHash, + uint256 nullifierHash, + uint256 externalNullifierHash, + uint256[8] calldata proof + ) external view; } diff --git a/src/Wallet.sol b/src/Wallet.sol index 5238665..aae3dcd 100644 --- a/src/Wallet.sol +++ b/src/Wallet.sol @@ -5,13 +5,13 @@ import "./interfaces/IERC20.sol"; import "./IWorldID.sol"; contract Wallet { - address public owner; - IWorldID public worldID; - IERC20 public usdt; + address public immutable owner; + IWorldID public immutable worldID; mapping(address => Transaction[]) public transactions; mapping(address => mapping(address => uint256)) token_balance; mapping(address => bool) public supportedTokens; + mapping(uint256 => bool) public nullifierHashes; struct Transaction { uint256 amount; @@ -23,9 +23,15 @@ contract Wallet { _; } - // Modifier to ensure that only verified users can call the function - modifier onlyVerified(address user) { - require(worldID.verifyIdentity(user), "User not verified"); + // Modifier to ensure that only users with a valid ZK proof can call the function + modifier onlyValidProof(bytes calldata _zkProof) { + { + (uint256 _root, uint256 _signalHash, uint256 _nullifierHash, uint256 _externalNullifierHash, uint256[8] memory _proof) + = abi.decode(_zkProof,(uint256, uint256, uint256, uint256, uint256[8])); + if (nullifierHashes[_nullifierHash]) revert("Invalid Nullifier"); + worldID.verifyProof(_root, _signalHash, _nullifierHash, _externalNullifierHash,_proof); + nullifierHashes[_nullifierHash] = true; + } _; } @@ -38,15 +44,12 @@ contract Wallet { require(msg.sender != address(0), "zero address found"); owner = _owner; worldID = IWorldID(_worldID); - usdt = IERC20(_usdt); supportedTokens[_usdt] = true; // Add USDT as a default supported token } - function createWorldId() external onlyOwner {} - - function transfer(address _recipient, address _token, uint256 _amount) - external - onlyVerified(msg.sender) + function transfer(address _recipient, address _token, uint256 _amount, bytes calldata _zkProof) + external + onlyValidProof(_zkProof) onlySupportedToken(_token) { require(_recipient != address(0), "Zero address detected"); diff --git a/test/Transfer.t.sol b/test/Transfer.t.sol index a5f5083..b27b019 100644 --- a/test/Transfer.t.sol +++ b/test/Transfer.t.sol @@ -6,19 +6,7 @@ import "../src/interfaces/IERC20.sol"; import "../src/Wallet.sol"; import "../src/IWorldID.sol"; - -// MockWorldID contract -contract MockWorldIDContract is IWorldID { - mapping(address => bool) public verifiedUsers; - - function verifyIdentity(address user) external view override returns (bool) { - return verifiedUsers[user]; - } - - function setVerified(address user, bool verified) public { - verifiedUsers[user] = verified; - } -} +import {MockWorldID} from "test/mocks/MockWorldID.sol"; // MockERC20Token contract contract MockERC20Token is IERC20 { @@ -83,7 +71,7 @@ contract MockERC20Token is IERC20 { contract USDTTransferTest is Test { Wallet private transferContract; MockERC20Token private mockUSDT; - MockWorldIDContract private mockWorldID; + MockWorldID private mockWorldID; address private user1 = address(0x1); address private user2 = address(0x2); uint256 private initialBalance = 1000e18; @@ -91,9 +79,10 @@ contract USDTTransferTest is Test { // Setup function to deploy mocks and the transfer contract function setUp() public { - mockWorldID = new MockWorldIDContract(); + mockWorldID = new MockWorldID(); + mockUSDT = new MockERC20Token(); - transferContract = new Wallet(address(_owner), address(mockWorldID), address(mockUSDT)); + transferContract = new Wallet(_owner, address(mockWorldID), address(mockUSDT)); // Mint some USDT for user1 mockUSDT.mint(user1, initialBalance); @@ -101,16 +90,13 @@ contract USDTTransferTest is Test { // Approve transferContract to spend user1's tokens vm.prank(user1); mockUSDT.approve(address(transferContract), initialBalance); - - // Set World ID verification for user1 - mockWorldID.setVerified(user1, true); } // Test constructor - Verifying contract state after deployment function testConstructor() public view { assertEq(transferContract.owner(), address(this)); assertEq(address(transferContract.worldID()), address(mockWorldID)); - assertEq(address(transferContract.usdt()), address(mockUSDT)); + assertTrue(transferContract.supportedTokens(address(mockUSDT))); } // Test transfer function - Successful transfer @@ -118,10 +104,11 @@ contract USDTTransferTest is Test { uint256 transferAmount = 100e18; uint256 user1BalanceBefore = mockUSDT.balanceOf(user1); uint256 user2BalanceBefore = mockUSDT.balanceOf(user2); + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); // Perform the transfer vm.prank(user1); - transferContract.transfer(user2, address(mockUSDT), transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount, validZkProof); // Assert balances after transfer assertEq(mockUSDT.balanceOf(user1), user1BalanceBefore - transferAmount); @@ -131,39 +118,43 @@ contract USDTTransferTest is Test { // Test transfer function - Insufficient balance function testTransferInsufficientBalance() public { uint256 transferAmount = initialBalance + 1e18; + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); vm.prank(user1); vm.expectRevert("Insufficient balance"); - transferContract.transfer(user2, address(mockUSDT), transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount, validZkProof); } // Test transfer function - Unverified user function testTransferUnverifiedUser() public { uint256 transferAmount = 100e18; + bytes memory invalidZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), false); - // Set user1 as unverified - mockWorldID.setVerified(user1, false); - + // Expect revert with specific reason + vm.expectRevert("NonExistentRoot or ExpiredRoot"); vm.prank(user1); - vm.expectRevert("User not verified"); - transferContract.transfer(user2, address(mockUSDT), transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount, invalidZkProof); } - // Test transfer function - Transfer to unverified recipient - function testTransferToUnverifiedRecipient() public { + // Test transfer function - Transfer with used proof should revert + function testTransferWithUsedProofReverts() public { uint256 transferAmount = 100e18; + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); - mockWorldID.setVerified(user2, false); - - vm.expectRevert("User not verified"); + vm.prank(user1); + transferContract.transfer(user2, address(mockUSDT), transferAmount, validZkProof); - transferContract.transfer(user2, address(mockUSDT), transferAmount); + // Expect revert with specific reason + vm.expectRevert("Invalid Nullifier"); + vm.prank(user1); + transferContract.transfer(user2, address(mockUSDT), transferAmount, validZkProof); } // Test transfer function - Multiple consecutive transfers function testMultipleConsecutiveTransfers() public { uint256 transferAmount1 = 100e18; uint256 transferAmount2 = 50e18; + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); // User1's initial balance uint256 user1BalanceBefore = mockUSDT.balanceOf(user1); @@ -171,7 +162,7 @@ contract USDTTransferTest is Test { // First transfer vm.prank(user1); - transferContract.transfer(user2, address(mockUSDT), transferAmount1); + transferContract.transfer(user2, address(mockUSDT), transferAmount1, validZkProof); // Assert balances after first transfer assertEq(mockUSDT.balanceOf(user1), user1BalanceBefore - transferAmount1); @@ -181,8 +172,9 @@ contract USDTTransferTest is Test { user1BalanceBefore = mockUSDT.balanceOf(user1); user2BalanceBefore = mockUSDT.balanceOf(user2); + validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some other seed")), true); vm.prank(user1); - transferContract.transfer(user2, address(mockUSDT), transferAmount2); + transferContract.transfer(user2, address(mockUSDT), transferAmount2, validZkProof); assertEq(mockUSDT.balanceOf(user1), user1BalanceBefore - transferAmount2); assertEq(mockUSDT.balanceOf(user2), user2BalanceBefore + transferAmount2); @@ -191,10 +183,11 @@ contract USDTTransferTest is Test { // Test transfer function - Zero amount transfer function testTransferZeroAmount() public { uint256 transferAmount = 0; + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); vm.prank(user1); vm.expectRevert("Transfer amount must be greater than zero"); - transferContract.transfer(user2, address(mockUSDT), transferAmount); + transferContract.transfer(user2, address(mockUSDT), transferAmount, validZkProof); } } diff --git a/test/Wallet.t.sol b/test/Wallet.t.sol index c9f0f09..5769678 100644 --- a/test/Wallet.t.sol +++ b/test/Wallet.t.sol @@ -19,7 +19,7 @@ contract WalletTest is Test { address public nonOwner; MockERC20 public usdt; MockERC20 public anotherToken; - MockWorldID public worldID; + MockWorldID public mockWorldID; /// @notice Set up the test environment before each test function setUp() public { @@ -30,10 +30,10 @@ contract WalletTest is Test { usdt = new MockERC20("USDT", "USDT"); anotherToken = new MockERC20("Another Token", "ATKN"); - worldID = new MockWorldID(); + mockWorldID = new MockWorldID(); factory = new WalletFactory(); - (wallet,) = factory.createWallet(address(worldID), address(usdt)); + (wallet,) = factory.createWallet(address(mockWorldID), address(usdt)); // Fund users usdt.mint(user1, 1000); @@ -45,15 +45,14 @@ contract WalletTest is Test { vm.prank(user2); usdt.approve(address(wallet), type(uint256).max); - // Set users as verified in MockWorldID - worldID.setVerified(user1, true); - worldID.setVerified(user2, true); } /// @notice Test recording a single transaction function testRecordSingleTransaction() public { + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); + vm.prank(user1); - wallet.transfer(user2, address(usdt), 100); + wallet.transfer(user2, address(usdt), 100, validZkProof); vm.prank(user1); Wallet.Transaction[] memory history = wallet.getTransactionHistory(user1); @@ -64,9 +63,12 @@ contract WalletTest is Test { /// @notice Test recording multiple transactions function testRecordMultipleTransactions() public { + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); vm.startPrank(user1); - wallet.transfer(user2, address(usdt), 100); - wallet.transfer(user2, address(usdt), 200); + wallet.transfer(user2, address(usdt), 100, validZkProof); + + validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some other seed")), true); + wallet.transfer(user2, address(usdt), 200, validZkProof); Wallet.Transaction[] memory history = wallet.getTransactionHistory(user1); vm.stopPrank(); @@ -80,14 +82,17 @@ contract WalletTest is Test { /// @notice Test recording transactions for different users function testRecordTransactionsForDifferentUsers() public { + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); vm.prank(user1); - wallet.transfer(user2, address(usdt), 100); + wallet.transfer(user2, address(usdt), 100, validZkProof); + validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some other seed")), true); vm.prank(user1); - wallet.transfer(user2, address(usdt), 50); + wallet.transfer(user2, address(usdt), 50, validZkProof); + validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some other one seed")), true); vm.prank(user2); - wallet.transfer(user1, address(usdt), 50); + wallet.transfer(user1, address(usdt), 50, validZkProof); vm.prank(user1); Wallet.Transaction[] memory user1History = wallet.getTransactionHistory(user1); @@ -103,12 +108,13 @@ contract WalletTest is Test { /// @notice Test recording a large amount transaction function testRecordLargeAmountTransaction() public { + bytes memory validZkProof = mockWorldID.generateZkProof(uint256(keccak256("some seed")), true); uint256 largeAmount = type(uint256).max / 2; // Use half of max to avoid overflow usdt.mint(user1, largeAmount); vm.startPrank(user1); usdt.approve(address(wallet), largeAmount); - wallet.transfer(user2, address(usdt), largeAmount); + wallet.transfer(user2, address(usdt), largeAmount, validZkProof); Wallet.Transaction[] memory history = wallet.getTransactionHistory(user1); vm.stopPrank(); diff --git a/test/WalletFactory.t.sol b/test/WalletFactory.t.sol index d3d2186..d860b0f 100644 --- a/test/WalletFactory.t.sol +++ b/test/WalletFactory.t.sol @@ -3,19 +3,19 @@ pragma solidity ^0.8.13; import {Test, console} from "forge-std/Test.sol"; import {WalletFactory} from "../src/WalletFactory.sol"; -import {MockERC20Token, MockWorldIDContract} from "./Transfer.t.sol"; // Import your MockERC20Token contract +import {MockERC20Token} from "./Transfer.t.sol"; // Import your MockERC20Token contract contract WalletFactoryTest is Test { WalletFactory public factory; MockERC20Token public mockUSDT; - MockWorldIDContract public mockWorldID; + address public mockWorldID = address(bytes20(keccak256("World ID Contract Address"))); function setUp() public { factory = new WalletFactory(); } function test_CreateWallet() public { - factory.createWallet(address(mockWorldID), address(mockUSDT)); + factory.createWallet(mockWorldID, address(mockUSDT)); assertEq(factory.getWalletClones().length, 1); } diff --git a/test/mocks/MockWorldID.sol b/test/mocks/MockWorldID.sol index c08ce50..6b7f71a 100644 --- a/test/mocks/MockWorldID.sol +++ b/test/mocks/MockWorldID.sol @@ -4,13 +4,48 @@ pragma solidity ^0.8.13; import "../../src/IWorldID.sol"; contract MockWorldID is IWorldID { - mapping(address => bool) private _verified; + mapping(bytes32 => bool) private _validProofs; - function verifyIdentity(address user) external view override returns (bool) { - return _verified[user]; + function verifyProof( + uint256 root, + uint256 signalHash, + uint256 nullifierHash, + uint256 externalNullifierHash, + uint256[8] calldata proof + ) external view override { + bytes32 proofHash = keccak256( + abi.encode( + root, + signalHash, + nullifierHash, + externalNullifierHash, + proof + ) + ); + if (!_validProofs[proofHash]) revert("NonExistentRoot or ExpiredRoot"); } - function setVerified(address user, bool status) external { - _verified[user] = status; + function generateZkProof(uint256 seed, bool set) external returns (bytes memory zkProof) { + uint256 root = uint256(keccak256(abi.encodePacked(seed, "root"))); + uint256 signalHash = uint256(keccak256(abi.encodePacked(seed, "signalHash"))); + uint256 nullifierHash = uint256(keccak256(abi.encodePacked(seed, "nullifierHash"))); + uint256 externalNullifierHash = uint256(keccak256(abi.encodePacked(seed, "externalNullifierHash"))); + uint256[8] memory proof; + for (uint256 i = 0; i < 8; i++) { + proof[i] = uint256(keccak256(abi.encodePacked(seed, i))); + } + if(set) { + bytes32 proofHash = keccak256( + abi.encode( + root, + signalHash, + nullifierHash, + externalNullifierHash, + proof + ) + ); + _validProofs[proofHash] = true; + } + return abi.encode(root, signalHash, nullifierHash, externalNullifierHash, proof); } }