From 6e19d7db3b2d7fdd2d8ac0dddf8e1ae4722f53f9 Mon Sep 17 00:00:00 2001 From: Yi Sun Date: Mon, 14 Feb 2022 00:52:52 -0600 Subject: [PATCH] input_gen: Generate proofs for punks --- scripts/input_gen/generate_punk_inputs.sh | 6 ++ .../generate_storage_proof_inputs.py | 69 +++++++++++-------- scripts/input_gen/generate_tx_inputs.sh | 7 ++ scripts/input_gen/generate_tx_proof_inputs.py | 65 ++++++++++------- 4 files changed, 91 insertions(+), 56 deletions(-) create mode 100755 scripts/input_gen/generate_punk_inputs.sh create mode 100755 scripts/input_gen/generate_tx_inputs.sh diff --git a/scripts/input_gen/generate_punk_inputs.sh b/scripts/input_gen/generate_punk_inputs.sh new file mode 100755 index 0000000..604e2fc --- /dev/null +++ b/scripts/input_gen/generate_punk_inputs.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +for PUNK_IDX in {0..9999} +do + python generate_storage_proof_inputs.py --storage --punk_slot "$PUNK_IDX" --storage_file_str inputs/input_punk_pf"$PUNK_IDX".json +done diff --git a/scripts/input_gen/generate_storage_proof_inputs.py b/scripts/input_gen/generate_storage_proof_inputs.py index 4dc899b..bb5618f 100644 --- a/scripts/input_gen/generate_storage_proof_inputs.py +++ b/scripts/input_gen/generate_storage_proof_inputs.py @@ -27,7 +27,7 @@ def serialize_hex(val_hex): val_arr = [int(nib, 16) for nib in val_hex] return val_arr -def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None): +def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None, debug=False): LEAF_RLP_HEXS_LEN = 74 + maxValueHexLen NODE_RLP_HEXS_LEN = 1064 EXT_RLP_HEXS_LEN = 4 + 2 + 64 + 2 + 64 @@ -63,7 +63,8 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None): nhash = keccak256(node[2:]) node_decode = mpt.node.Node.decode(bytearray.fromhex(node[2:])) - print(idx, node_decode, nhash, node) + if debug: + print(idx, node_decode, nhash, node) if type(node_decode) is mpt.node.Node.Leaf: rlp_prefix = node[2:4] curr_idx = 4 @@ -117,9 +118,10 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None): leafRlpHexs = serialize_hex(node[2:]) leafRlpHexs = leafRlpHexs + [0 for x in range(LEAF_RLP_HEXS_LEN - len(leafRlpHexs))] - - print(idx, 'Leaf', node_decode.path._data.hex(), node_decode.data.hex()) - print(node_decode.path) + + if debug: + print(idx, 'Leaf', node_decode.path._data.hex(), node_decode.data.hex()) + print(node_decode.path) elif type(node_decode) is mpt.node.Node.Branch: rlp_prefix = node[2:4] curr_idx = 4 @@ -144,9 +146,10 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None): nodeRlpHexs.append(node_rlp) nodeTypes.append(0) - print(idx, 'Branch', nhash, node_decode.encode().hex(), node_decode.data.hex()) - for idx2, b in enumerate(node_decode.branches): - print(idx, 'Branch', idx2, b.hex()) + if debug: + print(idx, 'Branch', nhash, node_decode.encode().hex(), node_decode.data.hex()) + for idx2, b in enumerate(node_decode.branches): + print(idx, 'Branch', idx2, b.hex()) elif type(node_decode) is mpt.node.Node.Extension: rlp_prefix = node[2:4] curr_idx = 4 @@ -203,12 +206,12 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None): nodeRefHexLen.append(temp) temp = serialize_hex(node[2:]) - print('LENGTH: {}'.format(len(temp))) temp = temp + [0 for idx in range(NODE_RLP_HEXS_LEN - len(temp))] nodeRlpHexs.append(temp) nodeTypes.append(1) - print(idx, 'Extension', nhash, node_decode.encode().hex()) - print(idx, 'Extension', node_decode.path._data.hex(), node_decode.next_ref.hex()) + if debug: + print(idx, 'Extension', nhash, node_decode.encode().hex()) + print(idx, 'Extension', node_decode.path._data.hex(), node_decode.next_ref.hex()) if maxDepth is not None: for idx in range(maxDepth - len(proof)): @@ -247,7 +250,7 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None): } return ret -def get_storage_pf(punk_pfs, slot=None): +def get_storage_pf(punk_pfs, slot=None, debug=False): punk_pf = None for x in punk_pfs['result']['storageProof']: if x['key'] == slot: @@ -256,18 +259,21 @@ def get_storage_pf(punk_pfs, slot=None): key = keccak256(punk_pf['key'][2:]) value = punk_pf['value'][2:] + if len(value) % 2 == 1: + value = '0' + value proof = punk_pf['proof'] root = punk_pfs['result']['storageHash'][2:] - print('addr: {}'.format(punk_pfs['result']['address'])) - print('stor root: {}'.format(root)) - print('key: {}'.format(key)) - print('value: {}'.format(value)) + if debug: + print('addr: {}'.format(punk_pfs['result']['address'])) + print('stor root: {}'.format(root)) + print('key: {}'.format(key)) + print('value: {}'.format(value)) - pf = gen_proof_input(proof, root, key, rlp.encode(bytearray.fromhex(value)).hex(), 114) + pf = gen_proof_input(proof, root, key, rlp.encode(bytearray.fromhex(value)).hex(), 114, debug=debug) return pf -def get_addr_pf(punk_pfs): +def get_addr_pf(punk_pfs, debug=False): acct_pf = punk_pfs['result']['accountProof'] key = keccak256(punk_pfs['result']['address'][2:]) nonce = punk_pfs['result']['nonce'][2:] @@ -279,25 +285,25 @@ def get_addr_pf(punk_pfs): int(balance, 16), bytearray.fromhex(storageHash), bytearray.fromhex(codeHash)]) - - print('key: {}'.format(key)) - print('value: {}'.format(addr_rlp.hex())) - print('nonce: {}'.format(nonce)) - print('balance: {}'.format(balance)) - print('storageHash: {}'.format(storageHash)) - print('codeHash: {}'.format(codeHash)) + if debug: + print('key: {}'.format(key)) + print('value: {}'.format(addr_rlp.hex())) + print('nonce: {}'.format(nonce)) + print('balance: {}'.format(balance)) + print('storageHash: {}'.format(storageHash)) + print('codeHash: {}'.format(codeHash)) - pf = gen_proof_input(acct_pf, keccak256(acct_pf[0][2:]), key, addr_rlp.hex(), 228) + pf = gen_proof_input(acct_pf, keccak256(acct_pf[0][2:]), key, addr_rlp.hex(), 228, debug=debug) return pf parser = argparse.ArgumentParser() parser.add_argument('--debug', action='store_true', default=False) parser.add_argument('--addr', action='store_true', default=False) -parser.add_argument('--addr_file_str', type=str, default='input_address_proof.json') +parser.add_argument('--addr_file_str', type=str, default='inputs/input_address_proof.json') parser.add_argument('--storage', action='store_true', default=False) -parser.add_argument('--storage_file_str', type=str, default='input_storage_proof.json') +parser.add_argument('--storage_file_str', type=str, default='inputs/input_storage_proof.json') parser.add_argument('--slot', type=int, default=10) parser.add_argument('--punk_slot', type=int, default=0) args = parser.parse_args() @@ -307,7 +313,7 @@ def main(): punk_block = json.loads(f.read()) if args.addr: - addr_pf = get_addr_pf(punk_block) + addr_pf = get_addr_pf(punk_block, debug=args.debug) pf_str = pprint.pformat(addr_pf, width=100, compact=True).replace("'", '"') with open(args.addr_file_str, 'w') as f: f.write(pf_str) @@ -335,7 +341,10 @@ def main(): y = ''.join(['0' for idx in range(64 - len(y))]) + y x = x + y slot = keccak256(x) - storage_pf = get_storage_pf(punk_block, slot=slot) + + storage_pf = get_storage_pf(punk_block, slot=slot, debug=args.debug) + print('Punk {:5} depth {:3}'.format(args.punk_slot, storage_pf['depth'])) + pf_str = pprint.pformat(storage_pf, width=100, compact=True).replace("'", '"') with open(args.storage_file_str, 'w') as f: f.write(pf_str) diff --git a/scripts/input_gen/generate_tx_inputs.sh b/scripts/input_gen/generate_tx_inputs.sh new file mode 100755 index 0000000..d3f3553 --- /dev/null +++ b/scripts/input_gen/generate_tx_inputs.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +for TX_IDX in {0..196} +do + python generate_tx_proof_inputs.py --tx_idx "$TX_IDX" --max_depth 5 --max_key_len 6 --max_val_len 234 --file_str inputs/input_tx_pf"$PUNK_IDX".json + echo "$TX_IDX" +done diff --git a/scripts/input_gen/generate_tx_proof_inputs.py b/scripts/input_gen/generate_tx_proof_inputs.py index 161d053..c3f19cd 100644 --- a/scripts/input_gen/generate_tx_proof_inputs.py +++ b/scripts/input_gen/generate_tx_proof_inputs.py @@ -33,7 +33,7 @@ def construct_mpt(d): trie.update(key, d[key]) return trie, storage -def get_proof(storage, proof, node, path): +def get_mpt_proof(storage, proof, node, path): if len(path) == 0: proof = proof + [node] return proof @@ -51,7 +51,7 @@ def get_proof(storage, proof, node, path): else: node = node.next_ref node = mpt.node.Node.decode(node) - return get_proof(storage, proof, node, rest_path) + return get_mpt_proof(storage, proof, node, rest_path) elif type(node) is mpt.node.Node.Branch: branch = node.branches[path.at(0)] proof = proof + [node] @@ -61,9 +61,9 @@ def get_proof(storage, proof, node, path): node = branch node = mpt.node.Node.decode(node) if len(branch) > 0: - return get_proof(storage, proof, node, path.consume(1)) + return get_mpt_proof(storage, proof, node, path.consume(1)) -def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHexLen): +def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHexLen, debug=False): maxNodeRefLen = 64 maxLeafRlpHexLen = 4 + (maxKeyHexLen + 2) + 4 + maxValueHexLen maxBranchRlpHexLen = 1064 + 2 + maxValueHexLen @@ -116,7 +116,8 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex lhash = keccak256(last.encode().hex()) if type(last) is mpt.node.Node.Branch: leafRlpHexs = [0 for idx in range(maxLeafRlpHexLen)] - print('Branch', len(last.data.hex())) + if debug: + print('Branch', len(last.data.hex())) isTerminalBranch = 1 last = last.encode().hex() @@ -155,11 +156,13 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex terminalBranchVtValueHexLen = 2 * int(str_len, 16) node_rlp = serialize_hex(last) - print('LENGTH: {}'.format(len(node_rlp))) + if debug: + print('LENGTH: {}'.format(len(node_rlp))) node_rlp = node_rlp + [0 for x in range(maxBranchRlpHexLen - len(node_rlp))] terminalBranchRlpHexs.append(node_rlp) else: - print('Leaf', len(last.data.hex()), last.encode().hex(), last.path) + if debug: + print('Leaf', len(last.data.hex()), last.encode().hex(), last.path) terminalBranchNodeRefHexLen = [0 for idx in range(16)] terminalBranchRlpHexs = [0 for idx in range(maxBranchRlpHexLen)] @@ -223,7 +226,8 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex leafValueLenHexLen = 2 * int(str_len, 16) leafRlpHexs = serialize_hex(last) - print('LENGTH: {}'.format(len(leafRlpHexs))) + if debug: + print('LENGTH: {}'.format(len(leafRlpHexs))) leafRlpHexs = leafRlpHexs + [0 for x in range(maxLeafRlpHexLen - len(leafRlpHexs))] for idx, node in enumerate(proof[:-1]): @@ -231,10 +235,12 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex node = node.encode().hex() node_decode = mpt.node.Node.decode(bytearray.fromhex(node)) - print(idx, node_decode, nhash, node) + if debug: + print(idx, node_decode, nhash, node) if type(node_decode) is mpt.node.Node.Branch: - for idx, x in enumerate(node_decode.branches): - print(idx, x.hex()) + if debug: + for idx, x in enumerate(node_decode.branches): + print(idx, x.hex()) rlp_prefix = node[:2] curr_idx = 2 if int(rlp_prefix, 16) <= int('f7', 16): @@ -273,14 +279,17 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex nodeVtValueHexLen.append(2 * int(str_len, 16)) node_rlp = serialize_hex(node) - print('LENGTH: {}'.format(len(node_rlp))) + if debug: + print('LENGTH: {}'.format(len(node_rlp))) node_rlp = node_rlp + [0 for x in range(maxBranchRlpHexLen - len(node_rlp))] nodeRlpHexs.append(node_rlp) nodeTypes.append(0) - print('Branch', len(node_decode.data.hex()), nhash, node_decode.encode().hex(), node_decode.data.hex()) + if debug: + print('Branch', len(node_decode.data.hex()), nhash, node_decode.encode().hex(), node_decode.data.hex()) elif type(node_decode) is mpt.node.Node.Extension: - print(node_decode.path) + if debug: + print(node_decode.path) rlp_prefix = node[:2] curr_idx = 2 if int(rlp_prefix, 16) <= int('f7', 16): @@ -313,8 +322,6 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex rlp_prefix = node[curr_idx: curr_idx + 2] curr_idx = curr_idx + 2 temp = [] - print('aa', rlp_prefix) - print(node[curr_idx:]) if int(rlp_prefix, 16) <= int('b7', 16): temp.append(2 * (int(rlp_prefix, 16) - int('80', 16))) curr_idx = curr_idx + temp[-1] @@ -325,7 +332,8 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex nodeVtValueHexLen.append(0) temp = serialize_hex(node) - print('LENGTH: {}'.format(len(temp))) + if debug: + print('LENGTH: {}'.format(len(temp))) temp = temp + [0 for idx in range(maxBranchRlpHexLen - len(temp))] nodeRlpHexs.append(temp) nodeTypes.append(1) @@ -377,7 +385,7 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex "depth": depth } return ret -def get_pf(block, tx_idx): +def get_pf(block, tx_idx, max_depth=None, max_key_len=64, max_val_len=234, debug=False): block = block['result'] block_hash = block['hash'] @@ -402,7 +410,8 @@ def get_pf(block, tx_idx): tx_list = block['transactions'] raw_tx_dict = {} - print('{} tx in block'.format(len(tx_list))) + if debug: + print('{} tx in block'.format(len(tx_list))) for idx, tx in enumerate(tx_list): if tx['type'] == '0x0': raw_tx = [int(tx['nonce'], 16), @@ -439,31 +448,35 @@ def get_pf(block, tx_idx): raw_tx_dict[rlp.encode(idx)] = bytearray.fromhex('02') + rlp.encode(raw_tx) else: print('type not handled: {}'.format(tx['type'])) - print(idx, rlp.encode(idx).hex(), tx['type'], tx['hash'], len(rlp.encode(raw_tx).hex())) + if debug: + print(idx, rlp.encode(idx).hex(), tx['type'], tx['hash'], len(rlp.encode(raw_tx).hex())) trie, storage = construct_mpt(raw_tx_dict) root = mpt.node.Node.decode(storage[trie._root]) path = mpt.nibble_path.NibblePath(rlp.encode(tx_idx)) - pf = get_proof(storage, [], root, path) - print(rlp.encode(tx_idx).hex(), pf) + + pf = get_mpt_proof(storage, [], root, path) value = trie.get(rlp.encode(tx_idx)).hex() - print(root) - ret = gen_proof_input(pf, keccak256(root.encode().hex()), rlp.encode(tx_idx).hex(), value, 6, 64, 234) + ret = gen_proof_input(pf, keccak256(root.encode().hex()), rlp.encode(tx_idx).hex(), value, max_depth, max_key_len, max_val_len, debug=debug) return ret parser = argparse.ArgumentParser() parser.add_argument('--debug', action='store_true', default=False) parser.add_argument('--tx_idx', type=int, default=2) -parser.add_argument('--file_str', type=str, default='input_tx_proof.json') +parser.add_argument('--file_str', type=str, default='inputs/input_tx_proof.json') +parser.add_argument('--max_depth', type=int, default=5) +parser.add_argument('--max_key_len', type=int, default=6) +parser.add_argument('--max_val_len', type=int, default=234) args = parser.parse_args() def main(): with open('punk_block.json', 'r') as f: block = json.loads(f.read()) - pf = get_pf(block, args.tx_idx) + pf = get_pf(block, args.tx_idx, max_depth=args.max_depth, max_key_len=args.max_key_len, max_val_len=args.max_val_len, debug=args.debug) + pf_str = pprint.pformat(pf, width=100, compact=True).replace("'", '"') with open(args.file_str, 'w') as f: f.write(pf_str)