diff --git a/config/fieldparams/mainnet.go b/config/fieldparams/mainnet.go index 8073d41df3..cfb2d2b03a 100644 --- a/config/fieldparams/mainnet.go +++ b/config/fieldparams/mainnet.go @@ -15,6 +15,8 @@ const ( SyncCommitteeLength = 512 // SYNC_COMMITTEE_SIZE RootLength = 32 // RootLength defines the byte length of a Merkle root. BLSSignatureLength = 96 // BLSSignatureLength defines the byte length of a BLSSignature. + MaxTxsPerPayloadLength = 1048576 // MaxTxsPerPayloadLength defines the maximum number of transactions that can be included in a payload. + MaxBytesPerTxLength = 1073741824 // MaxBytesPerTxLength defines the maximum number of bytes that can be included in a transaction. FeeRecipientLength = 20 // FeeRecipientLength defines the byte length of a fee recipient. LogsBloomLength = 256 // LogsBloomLength defines the byte length of a logs bloom. ) diff --git a/config/fieldparams/minimal.go b/config/fieldparams/minimal.go index 8e445a01eb..227816bfec 100644 --- a/config/fieldparams/minimal.go +++ b/config/fieldparams/minimal.go @@ -15,6 +15,8 @@ const ( SyncCommitteeLength = 32 // SYNC_COMMITTEE_SIZE RootLength = 32 // RootLength defines the byte length of a Merkle root. BLSSignatureLength = 96 // BLSSignatureLength defines the byte length of a BLSSignature. + MaxTxsPerPayloadLength = 1048576 // MaxTxsPerPayloadLength defines the maximum number of transactions that can be included in a payload. + MaxBytesPerTxLength = 1073741824 // MaxBytesPerTxLength defines the maximum number of bytes that can be included in a transaction. FeeRecipientLength = 20 // FeeRecipientLength defines the byte length of a fee recipient. LogsBloomLength = 256 // LogsBloomLength defines the byte length of a logs bloom. ) diff --git a/encoding/ssz/BUILD.bazel b/encoding/ssz/BUILD.bazel index 166507f4a4..78378fd63c 100644 --- a/encoding/ssz/BUILD.bazel +++ b/encoding/ssz/BUILD.bazel @@ -37,6 +37,7 @@ go_test( ], deps = [ ":go_default_library", + "//config/fieldparams:go_default_library", "//crypto/hash:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//testing/assert:go_default_library", diff --git a/encoding/ssz/htrutils.go b/encoding/ssz/htrutils.go index dec79e84e0..358f2446a6 100644 --- a/encoding/ssz/htrutils.go +++ b/encoding/ssz/htrutils.go @@ -90,3 +90,82 @@ func SlashingsRoot(slashings []uint64) ([32]byte, error) { } return BitwiseMerkleize(hash.CustomSHA256Hasher(), slashingChunks, uint64(len(slashingChunks)), uint64(len(slashingChunks))) } + +// TransactionsRoot computes the HTR for the Transactions' property of the ExecutionPayload +// The code was largely copy/pasted from the code generated to compute the HTR of the entire +// ExecutionPayload. +func TransactionsRoot(txs [][]byte) ([32]byte, error) { + hasher := hash.CustomSHA256Hasher() + listMarshaling := make([][]byte, 0) + for i := 0; i < len(txs); i++ { + rt, err := transactionRoot(txs[i]) + if err != nil { + return [32]byte{}, err + } + listMarshaling = append(listMarshaling, rt[:]) + } + + bytesRoot, err := BitwiseMerkleize(hasher, listMarshaling, uint64(len(listMarshaling)), fieldparams.MaxTxsPerPayloadLength) + if err != nil { + return [32]byte{}, errors.Wrap(err, "could not compute merkleization") + } + bytesRootBuf := new(bytes.Buffer) + if err := binary.Write(bytesRootBuf, binary.LittleEndian, uint64(len(txs))); err != nil { + return [32]byte{}, errors.Wrap(err, "could not marshal length") + } + bytesRootBufRoot := make([]byte, 32) + copy(bytesRootBufRoot, bytesRootBuf.Bytes()) + return MixInLength(bytesRoot, bytesRootBufRoot), nil +} + +func transactionRoot(tx []byte) ([32]byte, error) { + hasher := hash.CustomSHA256Hasher() + chunkedRoots, err := PackChunks(tx) + if err != nil { + return [32]byte{}, err + } + + maxLength := (fieldparams.MaxBytesPerTxLength + 31) / 32 + bytesRoot, err := BitwiseMerkleize(hasher, chunkedRoots, uint64(len(chunkedRoots)), uint64(maxLength)) + if err != nil { + return [32]byte{}, errors.Wrap(err, "could not compute merkleization") + } + bytesRootBuf := new(bytes.Buffer) + if err := binary.Write(bytesRootBuf, binary.LittleEndian, uint64(len(tx))); err != nil { + return [32]byte{}, errors.Wrap(err, "could not marshal length") + } + bytesRootBufRoot := make([]byte, 32) + copy(bytesRootBufRoot, bytesRootBuf.Bytes()) + return MixInLength(bytesRoot, bytesRootBufRoot), nil +} + +// PackChunks a given byte array into chunks. It'll pad the last chunk with zero bytes if +// it does not have length bytes per chunk. +func PackChunks(bytes []byte) ([][]byte, error) { + numItems := len(bytes) + var chunks [][]byte + for i := 0; i < numItems; i += 32 { + j := i + 32 + // We create our upper bound index of the chunk, if it is greater than numItems, + // we set it as numItems itself. + if j > numItems { + j = numItems + } + // We create chunks from the list of items based on the + // indices determined above. + chunks = append(chunks, bytes[i:j]) + } + + if len(chunks) == 0 { + return chunks, nil + } + + // Right-pad the last chunk with zero bytes if it does not + // have length bytes. + lastChunk := chunks[len(chunks)-1] + for len(lastChunk) < 32 { + lastChunk = append(lastChunk, 0) + } + chunks[len(chunks)-1] = lastChunk + return chunks, nil +} diff --git a/encoding/ssz/htrutils_test.go b/encoding/ssz/htrutils_test.go index 92477b86bf..6153557887 100644 --- a/encoding/ssz/htrutils_test.go +++ b/encoding/ssz/htrutils_test.go @@ -1,8 +1,10 @@ package ssz_test import ( + "reflect" "testing" + fieldparams "github.com/prysmaticlabs/prysm/config/fieldparams" "github.com/prysmaticlabs/prysm/crypto/hash" "github.com/prysmaticlabs/prysm/encoding/ssz" ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" @@ -61,3 +63,98 @@ func TestSlashingsRoot(t *testing.T) { require.NoError(t, err) assert.Equal(t, expected, result) } + +func TestTransactionsRoot(t *testing.T) { + tests := []struct { + name string + txs [][]byte + want [32]byte + wantErr bool + }{ + { + name: "nil", + txs: nil, + want: [32]byte{127, 254, 36, 30, 166, 1, 135, 253, 176, 24, 123, 250, 34, 222, 53, 209, 249, 190, 215, 171, 6, 29, 148, 1, 253, 71, 227, 74, 84, 251, 237, 225}, + }, + { + name: "empty", + txs: [][]byte{}, + want: [32]byte{127, 254, 36, 30, 166, 1, 135, 253, 176, 24, 123, 250, 34, 222, 53, 209, 249, 190, 215, 171, 6, 29, 148, 1, 253, 71, 227, 74, 84, 251, 237, 225}, + }, + { + name: "one tx", + txs: [][]byte{{1, 2, 3}}, + want: [32]byte{102, 209, 140, 87, 217, 28, 68, 12, 133, 42, 77, 136, 191, 18, 234, 105, 166, 228, 216, 235, 230, 95, 200, 73, 85, 33, 134, 254, 219, 97, 82, 209}, + }, + { + name: "max txs", + txs: func() [][]byte { + var txs [][]byte + for i := 0; i < fieldparams.MaxTxsPerPayloadLength; i++ { + txs = append(txs, []byte{}) + } + return txs + }(), + want: [32]byte{13, 66, 254, 206, 203, 58, 48, 133, 78, 218, 48, 231, 120, 90, 38, 72, 73, 137, 86, 9, 31, 213, 185, 101, 103, 144, 0, 236, 225, 57, 47, 244}, + }, + { + name: "exceed max txs", + txs: func() [][]byte { + var txs [][]byte + for i := 0; i < fieldparams.MaxTxsPerPayloadLength+1; i++ { + txs = append(txs, []byte{}) + } + return txs + }(), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ssz.TransactionsRoot(tt.txs) + if (err != nil) != tt.wantErr { + t.Errorf("TransactionsRoot() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("TransactionsRoot() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestPackChunks(t *testing.T) { + tests := []struct { + name string + input []byte + want [][]byte + }{ + { + name: "nil", + input: nil, + want: [][]byte{}, + }, + { + name: "empty", + input: []byte{}, + want: [][]byte{}, + }, + { + name: "one", + input: []byte{1}, + want: [][]byte{{1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + }, + { + name: "one, two", + input: []byte{1, 2}, + want: [][]byte{{1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ssz.PackChunks(tt.input) + require.NoError(t, err) + require.DeepSSZEqual(t, tt.want, got) + }) + } +}