From 001f719cc3bf3aa2047e7efe4d80ced151fb0159 Mon Sep 17 00:00:00 2001 From: Raul Jordan Date: Thu, 28 Apr 2022 13:57:40 +0000 Subject: [PATCH] Move ETH2 Types Into Prysm (#10534) * move eth2 types into Prysm * bazel * lint * use existing math helpers * rem eth2-types dep Co-authored-by: james-prysm <90280386+james-prysm@users.noreply.github.com> Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com> --- consensus-types/primitives/BUILD.bazel | 38 ++ consensus-types/primitives/committee_index.go | 54 +++ .../primitives/committee_index_test.go | 28 ++ consensus-types/primitives/domain.go | 57 +++ consensus-types/primitives/domain_test.go | 90 +++++ consensus-types/primitives/epoch.go | 152 ++++++++ consensus-types/primitives/epoch_test.go | 209 +++++++++++ consensus-types/primitives/slot.go | 200 +++++++++++ consensus-types/primitives/slot_test.go | 329 ++++++++++++++++++ consensus-types/primitives/sszbytes.go | 21 ++ consensus-types/primitives/sszbytes_test.go | 57 +++ consensus-types/primitives/sszuint64.go | 61 ++++ consensus-types/primitives/sszuint64_test.go | 96 +++++ consensus-types/primitives/validator.go | 80 +++++ consensus-types/primitives/validator_test.go | 35 ++ math/BUILD.bazel | 1 + math/math_helper.go | 26 +- math/math_helper_test.go | 69 ++++ 18 files changed, 1602 insertions(+), 1 deletion(-) create mode 100644 consensus-types/primitives/BUILD.bazel create mode 100644 consensus-types/primitives/committee_index.go create mode 100644 consensus-types/primitives/committee_index_test.go create mode 100644 consensus-types/primitives/domain.go create mode 100644 consensus-types/primitives/domain_test.go create mode 100644 consensus-types/primitives/epoch.go create mode 100644 consensus-types/primitives/epoch_test.go create mode 100644 consensus-types/primitives/slot.go create mode 100644 consensus-types/primitives/slot_test.go create mode 100644 consensus-types/primitives/sszbytes.go create mode 100644 consensus-types/primitives/sszbytes_test.go create mode 100644 consensus-types/primitives/sszuint64.go create mode 100644 consensus-types/primitives/sszuint64_test.go create mode 100644 consensus-types/primitives/validator.go create mode 100644 consensus-types/primitives/validator_test.go diff --git a/consensus-types/primitives/BUILD.bazel b/consensus-types/primitives/BUILD.bazel new file mode 100644 index 0000000000..0d950a4c95 --- /dev/null +++ b/consensus-types/primitives/BUILD.bazel @@ -0,0 +1,38 @@ +load("@prysm//tools/go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = [ + "committee_index.go", + "domain.go", + "epoch.go", + "slot.go", + "sszbytes.go", + "sszuint64.go", + "validator.go", + ], + importpath = "github.com/prysmaticlabs/prysm/consensus-types/primitives", + visibility = ["//visibility:public"], + deps = [ + "//math:go_default_library", + "@com_github_ferranbt_fastssz//:go_default_library", + ], +) + +go_test( + name = "go_default_test", + srcs = [ + "committee_index_test.go", + "domain_test.go", + "epoch_test.go", + "slot_test.go", + "sszbytes_test.go", + "sszuint64_test.go", + "validator_test.go", + ], + embed = [":go_default_library"], + deps = [ + "//math:go_default_library", + "@com_github_prysmaticlabs_eth2_types//:go_default_library", + ], +) diff --git a/consensus-types/primitives/committee_index.go b/consensus-types/primitives/committee_index.go new file mode 100644 index 0000000000..ff338337b8 --- /dev/null +++ b/consensus-types/primitives/committee_index.go @@ -0,0 +1,54 @@ +package types + +import ( + "fmt" + + fssz "github.com/ferranbt/fastssz" +) + +var _ fssz.HashRoot = (CommitteeIndex)(0) +var _ fssz.Marshaler = (*CommitteeIndex)(nil) +var _ fssz.Unmarshaler = (*CommitteeIndex)(nil) + +// CommitteeIndex -- +type CommitteeIndex uint64 + +// HashTreeRoot returns calculated hash root. +func (c CommitteeIndex) HashTreeRoot() ([32]byte, error) { + return fssz.HashWithDefaultHasher(c) +} + +// HashTreeRootWith -- +func (c CommitteeIndex) HashTreeRootWith(hh *fssz.Hasher) error { + hh.PutUint64(uint64(c)) + return nil +} + +// UnmarshalSSZ -- +func (c *CommitteeIndex) UnmarshalSSZ(buf []byte) error { + if len(buf) != c.SizeSSZ() { + return fmt.Errorf("expected buffer of length %d receiced %d", c.SizeSSZ(), len(buf)) + } + *c = CommitteeIndex(fssz.UnmarshallUint64(buf)) + return nil +} + +// MarshalSSZTo -- +func (c *CommitteeIndex) MarshalSSZTo(dst []byte) ([]byte, error) { + marshalled, err := c.MarshalSSZ() + if err != nil { + return nil, err + } + return append(dst, marshalled...), nil +} + +// MarshalSSZ -- +func (c *CommitteeIndex) MarshalSSZ() ([]byte, error) { + marshalled := fssz.MarshalUint64([]byte{}, uint64(*c)) + return marshalled, nil +} + +// SizeSSZ returns the size of the serialized object. +func (c *CommitteeIndex) SizeSSZ() int { + return 8 +} diff --git a/consensus-types/primitives/committee_index_test.go b/consensus-types/primitives/committee_index_test.go new file mode 100644 index 0000000000..087e283bc2 --- /dev/null +++ b/consensus-types/primitives/committee_index_test.go @@ -0,0 +1,28 @@ +package types + +import ( + "testing" +) + +func TestCommitteeIndex_Casting(t *testing.T) { + committeeIdx := CommitteeIndex(42) + + t.Run("floats", func(t *testing.T) { + var x1 float32 = 42.2 + if CommitteeIndex(x1) != committeeIdx { + t.Errorf("Unequal: %v = %v", CommitteeIndex(x1), committeeIdx) + } + + var x2 float64 = 42.2 + if CommitteeIndex(x2) != committeeIdx { + t.Errorf("Unequal: %v = %v", CommitteeIndex(x2), committeeIdx) + } + }) + + t.Run("int", func(t *testing.T) { + var x int = 42 + if CommitteeIndex(x) != committeeIdx { + t.Errorf("Unequal: %v = %v", CommitteeIndex(x), committeeIdx) + } + }) +} diff --git a/consensus-types/primitives/domain.go b/consensus-types/primitives/domain.go new file mode 100644 index 0000000000..685ee3c740 --- /dev/null +++ b/consensus-types/primitives/domain.go @@ -0,0 +1,57 @@ +package types + +import ( + "fmt" + + fssz "github.com/ferranbt/fastssz" +) + +var _ fssz.HashRoot = (Domain)([]byte{}) +var _ fssz.Marshaler = (*Domain)(nil) +var _ fssz.Unmarshaler = (*Domain)(nil) + +// Domain represents a 32 bytes domain object in Ethereum beacon chain consensus. +type Domain []byte + +// HashTreeRoot -- +func (e Domain) HashTreeRoot() ([32]byte, error) { + return fssz.HashWithDefaultHasher(e) +} + +// HashTreeRootWith -- +func (e Domain) HashTreeRootWith(hh *fssz.Hasher) error { + hh.PutBytes(e[:]) + return nil +} + +// UnmarshalSSZ -- +func (e *Domain) UnmarshalSSZ(buf []byte) error { + if len(buf) != e.SizeSSZ() { + return fmt.Errorf("expected buffer of length %d received %d", e.SizeSSZ(), len(buf)) + } + + var b [32]byte + item := Domain(b[:]) + copy(item, buf) + *e = item + return nil +} + +// MarshalSSZTo -- +func (e *Domain) MarshalSSZTo(dst []byte) ([]byte, error) { + marshalled, err := e.MarshalSSZ() + if err != nil { + return nil, err + } + return append(dst, marshalled...), nil +} + +// MarshalSSZ -- +func (e *Domain) MarshalSSZ() ([]byte, error) { + return *e, nil +} + +// SizeSSZ -- +func (e *Domain) SizeSSZ() int { + return 32 +} diff --git a/consensus-types/primitives/domain_test.go b/consensus-types/primitives/domain_test.go new file mode 100644 index 0000000000..ebe3f82a6d --- /dev/null +++ b/consensus-types/primitives/domain_test.go @@ -0,0 +1,90 @@ +package types + +import ( + "reflect" + "testing" +) + +func TestDomain_Casting(t *testing.T) { + t.Run("empty byte slice", func(t *testing.T) { + b := make([]byte, 0) + d := Domain(b) + if !reflect.DeepEqual([]byte(d), b) { + t.Errorf("Unequal: %v = %v", d, b) + } + }) + + t.Run("non-empty byte slice", func(t *testing.T) { + b := make([]byte, 2) + b[0] = byte('a') + b[1] = byte('b') + d := Domain(b) + if !reflect.DeepEqual([]byte(d), b) { + t.Errorf("Unequal: %v = %v", d, b) + } + }) + + t.Run("byte array", func(t *testing.T) { + var b [2]byte + b[0] = byte('a') + b[1] = byte('b') + d := Domain(b[:]) + if !reflect.DeepEqual([]byte(d), b[:]) { + t.Errorf("Unequal: %v = %v", d, b) + } + }) +} + +func TestDomain_UnmarshalSSZ(t *testing.T) { + t.Run("Ok", func(t *testing.T) { + d := Domain{} + var b = [32]byte{'f', 'o', 'o'} + err := d.UnmarshalSSZ(b[:]) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !reflect.DeepEqual(b[:], []byte(d)) { + t.Errorf("Unequal: %v = %v", b, []byte(d)) + } + }) + + t.Run("Wrong slice length", func(t *testing.T) { + d := Domain{} + var b = [16]byte{'f', 'o', 'o'} + err := d.UnmarshalSSZ(b[:]) + if err == nil { + t.Error("Expected error") + } + }) +} + +func TestDomain_MarshalSSZTo(t *testing.T) { + d := Domain("foo") + dst := []byte("bar") + b, err := d.MarshalSSZTo(dst) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + expected := []byte("barfoo") + if !reflect.DeepEqual(expected, b) { + t.Errorf("Unequal: %v = %v", expected, b) + } +} + +func TestDomain_MarshalSSZ(t *testing.T) { + d := Domain("foo") + b, err := d.MarshalSSZ() + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !reflect.DeepEqual(b, []byte(d)) { + t.Errorf("Unequal: %v = %v", b, []byte(d)) + } +} + +func TestDomain_SizeSSZ(t *testing.T) { + d := Domain{} + if d.SizeSSZ() != 32 { + t.Errorf("Wrong SSZ size. Expected %v vs actual %v", 32, d.SizeSSZ()) + } +} diff --git a/consensus-types/primitives/epoch.go b/consensus-types/primitives/epoch.go new file mode 100644 index 0000000000..80727d7af3 --- /dev/null +++ b/consensus-types/primitives/epoch.go @@ -0,0 +1,152 @@ +package types + +import ( + "fmt" + + fssz "github.com/ferranbt/fastssz" + "github.com/prysmaticlabs/prysm/math" +) + +var _ fssz.HashRoot = (Epoch)(0) +var _ fssz.Marshaler = (*Epoch)(nil) +var _ fssz.Unmarshaler = (*Epoch)(nil) + +// Epoch represents a single epoch. +type Epoch uint64 + +// Mul multiplies epoch by x. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (e Epoch) Mul(x uint64) Epoch { + res, err := e.SafeMul(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeMul multiplies epoch by x. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (e Epoch) SafeMul(x uint64) (Epoch, error) { + res, err := math.Mul64(uint64(e), x) + return Epoch(res), err +} + +// Div divides epoch by x. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (e Epoch) Div(x uint64) Epoch { + res, err := e.SafeDiv(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeDiv divides epoch by x. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (e Epoch) SafeDiv(x uint64) (Epoch, error) { + res, err := math.Div64(uint64(e), x) + return Epoch(res), err +} + +// Add increases epoch by x. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (e Epoch) Add(x uint64) Epoch { + res, err := e.SafeAdd(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeAdd increases epoch by x. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (e Epoch) SafeAdd(x uint64) (Epoch, error) { + res, err := math.Add64(uint64(e), x) + return Epoch(res), err +} + +// AddEpoch increases epoch using another epoch value. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (e Epoch) AddEpoch(x Epoch) Epoch { + return e.Add(uint64(x)) +} + +// SafeAddEpoch increases epoch using another epoch value. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (e Epoch) SafeAddEpoch(x Epoch) (Epoch, error) { + return e.SafeAdd(uint64(x)) +} + +// Sub subtracts x from the epoch. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (e Epoch) Sub(x uint64) Epoch { + res, err := e.SafeSub(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeSub subtracts x from the epoch. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (e Epoch) SafeSub(x uint64) (Epoch, error) { + res, err := math.Sub64(uint64(e), x) + return Epoch(res), err +} + +// Mod returns result of `epoch % x`. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (e Epoch) Mod(x uint64) Epoch { + res, err := e.SafeMod(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeMod returns result of `epoch % x`. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (e Epoch) SafeMod(x uint64) (Epoch, error) { + res, err := math.Mod64(uint64(e), x) + return Epoch(res), err +} + +// HashTreeRoot -- +func (e Epoch) HashTreeRoot() ([32]byte, error) { + return fssz.HashWithDefaultHasher(e) +} + +// HashTreeRootWith -- +func (e Epoch) HashTreeRootWith(hh *fssz.Hasher) error { + hh.PutUint64(uint64(e)) + return nil +} + +// UnmarshalSSZ -- +func (e *Epoch) UnmarshalSSZ(buf []byte) error { + if len(buf) != e.SizeSSZ() { + return fmt.Errorf("expected buffer of length %d received %d", e.SizeSSZ(), len(buf)) + } + *e = Epoch(fssz.UnmarshallUint64(buf)) + return nil +} + +// MarshalSSZTo -- +func (e *Epoch) MarshalSSZTo(dst []byte) ([]byte, error) { + marshalled, err := e.MarshalSSZ() + if err != nil { + return nil, err + } + return append(dst, marshalled...), nil +} + +// MarshalSSZ -- +func (e *Epoch) MarshalSSZ() ([]byte, error) { + marshalled := fssz.MarshalUint64([]byte{}, uint64(*e)) + return marshalled, nil +} + +// SizeSSZ -- +func (e *Epoch) SizeSSZ() int { + return 8 +} diff --git a/consensus-types/primitives/epoch_test.go b/consensus-types/primitives/epoch_test.go new file mode 100644 index 0000000000..2239c64ac1 --- /dev/null +++ b/consensus-types/primitives/epoch_test.go @@ -0,0 +1,209 @@ +package types_test + +import ( + "fmt" + "math" + "testing" + + types "github.com/prysmaticlabs/prysm/consensus-types/primitives" + mathprysm "github.com/prysmaticlabs/prysm/math" +) + +func TestEpoch_Mul(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Epoch + panicMsg string + }{ + {a: 0, b: 1, res: 0}, + {a: 1 << 32, b: 1, res: 1 << 32}, + {a: 1 << 32, b: 100, res: 429496729600}, + {a: 1 << 32, b: 1 << 31, res: 9223372036854775808}, + {a: 1 << 32, b: 1 << 32, res: 0, panicMsg: mathprysm.ErrMulOverflow.Error()}, + {a: 1 << 62, b: 2, res: 9223372036854775808}, + {a: 1 << 62, b: 4, res: 0, panicMsg: mathprysm.ErrMulOverflow.Error()}, + {a: 1 << 63, b: 1, res: 9223372036854775808}, + {a: 1 << 63, b: 2, res: 0, panicMsg: mathprysm.ErrMulOverflow.Error()}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Epoch(%v).Mul(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Epoch + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Epoch(tt.a).Mul(tt.b) + }) + } else { + res = types.Epoch(tt.a).Mul(tt.b) + } + if tt.res != res { + t.Errorf("Epoch.Mul() = %v, want %v", res, tt.res) + } + }) + } +} + +func TestEpoch_Div(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Epoch + panicMsg string + }{ + {a: 0, b: 1, res: 0}, + {a: 1, b: 0, res: 0, panicMsg: mathprysm.ErrDivByZero.Error()}, + {a: 1 << 32, b: 1 << 32, res: 1}, + {a: 429496729600, b: 1 << 32, res: 100}, + {a: 9223372036854775808, b: 1 << 32, res: 1 << 31}, + {a: 1 << 32, b: 1 << 32, res: 1}, + {a: 9223372036854775808, b: 1 << 62, res: 2}, + {a: 9223372036854775808, b: 1 << 63, res: 1}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Epoch(%v).Div(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Epoch + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Epoch(tt.a).Div(tt.b) + }) + } else { + res = types.Epoch(tt.a).Div(tt.b) + } + if tt.res != res { + t.Errorf("Epoch.Div() = %v, want %v", res, tt.res) + } + }) + } +} + +func TestEpoch_Add(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Epoch + panicMsg string + }{ + {a: 0, b: 1, res: 1}, + {a: 1 << 32, b: 1, res: 4294967297}, + {a: 1 << 32, b: 100, res: 4294967396}, + {a: 1 << 31, b: 1 << 31, res: 4294967296}, + {a: 1 << 63, b: 1 << 63, res: 0, panicMsg: mathprysm.ErrAddOverflow.Error()}, + {a: 1 << 63, b: 1, res: 9223372036854775809}, + {a: math.MaxUint64, b: 1, res: 0, panicMsg: mathprysm.ErrAddOverflow.Error()}, + {a: math.MaxUint64, b: 0, res: math.MaxUint64}, + {a: 1 << 63, b: 2, res: 9223372036854775810}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Epoch(%v).Add(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Epoch + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Epoch(tt.a).Add(tt.b) + }) + } else { + res = types.Epoch(tt.a).Add(tt.b) + } + if tt.res != res { + t.Errorf("Epoch.Add() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Epoch(%v).AddEpoch(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Epoch + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Epoch(tt.a).AddEpoch(types.Epoch(tt.b)) + }) + } else { + res = types.Epoch(tt.a).AddEpoch(types.Epoch(tt.b)) + } + if tt.res != res { + t.Errorf("Epoch.AddEpoch() = %v, want %v", res, tt.res) + } + }) + } +} + +func TestEpoch_Sub(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Epoch + panicMsg string + }{ + {a: 1, b: 0, res: 1}, + {a: 0, b: 1, res: 0, panicMsg: mathprysm.ErrSubUnderflow.Error()}, + {a: 1 << 32, b: 1, res: 4294967295}, + {a: 1 << 32, b: 100, res: 4294967196}, + {a: 1 << 31, b: 1 << 31, res: 0}, + {a: 1 << 63, b: 1 << 63, res: 0}, + {a: 1 << 63, b: 1, res: 9223372036854775807}, + {a: math.MaxUint64, b: math.MaxUint64, res: 0}, + {a: math.MaxUint64 - 1, b: math.MaxUint64, res: 0, panicMsg: mathprysm.ErrSubUnderflow.Error()}, + {a: math.MaxUint64, b: 0, res: math.MaxUint64}, + {a: 1 << 63, b: 2, res: 9223372036854775806}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Epoch(%v).Sub(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Epoch + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Epoch(tt.a).Sub(tt.b) + }) + } else { + res = types.Epoch(tt.a).Sub(tt.b) + } + if tt.res != res { + t.Errorf("Epoch.Sub() = %v, want %v", res, tt.res) + } + }) + } +} + +func TestEpoch_Mod(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Epoch + panicMsg string + }{ + {a: 1, b: 0, res: 0, panicMsg: mathprysm.ErrDivByZero.Error()}, + {a: 0, b: 1, res: 0}, + {a: 1 << 32, b: 1 << 32, res: 0}, + {a: 429496729600, b: 1 << 32, res: 0}, + {a: 9223372036854775808, b: 1 << 32, res: 0}, + {a: 1 << 32, b: 1 << 32, res: 0}, + {a: 9223372036854775808, b: 1 << 62, res: 0}, + {a: 9223372036854775808, b: 1 << 63, res: 0}, + {a: 1 << 32, b: 17, res: 1}, + {a: 1 << 32, b: 19, res: (1 << 32) % 19}, + {a: math.MaxUint64, b: math.MaxUint64, res: 0}, + {a: 1 << 63, b: 2, res: 0}, + {a: 1<<63 + 1, b: 2, res: 1}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Epoch(%v).Mod(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Epoch + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Epoch(tt.a).Mod(tt.b) + }) + } else { + res = types.Epoch(tt.a).Mod(tt.b) + } + if tt.res != res { + t.Errorf("Epoch.Mod() = %v, want %v", res, tt.res) + } + }) + } +} + +func assertPanic(t *testing.T, panicMessage string, f func()) { + defer func() { + if r := recover(); r == nil { + t.Errorf("Expected panic not thrown") + } else if r != panicMessage { + t.Errorf("Unexpected panic thrown, want: %#v, got: %#v", panicMessage, r) + } + }() + f() +} diff --git a/consensus-types/primitives/slot.go b/consensus-types/primitives/slot.go new file mode 100644 index 0000000000..114f2826da --- /dev/null +++ b/consensus-types/primitives/slot.go @@ -0,0 +1,200 @@ +package types + +import ( + fmt "fmt" + + fssz "github.com/ferranbt/fastssz" + "github.com/prysmaticlabs/prysm/math" +) + +var _ fssz.HashRoot = (Slot)(0) +var _ fssz.Marshaler = (*Slot)(nil) +var _ fssz.Unmarshaler = (*Slot)(nil) + +// Slot represents a single slot. +type Slot uint64 + +// Mul multiplies slot by x. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) Mul(x uint64) Slot { + res, err := s.SafeMul(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeMul multiplies slot by x. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeMul(x uint64) (Slot, error) { + res, err := math.Mul64(uint64(s), x) + return Slot(res), err +} + +// MulSlot multiplies slot by another slot. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) MulSlot(x Slot) Slot { + return s.Mul(uint64(x)) +} + +// SafeMulSlot multiplies slot by another slot. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeMulSlot(x Slot) (Slot, error) { + return s.SafeMul(uint64(x)) +} + +// Div divides slot by x. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) Div(x uint64) Slot { + res, err := s.SafeDiv(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeDiv divides slot by x. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeDiv(x uint64) (Slot, error) { + res, err := math.Div64(uint64(s), x) + return Slot(res), err +} + +// DivSlot divides slot by another slot. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) DivSlot(x Slot) Slot { + return s.Div(uint64(x)) +} + +// SafeDivSlot divides slot by another slot. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeDivSlot(x Slot) (Slot, error) { + return s.SafeDiv(uint64(x)) +} + +// Add increases slot by x. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) Add(x uint64) Slot { + res, err := s.SafeAdd(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeAdd increases slot by x. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeAdd(x uint64) (Slot, error) { + res, err := math.Add64(uint64(s), x) + return Slot(res), err +} + +// AddSlot increases slot by another slot. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) AddSlot(x Slot) Slot { + return s.Add(uint64(x)) +} + +// SafeAddSlot increases slot by another slot. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeAddSlot(x Slot) (Slot, error) { + return s.SafeAdd(uint64(x)) +} + +// Sub subtracts x from the slot. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) Sub(x uint64) Slot { + res, err := s.SafeSub(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeSub subtracts x from the slot. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeSub(x uint64) (Slot, error) { + res, err := math.Sub64(uint64(s), x) + return Slot(res), err +} + +// SubSlot finds difference between two slot values. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) SubSlot(x Slot) Slot { + return s.Sub(uint64(x)) +} + +// SafeSubSlot finds difference between two slot values. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeSubSlot(x Slot) (Slot, error) { + return s.SafeSub(uint64(x)) +} + +// Mod returns result of `slot % x`. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) Mod(x uint64) Slot { + res, err := s.SafeMod(x) + if err != nil { + panic(err.Error()) + } + return res +} + +// SafeMod returns result of `slot % x`. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeMod(x uint64) (Slot, error) { + res, err := math.Mod64(uint64(s), x) + return Slot(res), err +} + +// ModSlot returns result of `slot % slot`. +// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown. +func (s Slot) ModSlot(x Slot) Slot { + return s.Mod(uint64(x)) +} + +// SafeModSlot returns result of `slot % slot`. +// In case of arithmetic issues (overflow/underflow/div by zero) error is returned. +func (s Slot) SafeModSlot(x Slot) (Slot, error) { + return s.SafeMod(uint64(x)) +} + +// HashTreeRoot -- +func (s Slot) HashTreeRoot() ([32]byte, error) { + return fssz.HashWithDefaultHasher(s) +} + +// HashTreeRootWith -- +func (s Slot) HashTreeRootWith(hh *fssz.Hasher) error { + hh.PutUint64(uint64(s)) + return nil +} + +// UnmarshalSSZ -- +func (s *Slot) UnmarshalSSZ(buf []byte) error { + if len(buf) != s.SizeSSZ() { + return fmt.Errorf("expected buffer of length %d received %d", s.SizeSSZ(), len(buf)) + } + *s = Slot(fssz.UnmarshallUint64(buf)) + return nil +} + +// MarshalSSZTo -- +func (s *Slot) MarshalSSZTo(dst []byte) ([]byte, error) { + marshalled, err := s.MarshalSSZ() + if err != nil { + return nil, err + } + return append(dst, marshalled...), nil +} + +// MarshalSSZ -- +func (s *Slot) MarshalSSZ() ([]byte, error) { + marshalled := fssz.MarshalUint64([]byte{}, uint64(*s)) + return marshalled, nil +} + +// SizeSSZ -- +func (s *Slot) SizeSSZ() int { + return 8 +} diff --git a/consensus-types/primitives/slot_test.go b/consensus-types/primitives/slot_test.go new file mode 100644 index 0000000000..6306745952 --- /dev/null +++ b/consensus-types/primitives/slot_test.go @@ -0,0 +1,329 @@ +package types_test + +import ( + "fmt" + "math" + "testing" + "time" + + types "github.com/prysmaticlabs/eth2-types" +) + +func TestSlot_Casting(t *testing.T) { + slot := types.Slot(42) + + t.Run("time.Duration", func(t *testing.T) { + if uint64(time.Duration(slot)) != uint64(slot) { + t.Error("Slot should produce the same result with time.Duration") + } + }) + + t.Run("floats", func(t *testing.T) { + var x1 float32 = 42.2 + if types.Slot(x1) != slot { + t.Errorf("Unequal: %v = %v", types.Slot(x1), slot) + } + + var x2 float64 = 42.2 + if types.Slot(x2) != slot { + t.Errorf("Unequal: %v = %v", types.Slot(x2), slot) + } + }) + + t.Run("int", func(t *testing.T) { + var x int = 42 + if types.Slot(x) != slot { + t.Errorf("Unequal: %v = %v", types.Slot(x), slot) + } + }) +} + +func TestSlot_Mul(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Slot + panicMsg string + }{ + {a: 0, b: 1, res: 0}, + {a: 1 << 32, b: 1, res: 1 << 32}, + {a: 1 << 32, b: 100, res: 429496729600}, + {a: 1 << 32, b: 1 << 31, res: 9223372036854775808}, + {a: 1 << 32, b: 1 << 32, res: 0, panicMsg: types.ErrMulOverflow.Error()}, + {a: 1 << 62, b: 2, res: 9223372036854775808}, + {a: 1 << 62, b: 4, res: 0, panicMsg: types.ErrMulOverflow.Error()}, + {a: 1 << 63, b: 1, res: 9223372036854775808}, + {a: 1 << 63, b: 2, res: 0, panicMsg: types.ErrMulOverflow.Error()}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Slot(%v).Mul(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).Mul(tt.b) + }) + } else { + res = types.Slot(tt.a).Mul(tt.b) + } + if tt.res != res { + t.Errorf("Slot.Mul() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).MulSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).MulSlot(types.Slot(tt.b)) + }) + } else { + res = types.Slot(tt.a).MulSlot(types.Slot(tt.b)) + } + if tt.res != res { + t.Errorf("Slot.MulSlot() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).SafeMulSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + res, err := types.Slot(tt.a).SafeMulSlot(types.Slot(tt.b)) + if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) { + t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err) + return + } + if tt.res != res { + t.Errorf("Slot.SafeMulSlot() = %v, want %v", res, tt.res) + } + }) + } +} + +func TestSlot_Div(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Slot + panicMsg string + }{ + {a: 0, b: 1, res: 0}, + {a: 1, b: 0, res: 0, panicMsg: types.ErrDivByZero.Error()}, + {a: 1 << 32, b: 1 << 32, res: 1}, + {a: 429496729600, b: 1 << 32, res: 100}, + {a: 9223372036854775808, b: 1 << 32, res: 1 << 31}, + {a: 1 << 32, b: 1 << 32, res: 1}, + {a: 9223372036854775808, b: 1 << 62, res: 2}, + {a: 9223372036854775808, b: 1 << 63, res: 1}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Slot(%v).Div(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).Div(tt.b) + }) + } else { + res = types.Slot(tt.a).Div(tt.b) + } + if tt.res != res { + t.Errorf("Slot.Div() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).DivSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).DivSlot(types.Slot(tt.b)) + }) + } else { + res = types.Slot(tt.a).DivSlot(types.Slot(tt.b)) + } + if tt.res != res { + t.Errorf("Slot.DivSlot() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).SafeDivSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + res, err := types.Slot(tt.a).SafeDivSlot(types.Slot(tt.b)) + if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) { + t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err) + return + } + if tt.res != res { + t.Errorf("Slot.SafeDivSlot() = %v, want %v", res, tt.res) + } + }) + } +} + +func TestSlot_Add(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Slot + panicMsg string + }{ + {a: 0, b: 1, res: 1}, + {a: 1 << 32, b: 1, res: 4294967297}, + {a: 1 << 32, b: 100, res: 4294967396}, + {a: 1 << 31, b: 1 << 31, res: 4294967296}, + {a: 1 << 63, b: 1 << 63, res: 0, panicMsg: types.ErrAddOverflow.Error()}, + {a: 1 << 63, b: 1, res: 9223372036854775809}, + {a: math.MaxUint64, b: 1, res: 0, panicMsg: types.ErrAddOverflow.Error()}, + {a: math.MaxUint64, b: 0, res: math.MaxUint64}, + {a: 1 << 63, b: 2, res: 9223372036854775810}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Slot(%v).Add(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).Add(tt.b) + }) + } else { + res = types.Slot(tt.a).Add(tt.b) + } + if tt.res != res { + t.Errorf("Slot.Add() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).AddSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).AddSlot(types.Slot(tt.b)) + }) + } else { + res = types.Slot(tt.a).AddSlot(types.Slot(tt.b)) + } + if tt.res != res { + t.Errorf("Slot.AddSlot() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).SafeAddSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + res, err := types.Slot(tt.a).SafeAddSlot(types.Slot(tt.b)) + if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) { + t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err) + return + } + if tt.res != res { + t.Errorf("Slot.SafeAddSlot() = %v, want %v", res, tt.res) + } + }) + } +} + +func TestSlot_Sub(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Slot + panicMsg string + }{ + {a: 1, b: 0, res: 1}, + {a: 0, b: 1, res: 0, panicMsg: types.ErrSubUnderflow.Error()}, + {a: 1 << 32, b: 1, res: 4294967295}, + {a: 1 << 32, b: 100, res: 4294967196}, + {a: 1 << 31, b: 1 << 31, res: 0}, + {a: 1 << 63, b: 1 << 63, res: 0}, + {a: 1 << 63, b: 1, res: 9223372036854775807}, + {a: math.MaxUint64, b: math.MaxUint64, res: 0}, + {a: math.MaxUint64 - 1, b: math.MaxUint64, res: 0, panicMsg: types.ErrSubUnderflow.Error()}, + {a: math.MaxUint64, b: 0, res: math.MaxUint64}, + {a: 1 << 63, b: 2, res: 9223372036854775806}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Slot(%v).Sub(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).Sub(tt.b) + }) + } else { + res = types.Slot(tt.a).Sub(tt.b) + } + if tt.res != res { + t.Errorf("Slot.Sub() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).SubSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).SubSlot(types.Slot(tt.b)) + }) + } else { + res = types.Slot(tt.a).SubSlot(types.Slot(tt.b)) + } + if tt.res != res { + t.Errorf("Slot.SubSlot() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).SafeSubSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + res, err := types.Slot(tt.a).SafeSubSlot(types.Slot(tt.b)) + if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) { + t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err) + return + } + if tt.res != res { + t.Errorf("Slot.SafeSubSlot() = %v, want %v", res, tt.res) + } + }) + } +} + +func TestSlot_Mod(t *testing.T) { + tests := []struct { + a, b uint64 + res types.Slot + panicMsg string + }{ + {a: 1, b: 0, res: 0, panicMsg: types.ErrDivByZero.Error()}, + {a: 0, b: 1, res: 0}, + {a: 1 << 32, b: 1 << 32, res: 0}, + {a: 429496729600, b: 1 << 32, res: 0}, + {a: 9223372036854775808, b: 1 << 32, res: 0}, + {a: 1 << 32, b: 1 << 32, res: 0}, + {a: 9223372036854775808, b: 1 << 62, res: 0}, + {a: 9223372036854775808, b: 1 << 63, res: 0}, + {a: 1 << 32, b: 17, res: 1}, + {a: 1 << 32, b: 19, res: (1 << 32) % 19}, + {a: math.MaxUint64, b: math.MaxUint64, res: 0}, + {a: 1 << 63, b: 2, res: 0}, + {a: 1<<63 + 1, b: 2, res: 1}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Slot(%v).Mod(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).Mod(tt.b) + }) + } else { + res = types.Slot(tt.a).Mod(tt.b) + } + if tt.res != res { + t.Errorf("Slot.Mod() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).ModSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + var res types.Slot + if tt.panicMsg != "" { + assertPanic(t, tt.panicMsg, func() { + res = types.Slot(tt.a).ModSlot(types.Slot(tt.b)) + }) + } else { + res = types.Slot(tt.a).ModSlot(types.Slot(tt.b)) + } + if tt.res != res { + t.Errorf("Slot.Mod() = %v, want %v", res, tt.res) + } + }) + t.Run(fmt.Sprintf("Slot(%v).SafeModSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) { + res, err := types.Slot(tt.a).SafeModSlot(types.Slot(tt.b)) + if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) { + t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err) + return + } + if tt.res != res { + t.Errorf("Slot.SafeModSlot() = %v, want %v", res, tt.res) + } + }) + } +} diff --git a/consensus-types/primitives/sszbytes.go b/consensus-types/primitives/sszbytes.go new file mode 100644 index 0000000000..c51a2c0ac1 --- /dev/null +++ b/consensus-types/primitives/sszbytes.go @@ -0,0 +1,21 @@ +package types + +import ( + fssz "github.com/ferranbt/fastssz" +) + +// SSZBytes -- +type SSZBytes []byte + +// HashTreeRoot -- +func (b *SSZBytes) HashTreeRoot() ([32]byte, error) { + return fssz.HashWithDefaultHasher(b) +} + +// HashTreeRootWith -- +func (b *SSZBytes) HashTreeRootWith(hh *fssz.Hasher) error { + indx := hh.Index() + hh.PutBytes(*b) + hh.Merkleize(indx) + return nil +} diff --git a/consensus-types/primitives/sszbytes_test.go b/consensus-types/primitives/sszbytes_test.go new file mode 100644 index 0000000000..b665e9c3e2 --- /dev/null +++ b/consensus-types/primitives/sszbytes_test.go @@ -0,0 +1,57 @@ +package types_test + +import ( + "encoding/hex" + "reflect" + "testing" + + types "github.com/prysmaticlabs/eth2-types" +) + +func TestSSZBytes_HashTreeRoot(t *testing.T) { + tests := []struct { + name string + actualValue []byte + root []byte + wantErr bool + }{ + { + name: "random1", + actualValue: hexDecodeOrDie(t, "844e1063e0b396eed17be8eddb7eecd1fe3ea46542a4b72f7466e77325e5aa6d"), + root: hexDecodeOrDie(t, "844e1063e0b396eed17be8eddb7eecd1fe3ea46542a4b72f7466e77325e5aa6d"), + wantErr: false, + }, + { + name: "random1", + actualValue: hexDecodeOrDie(t, "7b16162ecd9a28fa80a475080b0e4fff4c27efe19ce5134ce3554b72274d59fd534400ba4c7f699aa1c307cd37c2b103"), + root: hexDecodeOrDie(t, "128ed34ee798b9f00716f9ba5c000df5c99443dabc4d3f2e9bb86c77c732e007"), + wantErr: false, + }, + { + name: "random2", + actualValue: []byte{}, + root: hexDecodeOrDie(t, "0000000000000000000000000000000000000000000000000000000000000000"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := types.SSZBytes(tt.actualValue) + htr, err := s.HashTreeRoot() + if err != nil { + t.Errorf("SSZBytes.HashTreeRoot() unexpected error = %v", err) + } + if !reflect.DeepEqual(tt.root, htr[:]) { + t.Errorf("SSZBytes.HashTreeRoot() = %v, want %v", htr[:], tt.root) + } + }) + } +} + +func hexDecodeOrDie(t *testing.T, str string) []byte { + decoded, err := hex.DecodeString(str) + if err != nil { + t.Errorf("hex.DecodeString(%s) unexpected error = %v", str, err) + } + return decoded +} diff --git a/consensus-types/primitives/sszuint64.go b/consensus-types/primitives/sszuint64.go new file mode 100644 index 0000000000..1ce66b9182 --- /dev/null +++ b/consensus-types/primitives/sszuint64.go @@ -0,0 +1,61 @@ +package types + +import ( + "encoding/binary" + "fmt" + + fssz "github.com/ferranbt/fastssz" +) + +var _ fssz.HashRoot = (Epoch)(0) +var _ fssz.Marshaler = (*Epoch)(nil) +var _ fssz.Unmarshaler = (*Epoch)(nil) + +// SSZUint64 -- +type SSZUint64 uint64 + +// SizeSSZ -- +func (s *SSZUint64) SizeSSZ() int { + return 8 +} + +// MarshalSSZTo -- +func (s *SSZUint64) MarshalSSZTo(dst []byte) ([]byte, error) { + marshalled, err := s.MarshalSSZ() + if err != nil { + return nil, err + } + return append(dst, marshalled...), nil +} + +// MarshalSSZ -- +func (s *SSZUint64) MarshalSSZ() ([]byte, error) { + marshalled := fssz.MarshalUint64([]byte{}, uint64(*s)) + return marshalled, nil +} + +// UnmarshalSSZ -- +func (s *SSZUint64) UnmarshalSSZ(buf []byte) error { + if len(buf) != s.SizeSSZ() { + return fmt.Errorf("expected buffer of length %d received %d", s.SizeSSZ(), len(buf)) + } + *s = SSZUint64(fssz.UnmarshallUint64(buf)) + return nil +} + +// HashTreeRoot -- +func (s *SSZUint64) HashTreeRoot() ([32]byte, error) { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(*s)) + var root [32]byte + copy(root[:], buf) + return root, nil +} + +// HashTreeRootWith -- +func (s *SSZUint64) HashTreeRootWith(hh *fssz.Hasher) error { + indx := hh.Index() + hh.PutUint64(uint64(*s)) + hh.Merkleize(indx) + return nil +} diff --git a/consensus-types/primitives/sszuint64_test.go b/consensus-types/primitives/sszuint64_test.go new file mode 100644 index 0000000000..bb4cbd6759 --- /dev/null +++ b/consensus-types/primitives/sszuint64_test.go @@ -0,0 +1,96 @@ +package types_test + +import ( + "reflect" + "strings" + "testing" + + types "github.com/prysmaticlabs/eth2-types" +) + +func TestSSZUint64_Limit(t *testing.T) { + sszType := types.SSZUint64(0) + serializedObj := [7]byte{} + err := sszType.UnmarshalSSZ(serializedObj[:]) + if err == nil || !strings.Contains(err.Error(), "expected buffer of length") { + t.Errorf("Expected Error = %s, got: %v", "expected buffer of length", err) + } +} + +func TestSSZUint64_RoundTrip(t *testing.T) { + fixedVal := uint64(8) + sszVal := types.SSZUint64(fixedVal) + + marshalledObj, err := sszVal.MarshalSSZ() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + newVal := types.SSZUint64(0) + + err = newVal.UnmarshalSSZ(marshalledObj) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if fixedVal != uint64(newVal) { + t.Errorf("Unequal: %v = %v", fixedVal, uint64(newVal)) + } +} + +func TestSSZUint64(t *testing.T) { + tests := []struct { + name string + serializedBytes []byte + actualValue uint64 + root []byte + wantErr bool + }{ + { + name: "max", + serializedBytes: hexDecodeOrDie(t, "ffffffffffffffff"), + actualValue: 18446744073709551615, + root: hexDecodeOrDie(t, "ffffffffffffffff000000000000000000000000000000000000000000000000"), + wantErr: false, + }, + { + name: "random", + serializedBytes: hexDecodeOrDie(t, "357c8de9d7204577"), + actualValue: 8594311575614880821, + root: hexDecodeOrDie(t, "357c8de9d7204577000000000000000000000000000000000000000000000000"), + wantErr: false, + }, + { + name: "zero", + serializedBytes: hexDecodeOrDie(t, "0000000000000000"), + actualValue: 0, + root: hexDecodeOrDie(t, "0000000000000000000000000000000000000000000000000000000000000000"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var s types.SSZUint64 + if err := s.UnmarshalSSZ(tt.serializedBytes); (err != nil) != tt.wantErr { + t.Errorf("SSZUint64.UnmarshalSSZ() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.actualValue != uint64(s) { + t.Errorf("SSZUint64.UnmarshalSSZ() = %v, want %v", uint64(s), tt.actualValue) + } + + serializedBytes, err := s.MarshalSSZ() + if err != nil { + t.Errorf("SSZUint64.MarshalSSZ() unexpected error = %v", err) + } + if !reflect.DeepEqual(tt.serializedBytes, serializedBytes) { + t.Errorf("SSZUint64.MarshalSSZ() = %v, want %v", serializedBytes, tt.serializedBytes) + } + + htr, err := s.HashTreeRoot() + if err != nil { + t.Errorf("SSZUint64.HashTreeRoot() unexpected error = %v", err) + } + if !reflect.DeepEqual(tt.root, htr[:]) { + t.Errorf("SSZUint64.HashTreeRoot() = %v, want %v", htr[:], tt.root) + } + }) + } +} diff --git a/consensus-types/primitives/validator.go b/consensus-types/primitives/validator.go new file mode 100644 index 0000000000..0dd94eac62 --- /dev/null +++ b/consensus-types/primitives/validator.go @@ -0,0 +1,80 @@ +package types + +import ( + fmt "fmt" + + fssz "github.com/ferranbt/fastssz" +) + +var _ fssz.HashRoot = (ValidatorIndex)(0) +var _ fssz.Marshaler = (*ValidatorIndex)(nil) +var _ fssz.Unmarshaler = (*ValidatorIndex)(nil) + +// ValidatorIndex in eth2. +type ValidatorIndex uint64 + +// Div divides validator index by x. +func (v ValidatorIndex) Div(x uint64) ValidatorIndex { + if x == 0 { + panic("divbyzero") + } + return ValidatorIndex(uint64(v) / x) +} + +// Add increases validator index by x. +func (v ValidatorIndex) Add(x uint64) ValidatorIndex { + return ValidatorIndex(uint64(v) + x) +} + +// Sub subtracts x from the validator index. +func (v ValidatorIndex) Sub(x uint64) ValidatorIndex { + if uint64(v) < x { + panic("underflow") + } + return ValidatorIndex(uint64(v) - x) +} + +// Mod returns result of `validator index % x`. +func (v ValidatorIndex) Mod(x uint64) ValidatorIndex { + return ValidatorIndex(uint64(v) % x) +} + +// HashTreeRoot -- +func (v ValidatorIndex) HashTreeRoot() ([32]byte, error) { + return fssz.HashWithDefaultHasher(v) +} + +// HashTreeRootWith -- +func (v ValidatorIndex) HashTreeRootWith(hh *fssz.Hasher) error { + hh.PutUint64(uint64(v)) + return nil +} + +// UnmarshalSSZ -- +func (v *ValidatorIndex) UnmarshalSSZ(buf []byte) error { + if len(buf) != v.SizeSSZ() { + return fmt.Errorf("expected buffer of length %d received %d", v.SizeSSZ(), len(buf)) + } + *v = ValidatorIndex(fssz.UnmarshallUint64(buf)) + return nil +} + +// MarshalSSZTo -- +func (v *ValidatorIndex) MarshalSSZTo(dst []byte) ([]byte, error) { + marshalled, err := v.MarshalSSZ() + if err != nil { + return nil, err + } + return append(dst, marshalled...), nil +} + +// MarshalSSZ -- +func (v *ValidatorIndex) MarshalSSZ() ([]byte, error) { + marshalled := fssz.MarshalUint64([]byte{}, uint64(*v)) + return marshalled, nil +} + +// SizeSSZ -- +func (v *ValidatorIndex) SizeSSZ() int { + return 8 +} diff --git a/consensus-types/primitives/validator_test.go b/consensus-types/primitives/validator_test.go new file mode 100644 index 0000000000..e08d57fe1e --- /dev/null +++ b/consensus-types/primitives/validator_test.go @@ -0,0 +1,35 @@ +package types + +import ( + "testing" + "time" +) + +func TestValidatorIndex_Casting(t *testing.T) { + valIdx := ValidatorIndex(42) + + t.Run("time.Duration", func(t *testing.T) { + if uint64(time.Duration(valIdx)) != uint64(valIdx) { + t.Error("ValidatorIndex should produce the same result with time.Duration") + } + }) + + t.Run("floats", func(t *testing.T) { + var x1 float32 = 42.2 + if ValidatorIndex(x1) != valIdx { + t.Errorf("Unequal: %v = %v", ValidatorIndex(x1), valIdx) + } + + var x2 float64 = 42.2 + if ValidatorIndex(x2) != valIdx { + t.Errorf("Unequal: %v = %v", ValidatorIndex(x2), valIdx) + } + }) + + t.Run("int", func(t *testing.T) { + var x int = 42 + if ValidatorIndex(x) != valIdx { + t.Errorf("Unequal: %v = %v", ValidatorIndex(x), valIdx) + } + }) +} diff --git a/math/BUILD.bazel b/math/BUILD.bazel index 192020846f..80db29eed3 100644 --- a/math/BUILD.bazel +++ b/math/BUILD.bazel @@ -15,5 +15,6 @@ go_test( deps = [ ":go_default_library", "//testing/require:go_default_library", + "@com_github_prysmaticlabs_eth2_types//:go_default_library", ], ) diff --git a/math/math_helper.go b/math/math_helper.go index d1ca6d4086..95c4398f0b 100644 --- a/math/math_helper.go +++ b/math/math_helper.go @@ -20,7 +20,13 @@ func init() { } // ErrOverflow occurs when an operation exceeds max or minimum values. -var ErrOverflow = errors.New("integer overflow") +var ( + ErrOverflow = errors.New("integer overflow") + ErrDivByZero = errors.New("integer divide by zero") + ErrMulOverflow = errors.New("multiplication overflows") + ErrAddOverflow = errors.New("addition overflows") + ErrSubUnderflow = errors.New("subtraction underflow") +) // Common square root values. var squareRootTable = map[uint64]uint64{ @@ -115,6 +121,15 @@ func Mul64(a, b uint64) (uint64, error) { return val, nil } +// Div64 divides two 64-bit unsigned integers and checks for errors. +func Div64(a, b uint64) (uint64, error) { + if b == 0 { + return 0, ErrDivByZero + } + val, _ := bits.Div64(0, a, b) + return val, nil +} + // Add64 adds 2 64-bit unsigned integers and checks if they // lead to an overflow. If they do not, it returns the result // without an error. @@ -135,6 +150,15 @@ func Sub64(a, b uint64) (uint64, error) { return res, nil } +// Mod64 finds remainder of division of two 64-bit unsigned integers and checks for errors. +func Mod64(a, b uint64) (uint64, error) { + if b == 0 { + return 0, ErrDivByZero + } + _, val := bits.Div64(0, a, b) + return val, nil +} + // Int returns the integer value of the uint64 argument. If there is an overlow, then an error is // returned. func Int(u uint64) (int, error) { diff --git a/math/math_helper_test.go b/math/math_helper_test.go index 4351888b12..2dbd1684dd 100644 --- a/math/math_helper_test.go +++ b/math/math_helper_test.go @@ -5,6 +5,7 @@ import ( stdmath "math" "testing" + types "github.com/prysmaticlabs/eth2-types" "github.com/prysmaticlabs/prysm/math" "github.com/prysmaticlabs/prysm/testing/require" ) @@ -85,6 +86,74 @@ func TestIntegerSquareRoot(t *testing.T) { } } +func TestMath_Div64(t *testing.T) { + type args struct { + a uint64 + b uint64 + } + tests := []struct { + args args + res uint64 + err bool + }{ + {args: args{0, 1}, res: 0, err: false}, + {args: args{0, 1}, res: 0}, + {args: args{1, 0}, res: 0, err: true}, + {args: args{1 << 32, 1 << 32}, res: 1}, + {args: args{429496729600, 1 << 32}, res: 100}, + {args: args{9223372036854775808, 1 << 32}, res: 1 << 31}, + {args: args{a: 1 << 32, b: 1 << 32}, res: 1}, + {args: args{9223372036854775808, 1 << 62}, res: 2}, + {args: args{9223372036854775808, 1 << 63}, res: 1}, + } + for _, tt := range tests { + got, err := math.Div64(tt.args.a, tt.args.b) + if tt.err && err == nil { + t.Errorf("Div64() Expected Error = %v, want error", tt.err) + continue + } + if tt.res != got { + t.Errorf("Div64() %v, want %v", got, tt.res) + } + } +} + +func TestMath_Mod(t *testing.T) { + type args struct { + a uint64 + b uint64 + } + tests := []struct { + args args + res uint64 + err bool + }{ + {args: args{1, 0}, res: 0, err: true}, + {args: args{0, 1}, res: 0}, + {args: args{1 << 32, 1 << 32}, res: 0}, + {args: args{429496729600, 1 << 32}, res: 0}, + {args: args{9223372036854775808, 1 << 32}, res: 0}, + {args: args{1 << 32, 1 << 32}, res: 0}, + {args: args{9223372036854775808, 1 << 62}, res: 0}, + {args: args{9223372036854775808, 1 << 63}, res: 0}, + {args: args{1 << 32, 17}, res: 1}, + {args: args{1 << 32, 19}, res: (1 << 32) % 19}, + {args: args{stdmath.MaxUint64, stdmath.MaxUint64}, res: 0}, + {args: args{1 << 63, 2}, res: 0}, + {args: args{1<<63 + 1, 2}, res: 1}, + } + for _, tt := range tests { + got, err := types.Mod64(tt.args.a, tt.args.b) + if tt.err && err == nil { + t.Errorf("Mod64() Expected Error = %v, want error", tt.err) + continue + } + if tt.res != got { + t.Errorf("Mod64() %v, want %v", got, tt.res) + } + } +} + func BenchmarkIntegerSquareRootBelow52Bits(b *testing.B) { val := uint64(1 << 33) for i := 0; i < b.N; i++ {