Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions contracts/interfaces/IHasher.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity 0.8.27;

/**
* @dev IHasher. Interface for generating hashes. Specifically used for Merkle Tree hashing.
*/
interface IHasher {
/**
* @dev hash2. hashes two uint256 parameters and returns the resulting hash as uint256.
* @param params The parameters array of size 2 to be hashed.
* @return The resulting hash as uint256.
*/
function hash2(uint256[2] memory params) external pure returns (uint256);

/**
* @dev hash3. hashes three uint256 parameters and returns the resulting hash as uint256.
* @param params The parameters array of size 3 to be hashed.
* @return The resulting hash as uint256.
*/
function hash3(uint256[3] memory params) external pure returns (uint256);
}
41 changes: 31 additions & 10 deletions contracts/lib/SmtLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pragma solidity 0.8.27;

import {PoseidonUnit2L, PoseidonUnit3L} from "./Poseidon.sol";
import {ArrayUtils} from "./ArrayUtils.sol";
import {IHasher} from "../interfaces/IHasher.sol";

/// @title A sparse merkle tree implementation, which keeps tree history.
// Note that this SMT implementation can manage duplicated roots in the history,
Expand Down Expand Up @@ -53,7 +54,8 @@ library SmtLib {
// of the SMT library to add new Data struct fields without shifting down
// storage of upgradable contracts that use this struct as a state variable
// (see https://docs.openzeppelin.com/upgrades-plugins/1.x/writing-upgradeable#storage-gaps)
uint256[45] __gap;
uint256[44] __gap;
IHasher hasher;
}

/**
Expand Down Expand Up @@ -136,6 +138,16 @@ library SmtLib {
_;
}

/**
* @dev Sets custom hashers for the SMT. MUST be called before any other SMT operations.
* @param customHasher IHasher implementation to be used for hashing.
*/
function setHasher(Data storage self, IHasher customHasher) external {
require(address(customHasher) != address(0), "Invalid hasher");
require(self.rootEntries.length == 1, "Hasher must be set before SMT usage");
self.hasher = customHasher;
}

/**
* @dev Add a leaf to the SMT
* @param i Index of a leaf
Expand Down Expand Up @@ -526,16 +538,16 @@ library SmtLib {
if (newLeafBitAtDepth) {
newNodeMiddle = Node({
nodeType: NodeType.MIDDLE,
childLeft: _getNodeHash(oldLeaf),
childRight: _getNodeHash(newLeaf),
childLeft: _getNodeHash(self, oldLeaf),
childRight: _getNodeHash(self, newLeaf),
index: 0,
value: 0
});
} else {
newNodeMiddle = Node({
nodeType: NodeType.MIDDLE,
childLeft: _getNodeHash(newLeaf),
childRight: _getNodeHash(oldLeaf),
childLeft: _getNodeHash(self, newLeaf),
childRight: _getNodeHash(self, oldLeaf),
index: 0,
value: 0
});
Expand All @@ -546,7 +558,7 @@ library SmtLib {
}

function _addNode(Data storage self, Node memory node) internal returns (uint256) {
uint256 nodeHash = _getNodeHash(node);
uint256 nodeHash = _getNodeHash(self, node);
// We don't have any guarantees if the hash function attached is good enough.
// So, if the node hash already exists, we need to check
// if the node in the tree exactly matches the one we are trying to add.
Expand All @@ -563,13 +575,22 @@ library SmtLib {
return nodeHash;
}

function _getNodeHash(Node memory node) internal pure returns (uint256) {
function _getNodeHash(Data storage self, Node memory node) internal view returns (uint256) {
uint256 nodeHash = 0;
if (node.nodeType == NodeType.LEAF) {
uint256[3] memory params = [node.index, node.value, uint256(1)];
nodeHash = PoseidonUnit3L.poseidon(params);
if (address(self.hasher) != address(0)) {
uint256[3] memory params = [node.index, node.value, uint256(1)];
nodeHash = self.hasher.hash3(params);
} else {
uint256[3] memory params = [node.index, node.value, uint256(1)];
nodeHash = PoseidonUnit3L.poseidon(params);
}
} else if (node.nodeType == NodeType.MIDDLE) {
nodeHash = PoseidonUnit2L.poseidon([node.childLeft, node.childRight]);
if (address(self.hasher) != address(0)) {
nodeHash = self.hasher.hash2([node.childLeft, node.childRight]);
} else {
nodeHash = PoseidonUnit2L.poseidon([node.childLeft, node.childRight]);
}
}
return nodeHash; // Note: expected to return 0 if NodeType.EMPTY, which is the only option left
}
Expand Down
17 changes: 17 additions & 0 deletions contracts/lib/hash/KeccakHasher.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity 0.8.27;

import {IHasher} from "../../interfaces/IHasher.sol";

/// @title A IHasher implementation using Keccak256.
contract Keccak256Hasher is IHasher {
function hash2(uint256[2] memory params) external pure override returns (uint256) {
bytes memory encoded = abi.encode(params);
return uint256(keccak256(encoded));
}

function hash3(uint256[3] memory params) external pure override returns (uint256) {
bytes memory encoded = abi.encode(params);
return uint256(keccak256(encoded));
}
}
14 changes: 14 additions & 0 deletions contracts/test-helpers/SmtLibKeccakTestWrapper.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity 0.8.27;

import {SmtLib} from "../lib/SmtLib.sol";
import {SmtLibTestWrapper} from "./SmtLibTestWrapper.sol";
import {Keccak256Hasher} from "../lib/hash/KeccakHasher.sol";

contract SmtLibKeccakTestWrapper is SmtLibTestWrapper {
using SmtLib for SmtLib.Data;

constructor(uint256 maxDepth) SmtLibTestWrapper(maxDepth) {
smtData.setHasher(new Keccak256Hasher());
}
}
7 changes: 5 additions & 2 deletions helpers/DeployHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,11 @@ export class DeployHelper {
return stateLib;
}

async deploySmtLibTestWrapper(maxDepth: number = SMT_MAX_DEPTH): Promise<Contract> {
const contractName = "SmtLibTestWrapper";
async deploySmtLibTestWrapper(
maxDepth: number = SMT_MAX_DEPTH,
useKeccakHashing: boolean = false,
): Promise<Contract> {
const contractName = useKeccakHashing ? "SmtLibKeccakTestWrapper" : "SmtLibTestWrapper";

this.log("deploying poseidons...");
const [poseidon2Elements, poseidon3Elements] = await deployPoseidons([2, 3]);
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
"lint:contracts": "npx solhint contracts/**/*.sol",
"prettier:contracts": "prettier --write --plugin=prettier-plugin-solidity 'contracts/**/*.sol'",
"slither": "slither .",
"postinstall": "patch-package"
"postinstall": "npx patch-package"
},
"overrides": {
"ws": "^8.17.1",
Expand Down
Loading