Refactor Crosslink Committees at Slot (#1771)

This commit is contained in:
terence tsao
2019-03-02 19:14:04 -08:00
committed by GitHub
parent 6a1addbd1a
commit 94e6cfe478
2 changed files with 273 additions and 123 deletions

View File

@@ -21,6 +21,14 @@ type CrosslinkCommittee struct {
Shard uint64
}
type shufflingInput struct {
seed []byte
shufflingEpoch uint64
slot uint64
startShard uint64
committeesPerEpoch uint64
}
// EpochCommitteeCount returns the number of crosslink committees of an epoch.
//
// Spec pseudocode definition:
@@ -117,9 +125,11 @@ func NextEpochCommitteeCount(state *pb.BeaconState) uint64 {
// CrosslinkCommitteesAtSlot returns the list of crosslink committees, it
// contains the shard associated with the committee and the validator indices
// in that committee.
//
// Spec pseudocode definition:
// def get_crosslink_committees_at_slot(state: BeaconState,
// slot: SlotNumber,
// registry_change=False: bool) -> List[Tuple[List[ValidatorIndex], Shard]]:
// slot: Slot,
// registry_change: bool=False) -> List[Tuple[List[ValidatorIndex], Shard]]:
// """
// Return the list of ``(committee, shard)`` tuples for the ``slot``.
//
@@ -128,143 +138,42 @@ func NextEpochCommitteeCount(state *pb.BeaconState) uint64 {
// """
// epoch = slot_to_epoch(slot)
// current_epoch = get_current_epoch(state)
// previous_epoch = current_epoch - 1 if current_epoch > GENESIS_EPOCH else current_epoch
// previous_epoch = get_previous_epoch(state)
// next_epoch = current_epoch + 1
//
// assert previous_epoch <= epoch <= next_epoch
//
// if epoch == current_epoch:
// committees_per_epoch = get_current_epoch_committee_count(state)
// seed = state.current_shuffling_seed
// shuffling_epoch = state.current_calculation_epoch
// shuffling_start_shard = state.current_epoch_start_shard
// return get_current_epoch_committees_at_slot(state, slot)
// elif epoch == previous_epoch:
// committees_per_epoch = get_previous_epoch_committee_count(state)
// seed = state.previous_shuffling_seed
// shuffling_epoch = state.previous_shuffling_epoch
// shuffling_start_shard = state.previous_shuffling_start_shard
// return get_previous_epoch_committees_at_slot(state, slot)
// elif epoch == next_epoch:
//
// epochs_since_last_registry_update = current_epoch - state.validator_registry_update_epoch
// if registry_change:
// committees_per_epoch = get_next_epoch_committee_count(state)
// shuffling_epoch = next_epoch
// seed = generate_seed(state, next_epoch)
// current_committees_per_epoch = get_current_epoch_committee_count(state)
// shuffling_start_shard = (state.current_epoch_start_shard + current_committees_per_epoch) % SHARD_COUNT
// elif epochs_since_last_registry_update > 1 and is_power_of_two(epochs_since_last_registry_update):
// committees_per_epoch = get_next_epoch_committee_count(state)
// shuffling_epoch = next_epoch
// seed = generate_seed(state, next_epoch)
// shuffling_start_shard = state.current_epoch_start_shard
// else:
// committees_per_epoch = get_current_epoch_committee_count(state)
// shuffling_epoch = state.current_shuffling_epoch
// seed = state.current_epoch_seed
// shuffling_start_shard = state.current_epoch_start_shard
//
// shuffling = get_shuffling(
// seed,
// state.validator_registry,
// shuffling_epoch,
// )
// offset = slot % SLOTS_PER_EPOCH
// committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH
// slot_start_shard = (shuffling_start_shard + committees_per_slot * offset) % SHARD_COUNT
//
// return [
// (
// shuffling[committees_per_slot * offset + i],
// (slot_start_shard + i) % SHARD_COUNT,
// )
// for i in range(committees_per_slot)
// ]
// return get_next_epoch_committee_count(state, slot, registry_change)
func CrosslinkCommitteesAtSlot(
state *pb.BeaconState,
slot uint64,
registryChange bool) ([]*CrosslinkCommittee, error) {
var committeesPerEpoch uint64
var shufflingEpoch uint64
var shufflingStartShard uint64
var seed [32]byte
var err error
wantedEpoch := SlotToEpoch(slot)
currentEpoch := CurrentEpoch(state)
prevEpoch := PrevEpoch(state)
nextEpoch := NextEpoch(state)
if wantedEpoch < prevEpoch || wantedEpoch > nextEpoch {
switch wantedEpoch {
case currentEpoch:
return currEpochCommitteesAtSlot(state, slot)
case prevEpoch:
return prevEpochCommitteesAtSlot(state, slot)
case nextEpoch:
return nextEpochCommitteesAtSlot(state, slot, registryChange)
default:
return nil, fmt.Errorf(
"input committee epoch %d out of bounds: %d <= epoch <= %d",
wantedEpoch,
prevEpoch,
currentEpoch,
wantedEpoch-params.BeaconConfig().GenesisEpoch,
prevEpoch-params.BeaconConfig().GenesisEpoch,
currentEpoch-params.BeaconConfig().GenesisEpoch,
)
}
if wantedEpoch == currentEpoch {
committeesPerEpoch = CurrentEpochCommitteeCount(state)
seed = bytesutil.ToBytes32(state.CurrentShufflingSeedHash32)
shufflingEpoch = state.CurrentShufflingEpoch
shufflingStartShard = state.CurrentShufflingStartShard
} else if wantedEpoch == prevEpoch {
committeesPerEpoch = PrevEpochCommitteeCount(state)
seed = bytesutil.ToBytes32(state.PreviousShufflingSeedHash32)
shufflingEpoch = state.PreviousShufflingEpoch
shufflingStartShard = state.PreviousShufflingStartShard
} else if wantedEpoch == nextEpoch {
epochsSinceLastRegistryUpdate := currentEpoch - state.ValidatorRegistryUpdateEpoch
if registryChange {
committeesPerEpoch = NextEpochCommitteeCount(state)
shufflingEpoch = nextEpoch
seed, err = GenerateSeed(state, nextEpoch)
currentCommitteesPerEpoch := CurrentEpochCommitteeCount(state)
if err != nil {
return nil, fmt.Errorf("could not generate seed: %v", err)
}
shufflingStartShard = (state.CurrentShufflingStartShard + currentCommitteesPerEpoch) %
params.BeaconConfig().ShardCount
} else if epochsSinceLastRegistryUpdate > 1 &&
mathutil.IsPowerOf2(epochsSinceLastRegistryUpdate) {
committeesPerEpoch = NextEpochCommitteeCount(state)
shufflingEpoch = nextEpoch
seed, err = GenerateSeed(state, nextEpoch)
if err != nil {
return nil, fmt.Errorf("could not generate seed: %v", err)
}
shufflingStartShard = state.CurrentShufflingStartShard
} else {
committeesPerEpoch = CurrentEpochCommitteeCount(state)
shufflingEpoch = state.CurrentShufflingEpoch
seed = bytesutil.ToBytes32(state.CurrentShufflingSeedHash32)
shufflingStartShard = state.CurrentShufflingStartShard
}
}
shuffledIndices, err := Shuffling(
seed,
state.ValidatorRegistry,
shufflingEpoch)
if err != nil {
return nil, fmt.Errorf("could not shuffle epoch validators: %v", err)
}
offSet := slot % params.BeaconConfig().SlotsPerEpoch
committeesPerSlot := committeesPerEpoch / params.BeaconConfig().SlotsPerEpoch
slotStardShard := (shufflingStartShard + committeesPerSlot*offSet) %
params.BeaconConfig().ShardCount
var crosslinkCommittees []*CrosslinkCommittee
for i := uint64(0); i < committeesPerSlot; i++ {
crosslinkCommittees = append(crosslinkCommittees, &CrosslinkCommittee{
Committee: shuffledIndices[committeesPerSlot*offSet+i],
Shard: (slotStardShard + i) % params.BeaconConfig().ShardCount,
})
}
return crosslinkCommittees, nil
}
// Shuffling shuffles input validator indices and splits them by slot and shard.
@@ -511,3 +420,188 @@ func CommitteeAssignment(
}
return []uint64{}, 0, 0, false, fmt.Errorf("could not get assignment validator %d", validatorIndex)
}
// prevEpochCommitteesAtSlot returns a list of crosslink committees of the previous epoch.
//
// Spec pseudocode definition:
// def get_previous_epoch_committees_at_slot(state: BeaconState,
// slot: Slot) -> List[Tuple[List[ValidatorIndex], Shard]]:
// committees_per_epoch = get_previous_epoch_committee_count(state)
// seed = state.previous_shuffling_seed
// shuffling_epoch = state.previous_shuffling_epoch
// shuffling_start_shard = state.previous_shuffling_start_shard
// return get_crosslink_committees(
// state,
// seed,
// shuffling_epoch,
// slot,
// start_shard,
// committees_per_epoch,
// )
func prevEpochCommitteesAtSlot(state *pb.BeaconState, slot uint64) ([]*CrosslinkCommittee, error) {
committeesPerEpoch := PrevEpochCommitteeCount(state)
return crosslinkCommittees(
state, &shufflingInput{
seed: state.PreviousShufflingSeedHash32,
shufflingEpoch: state.PreviousShufflingEpoch,
slot: slot,
startShard: state.PreviousShufflingStartShard,
committeesPerEpoch: committeesPerEpoch,
})
}
// currEpochCommitteesAtSlot returns a list of crosslink committees of the current epoch.
//
// Spec pseudocode definition:
// def get_current_epoch_committees_at_slot(state: BeaconState,
// slot: Slot) -> List[Tuple[List[ValidatorIndex], Shard]]:
// committees_per_epoch = get_current_epoch_committee_count(state)
// seed = state.current_shuffling_seed
// shuffling_epoch = state.current_shuffling_epoch
// shuffling_start_shard = state.current_shuffling_start_shard
// return get_crosslink_committees(
// state,
// seed,
// shuffling_epoch,
// slot,
// start_shard,
// committees_per_epoch,
// )
func currEpochCommitteesAtSlot(state *pb.BeaconState, slot uint64) ([]*CrosslinkCommittee, error) {
committeesPerEpoch := CurrentEpochCommitteeCount(state)
return crosslinkCommittees(
state, &shufflingInput{
seed: state.CurrentShufflingSeedHash32,
shufflingEpoch: state.CurrentShufflingEpoch,
slot: slot,
startShard: state.CurrentShufflingStartShard,
committeesPerEpoch: committeesPerEpoch,
})
}
// nextEpochCommitteesAtSlot returns a list of crosslink committees of the next epoch.
//
// Spec pseudocode definition:
// def get_next_epoch_committees_at_slot(state: BeaconState,
// slot: Slot,
// registry_change: bool) -> List[Tuple[List[ValidatorIndex], Shard]]:
// epochs_since_last_registry_update = current_epoch - state.validator_registry_update_epoch
// if registry_change:
// committees_per_epoch = get_next_epoch_committee_count(state)
// seed = generate_seed(state, next_epoch)
// shuffling_epoch = next_epoch
// current_committees_per_epoch = get_current_epoch_committee_count(state)
// shuffling_start_shard = (state.current_shuffling_start_shard + current_committees_per_epoch) % SHARD_COUNT
// elif epochs_since_last_registry_update > 1 and is_power_of_two(epochs_since_last_registry_update):
// committees_per_epoch = get_next_epoch_committee_count(state)
// seed = generate_seed(state, next_epoch)
// shuffling_epoch = next_epoch
// shuffling_start_shard = state.current_shuffling_start_shard
// else:
// committees_per_epoch = get_current_epoch_committee_count(state)
// seed = state.current_shuffling_seed
// shuffling_epoch = state.current_shuffling_epoch
// shuffling_start_shard = state.current_shuffling_start_shard
//
// return get_crosslink_committees(
// state,
// seed,
// shuffling_epoch,
// slot,
// start_shard,
// committees_per_epoch,
// )
func nextEpochCommitteesAtSlot(state *pb.BeaconState, slot uint64, registryChange bool) ([]*CrosslinkCommittee, error) {
var committeesPerEpoch uint64
var shufflingEpoch uint64
var shufflingStartShard uint64
var seed [32]byte
var err error
epochsSinceLastUpdate := CurrentEpoch(state) - state.ValidatorRegistryUpdateEpoch
if registryChange {
committeesPerEpoch = NextEpochCommitteeCount(state)
shufflingEpoch = NextEpoch(state)
seed, err = GenerateSeed(state, shufflingEpoch)
if err != nil {
return nil, fmt.Errorf("could not generate seed: %v", err)
}
shufflingStartShard = (state.CurrentShufflingStartShard + CurrentEpochCommitteeCount(state)) %
params.BeaconConfig().ShardCount
} else if epochsSinceLastUpdate > 1 &&
mathutil.IsPowerOf2(epochsSinceLastUpdate) {
committeesPerEpoch = NextEpochCommitteeCount(state)
shufflingEpoch = NextEpoch(state)
seed, err = GenerateSeed(state, shufflingEpoch)
if err != nil {
return nil, fmt.Errorf("could not generate seed: %v", err)
}
shufflingStartShard = state.CurrentShufflingStartShard
} else {
committeesPerEpoch = CurrentEpochCommitteeCount(state)
seed = bytesutil.ToBytes32(state.CurrentShufflingSeedHash32)
shufflingEpoch = state.CurrentShufflingEpoch
shufflingStartShard = state.CurrentShufflingStartShard
}
return crosslinkCommittees(
state, &shufflingInput{
seed: seed[:],
shufflingEpoch: shufflingEpoch,
slot: slot,
startShard: shufflingStartShard,
committeesPerEpoch: committeesPerEpoch,
})
}
// crosslinkCommittees breaks down the shuffled indices into list of crosslink committee structs
// which contains of validator indices and the shard they are assigned to.
//
// Spec pseudocode definition:
// def get_crosslink_committees(state: BeaconState,
// seed: Bytes32,
// shuffling_epoch: Epoch,
// slot: Slot,
// start_shard: Shard,
// committees_per_epoch: int) -> List[Tuple[List[ValidatorIndex], Shard]]:
// offset = slot % SLOTS_PER_EPOCH
// committees_per_slot = committees_per_epoch // SLOTS_PER_EPOCH
// slot_start_shard = (shuffling_start_shard + committees_per_slot * offset) % SHARD_COUNT
//
// shuffling = get_shuffling(
// seed,
// state.validator_registry,
// shuffling_epoch,
// )
//
// return [
// (
// shuffling[committees_per_slot * offset + i],
// (slot_start_shard + i) % SHARD_COUNT,
// )
// for i in range(committees_per_slot)
// ]
func crosslinkCommittees(state *pb.BeaconState, input *shufflingInput) ([]*CrosslinkCommittee, error) {
slotsPerEpoch := params.BeaconConfig().SlotsPerEpoch
offSet := input.slot % slotsPerEpoch
committeesPerSlot := input.committeesPerEpoch / slotsPerEpoch
slotStartShard := (input.startShard + committeesPerSlot*offSet) %
params.BeaconConfig().ShardCount
shuffledIndices, err := Shuffling(
bytesutil.ToBytes32(input.seed),
state.ValidatorRegistry,
input.shufflingEpoch)
if err != nil {
return nil, err
}
var crosslinkCommittees []*CrosslinkCommittee
for i := uint64(0); i < committeesPerSlot; i++ {
crosslinkCommittees = append(crosslinkCommittees, &CrosslinkCommittee{
Committee: shuffledIndices[committeesPerSlot*offSet+i],
Shard: (slotStartShard + i) % params.BeaconConfig().ShardCount,
})
}
return crosslinkCommittees, nil
}

View File

@@ -201,12 +201,69 @@ func TestCrosslinkCommitteesAtSlot_OK(t *testing.T) {
}
}
func TestCrosslinkCommitteesAtSlot_RegistryChange(t *testing.T) {
validatorsPerEpoch := params.BeaconConfig().SlotsPerEpoch * params.BeaconConfig().TargetCommitteeSize
committeesPerEpoch := uint64(4)
// Set epoch total validators count to 4 committees per slot.
validators := make([]*pb.Validator, committeesPerEpoch*validatorsPerEpoch)
for i := 0; i < len(validators); i++ {
validators[i] = &pb.Validator{
ExitEpoch: params.BeaconConfig().FarFutureEpoch,
}
}
state := &pb.BeaconState{
ValidatorRegistry: validators,
Slot: params.BeaconConfig().GenesisSlot,
LatestIndexRootHash32S: [][]byte{{'A'}, {'B'}},
LatestRandaoMixes: [][]byte{{'C'}, {'D'}},
}
committees, err := CrosslinkCommitteesAtSlot(state, params.BeaconConfig().GenesisSlot+100, true)
if err != nil {
t.Fatalf("Could not get crosslink committee: %v", err)
}
if len(committees) != int(committeesPerEpoch) {
t.Errorf("Incorrect committee count per slot. Wanted: %d, got: %d",
committeesPerEpoch, len(committees))
}
}
func TestCrosslinkCommitteesAtSlot_EpochSinceLastUpdatePow2(t *testing.T) {
validatorsPerEpoch := params.BeaconConfig().SlotsPerEpoch * params.BeaconConfig().TargetCommitteeSize
committeesPerEpoch := uint64(5)
// Set epoch total validators count to 5 committees per slot.
validators := make([]*pb.Validator, committeesPerEpoch*validatorsPerEpoch)
for i := 0; i < len(validators); i++ {
validators[i] = &pb.Validator{
ExitEpoch: params.BeaconConfig().FarFutureEpoch,
}
}
state := &pb.BeaconState{
ValidatorRegistry: validators,
Slot: params.BeaconConfig().GenesisSlot + 128,
LatestIndexRootHash32S: [][]byte{{'A'}, {'B'}, {'C'}},
LatestRandaoMixes: [][]byte{{'D'}, {'E'}, {'F'}},
ValidatorRegistryUpdateEpoch: params.BeaconConfig().GenesisEpoch,
}
committees, err := CrosslinkCommitteesAtSlot(state, params.BeaconConfig().GenesisSlot+192, false)
if err != nil {
t.Fatalf("Could not get crosslink committee: %v", err)
}
if len(committees) != int(committeesPerEpoch) {
t.Errorf("Incorrect committee count per slot. Wanted: %d, got: %d",
committeesPerEpoch, len(committees))
}
}
func TestCrosslinkCommitteesAtSlot_OutOfBound(t *testing.T) {
want := fmt.Sprintf(
"input committee epoch %d out of bounds: %d <= epoch <= %d",
params.BeaconConfig().GenesisEpoch,
params.BeaconConfig().GenesisEpoch+1,
params.BeaconConfig().GenesisEpoch+2,
0,
1,
2,
)
slot := params.BeaconConfig().GenesisSlot
beaconState := &pb.BeaconState{
@@ -225,8 +282,7 @@ func TestCrosslinkCommitteesAtSlot_ShuffleFailed(t *testing.T) {
}
want := fmt.Sprint(
"could not shuffle epoch validators: " +
"input list exceeded upper bound and reached modulo bias",
"input list exceeded upper bound and reached modulo bias",
)
if _, err := CrosslinkCommitteesAtSlot(state, params.BeaconConfig().GenesisSlot+1, false); !strings.Contains(err.Error(), want) {