From f6b1fe61727e7e8c4a6bcc43eaec58ae5ce2ead3 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Tue, 14 Jul 2020 17:36:27 +0800 Subject: [PATCH] Refactor tests and avoiding passing `shart_store` to helper functions --- specs/phase1/shard-fork-choice.md | 30 +-- .../test/fork_choice/test_on_shard_block.py | 178 ++++++++++++++---- 2 files changed, 159 insertions(+), 49 deletions(-) diff --git a/specs/phase1/shard-fork-choice.md b/specs/phase1/shard-fork-choice.md index 1835df432..39e957819 100644 --- a/specs/phase1/shard-fork-choice.md +++ b/specs/phase1/shard-fork-choice.md @@ -47,7 +47,8 @@ def get_forkchoice_shard_store(anchor_state: BeaconState, shard: Shard) -> Shard #### `get_shard_latest_attesting_balance` ```python -def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, root: Root) -> Gwei: +def get_shard_latest_attesting_balance(store: Store, shard: Shard, root: Root) -> Gwei: + shard_store = store.shard_stores[shard] state = store.checkpoint_states[store.justified_checkpoint] active_indices = get_active_validator_indices(state, get_current_epoch(state)) return Gwei(sum( @@ -58,7 +59,7 @@ def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, ro # would be ignored once their newer vote is accepted. Check if it makes sense. and get_shard_ancestor( store, - shard_store, + shard, shard_store.latest_messages[i].root, shard_store.signed_blocks[root].message.slot, ) == root @@ -69,10 +70,14 @@ def get_shard_latest_attesting_balance(store: Store, shard_store: ShardStore, ro #### `get_shard_head` ```python -def get_shard_head(store: Store, shard_store: ShardStore) -> Root: +def get_shard_head(store: Store, shard: Shard) -> Root: # Execute the LMD-GHOST fork choice + """ + Execute the LMD-GHOST fork choice. + """ + shard_store = store.shard_stores[shard] beacon_head_root = get_head(store) - shard_head_state = store.block_states[beacon_head_root].shard_states[shard_store.shard] + shard_head_state = store.block_states[beacon_head_root].shard_states[shard] shard_head_root = shard_head_state.latest_block_root shard_blocks = { root: signed_shard_block.message for root, signed_shard_block in shard_store.signed_blocks.items() @@ -88,17 +93,18 @@ def get_shard_head(store: Store, shard_store: ShardStore) -> Root: return shard_head_root # Sort by latest attesting balance with ties broken lexicographically shard_head_root = max( - children, key=lambda root: (get_shard_latest_attesting_balance(store, shard_store, root), root) + children, key=lambda root: (get_shard_latest_attesting_balance(store, shard, root), root) ) ``` #### `get_shard_ancestor` ```python -def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot: Slot) -> Root: +def get_shard_ancestor(store: Store, shard: Shard, root: Root, slot: Slot) -> Root: + shard_store = store.shard_stores[shard] block = shard_store.signed_blocks[root].message if block.slot > slot: - return get_shard_ancestor(store, shard_store, block.shard_parent_root, slot) + return get_shard_ancestor(store, shard, block.shard_parent_root, slot) elif block.slot == slot: return root else: @@ -109,17 +115,17 @@ def get_shard_ancestor(store: Store, shard_store: ShardStore, root: Root, slot: #### `get_pending_shard_blocks` ```python -def get_pending_shard_blocks(store: Store, shard_store: ShardStore) -> Sequence[SignedShardBlock]: +def get_pending_shard_blocks(store: Store, shard: Shard) -> Sequence[SignedShardBlock]: """ Return the canonical shard block branch that has not yet been crosslinked. """ - shard = shard_store.shard + shard_store = store.shard_stores[shard] beacon_head_root = get_head(store) beacon_head_state = store.block_states[beacon_head_root] latest_shard_block_root = beacon_head_state.shard_states[shard].latest_block_root - shard_head_root = get_shard_head(store, shard_store) + shard_head_root = get_shard_head(store, shard) root = shard_head_root signed_shard_blocks = [] while root != latest_shard_block_root: @@ -136,9 +142,9 @@ def get_pending_shard_blocks(store: Store, shard_store: ShardStore) -> Sequence[ #### `on_shard_block` ```python -def on_shard_block(store: Store, shard_store: ShardStore, signed_shard_block: SignedShardBlock) -> None: +def on_shard_block(store: Store, shard: Shard, signed_shard_block: SignedShardBlock) -> None: + shard_store = store.shard_stores[shard] shard_block = signed_shard_block.message - shard = shard_store.shard # Check shard # TODO: check it in networking spec diff --git a/tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_block.py b/tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_block.py index 3b03906d9..dae9246fb 100644 --- a/tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_block.py +++ b/tests/core/pyspec/eth2spec/test/fork_choice/test_on_shard_block.py @@ -8,27 +8,43 @@ from eth2spec.test.helpers.shard_block import ( get_committee_index_of_shard, ) from eth2spec.test.helpers.fork_choice import add_block_to_store, get_anchor_root +from eth2spec.test.helpers.shard_transitions import is_full_crosslink from eth2spec.test.helpers.state import state_transition_and_sign_block from eth2spec.test.helpers.block import build_empty_block -def run_on_shard_block(spec, store, shard_store, signed_block, valid=True): +def run_on_shard_block(spec, store, shard, signed_block, valid=True): if not valid: try: - spec.on_shard_block(store, shard_store, signed_block) + spec.on_shard_block(store, shard, signed_block) except AssertionError: return else: assert False - spec.on_shard_block(store, shard_store, signed_block) + spec.on_shard_block(store, shard, signed_block) + shard_store = store.shard_stores[shard] assert shard_store.signed_blocks[hash_tree_root(signed_block.message)] == signed_block -def apply_shard_block(spec, store, shard_store, beacon_parent_state, shard_blocks_buffer): - shard = shard_store.shard +def initialize_store(spec, state, shard): + store = spec.get_forkchoice_store(state) + anchor_root = get_anchor_root(spec, state) + assert spec.get_head(store) == anchor_root + + shard_head_root = spec.get_shard_head(store, shard) + assert shard_head_root == state.shard_states[shard].latest_block_root + shard_store = store.shard_stores[shard] + assert shard_store.block_states[shard_head_root].slot == 1 + assert shard_store.block_states[shard_head_root] == state.shard_states[shard] + + return store + + +def create_and_apply_shard_block(spec, store, shard, beacon_parent_state, shard_blocks_buffer): body = b'\x56' * 4 - shard_head_root = spec.get_shard_head(store, shard_store) + shard_head_root = spec.get_shard_head(store, shard) + shard_store = store.shard_stores[shard] shard_parent_state = shard_store.block_states[shard_head_root] assert shard_parent_state.slot != beacon_parent_state.slot shard_block = build_shard_block( @@ -36,12 +52,12 @@ def apply_shard_block(spec, store, shard_store, beacon_parent_state, shard_block shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True ) shard_blocks_buffer.append(shard_block) - run_on_shard_block(spec, store, shard_store, shard_block) - assert spec.get_shard_head(store, shard_store) == shard_block.message.hash_tree_root() + run_on_shard_block(spec, store, shard, shard_block) + assert spec.get_shard_head(store, shard) == shard_block.message.hash_tree_root() -def check_pending_shard_blocks(spec, store, shard_store, shard_blocks_buffer): - pending_shard_blocks = spec.get_pending_shard_blocks(store, shard_store) +def check_pending_shard_blocks(spec, store, shard, shard_blocks_buffer): + pending_shard_blocks = spec.get_pending_shard_blocks(store, shard) assert pending_shard_blocks == shard_blocks_buffer @@ -52,10 +68,22 @@ def is_in_offset_sets(spec, beacon_head_state, shard): return beacon_head_state.slot in offset_slots -def apply_shard_and_beacon(spec, state, store, shard_store, shard_blocks_buffer): - store.time = store.time + spec.SECONDS_PER_SLOT * spec.SLOTS_PER_EPOCH +def create_attestation_for_shard_blocks(spec, beacon_parent_state, shard, committee_index, blocks, + filter_participant_set=None): + shard_transition = spec.get_shard_transition(beacon_parent_state, shard, blocks) + attestation = get_valid_on_time_attestation( + spec, + beacon_parent_state, + index=committee_index, + shard_transition=shard_transition, + signed=False, + ) + return attestation - shard = shard_store.shard + +def create_beacon_block_with_shard_transition( + spec, state, store, shard, shard_blocks_buffer, is_checking_pending_shard_blocks=True): + beacon_block = build_empty_block(spec, state, slot=state.slot + 1) committee_index = get_committee_index_of_shard(spec, state, state.slot, shard) has_shard_committee = committee_index is not None # has committee of `shard` at this slot @@ -63,14 +91,12 @@ def apply_shard_and_beacon(spec, state, store, shard_store, shard_blocks_buffer) # If next slot has committee of `shard`, add `shard_transtion` to the proposing beacon block if has_shard_committee and len(shard_blocks_buffer) > 0: - # Sanity check `get_pending_shard_blocks` function - check_pending_shard_blocks(spec, store, shard_store, shard_blocks_buffer) + # Sanity check `get_pending_shard_blocks` + # Assert that the pending shard blocks set in the store equal to shard_blocks_buffer + if is_checking_pending_shard_blocks: + check_pending_shard_blocks(spec, store, shard, shard_blocks_buffer) # Use temporary next state to get ShardTransition of shard block - shard_transitions = get_shard_transitions( - spec, - state, - shard_block_dict={shard: shard_blocks_buffer}, - ) + shard_transitions = get_shard_transitions(spec, state, shard_block_dict={shard: shard_blocks_buffer}) shard_transition = shard_transitions[shard] attestation = get_valid_on_time_attestation( spec, @@ -86,15 +112,31 @@ def apply_shard_and_beacon(spec, state, store, shard_store, shard_blocks_buffer) # Clear buffer shard_blocks_buffer.clear() - signed_beacon_block = state_transition_and_sign_block(spec, state, beacon_block) # transition! - add_block_to_store(spec, store, signed_beacon_block) - assert spec.get_head(store) == beacon_block.hash_tree_root() + return beacon_block - # On shard block at transitioned `state.slot` + +def apply_all_attestation_to_store(spec, store, attestations): + for attestation in attestations: + spec.on_attestation(store, attestation) + + +def apply_beacon_block_to_store(spec, state, store, beacon_block): + signed_beacon_block = state_transition_and_sign_block(spec, state, beacon_block) # transition! + store.time = store.time + spec.SECONDS_PER_SLOT + add_block_to_store(spec, store, signed_beacon_block) + apply_all_attestation_to_store(spec, store, signed_beacon_block.message.body.attestations) + + +def create_and_apply_beacon_and_shard_blocks(spec, state, store, shard, shard_blocks_buffer): + beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard, shard_blocks_buffer) + apply_beacon_block_to_store(spec, state, store, beacon_block) + + # On shard block at the transitioned `state.slot` if is_in_offset_sets(spec, state, shard): # The created shard block would be appended to `shard_blocks_buffer` - apply_shard_block(spec, store, shard_store, state, shard_blocks_buffer) + create_and_apply_shard_block(spec, store, shard, state, shard_blocks_buffer) + has_shard_committee = get_committee_index_of_shard(spec, state, state.slot, shard) is not None return has_shard_committee @@ -107,23 +149,85 @@ def test_basic(spec, state): shard = spec.Shard(1) # Initialization - store = spec.get_forkchoice_store(state) - anchor_root = get_anchor_root(spec, state) - assert spec.get_head(store) == anchor_root - - shard_store = store.shard_stores[shard] - shard_head_root = spec.get_shard_head(store, shard_store) - assert shard_head_root == state.shard_states[shard].latest_block_root - assert shard_store.block_states[shard_head_root].slot == 1 - assert shard_store.block_states[shard_head_root] == state.shard_states[shard] + store = initialize_store(spec, state, shard) # For mainnet config, it's possible that only one committee of `shard` per epoch. # we set this counter to test more rounds. shard_committee_counter = 2 - shard_blocks_buffer = [] + shard_blocks_buffer = [] # the accumulated shard blocks that haven't been crosslinked yet while shard_committee_counter > 0: - has_shard_committee = apply_shard_and_beacon( - spec, state, store, shard_store, shard_blocks_buffer + has_shard_committee = create_and_apply_beacon_and_shard_blocks( + spec, state, store, shard, shard_blocks_buffer ) if has_shard_committee: shard_committee_counter -= 1 + + +def create_simple_fork(spec, state, store, shard): + # Beacon block + beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard, []) + apply_beacon_block_to_store(spec, state, store, beacon_block) + + beacon_head_root = spec.get_head(store) + assert beacon_head_root == beacon_block.hash_tree_root() + beacon_parent_state = store.block_states[beacon_head_root] + shard_store = store.shard_stores[shard] + shard_parent_state = shard_store.block_states[spec.get_shard_head(store, shard)] + + # Shard block A + body = b'\x56' * 4 + forking_block_child = build_shard_block( + spec, beacon_parent_state, shard, + shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True + ) + run_on_shard_block(spec, store, shard, forking_block_child) + + # Shard block B + body = b'\x78' * 4 # different body + shard_block_b = build_shard_block( + spec, beacon_parent_state, shard, + shard_parent_state=shard_parent_state, slot=beacon_parent_state.slot, body=body, signed=True + ) + run_on_shard_block(spec, store, shard, shard_block_b) + + # Set forking_block + current_head = spec.get_shard_head(store, shard) + if current_head == forking_block_child.message.hash_tree_root(): + head_block = forking_block_child + forking_block = shard_block_b + else: + assert current_head == shard_block_b.message.hash_tree_root() + head_block = shard_block_b + forking_block = forking_block_child + + return head_block, forking_block + + +@with_all_phases_except([PHASE0]) +@spec_state_test +@never_bls # Set to never_bls for testing `check_pending_shard_blocks` +def test_shard_simple_fork(spec, state): + if not is_full_crosslink(spec, state): + # skip + return + + spec.PHASE_1_GENESIS_SLOT = 0 # NOTE: mock genesis slot here + state = spec.upgrade_to_phase1(state) + shard = spec.Shard(1) + + # Initialization + store = initialize_store(spec, state, shard) + + # Create fork + _, forking_block = create_simple_fork(spec, state, store, shard) + + # Vote for forking_block + state = store.block_states[spec.get_head(store)].copy() + beacon_block = create_beacon_block_with_shard_transition(spec, state, store, shard, [forking_block], + is_checking_pending_shard_blocks=False) + # apply_beacon_block_to_store(spec, state, store, beacon_block) + store.time = store.time + spec.SECONDS_PER_SLOT + apply_all_attestation_to_store(spec, store, beacon_block.body.attestations) + + # Head block has been changed + assert spec.get_shard_head(store, shard) == forking_block.message.hash_tree_root()