diff --git a/beacon-chain/core/state/interop/write_state_to_disk.go b/beacon-chain/core/state/interop/write_state_to_disk.go index 8a7c7379f5..8030832a75 100644 --- a/beacon-chain/core/state/interop/write_state_to_disk.go +++ b/beacon-chain/core/state/interop/write_state_to_disk.go @@ -17,7 +17,7 @@ func WriteStateToDisk(state iface.ReadOnlyBeaconState) { } fp := path.Join(os.TempDir(), fmt.Sprintf("beacon_state_%d.ssz", state.Slot())) log.Warnf("Writing state to disk at %s", fp) - enc, err := state.InnerStateUnsafe().MarshalSSZ() + enc, err := state.MarshalSSZ() if err != nil { log.WithError(err).Error("Failed to ssz encode state") return diff --git a/beacon-chain/rpc/debug/state.go b/beacon-chain/rpc/debug/state.go index 3428455595..21b4cdb476 100644 --- a/beacon-chain/rpc/debug/state.go +++ b/beacon-chain/rpc/debug/state.go @@ -32,7 +32,7 @@ func (ds *Server) GetBeaconState( if err != nil { return nil, status.Errorf(codes.Internal, "Could not compute state by slot: %v", err) } - encoded, err := st.CloneInnerState().MarshalSSZ() + encoded, err := st.MarshalSSZ() if err != nil { return nil, status.Errorf(codes.Internal, "Could not ssz encode beacon state: %v", err) } @@ -44,7 +44,7 @@ func (ds *Server) GetBeaconState( if err != nil { return nil, status.Errorf(codes.Internal, "Could not compute state by block root: %v", err) } - encoded, err := st.CloneInnerState().MarshalSSZ() + encoded, err := st.MarshalSSZ() if err != nil { return nil, status.Errorf(codes.Internal, "Could not ssz encode beacon state: %v", err) } diff --git a/beacon-chain/rpc/debug/state_test.go b/beacon-chain/rpc/debug/state_test.go index 3ee1e6e6c8..ad40693daf 100644 --- a/beacon-chain/rpc/debug/state_test.go +++ b/beacon-chain/rpc/debug/state_test.go @@ -15,7 +15,6 @@ import ( ) func TestServer_GetBeaconState(t *testing.T) { - db := dbTest.SetupDB(t) ctx := context.Background() st, err := testutil.NewBeaconState() @@ -43,7 +42,7 @@ func TestServer_GetBeaconState(t *testing.T) { } res, err := bs.GetBeaconState(ctx, req) require.NoError(t, err) - wanted, err := st.CloneInnerState().MarshalSSZ() + wanted, err := st.MarshalSSZ() require.NoError(t, err) assert.DeepEqual(t, wanted, res.Encoded) req = &pbrpc.BeaconStateRequest{ @@ -57,7 +56,6 @@ func TestServer_GetBeaconState(t *testing.T) { } func TestServer_GetBeaconState_RequestFutureSlot(t *testing.T) { - ds := &Server{GenesisTimeFetcher: &mock.ChainService{}} req := &pbrpc.BeaconStateRequest{ QueryFilter: &pbrpc.BeaconStateRequest_Slot{ diff --git a/beacon-chain/state/getters.go b/beacon-chain/state/getters.go index 91a72e4a0e..f631dd9d6d 100644 --- a/beacon-chain/state/getters.go +++ b/beacon-chain/state/getters.go @@ -1027,3 +1027,11 @@ func (b *BeaconState) safeCopyCheckpoint(input *ethpb.Checkpoint) *ethpb.Checkpo return CopyCheckpoint(input) } + +// MarshalSSZ marshals the underlying beacon state to bytes. +func (b *BeaconState) MarshalSSZ() ([]byte, error) { + if !b.hasInnerState() { + return nil, errors.New("nil beacon state") + } + return b.state.MarshalSSZ() +} diff --git a/beacon-chain/state/getters_test.go b/beacon-chain/state/getters_test.go index 4149f1c79e..5e15e5ac26 100644 --- a/beacon-chain/state/getters_test.go +++ b/beacon-chain/state/getters_test.go @@ -113,3 +113,11 @@ func TestBeaconState_MatchPreviousJustifiedCheckpt(t *testing.T) { beaconState.state = nil require.Equal(t, false, beaconState.MatchPreviousJustifiedCheckpoint(c1)) } + +func TestBeaconState_MarshalSSZ_NilState(t *testing.T) { + s, err := InitializeFromProto(&pb.BeaconState{}) + require.NoError(t, err) + s.state = nil + _, err = s.MarshalSSZ() + require.ErrorContains(t, "nil beacon state", err) +} diff --git a/beacon-chain/state/interface/interface.go b/beacon-chain/state/interface/interface.go index 62cf52425e..808cb11e1f 100644 --- a/beacon-chain/state/interface/interface.go +++ b/beacon-chain/state/interface/interface.go @@ -40,6 +40,7 @@ type ReadOnlyBeaconState interface { HistoricalRoots() [][]byte Slashings() []uint64 FieldReferencesCount() map[string]uint64 + MarshalSSZ() ([]byte, error) } // WriteOnlyBeaconState defines a struct which only has write access to beacon state methods. diff --git a/shared/testutil/state_test.go b/shared/testutil/state_test.go index dc8f4738d4..892d82e798 100644 --- a/shared/testutil/state_test.go +++ b/shared/testutil/state_test.go @@ -12,7 +12,7 @@ import ( func TestNewBeaconState(t *testing.T) { st, err := NewBeaconState() require.NoError(t, err) - b, err := st.InnerStateUnsafe().MarshalSSZ() + b, err := st.MarshalSSZ() require.NoError(t, err) got := &pb.BeaconState{} require.NoError(t, got.UnmarshalSSZ(b)) diff --git a/tools/benchmark-files-gen/main.go b/tools/benchmark-files-gen/main.go index ecb5b54887..25af5dc73c 100644 --- a/tools/benchmark-files-gen/main.go +++ b/tools/benchmark-files-gen/main.go @@ -148,7 +148,7 @@ func generateMarshalledFullStateAndBlock() error { return err } - beaconBytes, err := beaconState.InnerStateUnsafe().MarshalSSZ() + beaconBytes, err := beaconState.MarshalSSZ() if err != nil { return err } @@ -197,7 +197,7 @@ func generate2FullEpochState() error { } } - beaconBytes, err := beaconState.InnerStateUnsafe().MarshalSSZ() + beaconBytes, err := beaconState.MarshalSSZ() if err != nil { return err } diff --git a/tools/interop/export-genesis/main.go b/tools/interop/export-genesis/main.go index 42764fb4cb..88ada3ae60 100644 --- a/tools/interop/export-genesis/main.go +++ b/tools/interop/export-genesis/main.go @@ -37,7 +37,7 @@ func main() { if gs == nil { panic("nil genesis state") } - b, err := gs.InnerStateUnsafe().MarshalSSZ() + b, err := gs.MarshalSSZ() if err != nil { panic(err) }