[squashed] shard transition wip

Fix the wrong `get_shard_proposer_index` parameters order

Phase 1 WIP

Add shard transition basic test

Fix lint error

Fix
This commit is contained in:
Hsiao-Wei Wang
2020-04-06 19:30:32 +08:00
parent 849d3f83bf
commit 4e8a7ff115
10 changed files with 374 additions and 237 deletions

View File

@@ -1,6 +1,6 @@
from typing import List
from eth2spec.test.context import expect_assertion_error, PHASE0
from eth2spec.test.context import expect_assertion_error, PHASE0, PHASE1
from eth2spec.test.helpers.state import state_transition_and_sign_block
from eth2spec.test.helpers.block import build_empty_block_for_next_slot
from eth2spec.test.helpers.keys import privkeys
@@ -43,7 +43,7 @@ def run_attestation_processing(spec, state, attestation, valid=True):
yield 'post', state
def build_attestation_data(spec, state, slot, index):
def build_attestation_data(spec, state, slot, index, shard_transition=None):
assert state.slot >= slot
if slot == state.slot:
@@ -66,12 +66,21 @@ def build_attestation_data(spec, state, slot, index):
source_epoch = state.current_justified_checkpoint.epoch
source_root = state.current_justified_checkpoint.root
if spec.fork == PHASE1 and shard_transition is not None:
shard_transition_root = shard_transition.hash_tree_root()
head_shard_root = shard_transition.shard_data_roots[len(shard_transition.shard_data_roots) - 1]
else:
shard_transition_root = spec.Root()
head_shard_root = spec.Root()
return spec.AttestationData(
slot=slot,
index=index,
beacon_block_root=block_root,
source=spec.Checkpoint(epoch=source_epoch, root=source_root),
target=spec.Checkpoint(epoch=spec.compute_epoch_at_slot(slot), root=epoch_boundary_root),
head_shard_root=head_shard_root,
shard_transition_root=shard_transition_root,
)
@@ -89,7 +98,7 @@ def convert_to_valid_on_time_attestation(spec, state, attestation, signed=False)
return attestation
def get_valid_on_time_attestation(spec, state, slot=None, index=None, signed=False):
def get_valid_on_time_attestation(spec, state, slot=None, index=None, shard_transition=None, signed=False):
'''
Construct on-time attestation for next slot
'''
@@ -98,7 +107,15 @@ def get_valid_on_time_attestation(spec, state, slot=None, index=None, signed=Fal
if index is None:
index = 0
return get_valid_attestation(spec, state, slot=slot, index=index, signed=signed, on_time=True)
return get_valid_attestation(
spec,
state,
slot=slot,
index=index,
shard_transition=shard_transition,
signed=signed,
on_time=True,
)
def get_valid_late_attestation(spec, state, slot=None, index=None, signed=False):
@@ -113,13 +130,20 @@ def get_valid_late_attestation(spec, state, slot=None, index=None, signed=False)
return get_valid_attestation(spec, state, slot=slot, index=index, signed=signed, on_time=False)
def get_valid_attestation(spec, state, slot=None, index=None, empty=False, signed=False, on_time=True):
def get_valid_attestation(spec,
state,
slot=None,
index=None,
shard_transition=None,
empty=False,
signed=False,
on_time=True):
if slot is None:
slot = state.slot
if index is None:
index = 0
attestation_data = build_attestation_data(spec, state, slot, index)
attestation_data = build_attestation_data(spec, state, slot=slot, index=index, shard_transition=shard_transition)
beacon_committee = spec.get_beacon_committee(
state,
@@ -138,7 +162,7 @@ def get_valid_attestation(spec, state, slot=None, index=None, empty=False, signe
if signed:
sign_attestation(spec, state, attestation)
if spec.fork == 'phase1' and on_time:
if spec.fork == PHASE1 and on_time:
attestation = convert_to_valid_on_time_attestation(spec, state, attestation, signed)
return attestation
@@ -210,7 +234,7 @@ def get_attestation_custody_signature(spec, state, attestation_data, block_index
def sign_attestation(spec, state, attestation):
if spec.fork == 'phase1' and any(attestation.custody_bits_blocks):
if spec.fork == PHASE1 and any(attestation.custody_bits_blocks):
sign_on_time_attestation(spec, state, attestation)
return

View File

@@ -0,0 +1,28 @@
from eth2spec.test.context import expect_assertion_error
def run_crosslinks_processing(spec, state, shard_transitions, attestations, valid=True):
"""
Run ``process_attestation``, yielding:
- pre-state ('pre')
- shard_transitions ('shard_transitions')
- attestations ('attestations')
- post-state ('post').
If ``valid == False``, run expecting ``AssertionError``
"""
# yield pre-state
yield 'pre', state
yield 'shard_transitions', shard_transitions
yield 'attestations', attestations
# If the attestation is invalid, processing is aborted, and there is no post-state.
if not valid:
expect_assertion_error(lambda: spec.process_crosslinks(state, shard_transitions, attestations))
yield 'post', None
return
# process crosslinks
spec.process_crosslinks(state, shard_transitions, attestations)
# yield post-state
yield 'post', state

View File

@@ -1,63 +0,0 @@
from eth2spec.utils.ssz.ssz_typing import Bitlist
from eth2spec.utils import bls
from eth2spec.test.helpers.keys import privkeys
import eth2spec.test.helpers.attestations as phase0_attestations
def get_valid_on_time_attestation(spec, state, index=None, signed=False):
'''
Construct on-time attestation for next slot
'''
if index is None:
index = 0
attestation = phase0_attestations.get_valid_attestation(spec, state, state.slot, index, False)
shard = spec.get_shard(state, attestation)
offset_slots = spec.compute_offset_slots(spec.get_latest_slot_for_shard(state, shard), state.slot + 1)
for _ in offset_slots:
attestation.custody_bits_blocks.append(
Bitlist[spec.MAX_VALIDATORS_PER_COMMITTEE]([0 for _ in attestation.aggregation_bits])
)
if signed:
sign_attestation(spec, state, attestation)
return attestation
def sign_attestation(spec, state, attestation):
if not any(attestation.custody_bits_blocks):
phase0_attestations.sign_attestation(spec, state, attestation)
return
committee = spec.get_beacon_committee(state, attestation.data.slot, attestation.data.index)
signatures = []
for block_index, custody_bits in enumerate(attestation.custody_bits_blocks):
for participant, abit, cbit in zip(committee, attestation.aggregation_bits, custody_bits):
if not abit:
continue
signatures.append(get_attestation_custody_signature(
spec,
state,
attestation.data,
block_index,
cbit,
privkeys[participant]
))
attestation.signature = bls.Aggregate(signatures)
def get_attestation_custody_signature(spec, state, attestation_data, block_index, bit, privkey):
domain = spec.get_domain(state, spec.DOMAIN_BEACON_ATTESTER, attestation_data.target.epoch)
signing_root = spec.compute_signing_root(
spec.AttestationCustodyBitWrapper(
attestation_data.hash_tree_root(),
block_index,
bit,
),
domain,
)
return bls.Sign(privkey, signing_root)

View File

@@ -1,71 +0,0 @@
from eth2spec.test.helpers.keys import privkeys
from eth2spec.utils import bls
from eth2spec.utils.bls import only_with_bls
from eth2spec.utils.ssz.ssz_impl import (
hash_tree_root,
)
from .attestations import (
sign_shard_attestation,
)
@only_with_bls()
def sign_shard_block(spec, beacon_state, shard_state, block, proposer_index=None):
if proposer_index is None:
proposer_index = spec.get_shard_proposer_index(beacon_state, shard_state.shard, block.slot)
privkey = privkeys[proposer_index]
domain = spec.get_domain(beacon_state, spec.DOMAIN_SHARD_PROPOSER, spec.compute_epoch_of_shard_slot(block.slot))
signing_root = spec.compute_signing_root(block, domain)
block.signature = bls.Sign(privkey, signing_root)
def build_empty_shard_block(spec,
beacon_state,
shard_state,
slot,
signed=False,
full_attestation=False):
if slot is None:
slot = shard_state.slot
previous_beacon_header = beacon_state.latest_block_header.copy()
if previous_beacon_header.state_root == spec.Bytes32():
previous_beacon_header.state_root = beacon_state.hash_tree_root()
beacon_block_root = hash_tree_root(previous_beacon_header)
previous_block_header = shard_state.latest_block_header.copy()
if previous_block_header.state_root == spec.Bytes32():
previous_block_header.state_root = shard_state.hash_tree_root()
parent_root = hash_tree_root(previous_block_header)
block = spec.ShardBlock(
shard=shard_state.shard,
slot=slot,
beacon_block_root=beacon_block_root,
parent_root=parent_root,
block_size_sum=shard_state.block_size_sum + spec.SHARD_HEADER_SIZE,
)
if full_attestation:
shard_committee = spec.get_shard_committee(beacon_state, shard_state.shard, block.slot)
block.aggregation_bits = list(
(True,) * len(shard_committee) +
(False,) * (spec.MAX_PERIOD_COMMITTEE_SIZE * 2 - len(shard_committee))
)
else:
shard_committee = []
block.attestations = sign_shard_attestation(
spec,
beacon_state,
shard_state,
block,
participants=shard_committee,
)
if signed:
sign_shard_block(spec, beacon_state, shard_state, block)
return block

View File

@@ -1,18 +0,0 @@
from eth2spec.test.helpers.phase1.shard_block import sign_shard_block
def configure_shard_state(spec, beacon_state, shard=0):
beacon_state.slot = spec.Slot(spec.SHARD_GENESIS_EPOCH * spec.SLOTS_PER_EPOCH)
shard_state = spec.get_genesis_shard_state(spec.Shard(shard))
shard_state.slot = spec.ShardSlot(spec.SHARD_GENESIS_EPOCH * spec.SHARD_SLOTS_PER_EPOCH)
return beacon_state, shard_state
def shard_state_transition_and_sign_block(spec, beacon_state, shard_state, block):
"""
Shard state transition via the provided ``block``
then package the block with the state root and signature.
"""
spec.shard_state_transition(beacon_state, shard_state, block)
block.state_root = shard_state.hash_tree_root()
sign_shard_block(spec, beacon_state, shard_state, block)

View File

@@ -0,0 +1,47 @@
from eth2spec.test.helpers.keys import privkeys
from eth2spec.utils import bls
from eth2spec.utils.bls import only_with_bls
@only_with_bls()
def sign_shard_block(spec, beacon_state, shard, block, proposer_index=None):
slot = block.message.slot
if proposer_index is None:
proposer_index = spec.get_shard_proposer_index(beacon_state, slot, shard)
privkey = privkeys[proposer_index]
domain = spec.get_domain(beacon_state, spec.DOMAIN_SHARD_PROPOSAL, spec.compute_epoch_at_slot(slot))
signing_root = spec.compute_signing_root(block.message, domain)
block.signature = bls.Sign(privkey, signing_root)
def build_empty_shard_block(spec,
beacon_state,
shard,
slot,
body=None,
signed=False):
shard_state = beacon_state.shard_states[shard]
if slot is None:
slot = shard_state.slot
if body is None:
body = []
proposer_index = spec.get_shard_proposer_index(beacon_state, slot, shard)
block = spec.ShardBlock(
shard_parent_root=shard_state.latest_block_root,
beacon_parent_root=spec.get_block_root_at_slot(beacon_state, spec.get_previous_slot(beacon_state.slot)),
slot=slot,
proposer_index=proposer_index,
body=body,
)
signed_block = spec.SignedShardBlock(
message=block,
)
if signed:
sign_shard_block(spec, beacon_state, shard, signed_block, proposer_index=proposer_index)
return signed_block

View File

@@ -0,0 +1,52 @@
from eth2spec.test.context import (
with_all_phases_except,
spec_state_test,
always_bls,
)
from eth2spec.test.helpers.attestations import (
get_valid_on_time_attestation,
)
from eth2spec.test.helpers.crosslinks import (
run_crosslinks_processing,
)
from eth2spec.test.helpers.shard_block import build_empty_shard_block
from eth2spec.test.helpers.state import next_epoch, next_slot
@with_all_phases_except(['phase0'])
@spec_state_test
@always_bls
def test_basic_crosslinks(spec, state):
next_epoch(spec, state)
next_epoch(spec, state)
state = spec.upgrade_to_phase1(state)
next_slot(spec, state)
committee_index = spec.CommitteeIndex(0)
shard = spec.compute_shard_from_committee_index(state, committee_index, state.slot)
shard_block = build_empty_shard_block(spec, state, shard, body=b'\x12' * 10, slot=state.slot, signed=True)
shard_blocks = [shard_block]
next_slot(spec, state)
shard_transition = spec.get_shard_transition(state, shard, shard_blocks)
shard_transitions = [spec.ShardTransition()] * len(state.shard_states)
shard_transitions[shard] = shard_transition
attestation = get_valid_on_time_attestation(
spec,
state,
slot=state.slot,
index=committee_index,
shard_transition=shard_transition,
signed=True,
)
attestations = [attestation]
offset_slots = spec.get_offset_slots(state, shard)
yield from run_crosslinks_processing(spec, state, shard_transitions, attestations)
shard_state = state.shard_states[shard]
assert shard_state.slot == offset_slots[-1]
assert shard_state.latest_block_root == shard_block.message.hash_tree_root()