input_gen: Generate proofs for punks

This commit is contained in:
Yi Sun
2022-02-14 00:52:52 -06:00
parent 7dfbdae986
commit 6e19d7db3b
4 changed files with 91 additions and 56 deletions

View File

@@ -0,0 +1,6 @@
#!/bin/bash
for PUNK_IDX in {0..9999}
do
python generate_storage_proof_inputs.py --storage --punk_slot "$PUNK_IDX" --storage_file_str inputs/input_punk_pf"$PUNK_IDX".json
done

View File

@@ -27,7 +27,7 @@ def serialize_hex(val_hex):
val_arr = [int(nib, 16) for nib in val_hex]
return val_arr
def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None):
def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None, debug=False):
LEAF_RLP_HEXS_LEN = 74 + maxValueHexLen
NODE_RLP_HEXS_LEN = 1064
EXT_RLP_HEXS_LEN = 4 + 2 + 64 + 2 + 64
@@ -63,7 +63,8 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None):
nhash = keccak256(node[2:])
node_decode = mpt.node.Node.decode(bytearray.fromhex(node[2:]))
print(idx, node_decode, nhash, node)
if debug:
print(idx, node_decode, nhash, node)
if type(node_decode) is mpt.node.Node.Leaf:
rlp_prefix = node[2:4]
curr_idx = 4
@@ -117,9 +118,10 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None):
leafRlpHexs = serialize_hex(node[2:])
leafRlpHexs = leafRlpHexs + [0 for x in range(LEAF_RLP_HEXS_LEN - len(leafRlpHexs))]
print(idx, 'Leaf', node_decode.path._data.hex(), node_decode.data.hex())
print(node_decode.path)
if debug:
print(idx, 'Leaf', node_decode.path._data.hex(), node_decode.data.hex())
print(node_decode.path)
elif type(node_decode) is mpt.node.Node.Branch:
rlp_prefix = node[2:4]
curr_idx = 4
@@ -144,9 +146,10 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None):
nodeRlpHexs.append(node_rlp)
nodeTypes.append(0)
print(idx, 'Branch', nhash, node_decode.encode().hex(), node_decode.data.hex())
for idx2, b in enumerate(node_decode.branches):
print(idx, 'Branch', idx2, b.hex())
if debug:
print(idx, 'Branch', nhash, node_decode.encode().hex(), node_decode.data.hex())
for idx2, b in enumerate(node_decode.branches):
print(idx, 'Branch', idx2, b.hex())
elif type(node_decode) is mpt.node.Node.Extension:
rlp_prefix = node[2:4]
curr_idx = 4
@@ -203,12 +206,12 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None):
nodeRefHexLen.append(temp)
temp = serialize_hex(node[2:])
print('LENGTH: {}'.format(len(temp)))
temp = temp + [0 for idx in range(NODE_RLP_HEXS_LEN - len(temp))]
nodeRlpHexs.append(temp)
nodeTypes.append(1)
print(idx, 'Extension', nhash, node_decode.encode().hex())
print(idx, 'Extension', node_decode.path._data.hex(), node_decode.next_ref.hex())
if debug:
print(idx, 'Extension', nhash, node_decode.encode().hex())
print(idx, 'Extension', node_decode.path._data.hex(), node_decode.next_ref.hex())
if maxDepth is not None:
for idx in range(maxDepth - len(proof)):
@@ -247,7 +250,7 @@ def gen_proof_input(proof, root, key, value, maxValueHexLen, maxDepth=None):
}
return ret
def get_storage_pf(punk_pfs, slot=None):
def get_storage_pf(punk_pfs, slot=None, debug=False):
punk_pf = None
for x in punk_pfs['result']['storageProof']:
if x['key'] == slot:
@@ -256,18 +259,21 @@ def get_storage_pf(punk_pfs, slot=None):
key = keccak256(punk_pf['key'][2:])
value = punk_pf['value'][2:]
if len(value) % 2 == 1:
value = '0' + value
proof = punk_pf['proof']
root = punk_pfs['result']['storageHash'][2:]
print('addr: {}'.format(punk_pfs['result']['address']))
print('stor root: {}'.format(root))
print('key: {}'.format(key))
print('value: {}'.format(value))
if debug:
print('addr: {}'.format(punk_pfs['result']['address']))
print('stor root: {}'.format(root))
print('key: {}'.format(key))
print('value: {}'.format(value))
pf = gen_proof_input(proof, root, key, rlp.encode(bytearray.fromhex(value)).hex(), 114)
pf = gen_proof_input(proof, root, key, rlp.encode(bytearray.fromhex(value)).hex(), 114, debug=debug)
return pf
def get_addr_pf(punk_pfs):
def get_addr_pf(punk_pfs, debug=False):
acct_pf = punk_pfs['result']['accountProof']
key = keccak256(punk_pfs['result']['address'][2:])
nonce = punk_pfs['result']['nonce'][2:]
@@ -279,25 +285,25 @@ def get_addr_pf(punk_pfs):
int(balance, 16),
bytearray.fromhex(storageHash),
bytearray.fromhex(codeHash)])
print('key: {}'.format(key))
print('value: {}'.format(addr_rlp.hex()))
print('nonce: {}'.format(nonce))
print('balance: {}'.format(balance))
print('storageHash: {}'.format(storageHash))
print('codeHash: {}'.format(codeHash))
if debug:
print('key: {}'.format(key))
print('value: {}'.format(addr_rlp.hex()))
print('nonce: {}'.format(nonce))
print('balance: {}'.format(balance))
print('storageHash: {}'.format(storageHash))
print('codeHash: {}'.format(codeHash))
pf = gen_proof_input(acct_pf, keccak256(acct_pf[0][2:]), key, addr_rlp.hex(), 228)
pf = gen_proof_input(acct_pf, keccak256(acct_pf[0][2:]), key, addr_rlp.hex(), 228, debug=debug)
return pf
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true', default=False)
parser.add_argument('--addr', action='store_true', default=False)
parser.add_argument('--addr_file_str', type=str, default='input_address_proof.json')
parser.add_argument('--addr_file_str', type=str, default='inputs/input_address_proof.json')
parser.add_argument('--storage', action='store_true', default=False)
parser.add_argument('--storage_file_str', type=str, default='input_storage_proof.json')
parser.add_argument('--storage_file_str', type=str, default='inputs/input_storage_proof.json')
parser.add_argument('--slot', type=int, default=10)
parser.add_argument('--punk_slot', type=int, default=0)
args = parser.parse_args()
@@ -307,7 +313,7 @@ def main():
punk_block = json.loads(f.read())
if args.addr:
addr_pf = get_addr_pf(punk_block)
addr_pf = get_addr_pf(punk_block, debug=args.debug)
pf_str = pprint.pformat(addr_pf, width=100, compact=True).replace("'", '"')
with open(args.addr_file_str, 'w') as f:
f.write(pf_str)
@@ -335,7 +341,10 @@ def main():
y = ''.join(['0' for idx in range(64 - len(y))]) + y
x = x + y
slot = keccak256(x)
storage_pf = get_storage_pf(punk_block, slot=slot)
storage_pf = get_storage_pf(punk_block, slot=slot, debug=args.debug)
print('Punk {:5} depth {:3}'.format(args.punk_slot, storage_pf['depth']))
pf_str = pprint.pformat(storage_pf, width=100, compact=True).replace("'", '"')
with open(args.storage_file_str, 'w') as f:
f.write(pf_str)

View File

@@ -0,0 +1,7 @@
#!/bin/bash
for TX_IDX in {0..196}
do
python generate_tx_proof_inputs.py --tx_idx "$TX_IDX" --max_depth 5 --max_key_len 6 --max_val_len 234 --file_str inputs/input_tx_pf"$PUNK_IDX".json
echo "$TX_IDX"
done

View File

@@ -33,7 +33,7 @@ def construct_mpt(d):
trie.update(key, d[key])
return trie, storage
def get_proof(storage, proof, node, path):
def get_mpt_proof(storage, proof, node, path):
if len(path) == 0:
proof = proof + [node]
return proof
@@ -51,7 +51,7 @@ def get_proof(storage, proof, node, path):
else:
node = node.next_ref
node = mpt.node.Node.decode(node)
return get_proof(storage, proof, node, rest_path)
return get_mpt_proof(storage, proof, node, rest_path)
elif type(node) is mpt.node.Node.Branch:
branch = node.branches[path.at(0)]
proof = proof + [node]
@@ -61,9 +61,9 @@ def get_proof(storage, proof, node, path):
node = branch
node = mpt.node.Node.decode(node)
if len(branch) > 0:
return get_proof(storage, proof, node, path.consume(1))
return get_mpt_proof(storage, proof, node, path.consume(1))
def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHexLen):
def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHexLen, debug=False):
maxNodeRefLen = 64
maxLeafRlpHexLen = 4 + (maxKeyHexLen + 2) + 4 + maxValueHexLen
maxBranchRlpHexLen = 1064 + 2 + maxValueHexLen
@@ -116,7 +116,8 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex
lhash = keccak256(last.encode().hex())
if type(last) is mpt.node.Node.Branch:
leafRlpHexs = [0 for idx in range(maxLeafRlpHexLen)]
print('Branch', len(last.data.hex()))
if debug:
print('Branch', len(last.data.hex()))
isTerminalBranch = 1
last = last.encode().hex()
@@ -155,11 +156,13 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex
terminalBranchVtValueHexLen = 2 * int(str_len, 16)
node_rlp = serialize_hex(last)
print('LENGTH: {}'.format(len(node_rlp)))
if debug:
print('LENGTH: {}'.format(len(node_rlp)))
node_rlp = node_rlp + [0 for x in range(maxBranchRlpHexLen - len(node_rlp))]
terminalBranchRlpHexs.append(node_rlp)
else:
print('Leaf', len(last.data.hex()), last.encode().hex(), last.path)
if debug:
print('Leaf', len(last.data.hex()), last.encode().hex(), last.path)
terminalBranchNodeRefHexLen = [0 for idx in range(16)]
terminalBranchRlpHexs = [0 for idx in range(maxBranchRlpHexLen)]
@@ -223,7 +226,8 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex
leafValueLenHexLen = 2 * int(str_len, 16)
leafRlpHexs = serialize_hex(last)
print('LENGTH: {}'.format(len(leafRlpHexs)))
if debug:
print('LENGTH: {}'.format(len(leafRlpHexs)))
leafRlpHexs = leafRlpHexs + [0 for x in range(maxLeafRlpHexLen - len(leafRlpHexs))]
for idx, node in enumerate(proof[:-1]):
@@ -231,10 +235,12 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex
node = node.encode().hex()
node_decode = mpt.node.Node.decode(bytearray.fromhex(node))
print(idx, node_decode, nhash, node)
if debug:
print(idx, node_decode, nhash, node)
if type(node_decode) is mpt.node.Node.Branch:
for idx, x in enumerate(node_decode.branches):
print(idx, x.hex())
if debug:
for idx, x in enumerate(node_decode.branches):
print(idx, x.hex())
rlp_prefix = node[:2]
curr_idx = 2
if int(rlp_prefix, 16) <= int('f7', 16):
@@ -273,14 +279,17 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex
nodeVtValueHexLen.append(2 * int(str_len, 16))
node_rlp = serialize_hex(node)
print('LENGTH: {}'.format(len(node_rlp)))
if debug:
print('LENGTH: {}'.format(len(node_rlp)))
node_rlp = node_rlp + [0 for x in range(maxBranchRlpHexLen - len(node_rlp))]
nodeRlpHexs.append(node_rlp)
nodeTypes.append(0)
print('Branch', len(node_decode.data.hex()), nhash, node_decode.encode().hex(), node_decode.data.hex())
if debug:
print('Branch', len(node_decode.data.hex()), nhash, node_decode.encode().hex(), node_decode.data.hex())
elif type(node_decode) is mpt.node.Node.Extension:
print(node_decode.path)
if debug:
print(node_decode.path)
rlp_prefix = node[:2]
curr_idx = 2
if int(rlp_prefix, 16) <= int('f7', 16):
@@ -313,8 +322,6 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex
rlp_prefix = node[curr_idx: curr_idx + 2]
curr_idx = curr_idx + 2
temp = []
print('aa', rlp_prefix)
print(node[curr_idx:])
if int(rlp_prefix, 16) <= int('b7', 16):
temp.append(2 * (int(rlp_prefix, 16) - int('80', 16)))
curr_idx = curr_idx + temp[-1]
@@ -325,7 +332,8 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex
nodeVtValueHexLen.append(0)
temp = serialize_hex(node)
print('LENGTH: {}'.format(len(temp)))
if debug:
print('LENGTH: {}'.format(len(temp)))
temp = temp + [0 for idx in range(maxBranchRlpHexLen - len(temp))]
nodeRlpHexs.append(temp)
nodeTypes.append(1)
@@ -377,7 +385,7 @@ def gen_proof_input(proof, root, key, value, maxDepth, maxKeyHexLen, maxValueHex
"depth": depth }
return ret
def get_pf(block, tx_idx):
def get_pf(block, tx_idx, max_depth=None, max_key_len=64, max_val_len=234, debug=False):
block = block['result']
block_hash = block['hash']
@@ -402,7 +410,8 @@ def get_pf(block, tx_idx):
tx_list = block['transactions']
raw_tx_dict = {}
print('{} tx in block'.format(len(tx_list)))
if debug:
print('{} tx in block'.format(len(tx_list)))
for idx, tx in enumerate(tx_list):
if tx['type'] == '0x0':
raw_tx = [int(tx['nonce'], 16),
@@ -439,31 +448,35 @@ def get_pf(block, tx_idx):
raw_tx_dict[rlp.encode(idx)] = bytearray.fromhex('02') + rlp.encode(raw_tx)
else:
print('type not handled: {}'.format(tx['type']))
print(idx, rlp.encode(idx).hex(), tx['type'], tx['hash'], len(rlp.encode(raw_tx).hex()))
if debug:
print(idx, rlp.encode(idx).hex(), tx['type'], tx['hash'], len(rlp.encode(raw_tx).hex()))
trie, storage = construct_mpt(raw_tx_dict)
root = mpt.node.Node.decode(storage[trie._root])
path = mpt.nibble_path.NibblePath(rlp.encode(tx_idx))
pf = get_proof(storage, [], root, path)
print(rlp.encode(tx_idx).hex(), pf)
pf = get_mpt_proof(storage, [], root, path)
value = trie.get(rlp.encode(tx_idx)).hex()
print(root)
ret = gen_proof_input(pf, keccak256(root.encode().hex()), rlp.encode(tx_idx).hex(), value, 6, 64, 234)
ret = gen_proof_input(pf, keccak256(root.encode().hex()), rlp.encode(tx_idx).hex(), value, max_depth, max_key_len, max_val_len, debug=debug)
return ret
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true', default=False)
parser.add_argument('--tx_idx', type=int, default=2)
parser.add_argument('--file_str', type=str, default='input_tx_proof.json')
parser.add_argument('--file_str', type=str, default='inputs/input_tx_proof.json')
parser.add_argument('--max_depth', type=int, default=5)
parser.add_argument('--max_key_len', type=int, default=6)
parser.add_argument('--max_val_len', type=int, default=234)
args = parser.parse_args()
def main():
with open('punk_block.json', 'r') as f:
block = json.loads(f.read())
pf = get_pf(block, args.tx_idx)
pf = get_pf(block, args.tx_idx, max_depth=args.max_depth, max_key_len=args.max_key_len, max_val_len=args.max_val_len, debug=args.debug)
pf_str = pprint.pformat(pf, width=100, compact=True).replace("'", '"')
with open(args.file_str, 'w') as f:
f.write(pf_str)