import { bufferToHex, keccak256 } from "ethereumjs-util"; export default class MerkleTree { private readonly elements: Buffer[]; private readonly bufferElementPositionIndex: { [hexElement: string]: number }; private readonly layers: Buffer[][]; constructor(elements: Buffer[]) { this.elements = [...elements]; // Sort elements this.elements.sort(Buffer.compare); // Deduplicate elements this.elements = MerkleTree.bufDedup(this.elements); this.bufferElementPositionIndex = this.elements.reduce<{ [hexElement: string]: number }>((memo, el, index) => { memo[bufferToHex(el)] = index; return memo; }, {}); // Create layers this.layers = this.getLayers(this.elements); } getLayers(elements: Buffer[]): Buffer[][] { if (elements.length === 0) { throw new Error("empty tree"); } const layers = []; layers.push(elements); // Get next layer until we reach the root while (layers[layers.length - 1].length > 1) { layers.push(this.getNextLayer(layers[layers.length - 1])); } return layers; } getNextLayer(elements: Buffer[]): Buffer[] { return elements.reduce((layer, el, idx, arr) => { if (idx % 2 === 0) { // Hash the current element with its pair element layer.push(MerkleTree.combinedHash(el, arr[idx + 1])); } return layer; }, []); } static combinedHash(first: Buffer, second: Buffer): Buffer { if (!first) { return second; } if (!second) { return first; } return keccak256(MerkleTree.sortAndConcat(first, second)); } getRoot(): Buffer { return this.layers[this.layers.length - 1][0]; } getHexRoot(): string { return bufferToHex(this.getRoot()); } getProof(el: Buffer) { let idx = this.bufferElementPositionIndex[bufferToHex(el)]; if (typeof idx !== "number") { throw new Error("Element does not exist in Merkle tree"); } return this.layers.reduce((proof, layer) => { const pairElement = MerkleTree.getPairElement(idx, layer); if (pairElement) { proof.push(pairElement); } idx = Math.floor(idx / 2); return proof; }, []); } getHexProof(el: Buffer): string[] { const proof = this.getProof(el); return MerkleTree.bufArrToHexArr(proof); } private static getPairElement(idx: number, layer: Buffer[]): Buffer | null { const pairIdx = idx % 2 === 0 ? idx + 1 : idx - 1; if (pairIdx < layer.length) { return layer[pairIdx]; } else { return null; } } private static bufDedup(elements: Buffer[]): Buffer[] { return elements.filter((el, idx) => { return idx === 0 || !elements[idx - 1].equals(el); }); } private static bufArrToHexArr(arr: Buffer[]): string[] { if (arr.some((el) => !Buffer.isBuffer(el))) { throw new Error("Array is not an array of buffers"); } return arr.map((el) => "0x" + el.toString("hex")); } private static sortAndConcat(...args: Buffer[]): Buffer { return Buffer.concat([...args].sort(Buffer.compare)); } }