diff --git a/beacon-chain/core/altair/block.go b/beacon-chain/core/altair/block.go index 76776bb4be..3b96b812d4 100644 --- a/beacon-chain/core/altair/block.go +++ b/beacon-chain/core/altair/block.go @@ -160,7 +160,7 @@ func ApplySyncRewardsPenalties(s state.BeaconStateAltair, votedIndices, didntVot } // SyncRewards returns the proposer reward and the sync participant reward given the total active balance in state. -func SyncRewards(activeBalance uint64) (proposerReward uint64, participantReward uint64, err error) { +func SyncRewards(activeBalance uint64) (proposerReward, participantReward uint64, err error) { cfg := params.BeaconConfig() totalActiveIncrements := activeBalance / cfg.EffectiveBalanceIncrement baseRewardPerInc, err := BaseRewardPerIncrement(activeBalance) diff --git a/beacon-chain/rpc/eth/beacon/BUILD.bazel b/beacon-chain/rpc/eth/beacon/BUILD.bazel index d9621c546d..3ff83b3f8b 100644 --- a/beacon-chain/rpc/eth/beacon/BUILD.bazel +++ b/beacon-chain/rpc/eth/beacon/BUILD.bazel @@ -16,6 +16,7 @@ go_library( visibility = ["//beacon-chain:__subpackages__"], deps = [ "//beacon-chain/blockchain:go_default_library", + "//beacon-chain/core/altair:go_default_library", "//beacon-chain/core/blocks:go_default_library", "//beacon-chain/core/feed:go_default_library", "//beacon-chain/core/feed/block:go_default_library", @@ -67,6 +68,7 @@ go_test( "pool_test.go", "server_test.go", "state_test.go", + "sync_committee_test.go", "validator_test.go", ], embed = [":go_default_library"], @@ -86,6 +88,7 @@ go_test( "//beacon-chain/state/v1:go_default_library", "//proto/eth/service:go_default_library", "//proto/eth/v1:go_default_library", + "//proto/eth/v2:go_default_library", "//proto/migration:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1/block:go_default_library", diff --git a/beacon-chain/rpc/eth/beacon/state.go b/beacon-chain/rpc/eth/beacon/state.go index 47149d92ff..3733a00cb1 100644 --- a/beacon-chain/rpc/eth/beacon/state.go +++ b/beacon-chain/rpc/eth/beacon/state.go @@ -3,7 +3,11 @@ package beacon import ( "bytes" "context" + "strconv" + types "github.com/prysmaticlabs/eth2-types" + corehelpers "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/eth/helpers" "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" "github.com/prysmaticlabs/prysm/beacon-chain/state" ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1" @@ -16,6 +20,11 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +type stateRequest struct { + epoch *types.Epoch + stateId []byte +} + // GetGenesis retrieves details of the chain's genesis which can be used to identify chain. func (bs *Server) GetGenesis(ctx context.Context, _ *emptypb.Empty) (*ethpb.GenesisResponse, error) { ctx, span := trace.StartSpan(ctx, "beaconv1.GetGenesis") @@ -82,12 +91,7 @@ func (bs *Server) GetStateFork(ctx context.Context, req *ethpb.StateRequest) (*e state, err = bs.StateFetcher.State(ctx, req.StateId) if err != nil { - if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { - return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) - } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { - return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) - } - return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) + return nil, helpers.PrepareStateFetchGRPCError(err) } fork := state.Fork() @@ -130,6 +134,31 @@ func (bs *Server) GetFinalityCheckpoints(ctx context.Context, req *ethpb.StateRe }, nil } +func (bs *Server) stateFromRequest(ctx context.Context, req *stateRequest) (state.BeaconState, error) { + if req.epoch != nil { + slot, err := corehelpers.StartSlot(*req.epoch) + if err != nil { + return nil, status.Errorf( + codes.Internal, + "Could not calculate start slot for epoch %d: %v", + *req.epoch, + err, + ) + } + st, err := bs.StateFetcher.State(ctx, []byte(strconv.FormatUint(uint64(slot), 10))) + if err != nil { + return nil, helpers.PrepareStateFetchGRPCError(err) + } + return st, nil + } + var err error + st, err := bs.StateFetcher.State(ctx, req.stateId) + if err != nil { + return nil, helpers.PrepareStateFetchGRPCError(err) + } + return st, nil +} + func checkpoint(sourceCheckpoint *eth.Checkpoint) *ethpb.Checkpoint { if sourceCheckpoint != nil { return ðpb.Checkpoint{ diff --git a/beacon-chain/rpc/eth/beacon/sync_committee.go b/beacon-chain/rpc/eth/beacon/sync_committee.go index 73058500b0..afc2eafc2f 100644 --- a/beacon-chain/rpc/eth/beacon/sync_committee.go +++ b/beacon-chain/rpc/eth/beacon/sync_committee.go @@ -2,15 +2,98 @@ package beacon import ( "context" + "fmt" "github.com/golang/protobuf/ptypes/empty" + types "github.com/prysmaticlabs/eth2-types" + "github.com/prysmaticlabs/prysm/beacon-chain/core/altair" + "github.com/prysmaticlabs/prysm/beacon-chain/state" "github.com/prysmaticlabs/prysm/proto/eth/v2" + ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/bytesutil" + "github.com/prysmaticlabs/prysm/shared/params" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) -func (bs *Server) ListSyncCommittees(ctx context.Context, request *eth.StateSyncCommitteesRequest) (*eth.StateSyncCommitteesResponse, error) { - panic("implement me") +// ListSyncCommittees retrieves the sync committees for the given epoch. +// If the epoch is not passed in, then the sync committees for the epoch of the state will be obtained. +func (bs *Server) ListSyncCommittees(ctx context.Context, req *eth.StateSyncCommitteesRequest) (*eth.StateSyncCommitteesResponse, error) { + st, err := bs.stateFromRequest(ctx, &stateRequest{ + epoch: req.Epoch, + stateId: req.StateId, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "Could not fetch beacon state using request: %v", err) + } + + // Get the current sync committee and sync committee indices from the state. + committeeIndices, committee, err := currentCommitteeIndicesFromState(st) + if err != nil { + return nil, status.Errorf(codes.Internal, "Could not get sync committee indices from state: %v", err) + } + subcommittees, err := extractSyncSubcommittees(st, committee) + if err != nil { + return nil, status.Errorf(codes.Internal, "Could not extract sync subcommittees: %v", err) + } + + return ð.StateSyncCommitteesResponse{ + Data: ð.SyncCommitteeValidators{ + Validators: committeeIndices, + ValidatorAggregates: subcommittees, + }, + }, nil } -func (bs *Server) SubmitSyncCommitteeSignature(ctx context.Context, message *eth.SyncCommitteeMessage) (*empty.Empty, error) { - panic("implement me") +// SubmitSyncCommitteeSignature -- +func (bs *Server) SubmitSyncCommitteeSignature(_ context.Context, _ *eth.SyncCommitteeMessage) (*empty.Empty, error) { + return nil, status.Error(codes.Unimplemented, "Unimplemented") +} + +func currentCommitteeIndicesFromState(st state.BeaconState) ([]types.ValidatorIndex, *ethpb.SyncCommittee, error) { + committee, err := st.CurrentSyncCommittee() + if err != nil { + return nil, nil, fmt.Errorf( + "could not get sync committee: %v", err, + ) + } + + committeeIndices := make([]types.ValidatorIndex, len(committee.Pubkeys)) + for i, key := range committee.Pubkeys { + index, ok := st.ValidatorIndexByPubkey(bytesutil.ToBytes48(key)) + if !ok { + return nil, nil, fmt.Errorf( + "validator index not found for pubkey %#x", + bytesutil.Trunc(key), + ) + } + committeeIndices[i] = index + } + return committeeIndices, committee, nil +} + +func extractSyncSubcommittees(st state.BeaconState, committee *ethpb.SyncCommittee) ([]*eth.SyncSubcommitteeValidators, error) { + subcommitteeCount := params.BeaconConfig().SyncCommitteeSubnetCount + subcommittees := make([]*eth.SyncSubcommitteeValidators, subcommitteeCount) + for i := uint64(0); i < subcommitteeCount; i++ { + pubkeys, err := altair.SyncSubCommitteePubkeys(committee, types.CommitteeIndex(i)) + if err != nil { + return nil, fmt.Errorf( + "failed to get subcommittee pubkeys: %v", err, + ) + } + subcommittee := ð.SyncSubcommitteeValidators{Validators: make([]types.ValidatorIndex, len(pubkeys))} + for j, key := range pubkeys { + index, ok := st.ValidatorIndexByPubkey(bytesutil.ToBytes48(key)) + if !ok { + return nil, fmt.Errorf( + "validator index not found for pubkey %#x", + bytesutil.Trunc(key), + ) + } + subcommittee.Validators[j] = index + } + subcommittees[i] = subcommittee + } + return subcommittees, nil } diff --git a/beacon-chain/rpc/eth/beacon/sync_committee_test.go b/beacon-chain/rpc/eth/beacon/sync_committee_test.go new file mode 100644 index 0000000000..2b8593c118 --- /dev/null +++ b/beacon-chain/rpc/eth/beacon/sync_committee_test.go @@ -0,0 +1,142 @@ +package beacon + +import ( + "context" + "testing" + + types "github.com/prysmaticlabs/eth2-types" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/testutil" + ethpbv2 "github.com/prysmaticlabs/prysm/proto/eth/v2" + ethpbalpha "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" + "github.com/prysmaticlabs/prysm/shared/bytesutil" + "github.com/prysmaticlabs/prysm/shared/params" + sharedtestutil "github.com/prysmaticlabs/prysm/shared/testutil" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" +) + +func Test_currentCommitteeIndicesFromState(t *testing.T) { + st, _ := sharedtestutil.DeterministicGenesisStateAltair(t, params.BeaconConfig().SyncCommitteeSize) + vals := st.Validators() + wantedCommittee := make([][]byte, params.BeaconConfig().SyncCommitteeSize) + wantedIndices := make([]types.ValidatorIndex, len(wantedCommittee)) + for i := 0; i < len(wantedCommittee); i++ { + wantedIndices[i] = types.ValidatorIndex(i) + wantedCommittee[i] = vals[i].PublicKey + } + require.NoError(t, st.SetCurrentSyncCommittee(ðpbalpha.SyncCommittee{ + Pubkeys: wantedCommittee, + AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength), + })) + + t.Run("OK", func(t *testing.T) { + indices, committee, err := currentCommitteeIndicesFromState(st) + require.NoError(t, err) + require.DeepEqual(t, wantedIndices, indices) + require.DeepEqual(t, wantedCommittee, committee.Pubkeys) + }) + t.Run("validator in committee not found in state", func(t *testing.T) { + wantedCommittee[0] = bytesutil.PadTo([]byte("fakepubkey"), 48) + require.NoError(t, st.SetCurrentSyncCommittee(ðpbalpha.SyncCommittee{ + Pubkeys: wantedCommittee, + AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength), + })) + _, _, err := currentCommitteeIndicesFromState(st) + require.ErrorContains(t, "index not found for pubkey", err) + }) +} + +func Test_extractSyncSubcommittees(t *testing.T) { + st, _ := sharedtestutil.DeterministicGenesisStateAltair(t, params.BeaconConfig().SyncCommitteeSize) + vals := st.Validators() + syncCommittee := make([][]byte, params.BeaconConfig().SyncCommitteeSize) + for i := 0; i < len(syncCommittee); i++ { + syncCommittee[i] = vals[i].PublicKey + } + require.NoError(t, st.SetCurrentSyncCommittee(ðpbalpha.SyncCommittee{ + Pubkeys: syncCommittee, + AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength), + })) + + commSize := params.BeaconConfig().SyncCommitteeSize + subCommSize := params.BeaconConfig().SyncCommitteeSize / params.BeaconConfig().SyncCommitteeSubnetCount + wantedSubcommitteeValidators := make([][]types.ValidatorIndex, 0) + + for i := uint64(0); i < commSize; i += subCommSize { + sub := make([]types.ValidatorIndex, 0) + start := i + end := i + subCommSize + if end > commSize { + end = commSize + } + for j := start; j < end; j++ { + sub = append(sub, types.ValidatorIndex(j)) + } + wantedSubcommitteeValidators = append(wantedSubcommitteeValidators, sub) + } + + t.Run("OK", func(t *testing.T) { + committee, err := st.CurrentSyncCommittee() + require.NoError(t, err) + subcommittee, err := extractSyncSubcommittees(st, committee) + require.NoError(t, err) + for i, got := range subcommittee { + want := wantedSubcommitteeValidators[i] + require.DeepEqual(t, want, got.Validators) + } + }) + t.Run("validator in subcommittee not found in state", func(t *testing.T) { + syncCommittee[0] = bytesutil.PadTo([]byte("fakepubkey"), 48) + require.NoError(t, st.SetCurrentSyncCommittee(ðpbalpha.SyncCommittee{ + Pubkeys: syncCommittee, + AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength), + })) + committee, err := st.CurrentSyncCommittee() + require.NoError(t, err) + _, err = extractSyncSubcommittees(st, committee) + require.ErrorContains(t, "index not found for pubkey", err) + }) +} + +func TestListSyncCommittees(t *testing.T) { + ctx := context.Background() + st, _ := sharedtestutil.DeterministicGenesisStateAltair(t, params.BeaconConfig().SyncCommitteeSize) + syncCommittee := make([][]byte, params.BeaconConfig().SyncCommitteeSize) + vals := st.Validators() + for i := 0; i < len(syncCommittee); i++ { + syncCommittee[i] = vals[i].PublicKey + } + require.NoError(t, st.SetCurrentSyncCommittee(ðpbalpha.SyncCommittee{ + Pubkeys: syncCommittee, + AggregatePubkey: bytesutil.PadTo([]byte{}, params.BeaconConfig().BLSPubkeyLength), + })) + stRoot, err := st.HashTreeRoot(ctx) + require.NoError(t, err) + + s := &Server{ + StateFetcher: &testutil.MockFetcher{ + BeaconState: st, + }, + } + req := ðpbv2.StateSyncCommitteesRequest{StateId: stRoot[:]} + resp, err := s.ListSyncCommittees(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp.Data) + committeeVals := resp.Data.Validators + require.NotNil(t, committeeVals) + require.Equal(t, params.BeaconConfig().SyncCommitteeSize, uint64(len(committeeVals)), "incorrect committee size") + for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSize; i++ { + assert.Equal(t, types.ValidatorIndex(i), committeeVals[i]) + } + require.NotNil(t, resp.Data.ValidatorAggregates) + assert.Equal(t, params.BeaconConfig().SyncCommitteeSubnetCount, uint64(len(resp.Data.ValidatorAggregates))) + for i := uint64(0); i < params.BeaconConfig().SyncCommitteeSubnetCount; i++ { + vStartIndex := types.ValidatorIndex(params.BeaconConfig().SyncCommitteeSize / params.BeaconConfig().SyncCommitteeSubnetCount * i) + vEndIndex := types.ValidatorIndex(params.BeaconConfig().SyncCommitteeSize/params.BeaconConfig().SyncCommitteeSubnetCount*(i+1) - 1) + j := 0 + for vIndex := vStartIndex; vIndex <= vEndIndex; vIndex++ { + assert.Equal(t, vIndex, resp.Data.ValidatorAggregates[i].Validators[j]) + j++ + } + } +} diff --git a/beacon-chain/rpc/eth/beacon/validator.go b/beacon-chain/rpc/eth/beacon/validator.go index a04ad00b7d..469ff901f5 100644 --- a/beacon-chain/rpc/eth/beacon/validator.go +++ b/beacon-chain/rpc/eth/beacon/validator.go @@ -6,9 +6,8 @@ import ( "github.com/pkg/errors" types "github.com/prysmaticlabs/eth2-types" - "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" - rpchelpers "github.com/prysmaticlabs/prysm/beacon-chain/rpc/eth/helpers" - "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" + corehelpers "github.com/prysmaticlabs/prysm/beacon-chain/core/helpers" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/eth/helpers" "github.com/prysmaticlabs/prysm/beacon-chain/state" v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1" ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1" @@ -40,12 +39,7 @@ func (e *invalidValidatorIdError) Error() string { func (bs *Server) GetValidator(ctx context.Context, req *ethpb.StateValidatorRequest) (*ethpb.StateValidatorResponse, error) { state, err := bs.StateFetcher.State(ctx, req.StateId) if err != nil { - if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { - return nil, status.Errorf(codes.NotFound, "could not get state: %v", stateNotFoundErr) - } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { - return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) - } - return nil, status.Errorf(codes.Internal, "State not found: %v", err) + return nil, helpers.PrepareStateFetchGRPCError(err) } if len(req.ValidatorId) == 0 { return nil, status.Error(codes.InvalidArgument, "Validator ID is required") @@ -64,12 +58,7 @@ func (bs *Server) GetValidator(ctx context.Context, req *ethpb.StateValidatorReq func (bs *Server) ListValidators(ctx context.Context, req *ethpb.StateValidatorsRequest) (*ethpb.StateValidatorsResponse, error) { state, err := bs.StateFetcher.State(ctx, req.StateId) if err != nil { - if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { - return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) - } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { - return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) - } - return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) + return nil, helpers.PrepareStateFetchGRPCError(err) } valContainers, err := valContainersByRequestIds(state, req.Id) @@ -90,18 +79,18 @@ func (bs *Server) ListValidators(ctx context.Context, req *ethpb.StateValidators } filterStatus[ss] = true } - epoch := helpers.SlotToEpoch(state.Slot()) + epoch := corehelpers.SlotToEpoch(state.Slot()) filteredVals := make([]*ethpb.ValidatorContainer, 0, len(valContainers)) for _, vc := range valContainers { readOnlyVal, err := v1.NewValidator(migration.V1ValidatorToV1Alpha1(vc.Validator)) if err != nil { return nil, status.Errorf(codes.Internal, "Could not convert validator: %v", err) } - valStatus, err := rpchelpers.ValidatorStatus(readOnlyVal, epoch) + valStatus, err := helpers.ValidatorStatus(readOnlyVal, epoch) if err != nil { return nil, status.Errorf(codes.Internal, "Could not get validator status: %v", err) } - valSubStatus, err := rpchelpers.ValidatorSubStatus(readOnlyVal, epoch) + valSubStatus, err := helpers.ValidatorSubStatus(readOnlyVal, epoch) if err != nil { return nil, status.Errorf(codes.Internal, "Could not get validator sub status: %v", err) } @@ -116,12 +105,7 @@ func (bs *Server) ListValidators(ctx context.Context, req *ethpb.StateValidators func (bs *Server) ListValidatorBalances(ctx context.Context, req *ethpb.ValidatorBalancesRequest) (*ethpb.ValidatorBalancesResponse, error) { state, err := bs.StateFetcher.State(ctx, req.StateId) if err != nil { - if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { - return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) - } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { - return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) - } - return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) + return nil, helpers.PrepareStateFetchGRPCError(err) } valContainers, err := valContainersByRequestIds(state, req.Id) @@ -143,32 +127,27 @@ func (bs *Server) ListValidatorBalances(ctx context.Context, req *ethpb.Validato func (bs *Server) ListCommittees(ctx context.Context, req *ethpb.StateCommitteesRequest) (*ethpb.StateCommitteesResponse, error) { state, err := bs.StateFetcher.State(ctx, req.StateId) if err != nil { - if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { - return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) - } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { - return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) - } - return nil, status.Errorf(codes.Internal, "Could not get state: %v", err) + return nil, helpers.PrepareStateFetchGRPCError(err) } - epoch := helpers.SlotToEpoch(state.Slot()) + epoch := corehelpers.SlotToEpoch(state.Slot()) if req.Epoch != nil { epoch = *req.Epoch } - activeCount, err := helpers.ActiveValidatorCount(state, epoch) + activeCount, err := corehelpers.ActiveValidatorCount(state, epoch) if err != nil { return nil, status.Errorf(codes.Internal, "Could not get active validator count: %v", err) } - startSlot, err := helpers.StartSlot(epoch) + startSlot, err := corehelpers.StartSlot(epoch) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "Invalid epoch: %v", err) } - endSlot, err := helpers.EndSlot(epoch) + endSlot, err := corehelpers.EndSlot(epoch) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "Invalid epoch: %v", err) } - committeesPerSlot := helpers.SlotCommitteeCount(activeCount) + committeesPerSlot := corehelpers.SlotCommitteeCount(activeCount) committees := make([]*ethpb.Committee, 0) for slot := startSlot; slot <= endSlot; slot++ { if req.Slot != nil && slot != *req.Slot { @@ -178,7 +157,7 @@ func (bs *Server) ListCommittees(ctx context.Context, req *ethpb.StateCommittees if req.Index != nil && index != *req.Index { continue } - committee, err := helpers.BeaconCommitteeFromState(state, slot, index) + committee, err := corehelpers.BeaconCommitteeFromState(state, slot, index) if err != nil { return nil, status.Errorf(codes.Internal, "Could not get committee: %v", err) } @@ -196,7 +175,7 @@ func (bs *Server) ListCommittees(ctx context.Context, req *ethpb.StateCommittees // This function returns the validator object based on the passed in ID. The validator ID could be its public key, // or its index. func valContainersByRequestIds(state state.BeaconState, validatorIds [][]byte) ([]*ethpb.ValidatorContainer, error) { - epoch := helpers.SlotToEpoch(state.Slot()) + epoch := corehelpers.SlotToEpoch(state.Slot()) var valContainers []*ethpb.ValidatorContainer if len(validatorIds) == 0 { allValidators := state.Validators() @@ -207,7 +186,7 @@ func valContainersByRequestIds(state state.BeaconState, validatorIds [][]byte) ( if err != nil { return nil, status.Errorf(codes.Internal, "Could not convert validator: %v", err) } - subStatus, err := rpchelpers.ValidatorSubStatus(readOnlyVal, epoch) + subStatus, err := helpers.ValidatorSubStatus(readOnlyVal, epoch) if err != nil { return nil, errors.Wrap(err, "could not get validator sub status") } @@ -250,7 +229,7 @@ func valContainersByRequestIds(state state.BeaconState, validatorIds [][]byte) ( if err != nil { return nil, status.Errorf(codes.Internal, "Could not convert validator: %v", err) } - subStatus, err := rpchelpers.ValidatorSubStatus(readOnlyVal, epoch) + subStatus, err := helpers.ValidatorSubStatus(readOnlyVal, epoch) if err != nil { return nil, errors.Wrap(err, "could not get validator sub status") } diff --git a/beacon-chain/rpc/eth/debug/BUILD.bazel b/beacon-chain/rpc/eth/debug/BUILD.bazel index a4d4583827..b1149813a3 100644 --- a/beacon-chain/rpc/eth/debug/BUILD.bazel +++ b/beacon-chain/rpc/eth/debug/BUILD.bazel @@ -11,6 +11,7 @@ go_library( deps = [ "//beacon-chain/blockchain:go_default_library", "//beacon-chain/db:go_default_library", + "//beacon-chain/rpc/eth/helpers:go_default_library", "//beacon-chain/rpc/statefetcher:go_default_library", "//proto/eth/v1:go_default_library", "//proto/eth/v2:go_default_library", diff --git a/beacon-chain/rpc/eth/debug/debug.go b/beacon-chain/rpc/eth/debug/debug.go index 4f22ef345c..908b2ef8c6 100644 --- a/beacon-chain/rpc/eth/debug/debug.go +++ b/beacon-chain/rpc/eth/debug/debug.go @@ -3,7 +3,7 @@ package debug import ( "context" - "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/eth/helpers" ethpbv1 "github.com/prysmaticlabs/prysm/proto/eth/v1" ethpbv2 "github.com/prysmaticlabs/prysm/proto/eth/v2" "go.opencensus.io/trace" @@ -19,12 +19,7 @@ func (ds *Server) GetBeaconState(ctx context.Context, req *ethpbv1.StateRequest) state, err := ds.StateFetcher.State(ctx, req.StateId) if err != nil { - if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { - return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) - } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { - return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) - } - return nil, status.Errorf(codes.Internal, "Invalid state ID: %v", err) + return nil, helpers.PrepareStateFetchGRPCError(err) } protoState, err := state.ToProto() @@ -44,12 +39,7 @@ func (ds *Server) GetBeaconStateSSZ(ctx context.Context, req *ethpbv1.StateReque state, err := ds.StateFetcher.State(ctx, req.StateId) if err != nil { - if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { - return nil, status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) - } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { - return nil, status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) - } - return nil, status.Errorf(codes.Internal, "Invalid state ID: %v", err) + return nil, helpers.PrepareStateFetchGRPCError(err) } sszState, err := state.MarshalSSZ() diff --git a/beacon-chain/rpc/eth/helpers/BUILD.bazel b/beacon-chain/rpc/eth/helpers/BUILD.bazel index e8f41a845a..95fbeb4170 100644 --- a/beacon-chain/rpc/eth/helpers/BUILD.bazel +++ b/beacon-chain/rpc/eth/helpers/BUILD.bazel @@ -2,15 +2,21 @@ load("@prysm//tools/go:def.bzl", "go_library", "go_test") go_library( name = "go_default_library", - srcs = ["validator_status.go"], + srcs = [ + "state.go", + "validator_status.go", + ], importpath = "github.com/prysmaticlabs/prysm/beacon-chain/rpc/eth/helpers", visibility = ["//beacon-chain:__subpackages__"], deps = [ + "//beacon-chain/rpc/statefetcher:go_default_library", "//beacon-chain/state:go_default_library", "//proto/eth/v1:go_default_library", "//shared/params:go_default_library", "@com_github_pkg_errors//:go_default_library", "@com_github_prysmaticlabs_eth2_types//:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//status:go_default_library", ], ) diff --git a/beacon-chain/rpc/eth/helpers/state.go b/beacon-chain/rpc/eth/helpers/state.go new file mode 100644 index 0000000000..93cfd6cb76 --- /dev/null +++ b/beacon-chain/rpc/eth/helpers/state.go @@ -0,0 +1,18 @@ +package helpers + +import ( + "github.com/prysmaticlabs/prysm/beacon-chain/rpc/statefetcher" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// PrepareStateFetchGRPCError returns an appropriate gRPC error based on the supplied argument. +// The argument error should be a result of fetching state. +func PrepareStateFetchGRPCError(err error) error { + if stateNotFoundErr, ok := err.(*statefetcher.StateNotFoundError); ok { + return status.Errorf(codes.NotFound, "State not found: %v", stateNotFoundErr) + } else if parseErr, ok := err.(*statefetcher.StateIdParseError); ok { + return status.Errorf(codes.InvalidArgument, "Invalid state ID: %v", parseErr) + } + return status.Errorf(codes.Internal, "Invalid state ID: %v", err) +}