Move ETH2 Types Into Prysm (#10534)

* move eth2 types into Prysm

* bazel

* lint

* use existing math helpers

* rem eth2-types dep

Co-authored-by: james-prysm <90280386+james-prysm@users.noreply.github.com>
Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
Raul Jordan
2022-04-28 13:57:40 +00:00
committed by GitHub
parent 58ad800553
commit 001f719cc3
18 changed files with 1602 additions and 1 deletions

View File

@@ -0,0 +1,38 @@
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = [
"committee_index.go",
"domain.go",
"epoch.go",
"slot.go",
"sszbytes.go",
"sszuint64.go",
"validator.go",
],
importpath = "github.com/prysmaticlabs/prysm/consensus-types/primitives",
visibility = ["//visibility:public"],
deps = [
"//math:go_default_library",
"@com_github_ferranbt_fastssz//:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = [
"committee_index_test.go",
"domain_test.go",
"epoch_test.go",
"slot_test.go",
"sszbytes_test.go",
"sszuint64_test.go",
"validator_test.go",
],
embed = [":go_default_library"],
deps = [
"//math:go_default_library",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
],
)

View File

@@ -0,0 +1,54 @@
package types
import (
"fmt"
fssz "github.com/ferranbt/fastssz"
)
var _ fssz.HashRoot = (CommitteeIndex)(0)
var _ fssz.Marshaler = (*CommitteeIndex)(nil)
var _ fssz.Unmarshaler = (*CommitteeIndex)(nil)
// CommitteeIndex --
type CommitteeIndex uint64
// HashTreeRoot returns calculated hash root.
func (c CommitteeIndex) HashTreeRoot() ([32]byte, error) {
return fssz.HashWithDefaultHasher(c)
}
// HashTreeRootWith --
func (c CommitteeIndex) HashTreeRootWith(hh *fssz.Hasher) error {
hh.PutUint64(uint64(c))
return nil
}
// UnmarshalSSZ --
func (c *CommitteeIndex) UnmarshalSSZ(buf []byte) error {
if len(buf) != c.SizeSSZ() {
return fmt.Errorf("expected buffer of length %d receiced %d", c.SizeSSZ(), len(buf))
}
*c = CommitteeIndex(fssz.UnmarshallUint64(buf))
return nil
}
// MarshalSSZTo --
func (c *CommitteeIndex) MarshalSSZTo(dst []byte) ([]byte, error) {
marshalled, err := c.MarshalSSZ()
if err != nil {
return nil, err
}
return append(dst, marshalled...), nil
}
// MarshalSSZ --
func (c *CommitteeIndex) MarshalSSZ() ([]byte, error) {
marshalled := fssz.MarshalUint64([]byte{}, uint64(*c))
return marshalled, nil
}
// SizeSSZ returns the size of the serialized object.
func (c *CommitteeIndex) SizeSSZ() int {
return 8
}

View File

@@ -0,0 +1,28 @@
package types
import (
"testing"
)
func TestCommitteeIndex_Casting(t *testing.T) {
committeeIdx := CommitteeIndex(42)
t.Run("floats", func(t *testing.T) {
var x1 float32 = 42.2
if CommitteeIndex(x1) != committeeIdx {
t.Errorf("Unequal: %v = %v", CommitteeIndex(x1), committeeIdx)
}
var x2 float64 = 42.2
if CommitteeIndex(x2) != committeeIdx {
t.Errorf("Unequal: %v = %v", CommitteeIndex(x2), committeeIdx)
}
})
t.Run("int", func(t *testing.T) {
var x int = 42
if CommitteeIndex(x) != committeeIdx {
t.Errorf("Unequal: %v = %v", CommitteeIndex(x), committeeIdx)
}
})
}

View File

@@ -0,0 +1,57 @@
package types
import (
"fmt"
fssz "github.com/ferranbt/fastssz"
)
var _ fssz.HashRoot = (Domain)([]byte{})
var _ fssz.Marshaler = (*Domain)(nil)
var _ fssz.Unmarshaler = (*Domain)(nil)
// Domain represents a 32 bytes domain object in Ethereum beacon chain consensus.
type Domain []byte
// HashTreeRoot --
func (e Domain) HashTreeRoot() ([32]byte, error) {
return fssz.HashWithDefaultHasher(e)
}
// HashTreeRootWith --
func (e Domain) HashTreeRootWith(hh *fssz.Hasher) error {
hh.PutBytes(e[:])
return nil
}
// UnmarshalSSZ --
func (e *Domain) UnmarshalSSZ(buf []byte) error {
if len(buf) != e.SizeSSZ() {
return fmt.Errorf("expected buffer of length %d received %d", e.SizeSSZ(), len(buf))
}
var b [32]byte
item := Domain(b[:])
copy(item, buf)
*e = item
return nil
}
// MarshalSSZTo --
func (e *Domain) MarshalSSZTo(dst []byte) ([]byte, error) {
marshalled, err := e.MarshalSSZ()
if err != nil {
return nil, err
}
return append(dst, marshalled...), nil
}
// MarshalSSZ --
func (e *Domain) MarshalSSZ() ([]byte, error) {
return *e, nil
}
// SizeSSZ --
func (e *Domain) SizeSSZ() int {
return 32
}

View File

@@ -0,0 +1,90 @@
package types
import (
"reflect"
"testing"
)
func TestDomain_Casting(t *testing.T) {
t.Run("empty byte slice", func(t *testing.T) {
b := make([]byte, 0)
d := Domain(b)
if !reflect.DeepEqual([]byte(d), b) {
t.Errorf("Unequal: %v = %v", d, b)
}
})
t.Run("non-empty byte slice", func(t *testing.T) {
b := make([]byte, 2)
b[0] = byte('a')
b[1] = byte('b')
d := Domain(b)
if !reflect.DeepEqual([]byte(d), b) {
t.Errorf("Unequal: %v = %v", d, b)
}
})
t.Run("byte array", func(t *testing.T) {
var b [2]byte
b[0] = byte('a')
b[1] = byte('b')
d := Domain(b[:])
if !reflect.DeepEqual([]byte(d), b[:]) {
t.Errorf("Unequal: %v = %v", d, b)
}
})
}
func TestDomain_UnmarshalSSZ(t *testing.T) {
t.Run("Ok", func(t *testing.T) {
d := Domain{}
var b = [32]byte{'f', 'o', 'o'}
err := d.UnmarshalSSZ(b[:])
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(b[:], []byte(d)) {
t.Errorf("Unequal: %v = %v", b, []byte(d))
}
})
t.Run("Wrong slice length", func(t *testing.T) {
d := Domain{}
var b = [16]byte{'f', 'o', 'o'}
err := d.UnmarshalSSZ(b[:])
if err == nil {
t.Error("Expected error")
}
})
}
func TestDomain_MarshalSSZTo(t *testing.T) {
d := Domain("foo")
dst := []byte("bar")
b, err := d.MarshalSSZTo(dst)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
expected := []byte("barfoo")
if !reflect.DeepEqual(expected, b) {
t.Errorf("Unequal: %v = %v", expected, b)
}
}
func TestDomain_MarshalSSZ(t *testing.T) {
d := Domain("foo")
b, err := d.MarshalSSZ()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if !reflect.DeepEqual(b, []byte(d)) {
t.Errorf("Unequal: %v = %v", b, []byte(d))
}
}
func TestDomain_SizeSSZ(t *testing.T) {
d := Domain{}
if d.SizeSSZ() != 32 {
t.Errorf("Wrong SSZ size. Expected %v vs actual %v", 32, d.SizeSSZ())
}
}

View File

@@ -0,0 +1,152 @@
package types
import (
"fmt"
fssz "github.com/ferranbt/fastssz"
"github.com/prysmaticlabs/prysm/math"
)
var _ fssz.HashRoot = (Epoch)(0)
var _ fssz.Marshaler = (*Epoch)(nil)
var _ fssz.Unmarshaler = (*Epoch)(nil)
// Epoch represents a single epoch.
type Epoch uint64
// Mul multiplies epoch by x.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (e Epoch) Mul(x uint64) Epoch {
res, err := e.SafeMul(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeMul multiplies epoch by x.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (e Epoch) SafeMul(x uint64) (Epoch, error) {
res, err := math.Mul64(uint64(e), x)
return Epoch(res), err
}
// Div divides epoch by x.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (e Epoch) Div(x uint64) Epoch {
res, err := e.SafeDiv(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeDiv divides epoch by x.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (e Epoch) SafeDiv(x uint64) (Epoch, error) {
res, err := math.Div64(uint64(e), x)
return Epoch(res), err
}
// Add increases epoch by x.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (e Epoch) Add(x uint64) Epoch {
res, err := e.SafeAdd(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeAdd increases epoch by x.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (e Epoch) SafeAdd(x uint64) (Epoch, error) {
res, err := math.Add64(uint64(e), x)
return Epoch(res), err
}
// AddEpoch increases epoch using another epoch value.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (e Epoch) AddEpoch(x Epoch) Epoch {
return e.Add(uint64(x))
}
// SafeAddEpoch increases epoch using another epoch value.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (e Epoch) SafeAddEpoch(x Epoch) (Epoch, error) {
return e.SafeAdd(uint64(x))
}
// Sub subtracts x from the epoch.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (e Epoch) Sub(x uint64) Epoch {
res, err := e.SafeSub(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeSub subtracts x from the epoch.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (e Epoch) SafeSub(x uint64) (Epoch, error) {
res, err := math.Sub64(uint64(e), x)
return Epoch(res), err
}
// Mod returns result of `epoch % x`.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (e Epoch) Mod(x uint64) Epoch {
res, err := e.SafeMod(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeMod returns result of `epoch % x`.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (e Epoch) SafeMod(x uint64) (Epoch, error) {
res, err := math.Mod64(uint64(e), x)
return Epoch(res), err
}
// HashTreeRoot --
func (e Epoch) HashTreeRoot() ([32]byte, error) {
return fssz.HashWithDefaultHasher(e)
}
// HashTreeRootWith --
func (e Epoch) HashTreeRootWith(hh *fssz.Hasher) error {
hh.PutUint64(uint64(e))
return nil
}
// UnmarshalSSZ --
func (e *Epoch) UnmarshalSSZ(buf []byte) error {
if len(buf) != e.SizeSSZ() {
return fmt.Errorf("expected buffer of length %d received %d", e.SizeSSZ(), len(buf))
}
*e = Epoch(fssz.UnmarshallUint64(buf))
return nil
}
// MarshalSSZTo --
func (e *Epoch) MarshalSSZTo(dst []byte) ([]byte, error) {
marshalled, err := e.MarshalSSZ()
if err != nil {
return nil, err
}
return append(dst, marshalled...), nil
}
// MarshalSSZ --
func (e *Epoch) MarshalSSZ() ([]byte, error) {
marshalled := fssz.MarshalUint64([]byte{}, uint64(*e))
return marshalled, nil
}
// SizeSSZ --
func (e *Epoch) SizeSSZ() int {
return 8
}

View File

@@ -0,0 +1,209 @@
package types_test
import (
"fmt"
"math"
"testing"
types "github.com/prysmaticlabs/prysm/consensus-types/primitives"
mathprysm "github.com/prysmaticlabs/prysm/math"
)
func TestEpoch_Mul(t *testing.T) {
tests := []struct {
a, b uint64
res types.Epoch
panicMsg string
}{
{a: 0, b: 1, res: 0},
{a: 1 << 32, b: 1, res: 1 << 32},
{a: 1 << 32, b: 100, res: 429496729600},
{a: 1 << 32, b: 1 << 31, res: 9223372036854775808},
{a: 1 << 32, b: 1 << 32, res: 0, panicMsg: mathprysm.ErrMulOverflow.Error()},
{a: 1 << 62, b: 2, res: 9223372036854775808},
{a: 1 << 62, b: 4, res: 0, panicMsg: mathprysm.ErrMulOverflow.Error()},
{a: 1 << 63, b: 1, res: 9223372036854775808},
{a: 1 << 63, b: 2, res: 0, panicMsg: mathprysm.ErrMulOverflow.Error()},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Epoch(%v).Mul(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Epoch
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Epoch(tt.a).Mul(tt.b)
})
} else {
res = types.Epoch(tt.a).Mul(tt.b)
}
if tt.res != res {
t.Errorf("Epoch.Mul() = %v, want %v", res, tt.res)
}
})
}
}
func TestEpoch_Div(t *testing.T) {
tests := []struct {
a, b uint64
res types.Epoch
panicMsg string
}{
{a: 0, b: 1, res: 0},
{a: 1, b: 0, res: 0, panicMsg: mathprysm.ErrDivByZero.Error()},
{a: 1 << 32, b: 1 << 32, res: 1},
{a: 429496729600, b: 1 << 32, res: 100},
{a: 9223372036854775808, b: 1 << 32, res: 1 << 31},
{a: 1 << 32, b: 1 << 32, res: 1},
{a: 9223372036854775808, b: 1 << 62, res: 2},
{a: 9223372036854775808, b: 1 << 63, res: 1},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Epoch(%v).Div(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Epoch
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Epoch(tt.a).Div(tt.b)
})
} else {
res = types.Epoch(tt.a).Div(tt.b)
}
if tt.res != res {
t.Errorf("Epoch.Div() = %v, want %v", res, tt.res)
}
})
}
}
func TestEpoch_Add(t *testing.T) {
tests := []struct {
a, b uint64
res types.Epoch
panicMsg string
}{
{a: 0, b: 1, res: 1},
{a: 1 << 32, b: 1, res: 4294967297},
{a: 1 << 32, b: 100, res: 4294967396},
{a: 1 << 31, b: 1 << 31, res: 4294967296},
{a: 1 << 63, b: 1 << 63, res: 0, panicMsg: mathprysm.ErrAddOverflow.Error()},
{a: 1 << 63, b: 1, res: 9223372036854775809},
{a: math.MaxUint64, b: 1, res: 0, panicMsg: mathprysm.ErrAddOverflow.Error()},
{a: math.MaxUint64, b: 0, res: math.MaxUint64},
{a: 1 << 63, b: 2, res: 9223372036854775810},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Epoch(%v).Add(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Epoch
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Epoch(tt.a).Add(tt.b)
})
} else {
res = types.Epoch(tt.a).Add(tt.b)
}
if tt.res != res {
t.Errorf("Epoch.Add() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Epoch(%v).AddEpoch(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Epoch
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Epoch(tt.a).AddEpoch(types.Epoch(tt.b))
})
} else {
res = types.Epoch(tt.a).AddEpoch(types.Epoch(tt.b))
}
if tt.res != res {
t.Errorf("Epoch.AddEpoch() = %v, want %v", res, tt.res)
}
})
}
}
func TestEpoch_Sub(t *testing.T) {
tests := []struct {
a, b uint64
res types.Epoch
panicMsg string
}{
{a: 1, b: 0, res: 1},
{a: 0, b: 1, res: 0, panicMsg: mathprysm.ErrSubUnderflow.Error()},
{a: 1 << 32, b: 1, res: 4294967295},
{a: 1 << 32, b: 100, res: 4294967196},
{a: 1 << 31, b: 1 << 31, res: 0},
{a: 1 << 63, b: 1 << 63, res: 0},
{a: 1 << 63, b: 1, res: 9223372036854775807},
{a: math.MaxUint64, b: math.MaxUint64, res: 0},
{a: math.MaxUint64 - 1, b: math.MaxUint64, res: 0, panicMsg: mathprysm.ErrSubUnderflow.Error()},
{a: math.MaxUint64, b: 0, res: math.MaxUint64},
{a: 1 << 63, b: 2, res: 9223372036854775806},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Epoch(%v).Sub(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Epoch
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Epoch(tt.a).Sub(tt.b)
})
} else {
res = types.Epoch(tt.a).Sub(tt.b)
}
if tt.res != res {
t.Errorf("Epoch.Sub() = %v, want %v", res, tt.res)
}
})
}
}
func TestEpoch_Mod(t *testing.T) {
tests := []struct {
a, b uint64
res types.Epoch
panicMsg string
}{
{a: 1, b: 0, res: 0, panicMsg: mathprysm.ErrDivByZero.Error()},
{a: 0, b: 1, res: 0},
{a: 1 << 32, b: 1 << 32, res: 0},
{a: 429496729600, b: 1 << 32, res: 0},
{a: 9223372036854775808, b: 1 << 32, res: 0},
{a: 1 << 32, b: 1 << 32, res: 0},
{a: 9223372036854775808, b: 1 << 62, res: 0},
{a: 9223372036854775808, b: 1 << 63, res: 0},
{a: 1 << 32, b: 17, res: 1},
{a: 1 << 32, b: 19, res: (1 << 32) % 19},
{a: math.MaxUint64, b: math.MaxUint64, res: 0},
{a: 1 << 63, b: 2, res: 0},
{a: 1<<63 + 1, b: 2, res: 1},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Epoch(%v).Mod(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Epoch
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Epoch(tt.a).Mod(tt.b)
})
} else {
res = types.Epoch(tt.a).Mod(tt.b)
}
if tt.res != res {
t.Errorf("Epoch.Mod() = %v, want %v", res, tt.res)
}
})
}
}
func assertPanic(t *testing.T, panicMessage string, f func()) {
defer func() {
if r := recover(); r == nil {
t.Errorf("Expected panic not thrown")
} else if r != panicMessage {
t.Errorf("Unexpected panic thrown, want: %#v, got: %#v", panicMessage, r)
}
}()
f()
}

View File

@@ -0,0 +1,200 @@
package types
import (
fmt "fmt"
fssz "github.com/ferranbt/fastssz"
"github.com/prysmaticlabs/prysm/math"
)
var _ fssz.HashRoot = (Slot)(0)
var _ fssz.Marshaler = (*Slot)(nil)
var _ fssz.Unmarshaler = (*Slot)(nil)
// Slot represents a single slot.
type Slot uint64
// Mul multiplies slot by x.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) Mul(x uint64) Slot {
res, err := s.SafeMul(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeMul multiplies slot by x.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeMul(x uint64) (Slot, error) {
res, err := math.Mul64(uint64(s), x)
return Slot(res), err
}
// MulSlot multiplies slot by another slot.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) MulSlot(x Slot) Slot {
return s.Mul(uint64(x))
}
// SafeMulSlot multiplies slot by another slot.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeMulSlot(x Slot) (Slot, error) {
return s.SafeMul(uint64(x))
}
// Div divides slot by x.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) Div(x uint64) Slot {
res, err := s.SafeDiv(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeDiv divides slot by x.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeDiv(x uint64) (Slot, error) {
res, err := math.Div64(uint64(s), x)
return Slot(res), err
}
// DivSlot divides slot by another slot.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) DivSlot(x Slot) Slot {
return s.Div(uint64(x))
}
// SafeDivSlot divides slot by another slot.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeDivSlot(x Slot) (Slot, error) {
return s.SafeDiv(uint64(x))
}
// Add increases slot by x.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) Add(x uint64) Slot {
res, err := s.SafeAdd(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeAdd increases slot by x.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeAdd(x uint64) (Slot, error) {
res, err := math.Add64(uint64(s), x)
return Slot(res), err
}
// AddSlot increases slot by another slot.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) AddSlot(x Slot) Slot {
return s.Add(uint64(x))
}
// SafeAddSlot increases slot by another slot.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeAddSlot(x Slot) (Slot, error) {
return s.SafeAdd(uint64(x))
}
// Sub subtracts x from the slot.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) Sub(x uint64) Slot {
res, err := s.SafeSub(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeSub subtracts x from the slot.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeSub(x uint64) (Slot, error) {
res, err := math.Sub64(uint64(s), x)
return Slot(res), err
}
// SubSlot finds difference between two slot values.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) SubSlot(x Slot) Slot {
return s.Sub(uint64(x))
}
// SafeSubSlot finds difference between two slot values.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeSubSlot(x Slot) (Slot, error) {
return s.SafeSub(uint64(x))
}
// Mod returns result of `slot % x`.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) Mod(x uint64) Slot {
res, err := s.SafeMod(x)
if err != nil {
panic(err.Error())
}
return res
}
// SafeMod returns result of `slot % x`.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeMod(x uint64) (Slot, error) {
res, err := math.Mod64(uint64(s), x)
return Slot(res), err
}
// ModSlot returns result of `slot % slot`.
// In case of arithmetic issues (overflow/underflow/div by zero) panic is thrown.
func (s Slot) ModSlot(x Slot) Slot {
return s.Mod(uint64(x))
}
// SafeModSlot returns result of `slot % slot`.
// In case of arithmetic issues (overflow/underflow/div by zero) error is returned.
func (s Slot) SafeModSlot(x Slot) (Slot, error) {
return s.SafeMod(uint64(x))
}
// HashTreeRoot --
func (s Slot) HashTreeRoot() ([32]byte, error) {
return fssz.HashWithDefaultHasher(s)
}
// HashTreeRootWith --
func (s Slot) HashTreeRootWith(hh *fssz.Hasher) error {
hh.PutUint64(uint64(s))
return nil
}
// UnmarshalSSZ --
func (s *Slot) UnmarshalSSZ(buf []byte) error {
if len(buf) != s.SizeSSZ() {
return fmt.Errorf("expected buffer of length %d received %d", s.SizeSSZ(), len(buf))
}
*s = Slot(fssz.UnmarshallUint64(buf))
return nil
}
// MarshalSSZTo --
func (s *Slot) MarshalSSZTo(dst []byte) ([]byte, error) {
marshalled, err := s.MarshalSSZ()
if err != nil {
return nil, err
}
return append(dst, marshalled...), nil
}
// MarshalSSZ --
func (s *Slot) MarshalSSZ() ([]byte, error) {
marshalled := fssz.MarshalUint64([]byte{}, uint64(*s))
return marshalled, nil
}
// SizeSSZ --
func (s *Slot) SizeSSZ() int {
return 8
}

View File

@@ -0,0 +1,329 @@
package types_test
import (
"fmt"
"math"
"testing"
"time"
types "github.com/prysmaticlabs/eth2-types"
)
func TestSlot_Casting(t *testing.T) {
slot := types.Slot(42)
t.Run("time.Duration", func(t *testing.T) {
if uint64(time.Duration(slot)) != uint64(slot) {
t.Error("Slot should produce the same result with time.Duration")
}
})
t.Run("floats", func(t *testing.T) {
var x1 float32 = 42.2
if types.Slot(x1) != slot {
t.Errorf("Unequal: %v = %v", types.Slot(x1), slot)
}
var x2 float64 = 42.2
if types.Slot(x2) != slot {
t.Errorf("Unequal: %v = %v", types.Slot(x2), slot)
}
})
t.Run("int", func(t *testing.T) {
var x int = 42
if types.Slot(x) != slot {
t.Errorf("Unequal: %v = %v", types.Slot(x), slot)
}
})
}
func TestSlot_Mul(t *testing.T) {
tests := []struct {
a, b uint64
res types.Slot
panicMsg string
}{
{a: 0, b: 1, res: 0},
{a: 1 << 32, b: 1, res: 1 << 32},
{a: 1 << 32, b: 100, res: 429496729600},
{a: 1 << 32, b: 1 << 31, res: 9223372036854775808},
{a: 1 << 32, b: 1 << 32, res: 0, panicMsg: types.ErrMulOverflow.Error()},
{a: 1 << 62, b: 2, res: 9223372036854775808},
{a: 1 << 62, b: 4, res: 0, panicMsg: types.ErrMulOverflow.Error()},
{a: 1 << 63, b: 1, res: 9223372036854775808},
{a: 1 << 63, b: 2, res: 0, panicMsg: types.ErrMulOverflow.Error()},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Slot(%v).Mul(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).Mul(tt.b)
})
} else {
res = types.Slot(tt.a).Mul(tt.b)
}
if tt.res != res {
t.Errorf("Slot.Mul() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).MulSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).MulSlot(types.Slot(tt.b))
})
} else {
res = types.Slot(tt.a).MulSlot(types.Slot(tt.b))
}
if tt.res != res {
t.Errorf("Slot.MulSlot() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).SafeMulSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
res, err := types.Slot(tt.a).SafeMulSlot(types.Slot(tt.b))
if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) {
t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err)
return
}
if tt.res != res {
t.Errorf("Slot.SafeMulSlot() = %v, want %v", res, tt.res)
}
})
}
}
func TestSlot_Div(t *testing.T) {
tests := []struct {
a, b uint64
res types.Slot
panicMsg string
}{
{a: 0, b: 1, res: 0},
{a: 1, b: 0, res: 0, panicMsg: types.ErrDivByZero.Error()},
{a: 1 << 32, b: 1 << 32, res: 1},
{a: 429496729600, b: 1 << 32, res: 100},
{a: 9223372036854775808, b: 1 << 32, res: 1 << 31},
{a: 1 << 32, b: 1 << 32, res: 1},
{a: 9223372036854775808, b: 1 << 62, res: 2},
{a: 9223372036854775808, b: 1 << 63, res: 1},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Slot(%v).Div(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).Div(tt.b)
})
} else {
res = types.Slot(tt.a).Div(tt.b)
}
if tt.res != res {
t.Errorf("Slot.Div() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).DivSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).DivSlot(types.Slot(tt.b))
})
} else {
res = types.Slot(tt.a).DivSlot(types.Slot(tt.b))
}
if tt.res != res {
t.Errorf("Slot.DivSlot() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).SafeDivSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
res, err := types.Slot(tt.a).SafeDivSlot(types.Slot(tt.b))
if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) {
t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err)
return
}
if tt.res != res {
t.Errorf("Slot.SafeDivSlot() = %v, want %v", res, tt.res)
}
})
}
}
func TestSlot_Add(t *testing.T) {
tests := []struct {
a, b uint64
res types.Slot
panicMsg string
}{
{a: 0, b: 1, res: 1},
{a: 1 << 32, b: 1, res: 4294967297},
{a: 1 << 32, b: 100, res: 4294967396},
{a: 1 << 31, b: 1 << 31, res: 4294967296},
{a: 1 << 63, b: 1 << 63, res: 0, panicMsg: types.ErrAddOverflow.Error()},
{a: 1 << 63, b: 1, res: 9223372036854775809},
{a: math.MaxUint64, b: 1, res: 0, panicMsg: types.ErrAddOverflow.Error()},
{a: math.MaxUint64, b: 0, res: math.MaxUint64},
{a: 1 << 63, b: 2, res: 9223372036854775810},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Slot(%v).Add(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).Add(tt.b)
})
} else {
res = types.Slot(tt.a).Add(tt.b)
}
if tt.res != res {
t.Errorf("Slot.Add() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).AddSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).AddSlot(types.Slot(tt.b))
})
} else {
res = types.Slot(tt.a).AddSlot(types.Slot(tt.b))
}
if tt.res != res {
t.Errorf("Slot.AddSlot() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).SafeAddSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
res, err := types.Slot(tt.a).SafeAddSlot(types.Slot(tt.b))
if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) {
t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err)
return
}
if tt.res != res {
t.Errorf("Slot.SafeAddSlot() = %v, want %v", res, tt.res)
}
})
}
}
func TestSlot_Sub(t *testing.T) {
tests := []struct {
a, b uint64
res types.Slot
panicMsg string
}{
{a: 1, b: 0, res: 1},
{a: 0, b: 1, res: 0, panicMsg: types.ErrSubUnderflow.Error()},
{a: 1 << 32, b: 1, res: 4294967295},
{a: 1 << 32, b: 100, res: 4294967196},
{a: 1 << 31, b: 1 << 31, res: 0},
{a: 1 << 63, b: 1 << 63, res: 0},
{a: 1 << 63, b: 1, res: 9223372036854775807},
{a: math.MaxUint64, b: math.MaxUint64, res: 0},
{a: math.MaxUint64 - 1, b: math.MaxUint64, res: 0, panicMsg: types.ErrSubUnderflow.Error()},
{a: math.MaxUint64, b: 0, res: math.MaxUint64},
{a: 1 << 63, b: 2, res: 9223372036854775806},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Slot(%v).Sub(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).Sub(tt.b)
})
} else {
res = types.Slot(tt.a).Sub(tt.b)
}
if tt.res != res {
t.Errorf("Slot.Sub() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).SubSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).SubSlot(types.Slot(tt.b))
})
} else {
res = types.Slot(tt.a).SubSlot(types.Slot(tt.b))
}
if tt.res != res {
t.Errorf("Slot.SubSlot() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).SafeSubSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
res, err := types.Slot(tt.a).SafeSubSlot(types.Slot(tt.b))
if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) {
t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err)
return
}
if tt.res != res {
t.Errorf("Slot.SafeSubSlot() = %v, want %v", res, tt.res)
}
})
}
}
func TestSlot_Mod(t *testing.T) {
tests := []struct {
a, b uint64
res types.Slot
panicMsg string
}{
{a: 1, b: 0, res: 0, panicMsg: types.ErrDivByZero.Error()},
{a: 0, b: 1, res: 0},
{a: 1 << 32, b: 1 << 32, res: 0},
{a: 429496729600, b: 1 << 32, res: 0},
{a: 9223372036854775808, b: 1 << 32, res: 0},
{a: 1 << 32, b: 1 << 32, res: 0},
{a: 9223372036854775808, b: 1 << 62, res: 0},
{a: 9223372036854775808, b: 1 << 63, res: 0},
{a: 1 << 32, b: 17, res: 1},
{a: 1 << 32, b: 19, res: (1 << 32) % 19},
{a: math.MaxUint64, b: math.MaxUint64, res: 0},
{a: 1 << 63, b: 2, res: 0},
{a: 1<<63 + 1, b: 2, res: 1},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("Slot(%v).Mod(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).Mod(tt.b)
})
} else {
res = types.Slot(tt.a).Mod(tt.b)
}
if tt.res != res {
t.Errorf("Slot.Mod() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).ModSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
var res types.Slot
if tt.panicMsg != "" {
assertPanic(t, tt.panicMsg, func() {
res = types.Slot(tt.a).ModSlot(types.Slot(tt.b))
})
} else {
res = types.Slot(tt.a).ModSlot(types.Slot(tt.b))
}
if tt.res != res {
t.Errorf("Slot.Mod() = %v, want %v", res, tt.res)
}
})
t.Run(fmt.Sprintf("Slot(%v).SafeModSlot(%v) = %v", tt.a, tt.b, tt.res), func(t *testing.T) {
res, err := types.Slot(tt.a).SafeModSlot(types.Slot(tt.b))
if tt.panicMsg != "" && (err == nil || err.Error() != tt.panicMsg) {
t.Errorf("Expected error not thrown, wanted: %v, got: %v", tt.panicMsg, err)
return
}
if tt.res != res {
t.Errorf("Slot.SafeModSlot() = %v, want %v", res, tt.res)
}
})
}
}

View File

@@ -0,0 +1,21 @@
package types
import (
fssz "github.com/ferranbt/fastssz"
)
// SSZBytes --
type SSZBytes []byte
// HashTreeRoot --
func (b *SSZBytes) HashTreeRoot() ([32]byte, error) {
return fssz.HashWithDefaultHasher(b)
}
// HashTreeRootWith --
func (b *SSZBytes) HashTreeRootWith(hh *fssz.Hasher) error {
indx := hh.Index()
hh.PutBytes(*b)
hh.Merkleize(indx)
return nil
}

View File

@@ -0,0 +1,57 @@
package types_test
import (
"encoding/hex"
"reflect"
"testing"
types "github.com/prysmaticlabs/eth2-types"
)
func TestSSZBytes_HashTreeRoot(t *testing.T) {
tests := []struct {
name string
actualValue []byte
root []byte
wantErr bool
}{
{
name: "random1",
actualValue: hexDecodeOrDie(t, "844e1063e0b396eed17be8eddb7eecd1fe3ea46542a4b72f7466e77325e5aa6d"),
root: hexDecodeOrDie(t, "844e1063e0b396eed17be8eddb7eecd1fe3ea46542a4b72f7466e77325e5aa6d"),
wantErr: false,
},
{
name: "random1",
actualValue: hexDecodeOrDie(t, "7b16162ecd9a28fa80a475080b0e4fff4c27efe19ce5134ce3554b72274d59fd534400ba4c7f699aa1c307cd37c2b103"),
root: hexDecodeOrDie(t, "128ed34ee798b9f00716f9ba5c000df5c99443dabc4d3f2e9bb86c77c732e007"),
wantErr: false,
},
{
name: "random2",
actualValue: []byte{},
root: hexDecodeOrDie(t, "0000000000000000000000000000000000000000000000000000000000000000"),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := types.SSZBytes(tt.actualValue)
htr, err := s.HashTreeRoot()
if err != nil {
t.Errorf("SSZBytes.HashTreeRoot() unexpected error = %v", err)
}
if !reflect.DeepEqual(tt.root, htr[:]) {
t.Errorf("SSZBytes.HashTreeRoot() = %v, want %v", htr[:], tt.root)
}
})
}
}
func hexDecodeOrDie(t *testing.T, str string) []byte {
decoded, err := hex.DecodeString(str)
if err != nil {
t.Errorf("hex.DecodeString(%s) unexpected error = %v", str, err)
}
return decoded
}

View File

@@ -0,0 +1,61 @@
package types
import (
"encoding/binary"
"fmt"
fssz "github.com/ferranbt/fastssz"
)
var _ fssz.HashRoot = (Epoch)(0)
var _ fssz.Marshaler = (*Epoch)(nil)
var _ fssz.Unmarshaler = (*Epoch)(nil)
// SSZUint64 --
type SSZUint64 uint64
// SizeSSZ --
func (s *SSZUint64) SizeSSZ() int {
return 8
}
// MarshalSSZTo --
func (s *SSZUint64) MarshalSSZTo(dst []byte) ([]byte, error) {
marshalled, err := s.MarshalSSZ()
if err != nil {
return nil, err
}
return append(dst, marshalled...), nil
}
// MarshalSSZ --
func (s *SSZUint64) MarshalSSZ() ([]byte, error) {
marshalled := fssz.MarshalUint64([]byte{}, uint64(*s))
return marshalled, nil
}
// UnmarshalSSZ --
func (s *SSZUint64) UnmarshalSSZ(buf []byte) error {
if len(buf) != s.SizeSSZ() {
return fmt.Errorf("expected buffer of length %d received %d", s.SizeSSZ(), len(buf))
}
*s = SSZUint64(fssz.UnmarshallUint64(buf))
return nil
}
// HashTreeRoot --
func (s *SSZUint64) HashTreeRoot() ([32]byte, error) {
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, uint64(*s))
var root [32]byte
copy(root[:], buf)
return root, nil
}
// HashTreeRootWith --
func (s *SSZUint64) HashTreeRootWith(hh *fssz.Hasher) error {
indx := hh.Index()
hh.PutUint64(uint64(*s))
hh.Merkleize(indx)
return nil
}

View File

@@ -0,0 +1,96 @@
package types_test
import (
"reflect"
"strings"
"testing"
types "github.com/prysmaticlabs/eth2-types"
)
func TestSSZUint64_Limit(t *testing.T) {
sszType := types.SSZUint64(0)
serializedObj := [7]byte{}
err := sszType.UnmarshalSSZ(serializedObj[:])
if err == nil || !strings.Contains(err.Error(), "expected buffer of length") {
t.Errorf("Expected Error = %s, got: %v", "expected buffer of length", err)
}
}
func TestSSZUint64_RoundTrip(t *testing.T) {
fixedVal := uint64(8)
sszVal := types.SSZUint64(fixedVal)
marshalledObj, err := sszVal.MarshalSSZ()
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
newVal := types.SSZUint64(0)
err = newVal.UnmarshalSSZ(marshalledObj)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if fixedVal != uint64(newVal) {
t.Errorf("Unequal: %v = %v", fixedVal, uint64(newVal))
}
}
func TestSSZUint64(t *testing.T) {
tests := []struct {
name string
serializedBytes []byte
actualValue uint64
root []byte
wantErr bool
}{
{
name: "max",
serializedBytes: hexDecodeOrDie(t, "ffffffffffffffff"),
actualValue: 18446744073709551615,
root: hexDecodeOrDie(t, "ffffffffffffffff000000000000000000000000000000000000000000000000"),
wantErr: false,
},
{
name: "random",
serializedBytes: hexDecodeOrDie(t, "357c8de9d7204577"),
actualValue: 8594311575614880821,
root: hexDecodeOrDie(t, "357c8de9d7204577000000000000000000000000000000000000000000000000"),
wantErr: false,
},
{
name: "zero",
serializedBytes: hexDecodeOrDie(t, "0000000000000000"),
actualValue: 0,
root: hexDecodeOrDie(t, "0000000000000000000000000000000000000000000000000000000000000000"),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var s types.SSZUint64
if err := s.UnmarshalSSZ(tt.serializedBytes); (err != nil) != tt.wantErr {
t.Errorf("SSZUint64.UnmarshalSSZ() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.actualValue != uint64(s) {
t.Errorf("SSZUint64.UnmarshalSSZ() = %v, want %v", uint64(s), tt.actualValue)
}
serializedBytes, err := s.MarshalSSZ()
if err != nil {
t.Errorf("SSZUint64.MarshalSSZ() unexpected error = %v", err)
}
if !reflect.DeepEqual(tt.serializedBytes, serializedBytes) {
t.Errorf("SSZUint64.MarshalSSZ() = %v, want %v", serializedBytes, tt.serializedBytes)
}
htr, err := s.HashTreeRoot()
if err != nil {
t.Errorf("SSZUint64.HashTreeRoot() unexpected error = %v", err)
}
if !reflect.DeepEqual(tt.root, htr[:]) {
t.Errorf("SSZUint64.HashTreeRoot() = %v, want %v", htr[:], tt.root)
}
})
}
}

View File

@@ -0,0 +1,80 @@
package types
import (
fmt "fmt"
fssz "github.com/ferranbt/fastssz"
)
var _ fssz.HashRoot = (ValidatorIndex)(0)
var _ fssz.Marshaler = (*ValidatorIndex)(nil)
var _ fssz.Unmarshaler = (*ValidatorIndex)(nil)
// ValidatorIndex in eth2.
type ValidatorIndex uint64
// Div divides validator index by x.
func (v ValidatorIndex) Div(x uint64) ValidatorIndex {
if x == 0 {
panic("divbyzero")
}
return ValidatorIndex(uint64(v) / x)
}
// Add increases validator index by x.
func (v ValidatorIndex) Add(x uint64) ValidatorIndex {
return ValidatorIndex(uint64(v) + x)
}
// Sub subtracts x from the validator index.
func (v ValidatorIndex) Sub(x uint64) ValidatorIndex {
if uint64(v) < x {
panic("underflow")
}
return ValidatorIndex(uint64(v) - x)
}
// Mod returns result of `validator index % x`.
func (v ValidatorIndex) Mod(x uint64) ValidatorIndex {
return ValidatorIndex(uint64(v) % x)
}
// HashTreeRoot --
func (v ValidatorIndex) HashTreeRoot() ([32]byte, error) {
return fssz.HashWithDefaultHasher(v)
}
// HashTreeRootWith --
func (v ValidatorIndex) HashTreeRootWith(hh *fssz.Hasher) error {
hh.PutUint64(uint64(v))
return nil
}
// UnmarshalSSZ --
func (v *ValidatorIndex) UnmarshalSSZ(buf []byte) error {
if len(buf) != v.SizeSSZ() {
return fmt.Errorf("expected buffer of length %d received %d", v.SizeSSZ(), len(buf))
}
*v = ValidatorIndex(fssz.UnmarshallUint64(buf))
return nil
}
// MarshalSSZTo --
func (v *ValidatorIndex) MarshalSSZTo(dst []byte) ([]byte, error) {
marshalled, err := v.MarshalSSZ()
if err != nil {
return nil, err
}
return append(dst, marshalled...), nil
}
// MarshalSSZ --
func (v *ValidatorIndex) MarshalSSZ() ([]byte, error) {
marshalled := fssz.MarshalUint64([]byte{}, uint64(*v))
return marshalled, nil
}
// SizeSSZ --
func (v *ValidatorIndex) SizeSSZ() int {
return 8
}

View File

@@ -0,0 +1,35 @@
package types
import (
"testing"
"time"
)
func TestValidatorIndex_Casting(t *testing.T) {
valIdx := ValidatorIndex(42)
t.Run("time.Duration", func(t *testing.T) {
if uint64(time.Duration(valIdx)) != uint64(valIdx) {
t.Error("ValidatorIndex should produce the same result with time.Duration")
}
})
t.Run("floats", func(t *testing.T) {
var x1 float32 = 42.2
if ValidatorIndex(x1) != valIdx {
t.Errorf("Unequal: %v = %v", ValidatorIndex(x1), valIdx)
}
var x2 float64 = 42.2
if ValidatorIndex(x2) != valIdx {
t.Errorf("Unequal: %v = %v", ValidatorIndex(x2), valIdx)
}
})
t.Run("int", func(t *testing.T) {
var x int = 42
if ValidatorIndex(x) != valIdx {
t.Errorf("Unequal: %v = %v", ValidatorIndex(x), valIdx)
}
})
}

View File

@@ -15,5 +15,6 @@ go_test(
deps = [
":go_default_library",
"//testing/require:go_default_library",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
],
)

View File

@@ -20,7 +20,13 @@ func init() {
}
// ErrOverflow occurs when an operation exceeds max or minimum values.
var ErrOverflow = errors.New("integer overflow")
var (
ErrOverflow = errors.New("integer overflow")
ErrDivByZero = errors.New("integer divide by zero")
ErrMulOverflow = errors.New("multiplication overflows")
ErrAddOverflow = errors.New("addition overflows")
ErrSubUnderflow = errors.New("subtraction underflow")
)
// Common square root values.
var squareRootTable = map[uint64]uint64{
@@ -115,6 +121,15 @@ func Mul64(a, b uint64) (uint64, error) {
return val, nil
}
// Div64 divides two 64-bit unsigned integers and checks for errors.
func Div64(a, b uint64) (uint64, error) {
if b == 0 {
return 0, ErrDivByZero
}
val, _ := bits.Div64(0, a, b)
return val, nil
}
// Add64 adds 2 64-bit unsigned integers and checks if they
// lead to an overflow. If they do not, it returns the result
// without an error.
@@ -135,6 +150,15 @@ func Sub64(a, b uint64) (uint64, error) {
return res, nil
}
// Mod64 finds remainder of division of two 64-bit unsigned integers and checks for errors.
func Mod64(a, b uint64) (uint64, error) {
if b == 0 {
return 0, ErrDivByZero
}
_, val := bits.Div64(0, a, b)
return val, nil
}
// Int returns the integer value of the uint64 argument. If there is an overlow, then an error is
// returned.
func Int(u uint64) (int, error) {

View File

@@ -5,6 +5,7 @@ import (
stdmath "math"
"testing"
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/prysm/math"
"github.com/prysmaticlabs/prysm/testing/require"
)
@@ -85,6 +86,74 @@ func TestIntegerSquareRoot(t *testing.T) {
}
}
func TestMath_Div64(t *testing.T) {
type args struct {
a uint64
b uint64
}
tests := []struct {
args args
res uint64
err bool
}{
{args: args{0, 1}, res: 0, err: false},
{args: args{0, 1}, res: 0},
{args: args{1, 0}, res: 0, err: true},
{args: args{1 << 32, 1 << 32}, res: 1},
{args: args{429496729600, 1 << 32}, res: 100},
{args: args{9223372036854775808, 1 << 32}, res: 1 << 31},
{args: args{a: 1 << 32, b: 1 << 32}, res: 1},
{args: args{9223372036854775808, 1 << 62}, res: 2},
{args: args{9223372036854775808, 1 << 63}, res: 1},
}
for _, tt := range tests {
got, err := math.Div64(tt.args.a, tt.args.b)
if tt.err && err == nil {
t.Errorf("Div64() Expected Error = %v, want error", tt.err)
continue
}
if tt.res != got {
t.Errorf("Div64() %v, want %v", got, tt.res)
}
}
}
func TestMath_Mod(t *testing.T) {
type args struct {
a uint64
b uint64
}
tests := []struct {
args args
res uint64
err bool
}{
{args: args{1, 0}, res: 0, err: true},
{args: args{0, 1}, res: 0},
{args: args{1 << 32, 1 << 32}, res: 0},
{args: args{429496729600, 1 << 32}, res: 0},
{args: args{9223372036854775808, 1 << 32}, res: 0},
{args: args{1 << 32, 1 << 32}, res: 0},
{args: args{9223372036854775808, 1 << 62}, res: 0},
{args: args{9223372036854775808, 1 << 63}, res: 0},
{args: args{1 << 32, 17}, res: 1},
{args: args{1 << 32, 19}, res: (1 << 32) % 19},
{args: args{stdmath.MaxUint64, stdmath.MaxUint64}, res: 0},
{args: args{1 << 63, 2}, res: 0},
{args: args{1<<63 + 1, 2}, res: 1},
}
for _, tt := range tests {
got, err := types.Mod64(tt.args.a, tt.args.b)
if tt.err && err == nil {
t.Errorf("Mod64() Expected Error = %v, want error", tt.err)
continue
}
if tt.res != got {
t.Errorf("Mod64() %v, want %v", got, tt.res)
}
}
}
func BenchmarkIntegerSquareRootBelow52Bits(b *testing.B) {
val := uint64(1 << 33)
for i := 0; i < b.N; i++ {