From 94e6cfe478537bbb90f56b68ee71e660af104f5f Mon Sep 17 00:00:00 2001 From: terence tsao Date: Sat, 2 Mar 2019 19:14:04 -0800 Subject: [PATCH] Refactor Crosslink Committees at Slot (#1771) --- beacon-chain/core/helpers/committee.go | 330 +++++++++++++------- beacon-chain/core/helpers/committee_test.go | 66 +++- 2 files changed, 273 insertions(+), 123 deletions(-) diff --git a/beacon-chain/core/helpers/committee.go b/beacon-chain/core/helpers/committee.go index 84c4db3150..63ae4aa07b 100644 --- a/beacon-chain/core/helpers/committee.go +++ b/beacon-chain/core/helpers/committee.go @@ -21,6 +21,14 @@ type CrosslinkCommittee struct { Shard uint64 } +type shufflingInput struct { + seed []byte + shufflingEpoch uint64 + slot uint64 + startShard uint64 + committeesPerEpoch uint64 +} + // EpochCommitteeCount returns the number of crosslink committees of an epoch. // // Spec pseudocode definition: @@ -117,9 +125,11 @@ func NextEpochCommitteeCount(state *pb.BeaconState) uint64 { // CrosslinkCommitteesAtSlot returns the list of crosslink committees, it // contains the shard associated with the committee and the validator indices // in that committee. +// +// Spec pseudocode definition: // def get_crosslink_committees_at_slot(state: BeaconState, -// slot: SlotNumber, -// registry_change=False: bool) -> List[Tuple[List[ValidatorIndex], Shard]]: +// slot: Slot, +// registry_change: bool=False) -> List[Tuple[List[ValidatorIndex], Shard]]: // """ // Return the list of ``(committee, shard)`` tuples for the ``slot``. // @@ -128,143 +138,42 @@ func NextEpochCommitteeCount(state *pb.BeaconState) uint64 { // """ // epoch = slot_to_epoch(slot) // current_epoch = get_current_epoch(state) -// previous_epoch = current_epoch - 1 if current_epoch > GENESIS_EPOCH else current_epoch +// previous_epoch = get_previous_epoch(state) // next_epoch = current_epoch + 1 // // assert previous_epoch <= epoch <= next_epoch // // if epoch == current_epoch: -// committees_per_epoch = get_current_epoch_committee_count(state) -// seed = state.current_shuffling_seed -// shuffling_epoch = state.current_calculation_epoch -// shuffling_start_shard = state.current_epoch_start_shard +// return get_current_epoch_committees_at_slot(state, slot) // elif epoch == previous_epoch: -// committees_per_epoch = get_previous_epoch_committee_count(state) -// seed = state.previous_shuffling_seed -// shuffling_epoch = state.previous_shuffling_epoch -// shuffling_start_shard = state.previous_shuffling_start_shard +// return get_previous_epoch_committees_at_slot(state, slot) // elif epoch == next_epoch: -// -// epochs_since_last_registry_update = current_epoch - state.validator_registry_update_epoch -// if registry_change: -// committees_per_epoch = get_next_epoch_committee_count(state) -// shuffling_epoch = next_epoch -// seed = generate_seed(state, next_epoch) -// current_committees_per_epoch = get_current_epoch_committee_count(state) -// shuffling_start_shard = (state.current_epoch_start_shard + current_committees_per_epoch) % SHARD_COUNT -// elif epochs_since_last_registry_update > 1 and is_power_of_two(epochs_since_last_registry_update): -// committees_per_epoch = get_next_epoch_committee_count(state) -// shuffling_epoch = next_epoch -// seed = generate_seed(state, next_epoch) -// shuffling_start_shard = state.current_epoch_start_shard -// else: -// committees_per_epoch = get_current_epoch_committee_count(state) -// shuffling_epoch = state.current_shuffling_epoch -// seed = state.current_epoch_seed -// shuffling_start_shard = state.current_epoch_start_shard -// -// shuffling = get_shuffling( -// seed, -// state.validator_registry, -// shuffling_epoch, -// ) -// offset = slot % SLOTS_PER_EPOCH -// committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH -// slot_start_shard = (shuffling_start_shard + committees_per_slot * offset) % SHARD_COUNT -// -// return [ -// ( -// shuffling[committees_per_slot * offset + i], -// (slot_start_shard + i) % SHARD_COUNT, -// ) -// for i in range(committees_per_slot) -// ] +// return get_next_epoch_committee_count(state, slot, registry_change) func CrosslinkCommitteesAtSlot( state *pb.BeaconState, slot uint64, registryChange bool) ([]*CrosslinkCommittee, error) { - var committeesPerEpoch uint64 - var shufflingEpoch uint64 - var shufflingStartShard uint64 - var seed [32]byte - var err error wantedEpoch := SlotToEpoch(slot) currentEpoch := CurrentEpoch(state) prevEpoch := PrevEpoch(state) nextEpoch := NextEpoch(state) - if wantedEpoch < prevEpoch || wantedEpoch > nextEpoch { + switch wantedEpoch { + case currentEpoch: + return currEpochCommitteesAtSlot(state, slot) + case prevEpoch: + return prevEpochCommitteesAtSlot(state, slot) + case nextEpoch: + return nextEpochCommitteesAtSlot(state, slot, registryChange) + default: return nil, fmt.Errorf( "input committee epoch %d out of bounds: %d <= epoch <= %d", - wantedEpoch, - prevEpoch, - currentEpoch, + wantedEpoch-params.BeaconConfig().GenesisEpoch, + prevEpoch-params.BeaconConfig().GenesisEpoch, + currentEpoch-params.BeaconConfig().GenesisEpoch, ) } - - if wantedEpoch == currentEpoch { - committeesPerEpoch = CurrentEpochCommitteeCount(state) - seed = bytesutil.ToBytes32(state.CurrentShufflingSeedHash32) - shufflingEpoch = state.CurrentShufflingEpoch - shufflingStartShard = state.CurrentShufflingStartShard - } else if wantedEpoch == prevEpoch { - committeesPerEpoch = PrevEpochCommitteeCount(state) - seed = bytesutil.ToBytes32(state.PreviousShufflingSeedHash32) - shufflingEpoch = state.PreviousShufflingEpoch - shufflingStartShard = state.PreviousShufflingStartShard - } else if wantedEpoch == nextEpoch { - - epochsSinceLastRegistryUpdate := currentEpoch - state.ValidatorRegistryUpdateEpoch - if registryChange { - committeesPerEpoch = NextEpochCommitteeCount(state) - shufflingEpoch = nextEpoch - seed, err = GenerateSeed(state, nextEpoch) - currentCommitteesPerEpoch := CurrentEpochCommitteeCount(state) - if err != nil { - return nil, fmt.Errorf("could not generate seed: %v", err) - } - shufflingStartShard = (state.CurrentShufflingStartShard + currentCommitteesPerEpoch) % - params.BeaconConfig().ShardCount - } else if epochsSinceLastRegistryUpdate > 1 && - mathutil.IsPowerOf2(epochsSinceLastRegistryUpdate) { - committeesPerEpoch = NextEpochCommitteeCount(state) - shufflingEpoch = nextEpoch - seed, err = GenerateSeed(state, nextEpoch) - if err != nil { - return nil, fmt.Errorf("could not generate seed: %v", err) - } - shufflingStartShard = state.CurrentShufflingStartShard - } else { - committeesPerEpoch = CurrentEpochCommitteeCount(state) - shufflingEpoch = state.CurrentShufflingEpoch - seed = bytesutil.ToBytes32(state.CurrentShufflingSeedHash32) - shufflingStartShard = state.CurrentShufflingStartShard - } - } - - shuffledIndices, err := Shuffling( - seed, - state.ValidatorRegistry, - shufflingEpoch) - if err != nil { - return nil, fmt.Errorf("could not shuffle epoch validators: %v", err) - } - - offSet := slot % params.BeaconConfig().SlotsPerEpoch - committeesPerSlot := committeesPerEpoch / params.BeaconConfig().SlotsPerEpoch - slotStardShard := (shufflingStartShard + committeesPerSlot*offSet) % - params.BeaconConfig().ShardCount - - var crosslinkCommittees []*CrosslinkCommittee - for i := uint64(0); i < committeesPerSlot; i++ { - crosslinkCommittees = append(crosslinkCommittees, &CrosslinkCommittee{ - Committee: shuffledIndices[committeesPerSlot*offSet+i], - Shard: (slotStardShard + i) % params.BeaconConfig().ShardCount, - }) - } - - return crosslinkCommittees, nil } // Shuffling shuffles input validator indices and splits them by slot and shard. @@ -511,3 +420,188 @@ func CommitteeAssignment( } return []uint64{}, 0, 0, false, fmt.Errorf("could not get assignment validator %d", validatorIndex) } + +// prevEpochCommitteesAtSlot returns a list of crosslink committees of the previous epoch. +// +// Spec pseudocode definition: +// def get_previous_epoch_committees_at_slot(state: BeaconState, +// slot: Slot) -> List[Tuple[List[ValidatorIndex], Shard]]: +// committees_per_epoch = get_previous_epoch_committee_count(state) +// seed = state.previous_shuffling_seed +// shuffling_epoch = state.previous_shuffling_epoch +// shuffling_start_shard = state.previous_shuffling_start_shard +// return get_crosslink_committees( +// state, +// seed, +// shuffling_epoch, +// slot, +// start_shard, +// committees_per_epoch, +// ) +func prevEpochCommitteesAtSlot(state *pb.BeaconState, slot uint64) ([]*CrosslinkCommittee, error) { + committeesPerEpoch := PrevEpochCommitteeCount(state) + return crosslinkCommittees( + state, &shufflingInput{ + seed: state.PreviousShufflingSeedHash32, + shufflingEpoch: state.PreviousShufflingEpoch, + slot: slot, + startShard: state.PreviousShufflingStartShard, + committeesPerEpoch: committeesPerEpoch, + }) +} + +// currEpochCommitteesAtSlot returns a list of crosslink committees of the current epoch. +// +// Spec pseudocode definition: +// def get_current_epoch_committees_at_slot(state: BeaconState, +// slot: Slot) -> List[Tuple[List[ValidatorIndex], Shard]]: +// committees_per_epoch = get_current_epoch_committee_count(state) +// seed = state.current_shuffling_seed +// shuffling_epoch = state.current_shuffling_epoch +// shuffling_start_shard = state.current_shuffling_start_shard +// return get_crosslink_committees( +// state, +// seed, +// shuffling_epoch, +// slot, +// start_shard, +// committees_per_epoch, +// ) +func currEpochCommitteesAtSlot(state *pb.BeaconState, slot uint64) ([]*CrosslinkCommittee, error) { + committeesPerEpoch := CurrentEpochCommitteeCount(state) + return crosslinkCommittees( + state, &shufflingInput{ + seed: state.CurrentShufflingSeedHash32, + shufflingEpoch: state.CurrentShufflingEpoch, + slot: slot, + startShard: state.CurrentShufflingStartShard, + committeesPerEpoch: committeesPerEpoch, + }) +} + +// nextEpochCommitteesAtSlot returns a list of crosslink committees of the next epoch. +// +// Spec pseudocode definition: +// def get_next_epoch_committees_at_slot(state: BeaconState, +// slot: Slot, +// registry_change: bool) -> List[Tuple[List[ValidatorIndex], Shard]]: +// epochs_since_last_registry_update = current_epoch - state.validator_registry_update_epoch +// if registry_change: +// committees_per_epoch = get_next_epoch_committee_count(state) +// seed = generate_seed(state, next_epoch) +// shuffling_epoch = next_epoch +// current_committees_per_epoch = get_current_epoch_committee_count(state) +// shuffling_start_shard = (state.current_shuffling_start_shard + current_committees_per_epoch) % SHARD_COUNT +// elif epochs_since_last_registry_update > 1 and is_power_of_two(epochs_since_last_registry_update): +// committees_per_epoch = get_next_epoch_committee_count(state) +// seed = generate_seed(state, next_epoch) +// shuffling_epoch = next_epoch +// shuffling_start_shard = state.current_shuffling_start_shard +// else: +// committees_per_epoch = get_current_epoch_committee_count(state) +// seed = state.current_shuffling_seed +// shuffling_epoch = state.current_shuffling_epoch +// shuffling_start_shard = state.current_shuffling_start_shard +// +// return get_crosslink_committees( +// state, +// seed, +// shuffling_epoch, +// slot, +// start_shard, +// committees_per_epoch, +// ) +func nextEpochCommitteesAtSlot(state *pb.BeaconState, slot uint64, registryChange bool) ([]*CrosslinkCommittee, error) { + var committeesPerEpoch uint64 + var shufflingEpoch uint64 + var shufflingStartShard uint64 + var seed [32]byte + var err error + + epochsSinceLastUpdate := CurrentEpoch(state) - state.ValidatorRegistryUpdateEpoch + if registryChange { + committeesPerEpoch = NextEpochCommitteeCount(state) + shufflingEpoch = NextEpoch(state) + seed, err = GenerateSeed(state, shufflingEpoch) + if err != nil { + return nil, fmt.Errorf("could not generate seed: %v", err) + } + shufflingStartShard = (state.CurrentShufflingStartShard + CurrentEpochCommitteeCount(state)) % + params.BeaconConfig().ShardCount + } else if epochsSinceLastUpdate > 1 && + mathutil.IsPowerOf2(epochsSinceLastUpdate) { + committeesPerEpoch = NextEpochCommitteeCount(state) + shufflingEpoch = NextEpoch(state) + seed, err = GenerateSeed(state, shufflingEpoch) + if err != nil { + return nil, fmt.Errorf("could not generate seed: %v", err) + } + shufflingStartShard = state.CurrentShufflingStartShard + } else { + committeesPerEpoch = CurrentEpochCommitteeCount(state) + seed = bytesutil.ToBytes32(state.CurrentShufflingSeedHash32) + shufflingEpoch = state.CurrentShufflingEpoch + shufflingStartShard = state.CurrentShufflingStartShard + } + + return crosslinkCommittees( + state, &shufflingInput{ + seed: seed[:], + shufflingEpoch: shufflingEpoch, + slot: slot, + startShard: shufflingStartShard, + committeesPerEpoch: committeesPerEpoch, + }) +} + +// crosslinkCommittees breaks down the shuffled indices into list of crosslink committee structs +// which contains of validator indices and the shard they are assigned to. +// +// Spec pseudocode definition: +// def get_crosslink_committees(state: BeaconState, +// seed: Bytes32, +// shuffling_epoch: Epoch, +// slot: Slot, +// start_shard: Shard, +// committees_per_epoch: int) -> List[Tuple[List[ValidatorIndex], Shard]]: +// offset = slot % SLOTS_PER_EPOCH +// committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH +// slot_start_shard = (shuffling_start_shard + committees_per_slot * offset) % SHARD_COUNT +// +// shuffling = get_shuffling( +// seed, +// state.validator_registry, +// shuffling_epoch, +// ) +// +// return [ +// ( +// shuffling[committees_per_slot * offset + i], +// (slot_start_shard + i) % SHARD_COUNT, +// ) +// for i in range(committees_per_slot) +// ] +func crosslinkCommittees(state *pb.BeaconState, input *shufflingInput) ([]*CrosslinkCommittee, error) { + slotsPerEpoch := params.BeaconConfig().SlotsPerEpoch + offSet := input.slot % slotsPerEpoch + committeesPerSlot := input.committeesPerEpoch / slotsPerEpoch + slotStartShard := (input.startShard + committeesPerSlot*offSet) % + params.BeaconConfig().ShardCount + + shuffledIndices, err := Shuffling( + bytesutil.ToBytes32(input.seed), + state.ValidatorRegistry, + input.shufflingEpoch) + if err != nil { + return nil, err + } + + var crosslinkCommittees []*CrosslinkCommittee + for i := uint64(0); i < committeesPerSlot; i++ { + crosslinkCommittees = append(crosslinkCommittees, &CrosslinkCommittee{ + Committee: shuffledIndices[committeesPerSlot*offSet+i], + Shard: (slotStartShard + i) % params.BeaconConfig().ShardCount, + }) + } + return crosslinkCommittees, nil +} diff --git a/beacon-chain/core/helpers/committee_test.go b/beacon-chain/core/helpers/committee_test.go index edad3ea68c..826085ddb1 100644 --- a/beacon-chain/core/helpers/committee_test.go +++ b/beacon-chain/core/helpers/committee_test.go @@ -201,12 +201,69 @@ func TestCrosslinkCommitteesAtSlot_OK(t *testing.T) { } } +func TestCrosslinkCommitteesAtSlot_RegistryChange(t *testing.T) { + validatorsPerEpoch := params.BeaconConfig().SlotsPerEpoch * params.BeaconConfig().TargetCommitteeSize + committeesPerEpoch := uint64(4) + // Set epoch total validators count to 4 committees per slot. + validators := make([]*pb.Validator, committeesPerEpoch*validatorsPerEpoch) + for i := 0; i < len(validators); i++ { + validators[i] = &pb.Validator{ + ExitEpoch: params.BeaconConfig().FarFutureEpoch, + } + } + + state := &pb.BeaconState{ + ValidatorRegistry: validators, + Slot: params.BeaconConfig().GenesisSlot, + LatestIndexRootHash32S: [][]byte{{'A'}, {'B'}}, + LatestRandaoMixes: [][]byte{{'C'}, {'D'}}, + } + + committees, err := CrosslinkCommitteesAtSlot(state, params.BeaconConfig().GenesisSlot+100, true) + if err != nil { + t.Fatalf("Could not get crosslink committee: %v", err) + } + if len(committees) != int(committeesPerEpoch) { + t.Errorf("Incorrect committee count per slot. Wanted: %d, got: %d", + committeesPerEpoch, len(committees)) + } +} + +func TestCrosslinkCommitteesAtSlot_EpochSinceLastUpdatePow2(t *testing.T) { + validatorsPerEpoch := params.BeaconConfig().SlotsPerEpoch * params.BeaconConfig().TargetCommitteeSize + committeesPerEpoch := uint64(5) + // Set epoch total validators count to 5 committees per slot. + validators := make([]*pb.Validator, committeesPerEpoch*validatorsPerEpoch) + for i := 0; i < len(validators); i++ { + validators[i] = &pb.Validator{ + ExitEpoch: params.BeaconConfig().FarFutureEpoch, + } + } + + state := &pb.BeaconState{ + ValidatorRegistry: validators, + Slot: params.BeaconConfig().GenesisSlot + 128, + LatestIndexRootHash32S: [][]byte{{'A'}, {'B'}, {'C'}}, + LatestRandaoMixes: [][]byte{{'D'}, {'E'}, {'F'}}, + ValidatorRegistryUpdateEpoch: params.BeaconConfig().GenesisEpoch, + } + + committees, err := CrosslinkCommitteesAtSlot(state, params.BeaconConfig().GenesisSlot+192, false) + if err != nil { + t.Fatalf("Could not get crosslink committee: %v", err) + } + if len(committees) != int(committeesPerEpoch) { + t.Errorf("Incorrect committee count per slot. Wanted: %d, got: %d", + committeesPerEpoch, len(committees)) + } +} + func TestCrosslinkCommitteesAtSlot_OutOfBound(t *testing.T) { want := fmt.Sprintf( "input committee epoch %d out of bounds: %d <= epoch <= %d", - params.BeaconConfig().GenesisEpoch, - params.BeaconConfig().GenesisEpoch+1, - params.BeaconConfig().GenesisEpoch+2, + 0, + 1, + 2, ) slot := params.BeaconConfig().GenesisSlot beaconState := &pb.BeaconState{ @@ -225,8 +282,7 @@ func TestCrosslinkCommitteesAtSlot_ShuffleFailed(t *testing.T) { } want := fmt.Sprint( - "could not shuffle epoch validators: " + - "input list exceeded upper bound and reached modulo bias", + "input list exceeded upper bound and reached modulo bias", ) if _, err := CrosslinkCommitteesAtSlot(state, params.BeaconConfig().GenesisSlot+1, false); !strings.Contains(err.Error(), want) {