mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-08 07:03:58 -05:00
Faster and cached square root (#12191)
* Faster and cached square root * deal with 0 * Rename function --------- Co-authored-by: terencechain <terence@prysmaticlabs.com>
This commit is contained in:
@@ -265,7 +265,7 @@ func AttestationsDelta(beaconState state.BeaconState, bal *precompute.Balance, v
|
||||
finalizedEpoch := beaconState.FinalizedCheckpointEpoch()
|
||||
increment := cfg.EffectiveBalanceIncrement
|
||||
factor := cfg.BaseRewardFactor
|
||||
baseRewardMultiplier := increment * factor / math.IntegerSquareRoot(bal.ActiveCurrentEpoch)
|
||||
baseRewardMultiplier := increment * factor / math.CachedSquareRoot(bal.ActiveCurrentEpoch)
|
||||
leak := helpers.IsInInactivityLeak(prevEpoch, finalizedEpoch)
|
||||
|
||||
// Modified in Altair and Bellatrix.
|
||||
|
||||
@@ -58,5 +58,5 @@ func BaseRewardPerIncrement(activeBalance uint64) (uint64, error) {
|
||||
return 0, errors.New("active balance can't be 0")
|
||||
}
|
||||
cfg := params.BeaconConfig()
|
||||
return cfg.EffectiveBalanceIncrement * cfg.BaseRewardFactor / math.IntegerSquareRoot(activeBalance), nil
|
||||
return cfg.EffectiveBalanceIncrement * cfg.BaseRewardFactor / math.CachedSquareRoot(activeBalance), nil
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func AttestationsDelta(state state.ReadOnlyBeaconState, pBal *Balance, vp []*Val
|
||||
prevEpoch := time.PrevEpoch(state)
|
||||
finalizedEpoch := state.FinalizedCheckpointEpoch()
|
||||
|
||||
sqrtActiveCurrentEpoch := math.IntegerSquareRoot(pBal.ActiveCurrentEpoch)
|
||||
sqrtActiveCurrentEpoch := math.CachedSquareRoot(pBal.ActiveCurrentEpoch)
|
||||
for i, v := range vp {
|
||||
rewards[i], penalties[i] = attestationDelta(pBal, sqrtActiveCurrentEpoch, v, prevEpoch, finalizedEpoch)
|
||||
}
|
||||
@@ -161,7 +161,7 @@ func ProposersDelta(state state.ReadOnlyBeaconState, pBal *Balance, vp []*Valida
|
||||
rewards := make([]uint64, numofVals)
|
||||
|
||||
totalBalance := pBal.ActiveCurrentEpoch
|
||||
balanceSqrt := math.IntegerSquareRoot(totalBalance)
|
||||
balanceSqrt := math.CachedSquareRoot(totalBalance)
|
||||
// Balance square root cannot be 0, this prevents division by 0.
|
||||
if balanceSqrt == 0 {
|
||||
balanceSqrt = 1
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
stdmath "math"
|
||||
"math/bits"
|
||||
"sync"
|
||||
|
||||
"github.com/thomaso-mirodin/intmath/u64"
|
||||
)
|
||||
@@ -26,6 +27,12 @@ var (
|
||||
ErrMulOverflow = errors.New("multiplication overflows")
|
||||
ErrAddOverflow = errors.New("addition overflows")
|
||||
ErrSubUnderflow = errors.New("subtraction underflow")
|
||||
|
||||
// Sensible guess for 500 000 validators
|
||||
cachedSquareRoot = struct {
|
||||
sync.Mutex
|
||||
squareRoot, balance uint64
|
||||
}{squareRoot: 126491106, balance: 15999999897103236}
|
||||
)
|
||||
|
||||
// Common square root values.
|
||||
@@ -43,6 +50,28 @@ var squareRootTable = map[uint64]uint64{
|
||||
4194304: 2048,
|
||||
}
|
||||
|
||||
// CachedSquareRoot implements Newton's algorithm to compute the square root of
|
||||
// the given uint64 starting from the last cached value
|
||||
func CachedSquareRoot(balance uint64) uint64 {
|
||||
if balance == 0 {
|
||||
return 0
|
||||
}
|
||||
cachedSquareRoot.Lock()
|
||||
defer cachedSquareRoot.Unlock()
|
||||
if balance == cachedSquareRoot.balance {
|
||||
return cachedSquareRoot.squareRoot
|
||||
}
|
||||
cachedSquareRoot.balance = balance
|
||||
val := balance / cachedSquareRoot.squareRoot
|
||||
for {
|
||||
cachedSquareRoot.squareRoot = (cachedSquareRoot.squareRoot + val) / 2
|
||||
val = balance / cachedSquareRoot.squareRoot
|
||||
if cachedSquareRoot.squareRoot <= val {
|
||||
return cachedSquareRoot.squareRoot
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IntegerSquareRoot defines a function that returns the
|
||||
// largest possible integer root of a number using go's standard library.
|
||||
func IntegerSquareRoot(n uint64) uint64 {
|
||||
|
||||
@@ -82,6 +82,7 @@ func TestIntegerSquareRoot(t *testing.T) {
|
||||
|
||||
for _, testVals := range tt {
|
||||
require.Equal(t, testVals.root, math.IntegerSquareRoot(testVals.number))
|
||||
require.Equal(t, testVals.root, math.CachedSquareRoot(testVals.number))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,6 +168,35 @@ func BenchmarkIntegerSquareRootAbove52Bits(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSquareRootEffectiveBalance(b *testing.B) {
|
||||
val := uint64(1 << 62)
|
||||
for i := 0; i < b.N; i++ {
|
||||
require.Equal(b, uint64(1<<31), math.CachedSquareRoot(val))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSquareRootBabylonian(b *testing.B) {
|
||||
//Start with 700K validators' effective balance
|
||||
val := uint64(22400000000000000)
|
||||
for i := 0; i < b.N; i++ {
|
||||
sqr := math.CachedSquareRoot(val)
|
||||
require.Equal(b, true, sqr^2 <= val)
|
||||
require.Equal(b, true, (sqr+1)*(sqr+1) > val)
|
||||
val += 10_000_000_000
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSquareRootOldWay(b *testing.B) {
|
||||
//Start with 700K validators' effective balance
|
||||
val := uint64(22400000000000000)
|
||||
for i := 0; i < b.N; i++ {
|
||||
sqr := math.IntegerSquareRoot(val)
|
||||
require.Equal(b, true, sqr^2 <= val)
|
||||
require.Equal(b, true, (sqr+1)*(sqr+1) > val)
|
||||
val += 10_000_000_000
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIntegerSquareRoot_WithDatatable(b *testing.B) {
|
||||
val := uint64(1024)
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
Reference in New Issue
Block a user