Decompressor Library (#229)

* chore bring over changes from other branch

* chore: blocks as byte arrays not lists; tx.S <- 1

* build: update compressor lib release script by jpnovais

* docs and go generate

---------

Co-authored-by: Arya Tabaie <15056835+Tabaie@users.noreply.github.com>
This commit is contained in:
Arya Tabaie
2024-10-22 14:32:14 -05:00
committed by GitHub
parent d2a120f3b4
commit ce964fc0cb
33 changed files with 1402 additions and 482 deletions

View File

@@ -7,6 +7,7 @@ on:
description: 'Version (e.g. v1.2.3)'
required: true
default: 'v0.0.0'
type: string
draft-release:
description: 'Draft Release'
required: false
@@ -37,13 +38,16 @@ jobs:
VERSION: ${{ github.event.inputs.version }}
SRC_SHNARF: "./lib/shnarf_calculator/shnarf_calculator.go"
TARGET_SHNARF: "shnarf_calculator"
SRC_COMPRESSOR: "./lib/compressor/libcompressor.go"
SRC_COMPRESSOR: "./lib/compressor/libcompressor/libcompressor.go"
TARGET_COMPRESSOR: "blob_compressor"
SRC_DECOMPRESSOR: "./lib/compressor/libdecompressor/libdecompressor.go"
TARGET_DECOMPRESSOR: "blob_decompressor"
run: |
cd prover
mkdir target
GOARCH="amd64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_SHNARF}_${VERSION}_linux_x86_64.so ${SRC_SHNARF}
GOARCH="amd64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_COMPRESSOR}_${VERSION}_linux_x86_64.so ${SRC_COMPRESSOR}
GOARCH="amd64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_DECOMPRESSOR}_${VERSION}_linux_x86_64.so ${SRC_DECOMPRESSOR}
- name: Cache built binaries
uses: actions/upload-artifact@master
@@ -69,13 +73,16 @@ jobs:
VERSION: ${{ github.event.inputs.version }}
SRC_SHNARF: "./lib/shnarf_calculator/shnarf_calculator.go"
TARGET_SHNARF: "shnarf_calculator"
SRC_COMPRESSOR: "./lib/compressor/libcompressor.go"
SRC_COMPRESSOR: "./lib/compressor/libcompressor/libcompressor.go"
TARGET_COMPRESSOR: "blob_compressor"
SRC_DECOMPRESSOR: "./lib/compressor/libdecompressor/libdecompressor.go"
TARGET_DECOMPRESSOR: "blob_decompressor"
run: |
cd prover
mkdir target
GOARCH="arm64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_SHNARF}_${VERSION}_linux_arm64.so ${SRC_SHNARF}
GOARCH="arm64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_COMPRESSOR}_${VERSION}_linux_arm64.so ${SRC_COMPRESSOR}
GOARCH="arm64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_DECOMPRESSOR}_${VERSION}_linux_arm64.so ${SRC_DECOMPRESSOR}
- name: Cache built binaries
uses: actions/upload-artifact@master
with:
@@ -98,8 +105,10 @@ jobs:
VERSION: ${{ github.event.inputs.version }}
SRC_SHNARF: "./lib/shnarf_calculator/shnarf_calculator.go"
TARGET_SHNARF: "shnarf_calculator"
SRC_COMPRESSOR: "./lib/compressor/libcompressor.go"
SRC_COMPRESSOR: "./lib/compressor/libcompressor/libcompressor.go"
TARGET_COMPRESSOR: "blob_compressor"
SRC_DECOMPRESSOR: "./lib/compressor/libdecompressor/libdecompressor.go"
TARGET_DECOMPRESSOR: "blob_decompressor"
run: |
cd prover
mkdir target
@@ -107,6 +116,8 @@ jobs:
GOARCH="arm64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_SHNARF}_${VERSION}_darwin_arm64.dylib ${SRC_SHNARF}
GOARCH="amd64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_COMPRESSOR}_${VERSION}_darwin_x86_64.dylib ${SRC_COMPRESSOR}
GOARCH="arm64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_COMPRESSOR}_${VERSION}_darwin_arm64.dylib ${SRC_COMPRESSOR}
GOARCH="amd64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_DECOMPRESSOR}_${VERSION}_darwin_x86_64.dylib ${SRC_DECOMPRESSOR}
GOARCH="arm64" go build -tags=nocorset -buildmode=c-shared -o ./target/${TARGET_DECOMPRESSOR}_${VERSION}_darwin_arm64.dylib ${SRC_DECOMPRESSOR}
- name: Cache built binaries
uses: actions/upload-artifact@v4

View File

@@ -9,7 +9,7 @@ import (
"strings"
"testing"
blob "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/linea-monorepo/prover/utils"
@@ -270,7 +270,7 @@ func TestKZGWithPoint(t *testing.T) {
}
// Compute all the prover fields
snarkHash, err := blob.MiMCChecksumPackedData(blobBytes[:], fr381.Bits-1, blob.NoTerminalSymbol())
snarkHash, err := encode.MiMCChecksumPackedData(blobBytes[:], fr381.Bits-1, encode.NoTerminalSymbol())
assert.NoError(t, err)
xUnreduced := evaluationChallenge(snarkHash, blobHash[:])

View File

@@ -7,9 +7,9 @@ import (
"hash"
"github.com/consensys/linea-monorepo/prover/crypto/mimc"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
blob "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
"github.com/consensys/linea-monorepo/prover/utils"
"golang.org/x/crypto/sha3"
)
@@ -72,7 +72,7 @@ func CraftResponseCalldata(req *Request) (*Response, error) {
}
// Compute all the prover fields
snarkHash, err := blob.MiMCChecksumPackedData(compressedStream, fr381.Bits-1, blob.NoTerminalSymbol())
snarkHash, err := encode.MiMCChecksumPackedData(compressedStream, fr381.Bits-1, encode.NoTerminalSymbol())
if err != nil {
return nil, fmt.Errorf("crafting response: could not compute snark hash: %w", err)
}

View File

@@ -4,6 +4,7 @@ import (
"crypto/sha256"
"errors"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
blob "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
@@ -91,7 +92,7 @@ func CraftResponse(req *Request) (*Response, error) {
}
// Compute all the prover fields
snarkHash, err := blob.MiMCChecksumPackedData(append(compressedStream, make([]byte, blob.MaxUsableBytes-len(compressedStream))...), fr381.Bits-1, blob.NoTerminalSymbol())
snarkHash, err := encode.MiMCChecksumPackedData(append(compressedStream, make([]byte, blob.MaxUsableBytes-len(compressedStream))...), fr381.Bits-1, encode.NoTerminalSymbol())
if err != nil {
return nil, fmt.Errorf("crafting response: could not compute snark hash: %w", err)
}

View File

@@ -117,10 +117,11 @@ func TestTransactionSigning(t *testing.T) {
assert.Equal(t, from.Hex(), recovered.Hex(), "Mismatch of the recovered address")
// Simulates the decoding of the transaction
var decodedTx types.Transaction
err = DecodeTxFromBytes(bytes.NewReader(rlp), &decodedTx)
decodedTxData, err := DecodeTxFromBytes(bytes.NewReader(rlp))
require.NoError(t, err)
decodedTx := types.NewTx(decodedTxData)
assert.Equal(t, tx.To(), decodedTx.To())
assert.Equal(t, tx.Nonce(), decodedTx.Nonce())
assert.Equal(t, tx.Data(), decodedTx.Data())

View File

@@ -97,25 +97,25 @@ const (
// than the transaction then the remaining bytes are discarded and only the
// first bytes are used to decode the transaction. The function returns the
// transactions and the number of bytes read.
func DecodeTxFromBytes(b *bytes.Reader, tx *types.Transaction) (err error) {
func DecodeTxFromBytes(b *bytes.Reader) (tx types.TxData, err error) {
var (
firstByte byte
)
if b.Len() == 0 {
return fmt.Errorf("empty buffer")
return nil, fmt.Errorf("empty buffer")
}
if firstByte, err = b.ReadByte(); err != nil {
return fmt.Errorf("could not read the first byte: %w", err)
return nil, fmt.Errorf("could not read the first byte: %w", err)
}
switch {
case firstByte == types.DynamicFeeTxType:
return decodeDynamicFeeTx(b, tx)
return decodeDynamicFeeTx(b)
case firstByte == types.AccessListTxType:
return decodeAccessListTx(b, tx)
return decodeAccessListTx(b)
// According to the RLP rule, `0xc0 + x` or `0xf7` indicates that the current
// item is a list and this is what's used to identify that the transaction is
// a legacy transaction or a EIP-155 transaction.
@@ -125,69 +125,69 @@ func DecodeTxFromBytes(b *bytes.Reader, tx *types.Transaction) (err error) {
// Set the byte-reader backward so that we can apply the rlp-decoder
// over it.
b.UnreadByte()
return decodeLegacyTx(b, tx)
return decodeLegacyTx(b)
default:
return nil, fmt.Errorf("unexpected first byte: %x", firstByte)
}
return fmt.Errorf("unexpected first byte: %x", firstByte)
}
// decodeDynamicFeeTx encodes a [types.DynamicFeeTx] into a [bytes.Reader] and
// returns an error if it did not pass.
func decodeDynamicFeeTx(b *bytes.Reader, tx *types.Transaction) (err error) {
func decodeDynamicFeeTx(b *bytes.Reader) (parsedTx *types.DynamicFeeTx, err error) {
decTx := []any{}
if err := rlp.Decode(b, &decTx); err != nil {
return fmt.Errorf("could not rlp decode transaction: %w", err)
if err = rlp.Decode(b, &decTx); err != nil {
return nil, fmt.Errorf("could not rlp decode transaction: %w", err)
}
if len(decTx) != dynFeeNumField {
return fmt.Errorf("invalid number of field for a dynamic transaction")
return nil, fmt.Errorf("invalid number of field for a dynamic transaction")
}
parsedTx := types.DynamicFeeTx{}
parsedTx = new(types.DynamicFeeTx)
err = errors.Join(
tryCast(&parsedTx.ChainID, decTx[0], "chainID"),
tryCast(&parsedTx.Nonce, decTx[1], "nonce"),
tryCast(&parsedTx.GasTipCap, decTx[2], "gas-tip-cap"),
tryCast(&parsedTx.GasFeeCap, decTx[3], "gas-fee-cap"),
tryCast(&parsedTx.Gas, decTx[4], "gas"),
tryCast(&parsedTx.To, decTx[5], "to"),
tryCast(&parsedTx.Value, decTx[6], "value"),
tryCast(&parsedTx.Data, decTx[7], "data"),
tryCast(&parsedTx.AccessList, decTx[8], "access-list"),
TryCast(&parsedTx.ChainID, decTx[0], "chainID"),
TryCast(&parsedTx.Nonce, decTx[1], "nonce"),
TryCast(&parsedTx.GasTipCap, decTx[2], "gas-tip-cap"),
TryCast(&parsedTx.GasFeeCap, decTx[3], "gas-fee-cap"),
TryCast(&parsedTx.Gas, decTx[4], "gas"),
TryCast(&parsedTx.To, decTx[5], "to"),
TryCast(&parsedTx.Value, decTx[6], "value"),
TryCast(&parsedTx.Data, decTx[7], "data"),
TryCast(&parsedTx.AccessList, decTx[8], "access-list"),
)
*tx = *types.NewTx(&parsedTx)
return err
return
}
// decodeAccessListTx decodes an [types.AccessListTx] from a [bytes.Reader]
// decodeAccessListTx decodes a [types.AccessListTx] from a [bytes.Reader]
// and returns an error if it did not pass.
func decodeAccessListTx(b *bytes.Reader, tx *types.Transaction) (err error) {
func decodeAccessListTx(b *bytes.Reader) (parsedTx *types.AccessListTx, err error) {
decTx := []any{}
if err := rlp.Decode(b, &decTx); err != nil {
return fmt.Errorf("could not rlp decode transaction: %w", err)
return nil, fmt.Errorf("could not rlp decode transaction: %w", err)
}
if len(decTx) != accessListTxNumField {
return fmt.Errorf("invalid number of field for a dynamic transaction")
return nil, fmt.Errorf("invalid number of field for a dynamic transaction")
}
parsedTx := types.AccessListTx{}
parsedTx = new(types.AccessListTx)
err = errors.Join(
tryCast(&parsedTx.ChainID, decTx[0], "chainID"),
tryCast(&parsedTx.Nonce, decTx[1], "nonce"),
tryCast(&parsedTx.GasPrice, decTx[2], "gas-price"),
tryCast(&parsedTx.Gas, decTx[3], "gas"),
tryCast(&parsedTx.To, decTx[4], "to"),
tryCast(&parsedTx.Value, decTx[5], "value"),
tryCast(&parsedTx.Data, decTx[6], "data"),
tryCast(&parsedTx.AccessList, decTx[7], "access-list"),
TryCast(&parsedTx.ChainID, decTx[0], "chainID"),
TryCast(&parsedTx.Nonce, decTx[1], "nonce"),
TryCast(&parsedTx.GasPrice, decTx[2], "gas-price"),
TryCast(&parsedTx.Gas, decTx[3], "gas"),
TryCast(&parsedTx.To, decTx[4], "to"),
TryCast(&parsedTx.Value, decTx[5], "value"),
TryCast(&parsedTx.Data, decTx[6], "data"),
TryCast(&parsedTx.AccessList, decTx[7], "access-list"),
)
*tx = *types.NewTx(&parsedTx)
return err
return
}
// decodeLegacyTx decodes a [types.LegacyTx] from a [bytes.Reader] and returns
@@ -197,36 +197,35 @@ func decodeAccessListTx(b *bytes.Reader, tx *types.Transaction) (err error) {
// not decoded although it could. The reason is that it is complicated to set
// it in the returned element as it "included" in the signature and we don't
// encode the signature.
func decodeLegacyTx(b *bytes.Reader, tx *types.Transaction) (err error) {
func decodeLegacyTx(b *bytes.Reader) (parsedTx *types.LegacyTx, err error) {
decTx := []any{}
if err := rlp.Decode(b, &decTx); err != nil {
return fmt.Errorf("could not rlp decode transaction: %w", err)
if err = rlp.Decode(b, &decTx); err != nil {
return nil, fmt.Errorf("could not rlp decode transaction: %w", err)
}
if len(decTx) != legacyTxNumField && len(decTx) != unprotectedTxNumField {
return fmt.Errorf("unexpected number of field")
return nil, fmt.Errorf("unexpected number of field")
}
parsedTx := types.LegacyTx{}
parsedTx = new(types.LegacyTx)
err = errors.Join(
tryCast(&parsedTx.Nonce, decTx[0], "nonce"),
tryCast(&parsedTx.GasPrice, decTx[1], "gas-price"),
tryCast(&parsedTx.Gas, decTx[2], "gas"),
tryCast(&parsedTx.To, decTx[3], "to"),
tryCast(&parsedTx.Value, decTx[4], "value"),
tryCast(&parsedTx.Data, decTx[5], "data"),
TryCast(&parsedTx.Nonce, decTx[0], "nonce"),
TryCast(&parsedTx.GasPrice, decTx[1], "gas-price"),
TryCast(&parsedTx.Gas, decTx[2], "gas"),
TryCast(&parsedTx.To, decTx[3], "to"),
TryCast(&parsedTx.Value, decTx[4], "value"),
TryCast(&parsedTx.Data, decTx[5], "data"),
)
*tx = *types.NewTx(&parsedTx)
return err
return
}
// tryCast will attempt to set t with the underlying value of `from` will return
// TryCast will attempt to set t with the underlying value of `from` will return
// an error if the type does not match. The explainer string is used to generate
// the error if any.
func tryCast[T any](into *T, from any, explainer string) error {
func TryCast[T any](into *T, from any, explainer string) error {
if into == nil || from == nil {
return fmt.Errorf("from or into is/are nil")
@@ -234,7 +233,7 @@ func tryCast[T any](into *T, from any, explainer string) error {
// The rlp encoding is not "type-aware", if the underlying field is an
// access-list, it will decode into []interface{} (and we recursively parse
// it) otherwise, it always decode to `[]byte`
// it) otherwise, it always decodes to `[]byte`
if list, ok := (from).([]interface{}); ok {
var (
@@ -249,7 +248,7 @@ func tryCast[T any](into *T, from any, explainer string) error {
for i := range accessList {
err = errors.Join(
err,
tryCast(&accessList[i], list[i], fmt.Sprintf("%v[%v]", explainer, i)),
TryCast(&accessList[i], list[i], fmt.Sprintf("%v[%v]", explainer, i)),
)
}
*into = (any(accessList)).(T)
@@ -258,8 +257,8 @@ func tryCast[T any](into *T, from any, explainer string) error {
case types.AccessTuple:
tuple := types.AccessTuple{}
err = errors.Join(
tryCast(&tuple.Address, list[0], fmt.Sprintf("%v.%v", explainer, "address")),
tryCast(&tuple.StorageKeys, list[1], fmt.Sprintf("%v.%v", explainer, "storage-key")),
TryCast(&tuple.Address, list[0], fmt.Sprintf("%v.%v", explainer, "address")),
TryCast(&tuple.StorageKeys, list[1], fmt.Sprintf("%v.%v", explainer, "storage-key")),
)
*into = (any(tuple)).(T)
return err
@@ -267,7 +266,7 @@ func tryCast[T any](into *T, from any, explainer string) error {
case []common.Hash:
hashes := make([]common.Hash, length)
for i := range hashes {
tryCast(&hashes[i], list[i], fmt.Sprintf("%v[%v]", explainer, i))
TryCast(&hashes[i], list[i], fmt.Sprintf("%v[%v]", explainer, i))
}
*into = (any(hashes)).(T)
return err
@@ -285,7 +284,7 @@ func tryCast[T any](into *T, from any, explainer string) error {
switch intoAny.(type) {
case *common.Address:
// Parse the bytes as an UTF8 string (= direct casting in go).
// Then, the string as an hexstring encoded address.
// Then, the string as a hex string encoded address.
address := common.BytesToAddress(fromBytes)
*into = any(&address).(T)
case common.Address:
@@ -295,7 +294,7 @@ func tryCast[T any](into *T, from any, explainer string) error {
*into = any(address).(T)
case common.Hash:
// Parse the bytes as an UTF8 string (= direct casting in go).
// Then, the string as an hexstring encoded address.
// Then, the string as a hexstring encoded address.
hash := common.BytesToHash(fromBytes)
*into = any(hash).(T)
case *big.Int:

View File

@@ -7,6 +7,7 @@ import (
"testing"
v0 "github.com/consensys/linea-monorepo/prover/circuits/blobdecompression/v0"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1/test_utils"
"github.com/consensys/gnark-crypto/ecc"
@@ -74,7 +75,9 @@ func mustGetTestCompressedData(t *testing.T) (resp blobsubmission.Response, blob
blobBytes, err = base64.StdEncoding.DecodeString(resp.CompressedData)
assert.NoError(t, err)
_, _, _, err = blob.DecompressBlob(blobBytes, dict)
dictStore, err := dictionary.SingletonStore(dict, 0)
assert.NoError(t, err)
_, _, _, err = blob.DecompressBlob(blobBytes, dictStore)
assert.NoError(t, err)
return

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"github.com/consensys/linea-monorepo/prover/circuits/internal"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
@@ -55,7 +56,12 @@ func Assign(blobData, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.Elem
return
}
header, uncompressedData, _, err := blob.DecompressBlob(blobData, dict)
dictStore, err := dictionary.SingletonStore(dict, 0)
if err != nil {
err = fmt.Errorf("failed to create dictionary store %w", err)
return
}
header, uncompressedData, _, err := blob.DecompressBlob(blobData, dictStore)
if err != nil {
err = fmt.Errorf("decompression circuit assignment : could not decompress the data : %w", err)
return

View File

@@ -5,6 +5,7 @@ package v1_test
import (
"encoding/base64"
"encoding/hex"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"testing"
"github.com/consensys/gnark-crypto/ecc"
@@ -27,7 +28,9 @@ func prepareTestBlob(t require.TestingT) (c, a frontend.Circuit) {
func prepare(t require.TestingT, blobBytes []byte) (c *v1.Circuit, a frontend.Circuit) {
_, payload, _, err := blobcompressorv1.DecompressBlob(blobBytes, blobtestutils.GetDict(t))
dictStore, err := dictionary.SingletonStore(blobtestutils.GetDict(t), 1)
assert.NoError(t, err)
_, payload, _, err := blobcompressorv1.DecompressBlob(blobBytes, dictStore)
assert.NoError(t, err)
resp, err := blobsubmission.CraftResponse(&blobsubmission.Request{

View File

@@ -4,6 +4,8 @@ import (
"bytes"
"errors"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"math/big"
"github.com/consensys/gnark-crypto/ecc"
@@ -235,7 +237,12 @@ func AssignFPI(blobBytes, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.
return
}
header, payload, _, err := blob.DecompressBlob(blobBytes, dict)
dictStore, err := dictionary.SingletonStore(dict, 1)
if err != nil {
err = fmt.Errorf("failed to create dictionary store %w", err)
return
}
header, payload, _, err := blob.DecompressBlob(blobBytes, dictStore)
if err != nil {
return
}
@@ -266,7 +273,7 @@ func AssignFPI(blobBytes, dict []byte, eip4844Enabled bool, x [32]byte, y fr381.
if len(blobBytes) != 128*1024 {
panic("blobBytes length is not 128*1024")
}
fpi.SnarkHash, err = blob.MiMCChecksumPackedData(blobBytes, fr381.Bits-1, blob.NoTerminalSymbol()) // TODO if forced to remove the above check, pad with zeros
fpi.SnarkHash, err = encode.MiMCChecksumPackedData(blobBytes, fr381.Bits-1, encode.NoTerminalSymbol()) // TODO if forced to remove the above check, pad with zeros
return
}

View File

@@ -6,6 +6,8 @@ import (
"errors"
"testing"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"github.com/consensys/linea-monorepo/prover/utils"
"github.com/consensys/linea-monorepo/prover/circuits/blobdecompression/v1/test_utils"
@@ -20,7 +22,7 @@ import (
"github.com/consensys/gnark/std/hash/mimc"
"github.com/consensys/gnark/test"
blob "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
blobtesting "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1/test_utils"
blobtestutils "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1/test_utils"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -30,7 +32,7 @@ func TestParseHeader(t *testing.T) {
maxBlobSize := 1024
blobs := [][]byte{
blobtesting.GenTestBlob(t, 100000),
blobtestutils.GenTestBlob(t, 100000),
}
for _, blobData := range blobs {
@@ -48,14 +50,17 @@ func TestParseHeader(t *testing.T) {
test.NoTestEngine(),
}
dictStore, err := dictionary.SingletonStore(blobtestutils.GetDict(t), 1)
assert.NoError(t, err)
for _, blobData := range blobs {
header, _, blocks, err := blob.DecompressBlob(blobData, blobtesting.GetDict(t))
header, _, blocks, err := blob.DecompressBlob(blobData, dictStore)
assert.NoError(t, err)
assert.LessOrEqual(t, len(blocks), MaxNbBatches, "too many batches")
unpacked, err := blob.UnpackAlign(blobData, fr381.Bits-1, false)
unpacked, err := encode.UnpackAlign(blobData, fr381.Bits-1, false)
require.NoError(t, err)
assignment := &testParseHeaderCircuit{
@@ -88,9 +93,9 @@ func TestChecksumBatches(t *testing.T) {
var batchEndss [nbAssignments][]int
for i := range batchEndss {
batchEndss[i] = make([]int, blobtesting.RandIntn(MaxNbBatches)+1)
batchEndss[i] = make([]int, blobtestutils.RandIntn(MaxNbBatches)+1)
for j := range batchEndss[i] {
batchEndss[i][j] = 31 + blobtesting.RandIntn(62)
batchEndss[i][j] = 31 + blobtestutils.RandIntn(62)
if j > 0 {
batchEndss[i][j] += batchEndss[i][j-1]
}
@@ -161,7 +166,7 @@ func testChecksumBatches(t *testing.T, blob []byte, batchEndss ...[]int) {
Sums: sums,
NbBatches: len(batchEnds),
}
assignment.Sums[blobtesting.RandIntn(len(batchEnds))] = 3
assignment.Sums[blobtestutils.RandIntn(len(batchEnds))] = 3
assert.Error(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField()))
@@ -224,7 +229,7 @@ func TestUnpackCircuit(t *testing.T) {
runTest := func(b []byte) {
var packedBuf bytes.Buffer
_, err := blob.PackAlign(&packedBuf, b, fr381.Bits-1) // todo use two different slices
_, err := encode.PackAlign(&packedBuf, b, fr381.Bits-1) // todo use two different slices
assert.NoError(t, err)
circuit := unpackCircuit{
@@ -308,7 +313,7 @@ func TestBlobChecksum(t *testing.T) { // aka "snark hash"
assignment := testDataChecksumCircuit{
DataBytes: dataVarsPadded[:nPadded],
}
assignment.Checksum, err = blob.MiMCChecksumPackedData(dataPadded[:nPadded], fr381.Bits-1, blob.NoTerminalSymbol())
assignment.Checksum, err = encode.MiMCChecksumPackedData(dataPadded[:nPadded], fr381.Bits-1, encode.NoTerminalSymbol())
assert.NoError(t, err)
assert.NoError(t, test.IsSolved(&circuit, &assignment, ecc.BLS12_377.ScalarField()))
@@ -338,9 +343,11 @@ func (c *testDataChecksumCircuit) Define(api frontend.API) error {
}
func TestDictHash(t *testing.T) {
blobBytes := blobtesting.GenTestBlob(t, 1)
dict := blobtesting.GetDict(t)
header, _, _, err := blob.DecompressBlob(blobBytes, dict) // a bit roundabout, but the header field is not public
blobBytes := blobtestutils.GenTestBlob(t, 1)
dict := blobtestutils.GetDict(t)
dictStore, err := dictionary.SingletonStore(blobtestutils.GetDict(t), 1)
assert.NoError(t, err)
header, _, _, err := blob.DecompressBlob(blobBytes, dictStore) // a bit roundabout, but the header field is not public
assert.NoError(t, err)
circuit := testDataDictHashCircuit{

View File

@@ -1,15 +1,17 @@
package blob
import (
"bytes"
"errors"
"os"
"path/filepath"
"strings"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/gnark-crypto/hash"
"github.com/consensys/linea-monorepo/prover/circuits/blobdecompression/v0/compress"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
v0 "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v0"
v1 "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
"github.com/ethereum/go-ethereum/rlp"
)
func GetVersion(blob []byte) uint16 {
@@ -23,17 +25,6 @@ func GetVersion(blob []byte) uint16 {
return 0
}
// DictionaryChecksum according to the given spec version
func DictionaryChecksum(dict []byte, version uint16) ([]byte, error) {
switch version {
case 1:
return v1.MiMCChecksumPackedData(dict, 8)
case 0:
return compress.ChecksumPaddedBytes(dict, len(dict), hash.MIMC_BLS12_377.New(), fr381.Bits), nil
}
return nil, errors.New("unsupported version")
}
// GetRepoRootPath assumes that current working directory is within the repo
func GetRepoRootPath() (string, error) {
wd, err := os.Getwd()
@@ -57,3 +48,41 @@ func GetDict() ([]byte, error) {
dictPath := filepath.Join(repoRoot, "prover/lib/compressor/compressor_dict.bin")
return os.ReadFile(dictPath)
}
// DecompressBlob takes in a Linea blob and outputs an RLP encoded list of RLP encoded blocks.
// Due to information loss during pre-compression encoding, two pieces of information are represented "hackily":
// The block hash is in the ParentHash field.
// The transaction from address is in the signature.R field.
func DecompressBlob(blob []byte, dictStore dictionary.Store) ([]byte, error) {
vsn := GetVersion(blob)
var (
blockDecoder func(*bytes.Reader) (encode.DecodedBlockData, error)
blocks [][]byte
err error
)
switch vsn {
case 0:
_, _, blocks, err = v0.DecompressBlob(blob, dictStore)
blockDecoder = v0.DecodeBlockFromUncompressed
case 1:
_, _, blocks, err = v1.DecompressBlob(blob, dictStore)
blockDecoder = v1.DecodeBlockFromUncompressed
default:
return nil, errors.New("unrecognized blob version")
}
if err != nil {
return nil, err
}
blocksSerialized := make([][]byte, len(blocks))
var decodedBlock encode.DecodedBlockData
for i, block := range blocks {
if decodedBlock, err = blockDecoder(bytes.NewReader(block)); err != nil {
return nil, err
}
if blocksSerialized[i], err = rlp.EncodeToBytes(decodedBlock.ToStd()); err != nil {
return nil, err
}
}
return rlp.EncodeToBytes(blocksSerialized)
}

View File

@@ -1,14 +1,162 @@
package blob_test
import (
"testing"
"bytes"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1/test_utils"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
v0 "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v0"
blobv1testing "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1/test_utils"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"os"
"path/filepath"
"testing"
)
func TestGetVersion(t *testing.T) {
_blob := test_utils.GenTestBlob(t, 1)
_blob := blobv1testing.GenTestBlob(t, 1)
assert.Equal(t, uint32(0x10000), uint32(0xffff)+uint32(blob.GetVersion(_blob)), "version should match the current one")
}
const dictPath = "../compressor_dict.bin"
func TestAddToBlob(t *testing.T) {
dictStore := dictionary.NewStore()
require.NoError(t, dictStore.Load(dictPath))
blobData := withNoError(t, os.ReadFile, "testdata/v0/sample-blob-01b9918c3f0ceb6a.bin")
header, _, blocksSerialized, err := v0.DecompressBlob(blobData, dictStore)
require.NoError(t, err)
blobData = withNoError(t, os.ReadFile, "testdata/v0/sample-blob-0151eda71505187b5.bin")
_, _, blocksSerializedNext, err := v0.DecompressBlob(blobData, dictStore)
require.NoError(t, err)
bm, err := v0.NewBlobMaker(v0.MaxUsableBytes, "../compressor_dict.bin")
require.NoError(t, err)
var ok bool
writeBlock := func(blocks *[][]byte) {
dbd, err := v0.DecodeBlockFromUncompressed(bytes.NewReader((*blocks)[0]))
assert.NoError(t, err)
stdBlockRlp, err := rlp.EncodeToBytes(dbd.ToStd())
ok, err = bm.Write(stdBlockRlp, false, encode.WithTxAddressGetter(encode.GetAddressFromR))
assert.NoError(t, err)
*blocks = (*blocks)[1:]
}
for i := 0; i < header.NbBatches(); i++ {
for j := 0; j < header.NbBlocksInBatch(i); j++ {
writeBlock(&blocksSerialized)
assert.True(t, ok)
}
bm.StartNewBatch()
}
assert.Empty(t, blocksSerialized)
util0 := 100 * bm.Len() / v0.MaxUsableBytes
require.NoError(t, err)
for ok { // all in one batch
writeBlock(&blocksSerializedNext)
}
util1 := 100 * bm.Len() / v0.MaxUsableBytes
fmt.Printf("%d%%\n%d%%\n", util0, util1)
}
func withNoError[X, Y any](t *testing.T, f func(X) (Y, error), x X) Y {
y, err := f(x)
require.NoError(t, err)
return y
}
func TestDecompressBlob(t *testing.T) {
store := dictionary.NewStore("../compressor_dict.bin")
files := newRecursiveFolderIterator(t, "testdata")
for files.hasNext() {
f := files.next()
if filepath.Ext(f.path) == ".bin" {
t.Run(f.path, func(t *testing.T) {
decompressed, err := blob.DecompressBlob(f.content, store)
assert.NoError(t, err)
t.Log("decompressed length", len(decompressed))
// load decompressed blob as blocks
var blocksSerialized [][]byte
assert.NoError(t, rlp.DecodeBytes(decompressed, &blocksSerialized))
t.Log("number of decoded blocks", len(blocksSerialized))
for _, blockSerialized := range blocksSerialized {
var b types.Block
assert.NoError(t, rlp.DecodeBytes(blockSerialized, &b))
}
})
}
}
}
type dirEntryWithFullPath struct {
path string
content os.DirEntry
}
// goes through all files in a directory and its subdirectories
type recursiveFolderIterator struct {
toVisit []dirEntryWithFullPath
t *testing.T
pathLen int
}
type file struct {
content []byte
path string
}
func (i *recursiveFolderIterator) openDir(path string) {
content, err := os.ReadDir(path)
require.NoError(i.t, err)
for _, c := range content {
i.toVisit = append(i.toVisit, dirEntryWithFullPath{path: filepath.Join(path, c.Name()), content: c})
}
}
func (i *recursiveFolderIterator) hasNext() bool {
return i.peek() != nil
}
func (i *recursiveFolderIterator) next() *file {
f := i.peek()
if f != nil {
i.toVisit = i.toVisit[:len(i.toVisit)-1]
}
return f
}
// counter-intuitively, peek does most of the work by ensuring the top of the stack is always a file
func (i *recursiveFolderIterator) peek() *file {
for len(i.toVisit) != 0 {
lastIndex := len(i.toVisit) - 1
c := i.toVisit[lastIndex]
if c.content.IsDir() {
i.toVisit = i.toVisit[:lastIndex]
i.openDir(c.path)
} else {
b, err := os.ReadFile(c.path)
require.NoError(i.t, err)
return &file{content: b, path: c.path[i.pathLen:]}
}
}
return nil
}
func newRecursiveFolderIterator(t *testing.T, path string) *recursiveFolderIterator {
res := recursiveFolderIterator{t: t, pathLen: len(path) + 1}
res.openDir(path)
return &res
}

View File

@@ -0,0 +1,79 @@
package dictionary
import (
"bytes"
"errors"
"github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/gnark-crypto/hash"
"github.com/consensys/linea-monorepo/prover/circuits/blobdecompression/v0/compress"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"os"
)
// Checksum according to the given spec version
func Checksum(dict []byte, version uint16) ([]byte, error) {
switch version {
case 1:
return encode.MiMCChecksumPackedData(dict, 8)
case 0:
return compress.ChecksumPaddedBytes(dict, len(dict), hash.MIMC_BLS12_377.New(), fr.Bits), nil
}
return nil, errors.New("unsupported version")
}
type Store []map[string][]byte
func NewStore(paths ...string) Store {
res := make(Store, 2)
for i := range res {
res[i] = make(map[string][]byte)
}
if err := res.Load(paths...); err != nil {
panic(err)
}
return res
}
func SingletonStore(dict []byte, version uint16) (Store, error) {
s := make(Store, version+1)
key, err := Checksum(dict, version)
s[version] = make(map[string][]byte, 1)
s[version][string(key)] = dict
return s, err
}
func (s Store) Load(paths ...string) error {
loadVsn := func(vsn uint16) error {
for _, path := range paths {
dict, err := os.ReadFile(path)
if err != nil {
return err
}
checksum, err := Checksum(dict, vsn)
if err != nil {
return err
}
key := string(checksum)
existing, exists := s[vsn][key]
if exists && !bytes.Equal(dict, existing) { // should be incredibly unlikely
return errors.New("unmatching dictionary found")
}
s[vsn][key] = dict
}
return nil
}
return errors.Join(loadVsn(0), loadVsn(1))
}
func (s Store) Get(checksum []byte, version uint16) ([]byte, error) {
if int(version) > len(s) {
return nil, errors.New("unrecognized blob version")
}
res, ok := s[version][string(checksum)]
if !ok {
return nil, errors.New("dictionary not found")
}
return res, nil
}

View File

@@ -0,0 +1,337 @@
package encode
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/gnark-crypto/hash"
"github.com/consensys/linea-monorepo/prover/backend/ethereum"
typesLinea "github.com/consensys/linea-monorepo/prover/utils/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/icza/bitio"
"io"
"math/big"
)
// UnpackAlign unpacks r (packed with PackAlign) and returns the unpacked data.
func UnpackAlign(r []byte, packingSize int, noTerminalSymbol bool) ([]byte, error) {
bytesPerElem := (packingSize + 7) / 8
packingSizeLastU64 := uint8(packingSize % 64)
if packingSizeLastU64 == 0 {
packingSizeLastU64 = 64
}
n := len(r) / bytesPerElem
if n*bytesPerElem != len(r) {
return nil, fmt.Errorf("invalid data length; expected multiple of %d", bytesPerElem)
}
var out bytes.Buffer
w := bitio.NewWriter(&out)
for i := 0; i < n; i++ {
// read bytes
element := r[bytesPerElem*i : bytesPerElem*(i+1)]
// write bits
w.TryWriteBits(binary.BigEndian.Uint64(element[0:8]), packingSizeLastU64)
for j := 8; j < bytesPerElem; j += 8 {
w.TryWriteBits(binary.BigEndian.Uint64(element[j:j+8]), 64)
}
}
if w.TryError != nil {
return nil, fmt.Errorf("when writing to bitio.Writer: %w", w.TryError)
}
if err := w.Close(); err != nil {
return nil, fmt.Errorf("when closing bitio.Writer: %w", err)
}
if !noTerminalSymbol {
// the last nonzero byte should be 0xff
outLen := out.Len() - 1
for out.Bytes()[outLen] == 0 {
outLen--
}
if out.Bytes()[outLen] != 0xff {
return nil, errors.New("invalid terminal symbol")
}
out.Truncate(outLen)
}
return out.Bytes(), nil
}
type packAlignSettings struct {
dataNbBits int
lastByteNbUnusedBits uint8
noTerminalSymbol bool
additionalInput [][]byte
}
func (s *packAlignSettings) initialize(length int, options ...packAlignOption) {
for _, opt := range options {
opt(s)
}
nbBytes := length
for _, data := range s.additionalInput {
nbBytes += len(data)
}
if !s.noTerminalSymbol {
nbBytes++
}
s.dataNbBits = nbBytes*8 - int(s.lastByteNbUnusedBits)
}
type packAlignOption func(*packAlignSettings)
func NoTerminalSymbol() packAlignOption {
return func(o *packAlignSettings) {
o.noTerminalSymbol = true
}
}
// PackAlignSize returns the size of the data when packed with PackAlign.
func PackAlignSize(length0, packingSize int, options ...packAlignOption) (n int) {
var s packAlignSettings
s.initialize(length0, options...)
// we may need to add some bits to a and b to ensure we can process some blocks of 248 bits
extraBits := (packingSize - s.dataNbBits%packingSize) % packingSize
nbBits := s.dataNbBits + extraBits
return (nbBits / packingSize) * ((packingSize + 7) / 8)
}
func WithAdditionalInput(data ...[]byte) packAlignOption {
return func(o *packAlignSettings) {
o.additionalInput = append(o.additionalInput, data...)
}
}
func WithLastByteNbUnusedBits(n uint8) packAlignOption {
if n > 7 {
panic("only 8 bits to a byte")
}
return func(o *packAlignSettings) {
o.lastByteNbUnusedBits = n
}
}
// PackAlign writes a and b to w, aligned to fr.Element (bls12-377) boundary.
// It returns the length of the data written to w.
func PackAlign(w io.Writer, a []byte, packingSize int, options ...packAlignOption) (n int64, err error) {
var s packAlignSettings
s.initialize(len(a), options...)
if !s.noTerminalSymbol && s.lastByteNbUnusedBits != 0 {
return 0, errors.New("terminal symbols with byte aligned input not yet supported")
}
// we may need to add some bits to a and b to ensure we can process some blocks of packingSize bits
nbBits := (s.dataNbBits + (packingSize - 1)) / packingSize * packingSize
extraBits := nbBits - s.dataNbBits
// padding will always be less than bytesPerElem bytes
bytesPerElem := (packingSize + 7) / 8
packingSizeLastU64 := uint8(packingSize % 64)
if packingSizeLastU64 == 0 {
packingSizeLastU64 = 64
}
bytePadding := (extraBits + 7) / 8
buf := make([]byte, bytesPerElem, bytesPerElem+1)
// the last nonzero byte is 0xff
if !s.noTerminalSymbol {
buf = append(buf, 0)
buf[0] = 0xff
}
inReaders := make([]io.Reader, 2+len(s.additionalInput))
inReaders[0] = bytes.NewReader(a)
for i, data := range s.additionalInput {
inReaders[i+1] = bytes.NewReader(data)
}
inReaders[len(inReaders)-1] = bytes.NewReader(buf[:bytePadding+1])
r := bitio.NewReader(io.MultiReader(inReaders...))
var tryWriteErr error
tryWrite := func(v uint64) {
if tryWriteErr == nil {
tryWriteErr = binary.Write(w, binary.BigEndian, v)
}
}
for i := 0; i < nbBits/packingSize; i++ {
tryWrite(r.TryReadBits(packingSizeLastU64))
for j := int(packingSizeLastU64); j < packingSize; j += 64 {
tryWrite(r.TryReadBits(64))
}
}
if tryWriteErr != nil {
return 0, fmt.Errorf("when writing to w: %w", tryWriteErr)
}
if r.TryError != nil {
return 0, fmt.Errorf("when reading from multi-reader: %w", r.TryError)
}
n1 := (nbBits / packingSize) * bytesPerElem
if n1 != PackAlignSize(len(a), packingSize, options...) {
return 0, errors.New("inconsistent PackAlignSize")
}
return int64(n1), nil
}
// MiMCChecksumPackedData re-packs the data tightly into bls12-377 elements and computes the MiMC checksum.
// only supporting packing without a terminal symbol. Input with a terminal symbol will be interpreted in full padded length.
func MiMCChecksumPackedData(data []byte, inputPackingSize int, hashPackingOptions ...packAlignOption) ([]byte, error) {
dataNbBits := len(data) * 8
if inputPackingSize%8 != 0 {
inputBytesPerElem := (inputPackingSize + 7) / 8
dataNbBits = dataNbBits / inputBytesPerElem * inputPackingSize
var err error
if data, err = UnpackAlign(data, inputPackingSize, true); err != nil {
return nil, err
}
}
lastByteNbUnusedBits := 8 - dataNbBits%8
if lastByteNbUnusedBits == 8 {
lastByteNbUnusedBits = 0
}
var bb bytes.Buffer
packingOptions := make([]packAlignOption, len(hashPackingOptions)+1)
copy(packingOptions, hashPackingOptions)
packingOptions[len(packingOptions)-1] = WithLastByteNbUnusedBits(uint8(lastByteNbUnusedBits))
if _, err := PackAlign(&bb, data, fr.Bits-1, packingOptions...); err != nil {
return nil, err
}
hsh := hash.MIMC_BLS12_377.New()
hsh.Write(bb.Bytes())
return hsh.Sum(nil), nil
}
// DecodedBlockData is a wrapper struct storing the different fields of a block
// that we deserialize when decoding an ethereum block.
type DecodedBlockData struct {
// BlockHash stores the decoded block hash
BlockHash common.Hash
// Timestamp holds the Unix timestamp of the block in
Timestamp uint64
// Froms stores the list of the sender address of every transaction
Froms []common.Address
// Txs stores the list of the decoded transactions.
Txs []types.TxData
}
func InjectFromAddressIntoR(txData types.TxData, from *common.Address) *types.Transaction {
switch txData := txData.(type) {
case *types.DynamicFeeTx:
tx := *txData
tx.R = new(big.Int)
tx.R.SetBytes(from[:])
tx.S = big.NewInt(1)
return types.NewTx(&tx)
case *types.AccessListTx:
tx := *txData
tx.R = new(big.Int)
tx.R.SetBytes(from[:])
tx.S = big.NewInt(1)
return types.NewTx(&tx)
case *types.LegacyTx:
tx := *txData
tx.R = new(big.Int)
tx.R.SetBytes(from[:])
tx.S = big.NewInt(1)
return types.NewTx(&tx)
default:
panic("unexpected transaction type")
}
}
// ToStd converts the decoded block data into a standard
// block object capable of being encoded in a way consumable
// by existing decoders. The process involves some abuse,
// whereby 1) the "from" address of a transaction is put in the
// signature.R field, though the signature as a whole is invalid.
// 2) the block hash is stored in the ParentHash field in the block
// header.
func (d *DecodedBlockData) ToStd() *types.Block {
header := types.Header{
ParentHash: d.BlockHash,
Time: d.Timestamp,
}
body := types.Body{
Transactions: make([]*types.Transaction, len(d.Txs)),
}
for i := range d.Txs {
body.Transactions[i] = InjectFromAddressIntoR(d.Txs[i], &d.Froms[i])
}
return types.NewBlock(&header, &body, nil, emptyTrieHasher{})
}
func GetAddressFromR(tx *types.Transaction) typesLinea.EthAddress {
_, r, _ := tx.RawSignatureValues()
var res typesLinea.EthAddress
r.FillBytes(res[:])
return res
}
// TODO delete if unused
type fixedTrieHasher common.Hash
func (e fixedTrieHasher) Reset() {
}
func (e fixedTrieHasher) Update(_, _ []byte) error {
return nil
}
func (e fixedTrieHasher) Hash() common.Hash {
return common.Hash(e)
}
type emptyTrieHasher struct{}
func (h emptyTrieHasher) Reset() {
}
func (h emptyTrieHasher) Update(_, _ []byte) error {
return nil
}
func (h emptyTrieHasher) Hash() common.Hash {
return common.Hash{}
}
type TxAddressGetter func(*types.Transaction) typesLinea.EthAddress
type Config struct {
GetAddress TxAddressGetter
}
func NewConfig() Config {
return Config{
GetAddress: ethereum.GetFrom,
}
}
type Option func(*Config)
func WithTxAddressGetter(g TxAddressGetter) Option {
return func(cfg *Config) {
cfg.GetAddress = g
}
}

View File

@@ -0,0 +1,22 @@
package test_utils
import (
"github.com/consensys/linea-monorepo/prover/backend/ethereum"
typesLinea "github.com/consensys/linea-monorepo/prover/utils/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/stretchr/testify/assert"
"testing"
)
// CheckSameTx checks if the most essential fields in two transactions are equal
// TODO cover type-specific fields
func CheckSameTx(t *testing.T, orig, decoded *types.Transaction, decodedFrom common.Address) {
assert.Equal(t, orig.Type(), decoded.Type())
assert.Equal(t, orig.To(), decoded.To())
assert.Equal(t, orig.Nonce(), decoded.Nonce())
assert.Equal(t, orig.Data(), decoded.Data())
assert.Equal(t, orig.Value(), decoded.Value())
assert.Equal(t, orig.Cost(), decoded.Cost())
assert.Equal(t, ethereum.GetFrom(orig), typesLinea.EthAddress(decodedFrom))
}

View File

@@ -5,6 +5,8 @@ import (
"encoding/binary"
"errors"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"io"
"os"
"strings"
@@ -39,6 +41,7 @@ type BlobMaker struct {
limit int // maximum size of the compressed data
compressor *lzss.Compressor // compressor used to compress the blob body
dict []byte // dictionary used for compression
dictStore dictionary.Store
header Header
@@ -67,6 +70,10 @@ func NewBlobMaker(dataLimit int, dictPath string) (*BlobMaker, error) {
}
dict = lzss.AugmentDict(dict)
blobMaker.dict = dict
blobMaker.dictStore, err = dictionary.SingletonStore(dict, 0)
if err != nil {
return nil, err
}
dictChecksum := compress.ChecksumPaddedBytes(dict, len(dict), hash.MIMC_BLS12_377.New(), fr.Bits)
copy(blobMaker.header.DictChecksum[:], dictChecksum)
@@ -119,7 +126,7 @@ func (bm *BlobMaker) Written() int {
func (bm *BlobMaker) Bytes() []byte {
if bm.currentBlobLength > 0 {
// sanity check that we can always decompress.
header, rawBlocks, _, err := DecompressBlob(bm.currentBlob[:bm.currentBlobLength], bm.dict)
header, rawBlocks, _, err := DecompressBlob(bm.currentBlob[:bm.currentBlobLength], bm.dictStore)
if err != nil {
var sbb strings.Builder
fmt.Fprintf(&sbb, "invalid blob: %v\n", err)
@@ -130,8 +137,8 @@ func (bm *BlobMaker) Bytes() []byte {
panic(sbb.String())
}
// compare the header
if !header.Equals(&bm.header) {
panic("invalid blob: header mismatch")
if err = header.CheckEquality(&bm.header); err != nil {
panic(fmt.Errorf("invalid blob: header mismatch %v", err))
}
rawBlocksUnpacked, err := UnpackAlign(rawBlocks)
if err != nil {
@@ -146,7 +153,7 @@ func (bm *BlobMaker) Bytes() []byte {
// Write attempts to append the RLP block to the current batch.
// if forceReset is set; this will NOT append the bytes but still returns true if the chunk could have been appended
func (bm *BlobMaker) Write(rlpBlock []byte, forceReset bool) (ok bool, err error) {
func (bm *BlobMaker) Write(rlpBlock []byte, forceReset bool, encodingOptions ...encode.Option) (ok bool, err error) {
// decode the RLP block.
var block types.Block
@@ -156,7 +163,7 @@ func (bm *BlobMaker) Write(rlpBlock []byte, forceReset bool) (ok bool, err error
// re-encode it for compression
bm.buf.Reset()
if err := EncodeBlockForCompression(&block, &bm.buf); err != nil {
if err := EncodeBlockForCompression(&block, &bm.buf, encodingOptions...); err != nil {
return false, fmt.Errorf("when re-encoding block for compression: %w", err)
}
blockLen := bm.buf.Len()
@@ -281,7 +288,8 @@ func (bm *BlobMaker) Equals(other *BlobMaker) bool {
}
// DecompressBlob decompresses a blob and returns the header and the blocks as they were compressed.
func DecompressBlob(b, dict []byte) (blobHeader *Header, rawBlocks []byte, blocks [][]byte, err error) {
// rawBlocks is the raw payload of the blob, delivered in packed format @TODO bad idea. fix
func DecompressBlob(b []byte, dictStore dictionary.Store) (blobHeader *Header, rawBlocks []byte, blocks [][]byte, err error) {
// UnpackAlign the blob
b, err = UnpackAlign(b)
if err != nil {
@@ -295,11 +303,10 @@ func DecompressBlob(b, dict []byte) (blobHeader *Header, rawBlocks []byte, block
return nil, nil, nil, fmt.Errorf("failed to read blob header: %w", err)
}
// ensure the dict hash matches
{
if !bytes.Equal(compress.ChecksumPaddedBytes(dict, len(dict), hash.MIMC_BLS12_377.New(), fr.Bits), blobHeader.DictChecksum[:]) {
return nil, nil, nil, errors.New("invalid dict hash")
}
// retrieve dict
dict, err := dictStore.Get(blobHeader.DictChecksum[:], 0)
if err != nil {
return nil, nil, nil, err
}
b = b[read:]
@@ -438,7 +445,8 @@ func UnpackAlign(r []byte) ([]byte, error) {
cpt++
}
// last byte should be equal to cpt
if cpt != int(out.Bytes()[out.Len()-1])-1 {
lastNonZero := out.Bytes()[out.Len()-1]
if (cpt % 31) != int(lastNonZero)-1 {
return nil, errors.New("invalid padding length")
}
out.Truncate(out.Len() - 1)

View File

@@ -8,14 +8,16 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
encodeTesting "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode/test_utils"
"github.com/consensys/linea-monorepo/prover/utils"
"io"
"math/big"
"math/rand"
"os"
"slices"
"testing"
"github.com/consensys/linea-monorepo/prover/utils"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v0/compress/lzss"
"github.com/consensys/linea-monorepo/prover/backend/ethereum"
@@ -603,7 +605,11 @@ func decompressBlob(b []byte) ([][][]byte, error) {
if err != nil {
return nil, fmt.Errorf("can't read dict: %w", err)
}
header, _, blocks, err := DecompressBlob(b, dict)
dictStore, err := dictionary.SingletonStore(dict, 0)
if err != nil {
return nil, err
}
header, _, blocks, err := DecompressBlob(b, dictStore)
if err != nil {
return nil, fmt.Errorf("can't decompress blob: %w", err)
}
@@ -744,3 +750,18 @@ func TestPack(t *testing.T) {
assert.Equal(s2, original[n1:], "slices should match")
}
}
func TestEncode(t *testing.T) {
var block types.Block
assert.NoError(t, rlp.DecodeBytes(testBlocks[0], &block))
tx := block.Transactions()[0]
var bb bytes.Buffer
assert.NoError(t, EncodeTxForCompression(tx, &bb))
var from common.Address
txBackData, err := DecodeTxFromUncompressed(bytes.NewReader(slices.Clone(bb.Bytes())), &from)
assert.NoError(t, err)
txBack := types.NewTx(txBackData)
encodeTesting.CheckSameTx(t, tx, txBack, from)
}

View File

@@ -1,23 +1,27 @@
package v0
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/consensys/linea-monorepo/prover/backend/ethereum"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
)
// EncodeBlockForCompression encodes a block for compression.
func EncodeBlockForCompression(block *types.Block, w io.Writer) error {
func EncodeBlockForCompression(block *types.Block, w io.Writer, encodingOptions ...encode.Option) error {
if err := binary.Write(w, binary.LittleEndian, block.Time()); err != nil {
return err
}
for _, tx := range block.Transactions() {
if err := EncodeTxForCompression(tx, w); err != nil {
if err := EncodeTxForCompression(tx, w, encodingOptions...); err != nil {
return err
}
}
@@ -26,7 +30,11 @@ func EncodeBlockForCompression(block *types.Block, w io.Writer) error {
// EncodeTxForCompression encodes a transaction for compression.
// this code is from zk-evm-monorepo/prover/... but doesn't include the chainID
func EncodeTxForCompression(tx *types.Transaction, w io.Writer) error {
func EncodeTxForCompression(tx *types.Transaction, w io.Writer, encodingOptions ...encode.Option) error {
cfg := encode.NewConfig()
for _, o := range encodingOptions {
o(&cfg)
}
switch {
// LONDON with dynamic fees
case tx.Type() == types.DynamicFeeTxType:
@@ -39,7 +47,7 @@ func EncodeTxForCompression(tx *types.Transaction, w io.Writer) error {
tx.GasTipCap(),
tx.GasFeeCap(),
tx.Gas(),
ethereum.GetFrom(tx),
cfg.GetAddress(tx),
tx.To(),
tx.Value(),
tx.Data(),
@@ -57,7 +65,7 @@ func EncodeTxForCompression(tx *types.Transaction, w io.Writer) error {
tx.Nonce(),
tx.GasPrice(),
tx.Gas(),
ethereum.GetFrom(tx),
cfg.GetAddress(tx),
tx.To(),
tx.Value(),
tx.Data(),
@@ -71,7 +79,7 @@ func EncodeTxForCompression(tx *types.Transaction, w io.Writer) error {
tx.Nonce(),
tx.GasPrice(),
tx.Gas(),
ethereum.GetFrom(tx),
cfg.GetAddress(tx),
tx.To(),
tx.Value(),
tx.Data(),
@@ -85,7 +93,7 @@ func EncodeTxForCompression(tx *types.Transaction, w io.Writer) error {
tx.Nonce(),
tx.GasPrice(),
tx.Gas(),
ethereum.GetFrom(tx),
cfg.GetAddress(tx),
tx.To(),
tx.Value(),
tx.Data(),
@@ -98,3 +106,140 @@ func EncodeTxForCompression(tx *types.Transaction, w io.Writer) error {
return nil
}
// DecodeBlockFromUncompressed inverts [EncodeBlockForCompression]. It is primarily meant for
// testing and ensuring the encoding is bijective.
func DecodeBlockFromUncompressed(r *bytes.Reader) (encode.DecodedBlockData, error) {
var decTimestamp uint64
if err := binary.Read(r, binary.LittleEndian, &decTimestamp); err != nil {
return encode.DecodedBlockData{}, fmt.Errorf("could not decode timestamp: %w", err)
}
decodedBlk := encode.DecodedBlockData{
Timestamp: decTimestamp,
}
for r.Len() != 0 {
var (
from common.Address
)
if tx, err := DecodeTxFromUncompressed(r, &from); err != nil {
return encode.DecodedBlockData{}, fmt.Errorf("could not decode transaction #%v: %w", len(decodedBlk.Txs), err)
} else {
decodedBlk.Txs = append(decodedBlk.Txs, tx)
decodedBlk.Froms = append(decodedBlk.Froms, from)
}
}
return decodedBlk, nil
}
func ReadTxAsRlp(r *bytes.Reader) (fields []any, _type uint8, err error) {
firstByte, err := r.ReadByte()
if err != nil {
err = fmt.Errorf("could not read the first byte: %w", err)
return
}
// According to the RLP rule, `0xc0 + x` or `0xf7` indicates that the current
// item is a list and this is what's used to identify that the transaction is
// a legacy transaction or an EIP-155 transaction.
//
// Note that 0xc0 would indicate an empty list and thus be an invalid tx.
if firstByte == types.AccessListTxType || firstByte == types.DynamicFeeTxType {
_type = firstByte
} else {
if firstByte > 0xc0 {
// Set the byte-reader backward so that we can apply the rlp-decoder
// over it.
if err = r.UnreadByte(); err != nil {
return
}
_type = 0
} else {
err = fmt.Errorf("unexpected first byte: %x", firstByte)
return
}
}
err = rlp.Decode(r, &fields)
return
}
// DecodeTxFromUncompressed puts all the transaction data into the output, except for the from address,
// which will be put where the argument "from" is referencing
func DecodeTxFromUncompressed(r *bytes.Reader, from *common.Address) (types.TxData, error) {
fields, _type, err := ReadTxAsRlp(r)
if err != nil {
return nil, err
}
decoders := [3]func([]any, *common.Address) (types.TxData, error){
decodeLegacyTx,
decodeAccessListTx,
decodeDynamicFeeTx,
}
return decoders[_type](fields, from)
}
func decodeLegacyTx(fields []any, from *common.Address) (types.TxData, error) {
if len(fields) != 7 {
return nil, fmt.Errorf("unexpected number of field")
}
tx := new(types.LegacyTx)
err := errors.Join(
ethereum.TryCast(&tx.Nonce, fields[0], "nonce"),
ethereum.TryCast(&tx.GasPrice, fields[1], "gas-price"),
ethereum.TryCast(&tx.Gas, fields[2], "gas"),
ethereum.TryCast(from, fields[3], "from"),
ethereum.TryCast(&tx.To, fields[4], "to"),
ethereum.TryCast(&tx.Value, fields[5], "value"),
ethereum.TryCast(&tx.Data, fields[6], "data"),
)
return tx, err
}
func decodeAccessListTx(fields []any, from *common.Address) (types.TxData, error) {
if len(fields) != 8 {
return nil, fmt.Errorf("invalid number of field for a dynamic transaction")
}
tx := new(types.AccessListTx)
err := errors.Join(
ethereum.TryCast(&tx.Nonce, fields[0], "nonce"),
ethereum.TryCast(&tx.GasPrice, fields[1], "gas-price"),
ethereum.TryCast(&tx.Gas, fields[2], "gas"),
ethereum.TryCast(from, fields[3], "from"),
ethereum.TryCast(&tx.To, fields[4], "to"),
ethereum.TryCast(&tx.Value, fields[5], "value"),
ethereum.TryCast(&tx.Data, fields[6], "data"),
ethereum.TryCast(&tx.AccessList, fields[7], "access-list"),
)
return tx, err
}
func decodeDynamicFeeTx(fields []any, from *common.Address) (types.TxData, error) {
if len(fields) != 9 {
return nil, fmt.Errorf("invalid number of field for a dynamic transaction")
}
tx := new(types.DynamicFeeTx)
err := errors.Join(
ethereum.TryCast(&tx.Nonce, fields[0], "nonce"),
ethereum.TryCast(&tx.GasTipCap, fields[1], "gas-tip-cap"),
ethereum.TryCast(&tx.GasFeeCap, fields[2], "gas-fee-cap"),
ethereum.TryCast(&tx.Gas, fields[3], "gas"),
ethereum.TryCast(from, fields[4], "from"),
ethereum.TryCast(&tx.To, fields[5], "to"),
ethereum.TryCast(&tx.Value, fields[6], "value"),
ethereum.TryCast(&tx.Data, fields[7], "data"),
ethereum.TryCast(&tx.AccessList, fields[8], "access-list"),
)
return tx, err
}

View File

@@ -2,6 +2,7 @@ package v0
import (
"encoding/binary"
"errors"
"fmt"
"io"
@@ -22,11 +23,17 @@ type Header struct {
}
func (s *Header) Equals(other *Header) bool {
return s.CheckEquality(other) == nil
}
// CheckEquality similar to Equals but returning a description of the mismatch,
// returning nil if the objects are equal
func (s *Header) CheckEquality(other *Header) error {
if other == nil {
return false
return errors.New("empty header")
}
if s.DictChecksum != other.DictChecksum {
return false
return errors.New("dictionary mismatch")
}
// we ignore batches of len(0), since caller could have
@@ -36,25 +43,27 @@ func (s *Header) Equals(other *Header) bool {
small, large = other, s
}
absJ := 0
for i := range small.table {
if len(small.table[i]) != len(large.table[i]) {
return false
return fmt.Errorf("batch size mismatch at #%d", i)
}
for j := range small.table[i] {
if small.table[i][j] != large.table[i][j] {
return false
return fmt.Errorf("block size mismatch at block #%d of batch #%d, #%d total", j, i, absJ+j)
}
}
absJ += len(small.table[i])
}
// remaining batches of large should be empty
for i := len(small.table); i < len(large.table); i++ {
if len(large.table[i]) != 0 {
return false
return errors.New("batch count mismatch")
}
}
return true
return nil
}
func (s *Header) NbBatches() int {

View File

@@ -2,10 +2,10 @@ package v1
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"os"
"slices"
"strings"
@@ -13,10 +13,6 @@ import (
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/sirupsen/logrus"
fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/gnark-crypto/hash"
"github.com/icza/bitio"
"github.com/consensys/compress/lzss"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/rlp"
@@ -40,6 +36,7 @@ type BlobMaker struct {
Limit int // maximum size of the compressed data
compressor *lzss.Compressor // compressor used to compress the blob body
dict []byte // dictionary used for compression
dictStore dictionary.Store // dictionary store comprising only dict, used for decompression sanity checks
Header Header
@@ -68,8 +65,11 @@ func NewBlobMaker(dataLimit int, dictPath string) (*BlobMaker, error) {
}
dict = lzss.AugmentDict(dict)
blobMaker.dict = dict
if blobMaker.dictStore, err = dictionary.SingletonStore(dict, 1); err != nil {
return nil, err
}
dictChecksum, err := MiMCChecksumPackedData(dict, 8)
dictChecksum, err := encode.MiMCChecksumPackedData(dict, 8)
if err != nil {
return nil, err
}
@@ -116,7 +116,7 @@ func (bm *BlobMaker) Written() int {
func (bm *BlobMaker) Bytes() []byte {
if bm.currentBlobLength > 0 {
// sanity check that we can always decompress.
header, rawBlocks, _, err := DecompressBlob(bm.currentBlob[:bm.currentBlobLength], bm.dict)
header, rawBlocks, _, err := DecompressBlob(bm.currentBlob[:bm.currentBlobLength], bm.dictStore)
if err != nil {
var sbb strings.Builder
fmt.Fprintf(&sbb, "invalid blob: %v\n", err)
@@ -191,13 +191,13 @@ func (bm *BlobMaker) Write(rlpBlock []byte, forceReset bool) (ok bool, err error
}
// check that the header + the compressed data fits in the blob
fitsInBlob := PackAlignSize(bm.buf.Len()+bm.compressor.Len(), fr381.Bits-1) <= bm.Limit
fitsInBlob := encode.PackAlignSize(bm.buf.Len()+bm.compressor.Len(), fr381.Bits-1) <= bm.Limit
if !fitsInBlob {
// first thing to check is if we bypass compression, would that fit?
if bm.compressor.ConsiderBypassing() {
// we can bypass compression and get a better ratio.
// let's check if now we fit in the blob.
if PackAlignSize(bm.buf.Len()+bm.compressor.Len(), fr381.Bits-1) <= bm.Limit {
if encode.PackAlignSize(bm.buf.Len()+bm.compressor.Len(), fr381.Bits-1) <= bm.Limit {
goto bypass
}
}
@@ -221,7 +221,7 @@ bypass:
// copy the compressed data to the blob
bm.packBuffer.Reset()
n2, err := PackAlign(&bm.packBuffer, bm.buf.Bytes(), fr381.Bits-1, WithAdditionalInput(bm.compressor.Bytes()))
n2, err := encode.PackAlign(&bm.packBuffer, bm.buf.Bytes(), fr381.Bits-1, encode.WithAdditionalInput(bm.compressor.Bytes()))
if err != nil {
bm.compressor.Revert()
bm.Header.removeLastBlock()
@@ -264,9 +264,9 @@ func (bm *BlobMaker) Equals(other *BlobMaker) bool {
}
// DecompressBlob decompresses a blob and returns the header and the blocks as they were compressed.
func DecompressBlob(b, dict []byte) (blobHeader *Header, rawPayload []byte, blocks [][]byte, err error) {
func DecompressBlob(b []byte, dictStore dictionary.Store) (blobHeader *Header, rawPayload []byte, blocks [][]byte, err error) {
// UnpackAlign the blob
b, err = UnpackAlign(b, fr381.Bits-1, false)
b, err = encode.UnpackAlign(b, fr381.Bits-1, false)
if err != nil {
return nil, nil, nil, err
}
@@ -277,15 +277,10 @@ func DecompressBlob(b, dict []byte) (blobHeader *Header, rawPayload []byte, bloc
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to read blob header: %w", err)
}
// ensure the dict hash matches
{
expectedDictChecksum, err := MiMCChecksumPackedData(dict, 8)
if err != nil {
return nil, nil, nil, err
}
if !bytes.Equal(expectedDictChecksum, blobHeader.DictChecksum[:]) {
return nil, nil, nil, errors.New("invalid dict hash")
}
// retrieve dict
dict, err := dictStore.Get(blobHeader.DictChecksum[:], 1)
if err != nil {
return nil, nil, nil, err
}
b = b[read:]
@@ -317,210 +312,6 @@ func DecompressBlob(b, dict []byte) (blobHeader *Header, rawPayload []byte, bloc
return blobHeader, rawPayload, blocks, nil
}
// PackAlignSize returns the size of the data when packed with PackAlign.
func PackAlignSize(length0, packingSize int, options ...packAlignOption) (n int) {
var s packAlignSettings
s.initialize(length0, options...)
// we may need to add some bits to a and b to ensure we can process some blocks of 248 bits
extraBits := (packingSize - s.dataNbBits%packingSize) % packingSize
nbBits := s.dataNbBits + extraBits
return (nbBits / packingSize) * ((packingSize + 7) / 8)
}
type packAlignSettings struct {
dataNbBits int
lastByteNbUnusedBits uint8
noTerminalSymbol bool
additionalInput [][]byte
}
type packAlignOption func(*packAlignSettings)
func NoTerminalSymbol() packAlignOption {
return func(o *packAlignSettings) {
o.noTerminalSymbol = true
}
}
func WithAdditionalInput(data ...[]byte) packAlignOption {
return func(o *packAlignSettings) {
o.additionalInput = append(o.additionalInput, data...)
}
}
func WithLastByteNbUnusedBits(n uint8) packAlignOption {
if n > 7 {
panic("only 8 bits to a byte")
}
return func(o *packAlignSettings) {
o.lastByteNbUnusedBits = n
}
}
func (s *packAlignSettings) initialize(length int, options ...packAlignOption) {
for _, opt := range options {
opt(s)
}
nbBytes := length
for _, data := range s.additionalInput {
nbBytes += len(data)
}
if !s.noTerminalSymbol {
nbBytes++
}
s.dataNbBits = nbBytes*8 - int(s.lastByteNbUnusedBits)
}
// PackAlign writes a and b to w, aligned to fr.Element (bls12-377) boundary.
// It returns the length of the data written to w.
func PackAlign(w io.Writer, a []byte, packingSize int, options ...packAlignOption) (n int64, err error) {
var s packAlignSettings
s.initialize(len(a), options...)
if !s.noTerminalSymbol && s.lastByteNbUnusedBits != 0 {
return 0, errors.New("terminal symbols with byte aligned input not yet supported")
}
// we may need to add some bits to a and b to ensure we can process some blocks of packingSize bits
nbBits := (s.dataNbBits + (packingSize - 1)) / packingSize * packingSize
extraBits := nbBits - s.dataNbBits
// padding will always be less than bytesPerElem bytes
bytesPerElem := (packingSize + 7) / 8
packingSizeLastU64 := uint8(packingSize % 64)
if packingSizeLastU64 == 0 {
packingSizeLastU64 = 64
}
bytePadding := (extraBits + 7) / 8
buf := make([]byte, bytesPerElem, bytesPerElem+1)
// the last nonzero byte is 0xff
if !s.noTerminalSymbol {
buf = append(buf, 0)
buf[0] = 0xff
}
inReaders := make([]io.Reader, 2+len(s.additionalInput))
inReaders[0] = bytes.NewReader(a)
for i, data := range s.additionalInput {
inReaders[i+1] = bytes.NewReader(data)
}
inReaders[len(inReaders)-1] = bytes.NewReader(buf[:bytePadding+1])
r := bitio.NewReader(io.MultiReader(inReaders...))
var tryWriteErr error
tryWrite := func(v uint64) {
if tryWriteErr == nil {
tryWriteErr = binary.Write(w, binary.BigEndian, v)
}
}
for i := 0; i < nbBits/packingSize; i++ {
tryWrite(r.TryReadBits(packingSizeLastU64))
for j := int(packingSizeLastU64); j < packingSize; j += 64 {
tryWrite(r.TryReadBits(64))
}
}
if tryWriteErr != nil {
return 0, fmt.Errorf("when writing to w: %w", tryWriteErr)
}
if r.TryError != nil {
return 0, fmt.Errorf("when reading from multi-reader: %w", r.TryError)
}
n1 := (nbBits / packingSize) * bytesPerElem
if n1 != PackAlignSize(len(a), packingSize, options...) {
return 0, errors.New("inconsistent PackAlignSize")
}
return int64(n1), nil
}
// UnpackAlign unpacks r (packed with PackAlign) and returns the unpacked data.
func UnpackAlign(r []byte, packingSize int, noTerminalSymbol bool) ([]byte, error) {
bytesPerElem := (packingSize + 7) / 8
packingSizeLastU64 := uint8(packingSize % 64)
if packingSizeLastU64 == 0 {
packingSizeLastU64 = 64
}
n := len(r) / bytesPerElem
if n*bytesPerElem != len(r) {
return nil, fmt.Errorf("invalid data length; expected multiple of %d", bytesPerElem)
}
var out bytes.Buffer
w := bitio.NewWriter(&out)
for i := 0; i < n; i++ {
// read bytes
element := r[bytesPerElem*i : bytesPerElem*(i+1)]
// write bits
w.TryWriteBits(binary.BigEndian.Uint64(element[0:8]), packingSizeLastU64)
for j := 8; j < bytesPerElem; j += 8 {
w.TryWriteBits(binary.BigEndian.Uint64(element[j:j+8]), 64)
}
}
if w.TryError != nil {
return nil, fmt.Errorf("when writing to bitio.Writer: %w", w.TryError)
}
if err := w.Close(); err != nil {
return nil, fmt.Errorf("when closing bitio.Writer: %w", err)
}
if !noTerminalSymbol {
// the last nonzero byte should be 0xff
outLen := out.Len() - 1
for out.Bytes()[outLen] == 0 {
outLen--
}
if out.Bytes()[outLen] != 0xff {
return nil, errors.New("invalid terminal symbol")
}
out.Truncate(outLen)
}
return out.Bytes(), nil
}
// MiMCChecksumPackedData re-packs the data tightly into bls12-377 elements and computes the MiMC checksum.
// only supporting packing without a terminal symbol. Input with a terminal symbol will be interpreted in full padded length.
func MiMCChecksumPackedData(data []byte, inputPackingSize int, hashPackingOptions ...packAlignOption) ([]byte, error) {
dataNbBits := len(data) * 8
if inputPackingSize%8 != 0 {
inputBytesPerElem := (inputPackingSize + 7) / 8
dataNbBits = dataNbBits / inputBytesPerElem * inputPackingSize
var err error
if data, err = UnpackAlign(data, inputPackingSize, true); err != nil {
return nil, err
}
}
lastByteNbUnusedBits := 8 - dataNbBits%8
if lastByteNbUnusedBits == 8 {
lastByteNbUnusedBits = 0
}
var bb bytes.Buffer
packingOptions := make([]packAlignOption, len(hashPackingOptions)+1)
copy(packingOptions, hashPackingOptions)
packingOptions[len(packingOptions)-1] = WithLastByteNbUnusedBits(uint8(lastByteNbUnusedBits))
if _, err := PackAlign(&bb, data, fr377.Bits-1, packingOptions...); err != nil {
return nil, err
}
hsh := hash.MIMC_BLS12_377.New()
hsh.Write(bb.Bytes())
return hsh.Sum(nil), nil
}
// WorstCompressedBlockSize returns the size of the given block, as compressed by an "empty" blob maker.
// That is, with more context, blob maker could compress the block further, but this function
// returns the maximum size that can be achieved.
@@ -557,7 +348,7 @@ func (bm *BlobMaker) WorstCompressedBlockSize(rlpBlock []byte) (bool, int, error
}
// account for the padding
n = PackAlignSize(n, fr381.Bits-1, NoTerminalSymbol())
n = encode.PackAlignSize(n, fr381.Bits-1, encode.NoTerminalSymbol())
return expandingBlock, n, nil
}
@@ -610,7 +401,7 @@ func (bm *BlobMaker) RawCompressedSize(data []byte) (int, error) {
}
// account for the padding
n = PackAlignSize(n, fr381.Bits-1, NoTerminalSymbol())
n = encode.PackAlignSize(n, fr381.Bits-1, encode.NoTerminalSymbol())
return n, nil
}

View File

@@ -9,6 +9,8 @@ import (
"encoding/hex"
"errors"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"math/big"
"math/rand"
"os"
@@ -57,7 +59,8 @@ func testCompressorSingleSmallBatch(t *testing.T, blocks [][]byte) {
dict, err := os.ReadFile(testDictPath)
assert.NoError(t, err)
_, _, blocksBack, err := v1.DecompressBlob(bm.Bytes(), dict)
dictStore, err := dictionary.SingletonStore(dict, 1)
_, _, blocksBack, err := v1.DecompressBlob(bm.Bytes(), dictStore)
assert.NoError(t, err)
assert.Equal(t, len(blocks), len(blocksBack), "number of blocks should match")
// TODO compare the blocks
@@ -121,7 +124,7 @@ func assertBatchesConsistent(t *testing.T, raw, decoded [][]byte) {
var block types.Block
assert.NoError(t, rlp.Decode(bytes.NewReader(raw[i]), &block))
blockBack, err := test_utils.DecodeBlockFromUncompressed(bytes.NewReader(decoded[i]))
blockBack, err := v1.DecodeBlockFromUncompressed(bytes.NewReader(decoded[i]))
assert.NoError(t, err)
assert.Equal(t, block.Time(), blockBack.Timestamp, "block time should match")
}
@@ -512,7 +515,11 @@ func decompressBlob(b []byte) ([][][]byte, error) {
if err != nil {
return nil, fmt.Errorf("can't read dict: %w", err)
}
header, _, blocks, err := v1.DecompressBlob(b, dict)
dictStore, err := dictionary.SingletonStore(dict, 1)
if err != nil {
return nil, err
}
header, _, blocks, err := v1.DecompressBlob(b, dictStore)
if err != nil {
return nil, fmt.Errorf("can't decompress blob: %w", err)
}
@@ -641,10 +648,10 @@ func TestPack(t *testing.T) {
runTest := func(s1, s2 []byte) {
// pack them
buf.Reset()
written, err := v1.PackAlign(&buf, s1, fr381.Bits-1, v1.WithAdditionalInput(s2))
written, err := encode.PackAlign(&buf, s1, fr381.Bits-1, encode.WithAdditionalInput(s2))
assert.NoError(err, "pack should not generate an error")
assert.Equal(v1.PackAlignSize(len(s1)+len(s2), fr381.Bits-1), int(written), "written bytes should match expected PackAlignSize")
original, err := v1.UnpackAlign(buf.Bytes(), fr381.Bits-1, false)
assert.Equal(encode.PackAlignSize(len(s1)+len(s2), fr381.Bits-1), int(written), "written bytes should match expected PackAlignSize")
original, err := encode.UnpackAlign(buf.Bytes(), fr381.Bits-1, false)
assert.NoError(err, "unpack should not generate an error")
assert.Equal(s1, original[:len(s1)], "slices should match")

View File

@@ -8,6 +8,8 @@ import (
"io"
"github.com/consensys/linea-monorepo/prover/backend/ethereum"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
)
@@ -159,3 +161,51 @@ func PassRlpList(r *bytes.Reader) error {
return nil
}
// DecodeBlockFromUncompressed inverts [EncodeBlockForCompression]. It is primarily meant for
// testing and ensuring the encoding is bijective.
func DecodeBlockFromUncompressed(r *bytes.Reader) (encode.DecodedBlockData, error) {
var (
decNumTxs uint16
decTimestamp uint32
blockHash common.Hash
)
if err := binary.Read(r, binary.BigEndian, &decNumTxs); err != nil {
return encode.DecodedBlockData{}, fmt.Errorf("could not decode nb txs: %w", err)
}
if err := binary.Read(r, binary.BigEndian, &decTimestamp); err != nil {
return encode.DecodedBlockData{}, fmt.Errorf("could not decode timestamp: %w", err)
}
if _, err := r.Read(blockHash[:]); err != nil {
return encode.DecodedBlockData{}, fmt.Errorf("could not read the block hash: %w", err)
}
numTxs := int(decNumTxs)
decodedBlk := encode.DecodedBlockData{
Froms: make([]common.Address, numTxs),
Txs: make([]types.TxData, numTxs),
Timestamp: uint64(decTimestamp),
BlockHash: blockHash,
}
var err error
for i := 0; i < int(decNumTxs); i++ {
if decodedBlk.Txs[i], err = DecodeTxFromUncompressed(r, &decodedBlk.Froms[i]); err != nil {
return encode.DecodedBlockData{}, fmt.Errorf("could not decode transaction #%v: %w", i, err)
}
}
return decodedBlk, nil
}
func DecodeTxFromUncompressed(r *bytes.Reader, from *common.Address) (types.TxData, error) {
if _, err := r.Read(from[:]); err != nil {
return nil, fmt.Errorf("could not read from address: %w", err)
}
return ethereum.DecodeTxFromBytes(r)
}

View File

@@ -6,13 +6,12 @@ import (
"bytes"
"encoding/hex"
"fmt"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
encodeTesting "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode/test_utils"
"testing"
v1 "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1/test_utils"
"github.com/consensys/linea-monorepo/prover/utils/types"
"github.com/consensys/linea-monorepo/prover/backend/ethereum"
"github.com/consensys/linea-monorepo/prover/utils"
"github.com/ethereum/go-ethereum/common"
ethtypes "github.com/ethereum/go-ethereum/core/types"
@@ -33,58 +32,57 @@ func TestEncodeDecode(t *testing.T) {
t.Fatalf("could not decode test RLP block: %s", err.Error())
}
var (
buf = &bytes.Buffer{}
expected = test_utils.DecodedBlockData{
BlockHash: block.Hash(),
Txs: make([]ethtypes.Transaction, len(block.Transactions())),
Timestamp: block.Time(),
}
)
var buf bytes.Buffer
for i := range expected.Txs {
expected.Txs[i] = *block.Transactions()[i]
}
if err := v1.EncodeBlockForCompression(&block, buf); err != nil {
if err := v1.EncodeBlockForCompression(&block, &buf); err != nil {
t.Fatalf("failed encoding the block: %s", err.Error())
}
var (
encoded = buf.Bytes()
r = bytes.NewReader(encoded)
decoded, err = test_utils.DecodeBlockFromUncompressed(r)
size, errScan = v1.ScanBlockByteLen(encoded)
)
encoded := buf.Bytes()
r := bytes.NewReader(encoded)
decoded, err := v1.DecodeBlockFromUncompressed(r)
size, errScan := v1.ScanBlockByteLen(encoded)
assert.NoError(t, errScan, "error scanning the payload length")
assert.NotZero(t, size, "scanned a block size of zero")
require.NoError(t, err)
assert.Equal(t, expected.BlockHash, decoded.BlockHash)
assert.Equal(t, expected.Timestamp, decoded.Timestamp)
assert.Equal(t, len(expected.Txs), len(decoded.Txs))
assert.Equal(t, block.Hash(), decoded.BlockHash)
assert.Equal(t, block.Time(), decoded.Timestamp)
assert.Equal(t, len(block.Transactions()), len(decoded.Txs))
for i := range expected.Txs {
checkSameTx(t, &expected.Txs[i], &decoded.Txs[i], decoded.Froms[i])
for i := range block.Transactions() {
encodeTesting.CheckSameTx(t, block.Transactions()[i], ethtypes.NewTx(decoded.Txs[i]), decoded.Froms[i])
if t.Failed() {
return
}
}
t.Log("attempting RLP serialization")
encoded, err = rlp.EncodeToBytes(decoded.ToStd())
assert.NoError(t, err)
var blockBack ethtypes.Block
assert.NoError(t, rlp.Decode(bytes.NewReader(encoded), &blockBack))
assert.Equal(t, block.Hash(), blockBack.ParentHash())
assert.Equal(t, block.Time(), blockBack.Time())
assert.Equal(t, len(block.Transactions()), len(blockBack.Transactions()))
for i := range block.Transactions() {
tx := blockBack.Transactions()[i]
encodeTesting.CheckSameTx(t, block.Transactions()[i], ethtypes.NewTx(decoded.Txs[i]), common.Address(encode.GetAddressFromR(tx)))
if t.Failed() {
return
}
}
})
}
}
func checkSameTx(t *testing.T, orig, decoded *ethtypes.Transaction, from common.Address) {
assert.Equal(t, orig.To(), decoded.To())
assert.Equal(t, orig.Nonce(), decoded.Nonce())
assert.Equal(t, orig.Data(), decoded.Data())
assert.Equal(t, orig.Value(), decoded.Value())
assert.Equal(t, orig.Cost(), decoded.Cost())
assert.Equal(t, ethereum.GetFrom(orig), types.EthAddress(from))
}
func TestPassRlpList(t *testing.T) {
makeRlpSlice := func(n int) []byte {
@@ -138,7 +136,7 @@ func TestVectorDecode(t *testing.T) {
var (
postPadded = append(b, postPad[:]...)
r = bytes.NewReader(b)
_, errDec = test_utils.DecodeBlockFromUncompressed(r)
_, errDec = v1.DecodeBlockFromUncompressed(r)
_, errScan = v1.ScanBlockByteLen(postPadded)
)

View File

@@ -5,20 +5,16 @@ import (
"crypto/rand"
"encoding/binary"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/consensys/linea-monorepo/prover/backend/ethereum"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob"
v1 "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/consensys/compress/lzss"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/linea-monorepo/prover/backend/execution"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/encode"
v1 "github.com/consensys/linea-monorepo/prover/lib/compressor/blob/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -79,7 +75,7 @@ func LoadTestBlocks(testDataDir string) (testBlocks [][]byte, err error) {
return testBlocks, nil
}
func RandIntn(n int) int {
func RandIntn(n int) int { // TODO @Tabaie remove
var b [8]byte
_, _ = rand.Read(b[:])
return int(binary.BigEndian.Uint64(b[:]) % uint64(n))
@@ -102,7 +98,7 @@ func EmptyBlob(t require.TestingT) []byte {
assert.NoError(t, err)
var bb bytes.Buffer
if _, err = v1.PackAlign(&bb, headerB.Bytes(), fr381.Bits-1, v1.WithAdditionalInput(compressor.Bytes())); err != nil {
if _, err = encode.PackAlign(&bb, headerB.Bytes(), fr381.Bits-1, encode.WithAdditionalInput(compressor.Bytes())); err != nil {
panic(err)
}
return bb.Bytes()
@@ -165,72 +161,6 @@ func TestBlocksAndBlobMaker(t require.TestingT) ([][]byte, *v1.BlobMaker) {
return testBlocks, bm
}
// DecodedBlockData is a wrapper struct storing the different fields of a block
// that we deserialize when decoding an ethereum block.
type DecodedBlockData struct {
// BlockHash stores the decoded block hash
BlockHash common.Hash
// Timestamp holds the Unix timestamp of the block in
Timestamp uint64
// Froms stores the list of the sender address of every transaction
Froms []common.Address
// Txs stores the list of the decoded transactions.
Txs []types.Transaction
}
// DecodeBlockFromUncompressed inverts [EncodeBlockForCompression]. It is primarily meant for
// testing and ensuring the encoding is bijective.
func DecodeBlockFromUncompressed(r *bytes.Reader) (DecodedBlockData, error) {
var (
decNumTxs uint16
decTimestamp uint32
blockHash common.Hash
)
if err := binary.Read(r, binary.BigEndian, &decNumTxs); err != nil {
return DecodedBlockData{}, fmt.Errorf("could not decode nb txs: %w", err)
}
if err := binary.Read(r, binary.BigEndian, &decTimestamp); err != nil {
return DecodedBlockData{}, fmt.Errorf("could not decode timestamp: %w", err)
}
if _, err := r.Read(blockHash[:]); err != nil {
return DecodedBlockData{}, fmt.Errorf("could not read the block hash: %w", err)
}
var (
numTxs = int(decNumTxs)
decodedBlk = DecodedBlockData{
Froms: make([]common.Address, numTxs),
Txs: make([]types.Transaction, numTxs),
Timestamp: uint64(decTimestamp),
BlockHash: blockHash,
}
)
for i := 0; i < int(decNumTxs); i++ {
if err := DecodeTxFromUncompressed(r, &decodedBlk.Txs[i], &decodedBlk.Froms[i]); err != nil {
return DecodedBlockData{}, fmt.Errorf("could not decode transaction #%v: %w", i, err)
}
}
return decodedBlk, nil
}
func DecodeTxFromUncompressed(r *bytes.Reader, tx *types.Transaction, from *common.Address) (err error) {
if _, err := r.Read(from[:]); err != nil {
return fmt.Errorf("could not read from address: %w", err)
}
if err := ethereum.DecodeTxFromBytes(r, tx); err != nil {
return fmt.Errorf("could not deserialize transaction")
}
return nil
}
func GetDict(t require.TestingT) []byte {
dict, err := blob.GetDict()
require.NoError(t, err)

View File

@@ -0,0 +1,96 @@
package main
import "C"
import (
"errors"
"strings"
"sync"
"unsafe"
decompressor "github.com/consensys/linea-monorepo/prover/lib/compressor/blob"
"github.com/consensys/linea-monorepo/prover/lib/compressor/blob/dictionary"
)
//go:generate go build -tags nocorset -ldflags "-s -w" -buildmode=c-shared -o libdecompressor.so libdecompressor.go
func main() {}
var (
dictStore dictionary.Store
lastError error
lock sync.Mutex // probably unnecessary if coordinator guarantees single-threaded access
)
// Init initializes the decompressor.
//
//export Init
func Init() {
dictStore = dictionary.NewStore()
}
// LoadDictionaries loads a number of dictionaries into the decompressor
// according to colon-separated paths.
// Returns the number of dictionaries loaded, or -1 if unsuccessful.
// If -1 is returned, the Error() method will return a string describing the error.
//
//export LoadDictionaries
func LoadDictionaries(dictPaths *C.char) C.int {
lock.Lock()
defer lock.Unlock()
pathsConcat := C.GoString(dictPaths)
paths := strings.Split(pathsConcat, ":")
if err := dictStore.Load(paths...); err != nil {
lastError = err
return -1
}
return C.int(len(paths))
}
// Decompress processes a Linea blob and outputs an RLP encoded list of RLP encoded blocks.
// Due to information loss during pre-compression encoding, two pieces of information are represented "hackily":
// The block hash is in the ParentHash field.
// The transaction from address is in the signature.R field.
//
// Returns the number of bytes in out, or -1 in case of failure
// If -1 is returned, the Error() method will return a string describing the error.
//
//export Decompress
func Decompress(blob *C.char, blobLength C.int, out *C.char, outMaxLength C.int) C.int {
lock.Lock()
defer lock.Unlock()
bGo := C.GoBytes(unsafe.Pointer(blob), blobLength)
blocks, err := decompressor.DecompressBlob(bGo, dictStore)
if err != nil {
lastError = err
return -1
}
if len(blocks) > int(outMaxLength) {
lastError = errors.New("decoded blob does not fit in output buffer")
return -1
}
outSlice := unsafe.Slice((*byte)(unsafe.Pointer(out)), len(blocks))
copy(outSlice, blocks)
return C.int(len(blocks))
}
// Error returns the last encountered error.
// If no error was encountered, returns nil.
//
//export Error
func Error() *C.char {
lock.Lock()
defer lock.Unlock()
if lastError != nil {
// this leaks memory, but since this represents a fatal error, it's probably ok.
return C.CString(lastError.Error())
}
return nil
}

View File

@@ -0,0 +1,106 @@
/* Code generated by cmd/cgo; DO NOT EDIT. */
/* package command-line-arguments */
#line 1 "cgo-builtin-export-prolog"
#include <stddef.h>
#ifndef GO_CGO_EXPORT_PROLOGUE_H
#define GO_CGO_EXPORT_PROLOGUE_H
#ifndef GO_CGO_GOSTRING_TYPEDEF
typedef struct { const char *p; ptrdiff_t n; } _GoString_;
#endif
#endif
/* Start of preamble from import "C" comments. */
/* End of preamble from import "C" comments. */
/* Start of boilerplate cgo prologue. */
#line 1 "cgo-gcc-export-header-prolog"
#ifndef GO_CGO_PROLOGUE_H
#define GO_CGO_PROLOGUE_H
typedef signed char GoInt8;
typedef unsigned char GoUint8;
typedef short GoInt16;
typedef unsigned short GoUint16;
typedef int GoInt32;
typedef unsigned int GoUint32;
typedef long long GoInt64;
typedef unsigned long long GoUint64;
typedef GoInt64 GoInt;
typedef GoUint64 GoUint;
typedef size_t GoUintptr;
typedef float GoFloat32;
typedef double GoFloat64;
#ifdef _MSC_VER
#include <complex.h>
typedef _Fcomplex GoComplex64;
typedef _Dcomplex GoComplex128;
#else
typedef float _Complex GoComplex64;
typedef double _Complex GoComplex128;
#endif
/*
static assertion to make sure the file is being used on architecture
at least with matching size of GoInt.
*/
typedef char _check_for_64_bit_pointer_matching_GoInt[sizeof(void*)==64/8 ? 1:-1];
#ifndef GO_CGO_GOSTRING_TYPEDEF
typedef _GoString_ GoString;
#endif
typedef void *GoMap;
typedef void *GoChan;
typedef struct { void *t; void *v; } GoInterface;
typedef struct { void *data; GoInt len; GoInt cap; } GoSlice;
#endif
/* End of boilerplate cgo prologue. */
#ifdef __cplusplus
extern "C" {
#endif
// Init initializes the decompressor.
//
extern void Init();
// LoadDictionaries loads a number of dictionaries into the decompressor
// according to colon-separated paths.
// Returns the number of dictionaries loaded, or -1 if unsuccessful.
// If -1 is returned, the Error() method will return a string describing the error.
//
extern int LoadDictionaries(char* dictPaths);
// Decompress processes a Linea blob and outputs an RLP encoded list of RLP encoded blocks.
// Due to information loss during pre-compression encoding, two pieces of information are represented "hackily":
// The block hash is in the ParentHash field.
// The transaction from address is in the signature.R field.
//
// Returns the number of bytes in out, or -1 in case of failure
// If -1 is returned, the Error() method will return a string describing the error.
//
extern int Decompress(char* blob, int blobLength, char* out, int outMaxLength);
// Error returns the last encountered error.
// If no error was encountered, returns nil.
//
extern char* Error();
#ifdef __cplusplus
}
#endif

View File

@@ -3,12 +3,17 @@ package test_utils
import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"fmt"
"github.com/stretchr/testify/require"
"math"
"os"
"reflect"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
type FakeTestingT struct{}
@@ -41,8 +46,109 @@ func RandIntSliceN(length, n int) []int {
return res
}
type BytesEqualError struct {
Index int
error string
}
func (e *BytesEqualError) Error() string {
return e.error
}
func LoadJson(t *testing.T, path string, v any) {
in, err := os.Open(path)
require.NoError(t, err)
require.NoError(t, json.NewDecoder(in).Decode(v))
}
// BytesEqual between byte slices a,b
// a readable error message would show in case of inequality
// TODO error options: block size, check forwards or backwards etc
func BytesEqual(expected, actual []byte) error {
l := min(len(expected), len(actual))
failure := 0
for failure < l {
if expected[failure] != actual[failure] {
break
}
failure++
}
if len(expected) == len(actual) {
return nil
}
// there is a mismatch
var sb strings.Builder
const (
radius = 40
blockSize = 32
)
printCentered := func(b []byte) {
for i := max(failure-radius, 0); i <= failure+radius; i++ {
if i%blockSize == 0 && i != failure-radius {
sb.WriteString(" ")
}
if i >= 0 && i < len(b) {
sb.WriteString(hex.EncodeToString([]byte{b[i]})) // inefficient, but this whole error printing sub-procedure will not be run more than once
} else {
sb.WriteString(" ")
}
}
}
sb.WriteString(fmt.Sprintf("mismatch starting at byte %d\n", failure))
sb.WriteString("expected: ")
printCentered(expected)
sb.WriteString("\n")
sb.WriteString("actual: ")
printCentered(actual)
sb.WriteString("\n")
sb.WriteString(" ")
for i := max(failure-radius, 0); i <= failure+radius; {
if i%blockSize == 0 && i != failure-radius {
s := strconv.Itoa(i)
sb.WriteString(" ")
sb.WriteString(s)
i += len(s) / 2
if len(s)%2 != 0 {
sb.WriteString(" ")
i++
}
} else {
if i == failure {
sb.WriteString("^^")
} else {
sb.WriteString(" ")
}
i++
}
}
sb.WriteString("\n")
return &BytesEqualError{
Index: failure,
error: sb.String(),
}
}
func SlicesEqual[T any](expected, actual []T) error {
if l1, l2 := len(expected), len(actual); l1 != l2 {
return fmt.Errorf("length mismatch %d≠%d", l1, l2)
}
for i := range expected {
if !reflect.DeepEqual(expected[i], actual[i]) {
return fmt.Errorf("mismatch at #%d:\nexpected %v\nencountered %v", i, expected[i], actual[i])
}
}
return nil
}