diff --git a/encoding/bytesutil/bytes.go b/encoding/bytesutil/bytes.go index 8a5a65b871..45d155a74d 100644 --- a/encoding/bytesutil/bytes.go +++ b/encoding/bytesutil/bytes.go @@ -406,7 +406,17 @@ func ReverseByteOrder(input []byte) []byte { return b } -// NonZeroRoot returns whether or not a root is of proper length and non-zero hash. -func NonZeroRoot(root []byte) bool { - return len(root) == fieldparams.RootLength && string(make([]byte, fieldparams.RootLength)) != string(root) +// ZeroRoot returns whether or not a root is of proper length and non-zero hash. +func ZeroRoot(root []byte) bool { + return string(make([]byte, fieldparams.RootLength)) == string(root) +} + +// IsRoot checks whether the byte array is a root. +func IsRoot(root []byte) bool { + return len(root) == fieldparams.RootLength +} + +// IsValidRoot checks whether the byte array is a valid root. +func IsValidRoot(root []byte) bool { + return IsRoot(root) && !ZeroRoot(root) } diff --git a/encoding/bytesutil/bytes_test.go b/encoding/bytesutil/bytes_test.go index e758f4e623..6a63017a35 100644 --- a/encoding/bytesutil/bytes_test.go +++ b/encoding/bytesutil/bytes_test.go @@ -518,12 +518,67 @@ func TestSafeCopy2d32Bytes(t *testing.T) { assert.DeepEqual(t, input, output) } -func TestNonZeroRoot(t *testing.T) { +func TestZeroRoot(t *testing.T) { input := make([]byte, fieldparams.RootLength) - output := bytesutil.NonZeroRoot(input) - assert.Equal(t, false, output) + output := bytesutil.ZeroRoot(input) + assert.Equal(t, true, output) copy(input[2:], "a") copy(input[3:], "b") - output = bytesutil.NonZeroRoot(input) + output = bytesutil.ZeroRoot(input) + assert.Equal(t, false, output) +} + +func TestIsRoot(t *testing.T) { + input := make([]byte, fieldparams.RootLength) + output := bytesutil.IsRoot(input) assert.Equal(t, true, output) } + +func TestIsValidRoot(t *testing.T) { + + zeroRoot := make([]byte, fieldparams.RootLength) + + validRoot := make([]byte, fieldparams.RootLength) + validRoot[0] = 'a' + + wrongLengthRoot := make([]byte, fieldparams.RootLength-4) + wrongLengthRoot[0] = 'a' + + type args struct { + root []byte + } + + tests := []struct { + name string + args args + want bool + }{ + { + name: "Is ZeroRoot", + args: args{ + root: zeroRoot, + }, + want: false, + }, + { + name: "Is ValidRoot", + args: args{ + root: validRoot, + }, + want: true, + }, + { + name: "Is NonZeroRoot but not length 32", + args: args{ + root: wrongLengthRoot, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := bytesutil.IsValidRoot(tt.args.root) + require.Equal(t, got, tt.want) + }) + } +} diff --git a/validator/accounts/wallet/wallet.go b/validator/accounts/wallet/wallet.go index 233ba493a5..7c870ec51c 100644 --- a/validator/accounts/wallet/wallet.go +++ b/validator/accounts/wallet/wallet.go @@ -311,7 +311,7 @@ func (w *Wallet) InitializeKeymanager(ctx context.Context, cfg iface.InitKeymana } // TODO(9883): future work needs to address how initialize keymanager is called for web3signer. // an error may be thrown for genesis validators root for some InitializeKeymanager calls. - if !bytesutil.NonZeroRoot(config.GenesisValidatorsRoot) { + if !bytesutil.IsValidRoot(config.GenesisValidatorsRoot) { return nil, errors.New("web3signer requires a genesis validators root value") } km, err = remote_web3signer.NewKeymanager(ctx, config) diff --git a/validator/keymanager/remote-web3signer/keymanager.go b/validator/keymanager/remote-web3signer/keymanager.go index 2996fc59da..9acd9e2ce1 100644 --- a/validator/keymanager/remote-web3signer/keymanager.go +++ b/validator/keymanager/remote-web3signer/keymanager.go @@ -47,7 +47,7 @@ type Keymanager struct { // NewKeymanager instantiates a new web3signer key manager. func NewKeymanager(_ context.Context, cfg *SetupConfig) (*Keymanager, error) { - if cfg.BaseEndpoint == "" || !bytesutil.NonZeroRoot(cfg.GenesisValidatorsRoot) { + if cfg.BaseEndpoint == "" || !bytesutil.IsValidRoot(cfg.GenesisValidatorsRoot) { return nil, fmt.Errorf("invalid setup config, one or more configs are empty: BaseEndpoint: %v, GenesisValidatorsRoot: %#x", cfg.BaseEndpoint, cfg.GenesisValidatorsRoot) } if cfg.PublicKeysURL != "" && len(cfg.ProvidedPublicKeys) != 0 { @@ -103,7 +103,7 @@ func getSignRequestJson(ctx context.Context, validator *validator.Validate, requ if request == nil { return nil, errors.New("nil sign request provided") } - if !bytesutil.NonZeroRoot(genesisValidatorsRoot) { + if !bytesutil.IsValidRoot(genesisValidatorsRoot) { return nil, fmt.Errorf("invalid genesis validators root length, genesis root: %v", genesisValidatorsRoot) } switch request.Object.(type) {