Compare commits

..

1 Commits

Author SHA1 Message Date
Péter Garamvölgyi
8a0b526391 refactor(orm): Make ORM usage consistent (#627) 2023-07-10 09:18:08 +02:00
18 changed files with 429 additions and 262 deletions

View File

@@ -159,13 +159,13 @@ func (r *Layer1Relayer) processSavedEvent(msg *orm.L1Message) error {
// ProcessGasPriceOracle imports gas price to layer2
func (r *Layer1Relayer) ProcessGasPriceOracle() {
latestBlockHeight, err := r.l1Block.GetLatestL1BlockHeight()
latestBlockHeight, err := r.l1Block.GetLatestL1BlockHeight(r.ctx)
if err != nil {
log.Warn("Failed to fetch latest L1 block height from db", "err", err)
return
}
blocks, err := r.l1Block.GetL1Blocks(map[string]interface{}{
blocks, err := r.l1Block.GetL1Blocks(r.ctx, map[string]interface{}{
"number": latestBlockHeight,
})
if err != nil {

View File

@@ -153,8 +153,8 @@ func testL1RelayerGasOracleConfirm(t *testing.T) {
// Check the database for the updated status using TryTimes.
ok := utils.TryTimes(5, func() bool {
msg1, err1 := l1BlockOrm.GetL1Blocks(map[string]interface{}{"hash": "gas-oracle-1"})
msg2, err2 := l1BlockOrm.GetL1Blocks(map[string]interface{}{"hash": "gas-oracle-2"})
msg1, err1 := l1BlockOrm.GetL1Blocks(ctx, map[string]interface{}{"hash": "gas-oracle-1"})
msg2, err2 := l1BlockOrm.GetL1Blocks(ctx, map[string]interface{}{"hash": "gas-oracle-2"})
return err1 == nil && len(msg1) == 1 && types.GasOracleStatus(msg1[0].GasOracleStatus) == types.GasOracleImported &&
err2 == nil && len(msg2) == 1 && types.GasOracleStatus(msg2[0].GasOracleStatus) == types.GasOracleFailed
})
@@ -175,28 +175,28 @@ func testL1RelayerProcessGasPriceOracle(t *testing.T) {
var l1BlockOrm *orm.L1Block
convey.Convey("GetLatestL1BlockHeight failure", t, func() {
targetErr := errors.New("GetLatestL1BlockHeight error")
patchGuard := gomonkey.ApplyMethodFunc(l1BlockOrm, "GetLatestL1BlockHeight", func() (uint64, error) {
patchGuard := gomonkey.ApplyMethodFunc(l1BlockOrm, "GetLatestL1BlockHeight", func(ctx context.Context) (uint64, error) {
return 0, targetErr
})
defer patchGuard.Reset()
l1Relayer.ProcessGasPriceOracle()
})
patchGuard := gomonkey.ApplyMethodFunc(l1BlockOrm, "GetLatestL1BlockHeight", func() (uint64, error) {
patchGuard := gomonkey.ApplyMethodFunc(l1BlockOrm, "GetLatestL1BlockHeight", func(ctx context.Context) (uint64, error) {
return 100, nil
})
defer patchGuard.Reset()
convey.Convey("GetL1Blocks failure", t, func() {
targetErr := errors.New("GetL1Blocks error")
patchGuard.ApplyMethodFunc(l1BlockOrm, "GetL1Blocks", func(fields map[string]interface{}) ([]orm.L1Block, error) {
patchGuard.ApplyMethodFunc(l1BlockOrm, "GetL1Blocks", func(ctx context.Context, fields map[string]interface{}) ([]orm.L1Block, error) {
return nil, targetErr
})
l1Relayer.ProcessGasPriceOracle()
})
convey.Convey("Block not exist", t, func() {
patchGuard.ApplyMethodFunc(l1BlockOrm, "GetL1Blocks", func(fields map[string]interface{}) ([]orm.L1Block, error) {
patchGuard.ApplyMethodFunc(l1BlockOrm, "GetL1Blocks", func(ctx context.Context, fields map[string]interface{}) ([]orm.L1Block, error) {
tmpInfo := []orm.L1Block{
{Hash: "gas-oracle-1", Number: 0},
{Hash: "gas-oracle-2", Number: 1},
@@ -206,7 +206,7 @@ func testL1RelayerProcessGasPriceOracle(t *testing.T) {
l1Relayer.ProcessGasPriceOracle()
})
patchGuard.ApplyMethodFunc(l1BlockOrm, "GetL1Blocks", func(fields map[string]interface{}) ([]orm.L1Block, error) {
patchGuard.ApplyMethodFunc(l1BlockOrm, "GetL1Blocks", func(ctx context.Context, fields map[string]interface{}) ([]orm.L1Block, error) {
tmpInfo := []orm.L1Block{
{
Hash: "gas-oracle-1",

View File

@@ -74,7 +74,7 @@ func NewL1WatcherClient(ctx context.Context, client *ethclient.Client, startHeig
}
l1BlockOrm := orm.NewL1Block(db)
savedL1BlockHeight, err := l1BlockOrm.GetLatestL1BlockHeight()
savedL1BlockHeight, err := l1BlockOrm.GetLatestL1BlockHeight(ctx)
if err != nil {
log.Warn("Failed to fetch latest L1 block height from db", "err", err)
savedL1BlockHeight = 0

View File

@@ -109,7 +109,7 @@ func (w *L2WatcherClient) TryFetchRunningMissingBlocks(blockHeight uint64) {
}
// Fetch and store block traces for missing blocks
for from := uint64(heightInDB) + 1; from <= blockHeight; from += blockTracesFetchLimit {
for from := heightInDB + 1; from <= blockHeight; from += blockTracesFetchLimit {
to := from + blockTracesFetchLimit - 1
if to > blockHeight {

View File

@@ -88,7 +88,7 @@ func testFetchRunningMissingBlocks(t *testing.T) {
wc := prepareWatcherClient(l2Cli, db, address)
wc.TryFetchRunningMissingBlocks(latestHeight)
fetchedHeight, err := l2BlockOrm.GetL2BlocksLatestHeight(context.Background())
return err == nil && uint64(fetchedHeight) == latestHeight
return err == nil && fetchedHeight == latestHeight
})
assert.True(t, ok)
}

View File

@@ -71,6 +71,7 @@ func (*Batch) TableName() string {
// The returned batches are sorted in ascending order by their index.
func (o *Batch) GetBatches(ctx context.Context, fields map[string]interface{}, orderByList []string, limit int) ([]*Batch, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
for key, value := range fields {
db = db.Where(key, value)
@@ -88,46 +89,51 @@ func (o *Batch) GetBatches(ctx context.Context, fields map[string]interface{}, o
var batches []*Batch
if err := db.Find(&batches).Error; err != nil {
return nil, err
return nil, fmt.Errorf("Batch.GetBatches error: %w, fields: %v, orderByList: %v", err, fields, orderByList)
}
return batches, nil
}
// GetBatchCount retrieves the total number of batches in the database.
func (o *Batch) GetBatchCount(ctx context.Context) (uint64, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
var count int64
err := o.db.WithContext(ctx).Model(&Batch{}).Count(&count).Error
if err != nil {
return 0, err
if err := db.Count(&count).Error; err != nil {
return 0, fmt.Errorf("Batch.GetBatchCount error: %w", err)
}
return uint64(count), nil
}
// GetVerifiedProofByHash retrieves the verified aggregate proof for a batch with the given hash.
func (o *Batch) GetVerifiedProofByHash(ctx context.Context, hash string) (*message.AggProof, error) {
var batch Batch
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Select("proof")
db = db.Where("hash = ? AND proving_status = ?", hash, types.ProvingTaskVerified)
var batch Batch
if err := db.Find(&batch).Error; err != nil {
return nil, err
return nil, fmt.Errorf("Batch.GetVerifiedProofByHash error: %w, batch hash: %v", err, hash)
}
var proof message.AggProof
if err := json.Unmarshal(batch.Proof, &proof); err != nil {
return nil, err
return nil, fmt.Errorf("Batch.GetVerifiedProofByHash error: %w, batch hash: %v", err, hash)
}
return &proof, nil
}
// GetLatestBatch retrieves the latest batch from the database.
func (o *Batch) GetLatestBatch(ctx context.Context) (*Batch, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Order("index desc")
var latestBatch Batch
err := o.db.WithContext(ctx).Order("index desc").First(&latestBatch).Error
if err != nil {
return nil, err
if err := db.First(&latestBatch).Error; err != nil {
return nil, fmt.Errorf("Batch.GetLatestBatch error: %w", err)
}
return &latestBatch, nil
}
@@ -138,13 +144,14 @@ func (o *Batch) GetRollupStatusByHashList(ctx context.Context, hashes []string)
return nil, nil
}
var batches []Batch
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Select("hash, rollup_status")
db = db.Where("hash IN ?", hashes)
var batches []Batch
if err := db.Find(&batches).Error; err != nil {
return nil, err
return nil, fmt.Errorf("Batch.GetRollupStatusByHashList error: %w, hashes: %v", err, hashes)
}
hashToStatusMap := make(map[string]types.RollupStatus)
@@ -156,7 +163,7 @@ func (o *Batch) GetRollupStatusByHashList(ctx context.Context, hashes []string)
for _, hash := range hashes {
status, ok := hashToStatusMap[hash]
if !ok {
return nil, fmt.Errorf("hash not found in database: %s", hash)
return nil, fmt.Errorf("Batch.GetRollupStatusByHashList: hash not found in database: %s", hash)
}
statuses = append(statuses, status)
}
@@ -171,23 +178,28 @@ func (o *Batch) GetPendingBatches(ctx context.Context, limit int) ([]*Batch, err
return nil, errors.New("limit must be greater than zero")
}
var batches []*Batch
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("rollup_status = ?", types.RollupPending)
db = db.Order("index ASC")
db = db.Limit(limit)
db = db.Where("rollup_status = ?", types.RollupPending).Order("index ASC").Limit(limit)
var batches []*Batch
if err := db.Find(&batches).Error; err != nil {
return nil, err
return nil, fmt.Errorf("Batch.GetPendingBatches error: %w", err)
}
return batches, nil
}
// GetBatchByIndex retrieves the batch by the given index.
func (o *Batch) GetBatchByIndex(ctx context.Context, index uint64) (*Batch, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("index = ?", index)
var batch Batch
err := o.db.WithContext(ctx).Where("index = ?", index).First(&batch).Error
if err != nil {
return nil, err
if err := db.First(&batch).Error; err != nil {
return nil, fmt.Errorf("Batch.GetBatchByIndex error: %w, index: %v", err, index)
}
return &batch, nil
}
@@ -198,13 +210,8 @@ func (o *Batch) InsertBatch(ctx context.Context, startChunkIndex, endChunkIndex
return nil, errors.New("invalid args")
}
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
parentBatch, err := o.GetLatestBatch(ctx)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
if err != nil && !errors.Is(errors.Unwrap(err), gorm.ErrRecordNotFound) {
log.Error("failed to get the latest batch", "err", err)
return nil, err
}
@@ -258,11 +265,17 @@ func (o *Batch) InsertBatch(ctx context.Context, startChunkIndex, endChunkIndex
RollupStatus: int16(types.RollupPending),
}
if err := db.WithContext(ctx).Create(&newBatch).Error; err != nil {
log.Error("failed to insert batch", "batch", newBatch, "err", err)
return nil, err
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db.WithContext(ctx)
db = db.Model(&Batch{})
if err := db.Create(&newBatch).Error; err != nil {
log.Error("failed to insert batch", "batch", newBatch, "err", err)
return nil, fmt.Errorf("Batch.InsertBatch error: %w", err)
}
return &newBatch, nil
}
@@ -272,10 +285,15 @@ func (o *Batch) UpdateSkippedBatches(ctx context.Context) (uint64, error) {
int(types.ProvingTaskSkipped),
int(types.ProvingTaskFailed),
}
result := o.db.WithContext(ctx).Model(&Batch{}).Where("rollup_status", int(types.RollupCommitted)).
Where("proving_status IN (?)", provingStatusList).Update("rollup_status", int(types.RollupFinalizationSkipped))
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("rollup_status", int(types.RollupCommitted))
db = db.Where("proving_status IN (?)", provingStatusList)
result := db.Update("rollup_status", int(types.RollupFinalizationSkipped))
if result.Error != nil {
return 0, result.Error
return 0, fmt.Errorf("Batch.UpdateSkippedBatches error: %w", result.Error)
}
return uint64(result.RowsAffected), nil
}
@@ -285,19 +303,19 @@ func (o *Batch) UpdateL2GasOracleStatusAndOracleTxHash(ctx context.Context, hash
updateFields := make(map[string]interface{})
updateFields["oracle_status"] = int(status)
updateFields["oracle_tx_hash"] = txHash
if err := o.db.WithContext(ctx).Model(&Batch{}).Where("hash", hash).Updates(updateFields).Error; err != nil {
return err
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash", hash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Batch.UpdateL2GasOracleStatusAndOracleTxHash error: %w, batch hash: %v, status: %v, txHash: %v", err, hash, status.String(), txHash)
}
return nil
}
// UpdateProvingStatus updates the proving status of a batch.
func (o *Batch) UpdateProvingStatus(ctx context.Context, hash string, status types.ProvingStatus, dbTX ...*gorm.DB) error {
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
updateFields := make(map[string]interface{})
updateFields["proving_status"] = int(status)
@@ -310,19 +328,22 @@ func (o *Batch) UpdateProvingStatus(ctx context.Context, hash string, status typ
updateFields["proved_at"] = time.Now()
}
if err := db.WithContext(ctx).Model(&Batch{}).Where("hash", hash).Updates(updateFields).Error; err != nil {
return err
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash", hash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Batch.UpdateProvingStatus error: %w, batch hash: %v, status: %v", err, hash, status.String())
}
return nil
}
// UpdateRollupStatus updates the rollup status of a batch.
func (o *Batch) UpdateRollupStatus(ctx context.Context, hash string, status types.RollupStatus, dbTX ...*gorm.DB) error {
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
updateFields := make(map[string]interface{})
updateFields["rollup_status"] = int(status)
@@ -332,8 +353,17 @@ func (o *Batch) UpdateRollupStatus(ctx context.Context, hash string, status type
case types.RollupFinalized:
updateFields["finalized_at"] = time.Now()
}
if err := db.WithContext(ctx).Model(&Batch{}).Where("hash", hash).Updates(updateFields).Error; err != nil {
return err
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash", hash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Batch.UpdateRollupStatus error: %w, batch hash: %v, status: %v", err, hash, status.String())
}
return nil
}
@@ -346,8 +376,13 @@ func (o *Batch) UpdateCommitTxHashAndRollupStatus(ctx context.Context, hash stri
if status == types.RollupCommitted {
updateFields["committed_at"] = time.Now()
}
if err := o.db.WithContext(ctx).Model(&Batch{}).Where("hash", hash).Updates(updateFields).Error; err != nil {
return err
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash", hash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Batch.UpdateCommitTxHashAndRollupStatus error: %w, batch hash: %v, status: %v, commitTxHash: %v", err, hash, status.String(), commitTxHash)
}
return nil
}
@@ -360,8 +395,13 @@ func (o *Batch) UpdateFinalizeTxHashAndRollupStatus(ctx context.Context, hash st
if status == types.RollupFinalized {
updateFields["finalized_at"] = time.Now()
}
if err := o.db.WithContext(ctx).Model(&Batch{}).Where("hash", hash).Updates(updateFields).Error; err != nil {
return err
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash", hash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Batch.UpdateFinalizeTxHashAndRollupStatus error: %w, batch hash: %v, status: %v, commitTxHash: %v", err, hash, status.String(), finalizeTxHash)
}
return nil
}
@@ -371,12 +411,19 @@ func (o *Batch) UpdateFinalizeTxHashAndRollupStatus(ctx context.Context, hash st
func (o *Batch) UpdateProofByHash(ctx context.Context, hash string, proof *message.AggProof, proofTimeSec uint64) error {
proofBytes, err := json.Marshal(proof)
if err != nil {
return err
return fmt.Errorf("Batch.UpdateProofByHash error: %w, batch hash: %v", err, hash)
}
updateFields := make(map[string]interface{})
updateFields["proof"] = proofBytes
updateFields["proof_time_sec"] = proofTimeSec
err = o.db.WithContext(ctx).Model(&Batch{}).Where("hash", hash).Updates(updateFields).Error
return err
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash", hash)
if err = db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Batch.UpdateProofByHash error: %w, batch hash: %v", err, hash)
}
return nil
}

View File

@@ -3,6 +3,7 @@ package orm
import (
"context"
"errors"
"fmt"
"time"
"scroll-tech/common/types"
@@ -61,19 +62,22 @@ func (*Chunk) TableName() string {
// The returned chunks are sorted in ascending order by their index.
func (o *Chunk) GetChunksInRange(ctx context.Context, startIndex uint64, endIndex uint64) ([]*Chunk, error) {
if startIndex > endIndex {
return nil, errors.New("start index should be less than or equal to end index")
return nil, fmt.Errorf("Chunk.GetChunksInRange: start index should be less than or equal to end index, start index: %v, end index: %v", startIndex, endIndex)
}
var chunks []*Chunk
db := o.db.WithContext(ctx).Where("index >= ? AND index <= ?", startIndex, endIndex)
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("index >= ? AND index <= ?", startIndex, endIndex)
db = db.Order("index ASC")
var chunks []*Chunk
if err := db.Find(&chunks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("Chunk.GetChunksInRange error: %w, start index: %v, end index: %v", err, startIndex, endIndex)
}
if startIndex+uint64(len(chunks)) != endIndex+1 {
return nil, errors.New("number of chunks not expected in the specified range")
// sanity check
if uint64(len(chunks)) != endIndex-startIndex+1 {
return nil, fmt.Errorf("Chunk.GetChunksInRange: incorrect number of chunks, expected: %v, got: %v, start index: %v, end index: %v", endIndex-startIndex+1, len(chunks), startIndex, endIndex)
}
return chunks, nil
@@ -81,25 +85,27 @@ func (o *Chunk) GetChunksInRange(ctx context.Context, startIndex uint64, endInde
// GetUnbatchedChunks retrieves unbatched chunks from the database.
func (o *Chunk) GetUnbatchedChunks(ctx context.Context) ([]*Chunk, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("batch_hash IS NULL")
db = db.Order("index asc")
var chunks []*Chunk
err := o.db.WithContext(ctx).
Where("batch_hash IS NULL").
Order("index asc").
Find(&chunks).Error
if err != nil {
return nil, err
if err := db.Find(&chunks).Error; err != nil {
return nil, fmt.Errorf("Chunk.GetUnbatchedChunks error: %w", err)
}
return chunks, nil
}
// GetLatestChunk retrieves the latest chunk from the database.
func (o *Chunk) GetLatestChunk(ctx context.Context) (*Chunk, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Order("index desc")
var latestChunk Chunk
err := o.db.WithContext(ctx).
Order("index desc").
First(&latestChunk).Error
if err != nil {
return nil, err
if err := db.First(&latestChunk).Error; err != nil {
return nil, fmt.Errorf("Chunk.GetLatestChunk error: %w", err)
}
return &latestChunk, nil
}
@@ -110,17 +116,12 @@ func (o *Chunk) InsertChunk(ctx context.Context, chunk *types.Chunk, dbTX ...*go
return nil, errors.New("invalid args")
}
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
var chunkIndex uint64
var totalL1MessagePoppedBefore uint64
parentChunk, err := o.GetLatestChunk(ctx)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
if err != nil && !errors.Is(errors.Unwrap(err), gorm.ErrRecordNotFound) {
log.Error("failed to get latest chunk", "err", err)
return nil, err
return nil, fmt.Errorf("Chunk.InsertChunk error: %w", err)
}
// if parentChunk==nil then err==gorm.ErrRecordNotFound, which means there's
@@ -134,7 +135,7 @@ func (o *Chunk) InsertChunk(ctx context.Context, chunk *types.Chunk, dbTX ...*go
hash, err := chunk.Hash(totalL1MessagePoppedBefore)
if err != nil {
log.Error("failed to get chunk hash", "err", err)
return nil, err
return nil, fmt.Errorf("Chunk.InsertChunk error: %w", err)
}
var totalL2TxGas uint64
@@ -166,9 +167,15 @@ func (o *Chunk) InsertChunk(ctx context.Context, chunk *types.Chunk, dbTX ...*go
ProvingStatus: int16(types.ProvingTaskUnassigned),
}
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.WithContext(ctx)
db = db.Model(&Chunk{})
if err := db.Create(&newChunk).Error; err != nil {
log.Error("failed to insert chunk", "hash", hash, "err", err)
return nil, err
return nil, fmt.Errorf("Chunk.InsertChunk error: %w, chunk hash: %v", err, newChunk.Hash)
}
return &newChunk, nil
@@ -176,11 +183,6 @@ func (o *Chunk) InsertChunk(ctx context.Context, chunk *types.Chunk, dbTX ...*go
// UpdateProvingStatus updates the proving status of a chunk.
func (o *Chunk) UpdateProvingStatus(ctx context.Context, hash string, status types.ProvingStatus, dbTX ...*gorm.DB) error {
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
updateFields := make(map[string]interface{})
updateFields["proving_status"] = int(status)
@@ -193,8 +195,16 @@ func (o *Chunk) UpdateProvingStatus(ctx context.Context, hash string, status typ
updateFields["proved_at"] = time.Now()
}
if err := db.Model(&Chunk{}).Where("hash", hash).Updates(updateFields).Error; err != nil {
return err
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("hash", hash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Chunk.UpdateProvingStatus error: %w, chunk hash: %v, status: %v", err, hash, status.String())
}
return nil
}
@@ -206,7 +216,12 @@ func (o *Chunk) UpdateBatchHashInRange(ctx context.Context, startIndex uint64, e
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.Model(&Chunk{}).Where("index >= ? AND index <= ?", startIndex, endIndex)
db = db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("index >= ? AND index <= ?", startIndex, endIndex)
return db.Update("batch_hash", batchHash).Error
if err := db.Update("batch_hash", batchHash).Error; err != nil {
return fmt.Errorf("Chunk.UpdateBatchHashInRange error: %w, start index: %v, end index: %v, batch hash: %v", err, startIndex, endIndex, batchHash)
}
return nil
}

View File

@@ -2,8 +2,8 @@ package orm
import (
"context"
"fmt"
"github.com/scroll-tech/go-ethereum/log"
"gorm.io/gorm"
"scroll-tech/common/types"
@@ -34,54 +34,64 @@ func (*L1Block) TableName() string {
}
// GetLatestL1BlockHeight get the latest l1 block height
func (l *L1Block) GetLatestL1BlockHeight() (uint64, error) {
result := l.db.Model(&L1Block{}).Select("COALESCE(MAX(number), 0)").Row()
if result.Err() != nil {
return 0, result.Err()
}
func (o *L1Block) GetLatestL1BlockHeight(ctx context.Context) (uint64, error) {
db := o.db.WithContext(ctx)
db = db.Model(&L1Block{})
db = db.Select("COALESCE(MAX(number), 0)")
var maxNumber uint64
if err := result.Scan(&maxNumber); err != nil {
return 0, err
if err := db.Row().Scan(&maxNumber); err != nil {
return 0, fmt.Errorf("L1Block.GetLatestL1BlockHeight error: %w", err)
}
return maxNumber, nil
}
// GetL1Blocks get the l1 blocks
func (l *L1Block) GetL1Blocks(fields map[string]interface{}) ([]L1Block, error) {
var l1Blocks []L1Block
db := l.db
func (o *L1Block) GetL1Blocks(ctx context.Context, fields map[string]interface{}) ([]L1Block, error) {
db := o.db.WithContext(ctx)
db = db.Model(&L1Block{})
for key, value := range fields {
db = db.Where(key, value)
}
db = db.Order("number ASC")
var l1Blocks []L1Block
if err := db.Find(&l1Blocks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("L1Block.GetL1Blocks error: %w, fields: %v", err, fields)
}
return l1Blocks, nil
}
// InsertL1Blocks batch insert l1 blocks
func (l *L1Block) InsertL1Blocks(ctx context.Context, blocks []L1Block) error {
func (o *L1Block) InsertL1Blocks(ctx context.Context, blocks []L1Block) error {
if len(blocks) == 0 {
return nil
}
err := l.db.WithContext(ctx).Create(&blocks).Error
if err != nil {
log.Error("failed to insert L1 Blocks", "err", err)
db := o.db.WithContext(ctx)
db = db.Model(&L1Block{})
if err := db.Create(&blocks).Error; err != nil {
return fmt.Errorf("L1Block.InsertL1Blocks error: %w", err)
}
return err
return nil
}
// UpdateL1GasOracleStatusAndOracleTxHash update l1 gas oracle status and oracle tx hash
func (l *L1Block) UpdateL1GasOracleStatusAndOracleTxHash(ctx context.Context, blockHash string, status types.GasOracleStatus, txHash string) error {
func (o *L1Block) UpdateL1GasOracleStatusAndOracleTxHash(ctx context.Context, blockHash string, status types.GasOracleStatus, txHash string) error {
updateFields := map[string]interface{}{
"oracle_status": int(status),
"oracle_tx_hash": txHash,
}
if err := l.db.WithContext(ctx).Model(&L1Block{}).Where("hash", blockHash).Updates(updateFields).Error; err != nil {
return err
db := o.db.WithContext(ctx)
db = db.Model(&L1Block{})
db = db.Where("hash", blockHash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("L1Block.UpdateL1GasOracleStatusAndOracleTxHash error: %w, block hash: %v, status: %v, tx hash: %v", err, blockHash, status.String(), txHash)
}
return nil
}

View File

@@ -3,7 +3,6 @@ package orm
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/scroll-tech/go-ethereum/common"
@@ -42,27 +41,30 @@ func (*L2Block) TableName() string {
// GetL2BlocksLatestHeight retrieves the height of the latest L2 block.
// If the l2_block table is empty, it returns 0 to represent the genesis block height.
// In case of an error, it returns -1 along with the error.
func (o *L2Block) GetL2BlocksLatestHeight(ctx context.Context) (int64, error) {
var maxNumber int64
if err := o.db.WithContext(ctx).Model(&L2Block{}).Select("COALESCE(MAX(number), 0)").Row().Scan(&maxNumber); err != nil {
return -1, err
}
func (o *L2Block) GetL2BlocksLatestHeight(ctx context.Context) (uint64, error) {
db := o.db.WithContext(ctx)
db = db.Model(&L2Block{})
db = db.Select("COALESCE(MAX(number), 0)")
var maxNumber uint64
if err := db.Row().Scan(&maxNumber); err != nil {
return 0, fmt.Errorf("L2Block.GetL2BlocksLatestHeight error: %w", err)
}
return maxNumber, nil
}
// GetUnchunkedBlocks get the l2 blocks that have not been put into a chunk.
// The returned blocks are sorted in ascending order by their block number.
func (o *L2Block) GetUnchunkedBlocks(ctx context.Context) ([]*types.WrappedBlock, error) {
var l2Blocks []L2Block
db := o.db.WithContext(ctx)
db = db.Model(&L2Block{})
db = db.Select("header, transactions, withdraw_trie_root")
db = db.Where("chunk_hash IS NULL")
db = db.Order("number ASC")
var l2Blocks []L2Block
if err := db.Find(&l2Blocks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetUnchunkedBlocks error: %w", err)
}
var wrappedBlocks []*types.WrappedBlock
@@ -70,12 +72,12 @@ func (o *L2Block) GetUnchunkedBlocks(ctx context.Context) ([]*types.WrappedBlock
var wrappedBlock types.WrappedBlock
if err := json.Unmarshal([]byte(v.Transactions), &wrappedBlock.Transactions); err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetUnchunkedBlocks error: %w", err)
}
wrappedBlock.Header = &gethTypes.Header{}
if err := json.Unmarshal([]byte(v.Header), wrappedBlock.Header); err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetUnchunkedBlocks error: %w", err)
}
wrappedBlock.WithdrawTrieRoot = common.HexToHash(v.WithdrawTrieRoot)
@@ -89,6 +91,7 @@ func (o *L2Block) GetUnchunkedBlocks(ctx context.Context) ([]*types.WrappedBlock
// The returned L2Blocks are sorted in ascending order by their block number.
func (o *L2Block) GetL2Blocks(ctx context.Context, fields map[string]interface{}, orderByList []string, limit int) ([]*L2Block, error) {
db := o.db.WithContext(ctx)
db = db.Model(&L2Block{})
for key, value := range fields {
db = db.Where(key, value)
@@ -106,7 +109,7 @@ func (o *L2Block) GetL2Blocks(ctx context.Context, fields map[string]interface{}
var l2Blocks []*L2Block
if err := db.Find(&l2Blocks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetL2Blocks error: %w, fields: %v, orderByList: %v", err, fields, orderByList)
}
return l2Blocks, nil
}
@@ -116,22 +119,23 @@ func (o *L2Block) GetL2Blocks(ctx context.Context, fields map[string]interface{}
// The returned blocks are sorted in ascending order by their block number.
func (o *L2Block) GetL2BlocksInRange(ctx context.Context, startBlockNumber uint64, endBlockNumber uint64) ([]*types.WrappedBlock, error) {
if startBlockNumber > endBlockNumber {
return nil, errors.New("start block number should be less than or equal to end block number")
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange: start block number should be less than or equal to end block number, start block: %v, end block: %v", startBlockNumber, endBlockNumber)
}
var l2Blocks []L2Block
db := o.db.WithContext(ctx)
db = db.Model(&L2Block{})
db = db.Select("header, transactions, withdraw_trie_root")
db = db.Where("number >= ? AND number <= ?", startBlockNumber, endBlockNumber)
db = db.Order("number ASC")
var l2Blocks []L2Block
if err := db.Find(&l2Blocks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber)
}
// sanity check
if uint64(len(l2Blocks)) != endBlockNumber-startBlockNumber+1 {
return nil, errors.New("number of blocks not expected in the specified range")
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange: unexpected number of results, expected: %v, got: %v", endBlockNumber-startBlockNumber+1, len(l2Blocks))
}
var wrappedBlocks []*types.WrappedBlock
@@ -139,12 +143,12 @@ func (o *L2Block) GetL2BlocksInRange(ctx context.Context, startBlockNumber uint6
var wrappedBlock types.WrappedBlock
if err := json.Unmarshal([]byte(v.Transactions), &wrappedBlock.Transactions); err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber)
}
wrappedBlock.Header = &gethTypes.Header{}
if err := json.Unmarshal([]byte(v.Header), wrappedBlock.Header); err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetL2BlocksInRange error: %w, start block: %v, end block: %v", err, startBlockNumber, endBlockNumber)
}
wrappedBlock.WithdrawTrieRoot = common.HexToHash(v.WithdrawTrieRoot)
@@ -161,13 +165,13 @@ func (o *L2Block) InsertL2Blocks(ctx context.Context, blocks []*types.WrappedBlo
header, err := json.Marshal(block.Header)
if err != nil {
log.Error("failed to marshal block header", "hash", block.Header.Hash().String(), "err", err)
return err
return fmt.Errorf("L2Block.InsertL2Blocks error: %w", err)
}
txs, err := json.Marshal(block.Transactions)
if err != nil {
log.Error("failed to marshal transactions", "hash", block.Header.Hash().String(), "err", err)
return err
return fmt.Errorf("L2Block.InsertL2Blocks error: %w", err)
}
l2Block := L2Block{
@@ -184,9 +188,11 @@ func (o *L2Block) InsertL2Blocks(ctx context.Context, blocks []*types.WrappedBlo
l2Blocks = append(l2Blocks, l2Block)
}
if err := o.db.WithContext(ctx).Create(&l2Blocks).Error; err != nil {
log.Error("failed to insert l2Blocks", "err", err)
return err
db := o.db.WithContext(ctx)
db = db.Model(&L2Block{})
if err := db.Create(&l2Blocks).Error; err != nil {
return fmt.Errorf("L2Block.InsertL2Blocks error: %w", err)
}
return nil
}
@@ -200,13 +206,19 @@ func (o *L2Block) UpdateChunkHashInRange(ctx context.Context, startIndex uint64,
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.WithContext(ctx)
db = db.Model(&L2Block{})
db = db.Where("number >= ? AND number <= ?", startIndex, endIndex)
db = db.WithContext(ctx).Model(&L2Block{}).Where("number >= ? AND number <= ?", startIndex, endIndex)
tx := db.Update("chunk_hash", chunkHash)
if tx.RowsAffected != int64(endIndex-startIndex+1) {
return fmt.Errorf("expected %d rows to be updated, got %d", endIndex-startIndex+1, tx.RowsAffected)
if tx.Error != nil {
return fmt.Errorf("L2Block.UpdateChunkHashInRange error: %w, start index: %v, end index: %v, chunk hash: %v", tx.Error, startIndex, endIndex, chunkHash)
}
return tx.Error
// sanity check
if uint64(tx.RowsAffected) != endIndex-startIndex+1 {
return fmt.Errorf("L2Block.UpdateChunkHashInRange: incorrect number of rows affected, expected: %v, got: %v", endIndex-startIndex+1, tx.RowsAffected)
}
return nil
}

View File

@@ -99,7 +99,7 @@ func TestL2BlockOrm(t *testing.T) {
height, err := l2BlockOrm.GetL2BlocksLatestHeight(context.Background())
assert.NoError(t, err)
assert.Equal(t, int64(3), height)
assert.Equal(t, uint64(3), height)
blocks, err := l2BlockOrm.GetUnchunkedBlocks(context.Background())
assert.NoError(t, err)

View File

@@ -43,10 +43,10 @@ func testImportL1GasPrice(t *testing.T) {
l1BlockOrm := orm.NewL1Block(db)
// check db status
latestBlockHeight, err := l1BlockOrm.GetLatestL1BlockHeight()
latestBlockHeight, err := l1BlockOrm.GetLatestL1BlockHeight(context.Background())
assert.NoError(t, err)
assert.Equal(t, number, latestBlockHeight)
blocks, err := l1BlockOrm.GetL1Blocks(map[string]interface{}{"number": latestBlockHeight})
blocks, err := l1BlockOrm.GetL1Blocks(context.Background(), map[string]interface{}{"number": latestBlockHeight})
assert.NoError(t, err)
assert.Equal(t, len(blocks), 1)
assert.Empty(t, blocks[0].OracleTxHash)
@@ -54,7 +54,7 @@ func testImportL1GasPrice(t *testing.T) {
// relay gas price
l1Relayer.ProcessGasPriceOracle()
blocks, err = l1BlockOrm.GetL1Blocks(map[string]interface{}{"number": latestBlockHeight})
blocks, err = l1BlockOrm.GetL1Blocks(context.Background(), map[string]interface{}{"number": latestBlockHeight})
assert.NoError(t, err)
assert.Equal(t, len(blocks), 1)
assert.NotEmpty(t, blocks[0].OracleTxHash)

View File

@@ -9,9 +9,6 @@ import (
// L1BlockStatus represents current l1 block processing status
type L1BlockStatus int
// GasOracleStatus represents current gas oracle processing status
type GasOracleStatus int
const (
// L1BlockUndefined : undefined l1 block status
L1BlockUndefined L1BlockStatus = iota
@@ -29,6 +26,9 @@ const (
L1BlockFailed
)
// GasOracleStatus represents current gas oracle processing status
type GasOracleStatus int
const (
// GasOracleUndefined : undefined gas oracle status
GasOracleUndefined GasOracleStatus = iota
@@ -46,6 +46,23 @@ const (
GasOracleFailed
)
func (s GasOracleStatus) String() string {
switch s {
case GasOracleUndefined:
return "GasOracleUndefined"
case GasOraclePending:
return "GasOraclePending"
case GasOracleImporting:
return "GasOracleImporting"
case GasOracleImported:
return "GasOracleImported"
case GasOracleFailed:
return "GasOracleFailed"
default:
return fmt.Sprintf("Undefined (%d)", int32(s))
}
}
// L1BlockInfo is structure of stored l1 block
type L1BlockInfo struct {
Number uint64 `json:"number" db:"number"`
@@ -191,7 +208,7 @@ func (ps ProvingStatus) String() string {
case ProvingTaskFailed:
return "failed"
default:
return "undefined"
return fmt.Sprintf("Undefined (%d)", int32(ps))
}
}
@@ -209,6 +226,17 @@ const (
ChunkProofsStatusReady
)
func (s ChunkProofsStatus) String() string {
switch s {
case ChunkProofsStatusPending:
return "ChunkProofsStatusPending"
case ChunkProofsStatusReady:
return "ChunkProofsStatusReady"
default:
return fmt.Sprintf("Undefined (%d)", int32(s))
}
}
// RollupStatus block_batch rollup_status (pending, committing, committed, commit_failed, finalizing, finalized, finalize_skipped, finalize_failed)
type RollupStatus int
@@ -232,3 +260,26 @@ const (
// RollupFinalizeFailed : rollup finalize transaction is confirmed but failed
RollupFinalizeFailed
)
func (s RollupStatus) String() string {
switch s {
case RollupPending:
return "RollupPending"
case RollupCommitting:
return "RollupCommitting"
case RollupCommitted:
return "RollupCommitted"
case RollupFinalizing:
return "RollupFinalizing"
case RollupFinalized:
return "RollupFinalized"
case RollupFinalizationSkipped:
return "RollupFinalizationSkipped"
case RollupCommitFailed:
return "RollupCommitFailed"
case RollupFinalizeFailed:
return "RollupFinalizeFailed"
default:
return fmt.Sprintf("Undefined (%d)", int32(s))
}
}

View File

@@ -80,7 +80,7 @@ func TestProvingStatus(t *testing.T) {
{
"Undefined",
ProvingStatus(999), // Invalid value.
"undefined",
"Undefined (999)",
},
}

View File

@@ -5,7 +5,7 @@ import (
"runtime/debug"
)
var tag = "v4.0.11"
var tag = "v4.0.12"
var commit = func() string {
if info, ok := debug.ReadBuildInfo(); ok {

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"scroll-tech/common/types"
@@ -76,67 +77,67 @@ func (o *Batch) GetUnassignedBatches(ctx context.Context, limit int) ([]*Batch,
return nil, nil
}
var batches []*Batch
db := o.db.WithContext(ctx)
db = db.Where("proving_status = ? AND chunk_proofs_status = ?", types.ProvingTaskUnassigned, types.ChunkProofsStatusReady)
db = db.Order("index ASC")
db = db.Limit(limit)
var batches []*Batch
if err := db.Find(&batches).Error; err != nil {
return nil, err
return nil, fmt.Errorf("Batch.GetUnassignedBatches error: %w", err)
}
return batches, nil
}
// GetAssignedBatches retrieves all batches whose proving_status is either types.ProvingTaskAssigned or types.ProvingTaskProved.
func (o *Batch) GetAssignedBatches(ctx context.Context) ([]*Batch, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("proving_status IN (?)", []int{int(types.ProvingTaskAssigned), int(types.ProvingTaskProved)})
var assignedBatches []*Batch
err := o.db.WithContext(ctx).
Where("proving_status IN (?)", []int{int(types.ProvingTaskAssigned), int(types.ProvingTaskProved)}).
Find(&assignedBatches).Error
if err != nil {
return nil, err
if err := db.Find(&assignedBatches).Error; err != nil {
return nil, fmt.Errorf("Batch.GetAssignedBatches error: %w", err)
}
return assignedBatches, nil
}
// GetProvingStatusByHash retrieves the proving status of a batch given its hash.
func (o *Batch) GetProvingStatusByHash(ctx context.Context, hash string) (types.ProvingStatus, error) {
var batch Batch
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Select("proving_status")
db = db.Where("hash = ?", hash)
var batch Batch
if err := db.Find(&batch).Error; err != nil {
return types.ProvingStatusUndefined, err
return types.ProvingStatusUndefined, fmt.Errorf("Batch.GetProvingStatusByHash error: %w, batch hash: %v", err, hash)
}
return types.ProvingStatus(batch.ProvingStatus), nil
}
// GetLatestBatch retrieves the latest batch from the database.
func (o *Batch) GetLatestBatch(ctx context.Context) (*Batch, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Order("index desc")
var latestBatch Batch
err := o.db.WithContext(ctx).Order("index desc").First(&latestBatch).Error
if err != nil {
return nil, err
if err := db.First(&latestBatch).Error; err != nil {
return nil, fmt.Errorf("Batch.GetLatestBatch error: %w", err)
}
return &latestBatch, nil
}
// InsertBatch inserts a new batch into the database.
// for unit test
func (o *Batch) InsertBatch(ctx context.Context, startChunkIndex, endChunkIndex uint64, startChunkHash, endChunkHash string, chunks []*types.Chunk, dbTX ...*gorm.DB) (*Batch, error) {
func (o *Batch) InsertBatch(ctx context.Context, startChunkIndex, endChunkIndex uint64, startChunkHash, endChunkHash string, chunks []*types.Chunk) (*Batch, error) {
if len(chunks) == 0 {
return nil, errors.New("invalid args")
}
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
parentBatch, err := o.GetLatestBatch(ctx)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
if err != nil && !errors.Is(errors.Unwrap(err), gorm.ErrRecordNotFound) {
log.Error("failed to get the latest batch", "err", err)
return nil, err
}
@@ -190,11 +191,13 @@ func (o *Batch) InsertBatch(ctx context.Context, startChunkIndex, endChunkIndex
RollupStatus: int16(types.RollupPending),
}
if err := db.WithContext(ctx).Create(&newBatch).Error; err != nil {
log.Error("failed to insert batch", "batch", newBatch, "err", err)
return nil, err
}
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
if err := db.Create(&newBatch).Error; err != nil {
log.Error("failed to insert batch", "batch", newBatch, "err", err)
return nil, fmt.Errorf("Batch.InsertBatch error: %w", err)
}
return &newBatch, nil
}
@@ -204,16 +207,15 @@ func (o *Chunk) UpdateChunkProofsStatusByBatchHash(ctx context.Context, batchHas
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash = ?", batchHash)
return db.Update("chunk_proofs_status", int(status)).Error
if err := db.Update("chunk_proofs_status", status).Error; err != nil {
return fmt.Errorf("Batch.UpdateChunkProofsStatusByBatchHash error: %w, batch hash: %v, status: %v", err, batchHash, status.String())
}
return nil
}
// UpdateProvingStatus updates the proving status of a batch.
func (o *Batch) UpdateProvingStatus(ctx context.Context, hash string, status types.ProvingStatus, dbTX ...*gorm.DB) error {
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
func (o *Batch) UpdateProvingStatus(ctx context.Context, hash string, status types.ProvingStatus) error {
updateFields := make(map[string]interface{})
updateFields["proving_status"] = int(status)
@@ -226,8 +228,12 @@ func (o *Batch) UpdateProvingStatus(ctx context.Context, hash string, status typ
updateFields["proved_at"] = time.Now()
}
if err := db.WithContext(ctx).Model(&Batch{}).Where("hash", hash).Updates(updateFields).Error; err != nil {
return err
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash", hash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Batch.UpdateProvingStatus error: %w, batch hash: %v, status: %v", err, hash, status.String())
}
return nil
}
@@ -242,6 +248,13 @@ func (o *Batch) UpdateProofByHash(ctx context.Context, hash string, proof *messa
updateFields := make(map[string]interface{})
updateFields["proof"] = proofBytes
updateFields["proof_time_sec"] = proofTimeSec
err = o.db.WithContext(ctx).Model(&Batch{}).Where("hash", hash).Updates(updateFields).Error
return err
db := o.db.WithContext(ctx)
db = db.Model(&Batch{})
db = db.Where("hash", hash)
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Batch.UpdateProofByHash error: %w, batch hash: %v", err, hash)
}
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"scroll-tech/common/types"
@@ -68,14 +69,15 @@ func (o *Chunk) GetUnassignedChunks(ctx context.Context, limit int) ([]*Chunk, e
return nil, nil
}
var chunks []*Chunk
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("proving_status = ?", types.ProvingTaskUnassigned)
db = db.Order("index ASC")
db = db.Limit(limit)
var chunks []*Chunk
if err := db.Find(&chunks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("Chunk.GetUnassignedChunks error: %w", err)
}
return chunks, nil
}
@@ -84,22 +86,22 @@ func (o *Chunk) GetUnassignedChunks(ctx context.Context, limit int) ([]*Chunk, e
// It returns a slice of decoded proofs (message.AggProof) obtained from the database.
// The returned proofs are sorted in ascending order by their associated chunk index.
func (o *Chunk) GetProofsByBatchHash(ctx context.Context, batchHash string) ([]*message.AggProof, error) {
var chunks []*Chunk
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("batch_hash", batchHash)
db = db.Order("index ASC")
var chunks []*Chunk
if err := db.Find(&chunks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("Chunk.GetProofsByBatchHash error: %w, batch hash: %v", err, batchHash)
}
var proofs []*message.AggProof
for _, chunk := range chunks {
var proof message.AggProof
if err := json.Unmarshal(chunk.Proof, &proof); err != nil {
return nil, err
return nil, fmt.Errorf("Chunk.GetProofsByBatchHash error: %w, batch hash: %v, chunk hash: %v", err, batchHash, chunk.Hash)
}
proofs = append(proofs, &proof)
}
@@ -108,85 +110,82 @@ func (o *Chunk) GetProofsByBatchHash(ctx context.Context, batchHash string) ([]*
// GetLatestChunk retrieves the latest chunk from the database.
func (o *Chunk) GetLatestChunk(ctx context.Context) (*Chunk, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Order("index desc")
var latestChunk Chunk
err := o.db.WithContext(ctx).
Order("index desc").
First(&latestChunk).Error
if err != nil {
return nil, err
if err := db.First(&latestChunk).Error; err != nil {
return nil, fmt.Errorf("Chunk.GetLatestChunk error: %w", err)
}
return &latestChunk, nil
}
// GetProvingStatusByHash retrieves the proving status of a chunk given its hash.
func (o *Chunk) GetProvingStatusByHash(ctx context.Context, hash string) (types.ProvingStatus, error) {
var chunk Chunk
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Select("proving_status")
db = db.Where("hash = ?", hash)
var chunk Chunk
if err := db.Find(&chunk).Error; err != nil {
return types.ProvingStatusUndefined, err
return types.ProvingStatusUndefined, fmt.Errorf("Chunk.GetProvingStatusByHash error: %w, chunk hash: %v", err, hash)
}
return types.ProvingStatus(chunk.ProvingStatus), nil
}
// GetAssignedChunks retrieves all chunks whose proving_status is either types.ProvingTaskAssigned or types.ProvingTaskProved.
func (o *Chunk) GetAssignedChunks(ctx context.Context) ([]*Chunk, error) {
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("proving_status IN (?)", []int{int(types.ProvingTaskAssigned), int(types.ProvingTaskProved)})
var chunks []*Chunk
err := o.db.WithContext(ctx).Where("proving_status IN (?)", []int{int(types.ProvingTaskAssigned), int(types.ProvingTaskProved)}).
Find(&chunks).Error
if err != nil {
return nil, err
if err := db.Find(&chunks).Error; err != nil {
return nil, fmt.Errorf("Chunk.GetAssignedChunks error: %w", err)
}
return chunks, nil
}
// CheckIfBatchChunkProofsAreReady checks if all proofs for all chunks of a given batchHash are collected.
func (o *Chunk) CheckIfBatchChunkProofsAreReady(ctx context.Context, batchHash string) (bool, error) {
var count int64
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("batch_hash = ? AND proving_status != ?", batchHash, types.ProvingTaskVerified)
err := db.Count(&count).Error
if err != nil {
return false, err
}
var count int64
if err := db.Count(&count).Error; err != nil {
return false, fmt.Errorf("Chunk.CheckIfBatchChunkProofsAreReady error: %w, batch hash: %v", err, batchHash)
}
return count == 0, nil
}
// GetChunkBatchHash retrieves the batchHash of a given chunk.
func (o *Chunk) GetChunkBatchHash(ctx context.Context, chunkHash string) (string, error) {
var chunk Chunk
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("hash = ?", chunkHash)
db = db.Select("batch_hash")
if err := db.First(&chunk).Error; err != nil {
return "", err
}
var chunk Chunk
if err := db.First(&chunk).Error; err != nil {
return "", fmt.Errorf("Chunk.GetChunkBatchHash error: %w, chunk hash: %v", err, chunkHash)
}
return chunk.BatchHash, nil
}
// InsertChunk inserts a new chunk into the database.
// for unit test
func (o *Chunk) InsertChunk(ctx context.Context, chunk *types.Chunk, dbTX ...*gorm.DB) (*Chunk, error) {
func (o *Chunk) InsertChunk(ctx context.Context, chunk *types.Chunk) (*Chunk, error) {
if chunk == nil || len(chunk.Blocks) == 0 {
return nil, errors.New("invalid args")
}
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
var chunkIndex uint64
var totalL1MessagePoppedBefore uint64
parentChunk, err := o.GetLatestChunk(ctx)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
if err != nil && !errors.Is(errors.Unwrap(err), gorm.ErrRecordNotFound) {
log.Error("failed to get latest chunk", "err", err)
return nil, err
}
@@ -234,7 +233,10 @@ func (o *Chunk) InsertChunk(ctx context.Context, chunk *types.Chunk, dbTX ...*go
ProvingStatus: int16(types.ProvingTaskUnassigned),
}
if err := db.WithContext(ctx).Create(&newChunk).Error; err != nil {
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
if err := db.Create(&newChunk).Error; err != nil {
log.Error("failed to insert chunk", "hash", hash, "err", err)
return nil, err
}
@@ -243,12 +245,7 @@ func (o *Chunk) InsertChunk(ctx context.Context, chunk *types.Chunk, dbTX ...*go
}
// UpdateProvingStatus updates the proving status of a chunk.
func (o *Chunk) UpdateProvingStatus(ctx context.Context, hash string, status types.ProvingStatus, dbTX ...*gorm.DB) error {
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
func (o *Chunk) UpdateProvingStatus(ctx context.Context, hash string, status types.ProvingStatus) error {
updateFields := make(map[string]interface{})
updateFields["proving_status"] = int(status)
@@ -261,10 +258,14 @@ func (o *Chunk) UpdateProvingStatus(ctx context.Context, hash string, status typ
updateFields["proved_at"] = time.Now()
}
db = db.WithContext(ctx)
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("hash", hash)
return db.Updates(updateFields).Error
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Chunk.UpdateProvingStatus error: %w, chunk hash: %v, status: %v", err, hash, status.String())
}
return nil
}
// UpdateProofByHash updates the chunk proof by hash.
@@ -277,24 +278,27 @@ func (o *Chunk) UpdateProofByHash(ctx context.Context, hash string, proof *messa
updateFields := make(map[string]interface{})
updateFields["proof"] = proofBytes
updateFields["proof_time_sec"] = proofTimeSec
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("hash", hash)
return db.Updates(updateFields).Error
if err := db.Updates(updateFields).Error; err != nil {
return fmt.Errorf("Chunk.UpdateProofByHash error: %w, chunk hash: %v", err, hash)
}
return nil
}
// UpdateBatchHashInRange updates the batch_hash for chunks within the specified range (inclusive).
// The range is closed, i.e., it includes both start and end indices.
// for unit test
func (o *Chunk) UpdateBatchHashInRange(ctx context.Context, startIndex uint64, endIndex uint64, batchHash string, dbTX ...*gorm.DB) error {
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.Model(&Chunk{}).Where("index >= ? AND index <= ?", startIndex, endIndex)
func (o *Chunk) UpdateBatchHashInRange(ctx context.Context, startIndex uint64, endIndex uint64, batchHash string) error {
db := o.db.WithContext(ctx)
db = db.Model(&Chunk{})
db = db.Where("index >= ? AND index <= ?", startIndex, endIndex)
if err := db.Update("batch_hash", batchHash).Error; err != nil {
return err
return fmt.Errorf("Chunk.UpdateBatchHashInRange error: %w, start index: %v, end index: %v, batch hash: %v", err, startIndex, endIndex, batchHash)
}
return nil
}

View File

@@ -3,6 +3,7 @@ package orm
import (
"context"
"encoding/json"
"fmt"
"github.com/scroll-tech/go-ethereum/common"
gethTypes "github.com/scroll-tech/go-ethereum/core/types"
@@ -41,14 +42,15 @@ func (*L2Block) TableName() string {
// GetL2BlocksByChunkHash retrieves the L2 blocks associated with the specified chunk hash.
// The returned blocks are sorted in ascending order by their block number.
func (o *L2Block) GetL2BlocksByChunkHash(ctx context.Context, chunkHash string) ([]*types.WrappedBlock, error) {
var l2Blocks []L2Block
db := o.db.WithContext(ctx)
db = db.Model(&L2Block{})
db = db.Select("header, transactions, withdraw_trie_root")
db = db.Where("chunk_hash = ?", chunkHash)
db = db.Order("number ASC")
var l2Blocks []L2Block
if err := db.Find(&l2Blocks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetL2BlocksByChunkHash error: %w, chunk hash: %v", err, chunkHash)
}
var wrappedBlocks []*types.WrappedBlock
@@ -56,12 +58,12 @@ func (o *L2Block) GetL2BlocksByChunkHash(ctx context.Context, chunkHash string)
var wrappedBlock types.WrappedBlock
if err := json.Unmarshal([]byte(v.Transactions), &wrappedBlock.Transactions); err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetL2BlocksByChunkHash error: %w, chunk hash: %v", err, chunkHash)
}
wrappedBlock.Header = &gethTypes.Header{}
if err := json.Unmarshal([]byte(v.Header), wrappedBlock.Header); err != nil {
return nil, err
return nil, fmt.Errorf("L2Block.GetL2BlocksByChunkHash error: %w, chunk hash: %v", err, chunkHash)
}
wrappedBlock.WithdrawTrieRoot = common.HexToHash(v.WithdrawTrieRoot)
@@ -78,13 +80,13 @@ func (o *L2Block) InsertL2Blocks(ctx context.Context, blocks []*types.WrappedBlo
header, err := json.Marshal(block.Header)
if err != nil {
log.Error("failed to marshal block header", "hash", block.Header.Hash().String(), "err", err)
return err
return fmt.Errorf("L2Block.InsertL2Blocks error: %w, block hash: %v", err, block.Header.Hash().String())
}
txs, err := json.Marshal(block.Transactions)
if err != nil {
log.Error("failed to marshal transactions", "hash", block.Header.Hash().String(), "err", err)
return err
return fmt.Errorf("L2Block.InsertL2Blocks error: %w, block hash: %v", err, block.Header.Hash().String())
}
l2Block := L2Block{
@@ -101,9 +103,11 @@ func (o *L2Block) InsertL2Blocks(ctx context.Context, blocks []*types.WrappedBlo
l2Blocks = append(l2Blocks, l2Block)
}
if err := o.db.WithContext(ctx).Create(&l2Blocks).Error; err != nil {
log.Error("failed to insert l2Blocks", "err", err)
return err
db := o.db.WithContext(ctx)
db = db.Model(&L2Block{})
if err := db.Create(&l2Blocks).Error; err != nil {
return fmt.Errorf("L2Block.InsertL2Blocks error: %w", err)
}
return nil
}

View File

@@ -2,6 +2,7 @@ package orm
import (
"context"
"fmt"
"time"
"gorm.io/gorm"
@@ -45,25 +46,32 @@ func (o *ProverTask) GetProverTasksByHashes(ctx context.Context, hashes []string
if len(hashes) == 0 {
return nil, nil
}
var proverTasks []*ProverTask
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("task_id IN ?", hashes)
db = db.Order("id asc")
var proverTasks []*ProverTask
if err := db.Find(&proverTasks).Error; err != nil {
return nil, err
return nil, fmt.Errorf("ProverTask.GetProverTasksByHashes error: %w, hashes: %v", err, hashes)
}
return proverTasks, nil
}
// SetProverTask updates or inserts a ProverTask record.
func (o *ProverTask) SetProverTask(ctx context.Context, sessionInfo *ProverTask) error {
func (o *ProverTask) SetProverTask(ctx context.Context, proverTask *ProverTask) error {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "task_type"}, {Name: "task_id"}, {Name: "prover_public_key"}},
DoUpdates: clause.AssignmentColumns([]string{"proving_status"}),
})
return db.Create(&sessionInfo).Error
if err := db.Create(&proverTask).Error; err != nil {
return fmt.Errorf("ProverTask.SetProverTask error: %w, prover task: %v", err, proverTask)
}
return nil
}
// UpdateProverTaskProvingStatus updates the proving_status of a specific ProverTask record.
@@ -72,5 +80,8 @@ func (o *ProverTask) UpdateProverTaskProvingStatus(ctx context.Context, proofTyp
db = db.Model(&ProverTask{})
db = db.Where("task_type = ? AND task_id = ? AND prover_public_key = ?", proofType, taskID, pk)
return db.Update("proving_status", status).Error
if err := db.Update("proving_status", status).Error; err != nil {
return fmt.Errorf("ProverTask.UpdateProverTaskProvingStatus error: %w, proof type: %v, taskID: %v, prover public key: %v, status: %v", err, proofType.String(), taskID, pk, status.String())
}
return nil
}