diff --git a/src/Utils.sol b/src/Utils.sol index 8da8174..4edc00e 100644 --- a/src/Utils.sol +++ b/src/Utils.sol @@ -1,6 +1,7 @@ - // SPDX-License-Identifier: Apache 2 -pragma solidity ^0.8.19; +pragma solidity ^0.8.4; + +import { WORD_SIZE, SCRATCH_SPACE_PTR, FREE_MEMORY_PTR } from "./constants/Common.sol"; error NotAnEvmAddress(bytes32); @@ -12,7 +13,8 @@ function fromUniversalAddress(bytes32 universalAddr) pure returns (address addr) if (bytes12(universalAddr) != 0) revert NotAnEvmAddress(universalAddr); - assembly ("memory-safe") { + /// @solidity memory-safe-assembly + assembly { addr := universalAddr } } @@ -22,21 +24,56 @@ function fromUniversalAddress(bytes32 universalAddr) pure returns (address addr) * Meant to be used to easily bubble up errors from low level calls when they fail. */ function reRevert(bytes memory err) pure { - assembly ("memory-safe") { + /// @solidity memory-safe-assembly + assembly { revert(add(err, 32), mload(err)) } } //see Optimization.md for rationale on avoiding short-circuiting function eagerAnd(bool lhs, bool rhs) pure returns (bool ret) { - assembly ("memory-safe") { + /// @solidity memory-safe-assembly + assembly { ret := and(lhs, rhs) } } //see Optimization.md for rationale on avoiding short-circuiting function eagerOr(bool lhs, bool rhs) pure returns (bool ret) { - assembly ("memory-safe") { + /// @solidity memory-safe-assembly + assembly { ret := or(lhs, rhs) } } + +function keccak256Word(bytes32 word) pure returns (bytes32 hash) { + /// @solidity memory-safe-assembly + assembly { + mstore(SCRATCH_SPACE_PTR, word) + hash := keccak256(SCRATCH_SPACE_PTR, WORD_SIZE) + } +} + +function keccak256SliceUnchecked( + bytes memory encoded, + uint offset, + uint length +) pure returns (bytes32 hash) { + /// @solidity memory-safe-assembly + assembly { + // The length of the bytes type `length` field is that of a word in memory + let ptr := add(add(encoded, offset), WORD_SIZE) + hash := keccak256(ptr, length) + } +} + +function keccak256Cd( + bytes calldata encoded +) pure returns (bytes32 hash) { + /// @solidity memory-safe-assembly + assembly { + let freeMemory := mload(FREE_MEMORY_PTR) + calldatacopy(freeMemory, encoded.offset, encoded.length) + hash := keccak256(freeMemory, encoded.length) + } +} diff --git a/src/constants/Common.sol b/src/constants/Common.sol index 27f19d9..3f5f9f4 100644 --- a/src/constants/Common.sol +++ b/src/constants/Common.sol @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache 2 pragma solidity ^0.8.4; +//see https://docs.soliditylang.org/en/v0.8.4/internals/layout_in_memory.html +uint256 constant SCRATCH_SPACE_PTR = 0x00; uint256 constant FREE_MEMORY_PTR = 0x40; uint256 constant WORD_SIZE = 32; //we can't define _WORD_SIZE_MINUS_ONE via _WORD_SIZE - 1 because of solc restrictions diff --git a/src/libraries/BytesParsing.sol b/src/libraries/BytesParsing.sol index d794ca5..84643d6 100644 --- a/src/libraries/BytesParsing.sol +++ b/src/libraries/BytesParsing.sol @@ -9,9 +9,14 @@ library BytesParsing { error LengthMismatch(uint256 encodedLength, uint256 expectedLength); error InvalidBoolVal(uint8 val); - function checkBound(uint offset, uint length) internal pure { - if (offset > length) - revert OutOfBounds(offset, length); + /** + * Implements runtime check of logic that accesses memory. + * @param pastTheEndOffset The offset past the end relative to the accessed memory fragment. + * @param length The length of the memory fragment accessed. + */ + function checkBound(uint pastTheEndOffset, uint length) internal pure { + if (pastTheEndOffset > length) + revert OutOfBounds(pastTheEndOffset, length); } function checkLength(uint encodedLength, uint expectedLength) internal pure { diff --git a/test/Keccak.t.sol b/test/Keccak.t.sol new file mode 100644 index 0000000..d232233 --- /dev/null +++ b/test/Keccak.t.sol @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: Apache 2 + +// forge test --match-contract TestKeccak + +pragma solidity ^0.8.0; + +import "forge-std/Test.sol"; +import { keccak256Word, keccak256SliceUnchecked, keccak256Cd } from "../src/Utils.sol"; + +contract TestKeccak is Test { + using { keccak256Word } for bytes32; + using { keccak256SliceUnchecked } for bytes; + + function test_bytesShouldHashTheSame(bytes calldata data) public { + bytes32 hash = data.keccak256SliceUnchecked(0, data.length); + bytes32 hashCd = keccak256Cd(data); + bytes32 expectedHash = keccak256(abi.encodePacked(data)); + assertEq(hash, expectedHash); + assertEq(hashCd, expectedHash); + } + + function test_bytesSubArrayEndShouldHashTheSame(bytes calldata data, uint seed) public { + vm.assume(data.length > 0); + uint length = seed % data.length; + bytes calldata slice = data[0 : length]; + + bytes32 hash = data.keccak256SliceUnchecked(0, length); + bytes32 hashCd = keccak256Cd(slice); + + bytes32 expectedHash = keccak256(abi.encodePacked(slice)); + assertEq(hash, expectedHash); + assertEq(hashCd, expectedHash); + } + + function test_bytesSubArrayStartShouldHashTheSame(bytes calldata data, uint seed) public { + vm.assume(data.length > 0); + uint start = seed % data.length; + bytes calldata slice = data[start : data.length]; + + bytes32 hash = data.keccak256SliceUnchecked(start, data.length - start); + bytes32 hashCd = keccak256Cd(slice); + + bytes32 expectedHash = keccak256(abi.encodePacked(slice)); + assertEq(hash, expectedHash); + assertEq(hashCd, expectedHash); + } + + function test_bytesSubArrayStartEndShouldHashTheSame(bytes calldata data, uint seed) public { + vm.assume(data.length > 0); + uint end = bound(seed, 1, data.length); + uint start = uint(keccak256(abi.encodePacked(seed))) % end; + bytes calldata slice = data[start : end]; + + bytes32 hash = data.keccak256SliceUnchecked(start, end - start); + bytes32 hashCd = keccak256Cd(slice); + + bytes32 expectedHash = keccak256(abi.encodePacked(slice)); + assertEq(hash, expectedHash); + assertEq(hashCd, expectedHash); + } + + function test_wordShouldHashTheSame(bytes32 data) public { + bytes32 hash = data.keccak256Word(); + assertEq(hash, keccak256(abi.encodePacked(data))); + } +} \ No newline at end of file diff --git a/test/generated/BytesParsingTestWrapper.sol b/test/generated/BytesParsingTestWrapper.sol index a246dc3..c4586d7 100644 --- a/test/generated/BytesParsingTestWrapper.sol +++ b/test/generated/BytesParsingTestWrapper.sol @@ -6,8 +6,8 @@ import "wormhole-sdk/libraries/BytesParsing.sol"; // This file was auto-generated by wormhole-solidity-sdk gen/libraryTestWrapper.ts contract BytesParsingTestWrapper { - function checkBound(uint offset, uint length) external pure { - BytesParsing.checkBound(offset, length); + function checkBound(uint pastTheEndOffset, uint length) external pure { + BytesParsing.checkBound(pastTheEndOffset, length); } function checkLength(uint encodedLength, uint expectedLength) external pure {