[Fix] I-01 Use safe casting for the leaf index verification (#222)

* Use safe casting for the leaf index verification

* use return variable for the cast uint32
This commit is contained in:
The Dark Jester
2024-10-31 09:38:06 -07:00
committed by GitHub
parent 9a588d80bb
commit 4a9b4ec096
4 changed files with 61 additions and 3 deletions

View File

@@ -11,6 +11,12 @@ import { Utils } from "../../lib/Utils.sol";
library SparseMerkleTreeVerifier {
using Utils for *;
/**
* @dev Value doesn't fit in a uint of `bits` size.
* @dev This is based on OpenZeppelin's SafeCast library.
*/
error SafeCastOverflowedUintDowncast(uint8 bits, uint256 value);
/**
* @dev Custom error for when the leaf index is out of bounds.
*/
@@ -30,7 +36,7 @@ library SparseMerkleTreeVerifier {
uint32 _leafIndex,
bytes32 _root
) internal pure returns (bool) {
uint32 maxAllowedIndex = uint32((2 ** _proof.length) - 1);
uint32 maxAllowedIndex = safeCastToUint32((2 ** _proof.length) - 1);
if (_leafIndex > maxAllowedIndex) {
revert LeafIndexOutOfBounds(_leafIndex, maxAllowedIndex);
}
@@ -46,4 +52,17 @@ library SparseMerkleTreeVerifier {
}
return node == _root;
}
/**
* @notice Tries to safely cast to uint32.
* @param _value The value being cast to uint32.
* @return castUint32 Returns a uint32 safely cast.
* @dev This is based on OpenZeppelin's SafeCast library.
*/
function safeCastToUint32(uint256 _value) internal pure returns (uint32 castUint32) {
if (_value > type(uint32).max) {
revert SafeCastOverflowedUintDowncast(32, _value);
}
castUint32 = uint32(_value);
}
}

View File

@@ -21,6 +21,10 @@ contract TestSparseMerkleTreeVerifier {
return Utils._efficientKeccak(_left, _right);
}
function testSafeCastToUint32(uint256 _value) external pure returns (uint32) {
return SparseMerkleTreeVerifier.safeCastToUint32(_value);
}
function getLeafHash(
address _from,
address _to,

View File

@@ -1,5 +1,7 @@
import { ethers } from "hardhat";
export const MAX_UINT32 = BigInt(2 ** 32 - 1);
export const MAX_UINT33 = BigInt(2 ** 33 - 1);
export const HASH_ZERO = ethers.ZeroHash;
export const ADDRESS_ZERO = ethers.ZeroAddress;
export const HASH_WITHOUT_ZERO_FIRST_BYTE = "0xf887bbc07b0e849fb625aafadf4cb6b65b98e492fbb689705312bf1db98ead7f";

View File

@@ -3,9 +3,9 @@ import { loadFixture } from "@nomicfoundation/hardhat-network-helpers";
import { expect } from "chai";
import { ethers } from "hardhat";
import { TestSparseMerkleTreeVerifier } from "../../../typechain-types";
import { MESSAGE_FEE, MESSAGE_VALUE_1ETH } from "../../common/constants";
import { MAX_UINT32, MAX_UINT33, MESSAGE_FEE, MESSAGE_VALUE_1ETH } from "../../common/constants";
import { deployFromFactory } from "../../common/deployment";
import { expectRevertWithCustomError } from "contracts/test/common/helpers";
import { expectRevertWithCustomError, generateRandomBytes, range } from "contracts/test/common/helpers";
describe("SparseMerkleTreeVerifier", () => {
let sparseMerkleTreeVerifier: TestSparseMerkleTreeVerifier;
@@ -180,5 +180,38 @@ describe("SparseMerkleTreeVerifier", () => {
[invalidLeafIndex, 2 ** proof.length - 1],
);
});
it("Should revert if casting to a value higher than 32 bit uint", async () => {
const testValue = MAX_UINT32 + 1n;
await expectRevertWithCustomError(
sparseMerkleTreeVerifier,
sparseMerkleTreeVerifier.testSafeCastToUint32(testValue),
"SafeCastOverflowedUintDowncast",
[32, testValue],
);
});
it("Should revert if proof length results in casting to a value higher than 32 bit uint", async () => {
const proof = range(0, 32).map(() => generateRandomBytes(32));
await expectRevertWithCustomError(
sparseMerkleTreeVerifier,
sparseMerkleTreeVerifier.verifyMerkleProof(generateRandomBytes(32), proof, 25, generateRandomBytes(32)),
"SafeCastOverflowedUintDowncast",
[32, MAX_UINT33],
);
});
it("Should cast if casting to max uint 32", async () => {
const testValue = MAX_UINT32;
expect(await sparseMerkleTreeVerifier.testSafeCastToUint32(testValue)).to.equal(testValue);
});
it("Should cast if casting to lower than max uint 32", async () => {
const testValue = MAX_UINT32 - 1n;
expect(await sparseMerkleTreeVerifier.testSafeCastToUint32(testValue)).to.equal(testValue);
});
});
});