From 340dc45108b5f0da02786aabd00a884c9a5feb0e Mon Sep 17 00:00:00 2001 From: maskpp Date: Tue, 4 Apr 2023 14:35:26 +0800 Subject: [PATCH] add resend unconfirm txs logic --- bridge/relayer/l1_relayer.go | 38 +++++++++++++++++ bridge/sender/sender.go | 57 +++++++++++++++++++++++++ database/orm/interface.go | 6 +-- database/orm/transaction.go | 82 +++++++++++++++++++++++------------- database/orm_test.go | 16 ++++--- 5 files changed, 161 insertions(+), 38 deletions(-) diff --git a/bridge/relayer/l1_relayer.go b/bridge/relayer/l1_relayer.go index 5b20f1601..36ec42328 100644 --- a/bridge/relayer/l1_relayer.go +++ b/bridge/relayer/l1_relayer.go @@ -3,7 +3,9 @@ package relayer import ( "context" "errors" + "fmt" "math/big" + "scroll-tech/common/utils" // not sure if this will make problems when relay with l1geth @@ -109,6 +111,42 @@ func NewLayer1Relayer(ctx context.Context, db database.OrmFactory, cfg *config.R return l1Relayer, nil } +func (r *Layer1Relayer) checkSubmittedMessages() error { + var ( + index uint64 + msgsSize = 100 + db = r.db + ) + for { + l1Index, msgs, err := db.GetL1TxMessages( + map[string]interface{}{"status": types.MsgSubmitted}, + fmt.Sprintf("AND queue_index > %d", index), + fmt.Sprintf("ORDER BY queue_index ASC LIMIT %d", msgsSize), + ) + if err != nil { + log.Error("failed to get l1 submitted messages", "queue_index", index, "err", err) + return err + } + if len(msgs) == 0 { + return nil + } + + index = l1Index + for _, msg := range msgs { + // TODO: restore incomplete transaction. + if !msg.TxHash.Valid { + continue + } + // If pending txs pool is full, wait until pending pool is available. + utils.TryTimes(-1, func() bool { + return !r.messageSender.IsFull() + }) + + } + } + return nil +} + // ProcessSavedEvents relays saved un-processed cross-domain transactions to desired blockchain func (r *Layer1Relayer) ProcessSavedEvents() { // msgs are sorted by nonce in increasing order diff --git a/bridge/sender/sender.go b/bridge/sender/sender.go index 23b42c6ad..2b7e9b10b 100644 --- a/bridge/sender/sender.go +++ b/bridge/sender/sender.go @@ -181,6 +181,63 @@ func (s *Sender) getFeeData(auth *bind.TransactOpts, target *common.Address, val return s.estimateLegacyGas(auth, target, value, data, minGasLimit) } +func (s *Sender) getTransaction(txHash common.Hash) (*types.Transaction, uint64, error) { + tx, isPending, err := s.client.TransactionByHash(s.ctx, txHash) + if err != nil { + return nil, 0, err + } + + if isPending { + return tx, atomic.LoadUint64(&s.blockNumber), nil + } + + receipt, err := s.client.TransactionReceipt(s.ctx, txHash) + if err != nil { + return nil, 0, err + } + return tx, receipt.BlockNumber.Uint64(), nil +} + +// LoadOrResendTx load +func (s *Sender) LoadOrResendTx(destTxHash common.Hash, sender common.Address, nonce uint64, ID string, target *common.Address, value *big.Int, data []byte, minGasLimit uint64) error { + tx, number, err := s.getTransaction(destTxHash) + // check error except `not found` tx. + if err != nil && err.Error() != "not found" { + return err + } + // We occupy the ID, in case some other threads call with the same ID in the same time + if ok := s.pendingTxs.SetIfAbsent(ID, nil); !ok { + return fmt.Errorf("pending pool has repeat ID, ID: %s", ID) + } + defer func() { + if err != nil { + s.pendingTxs.Remove(ID) + } + }() + + auth := s.auths.accounts[sender] + var feeData *FeeData + feeData, err = s.getFeeData(auth, target, value, data, minGasLimit) + if err != nil { + return err + } + // If tx is not in chain node, create and resend it. + if err != nil || tx == nil { + tx, err = s.createAndSendTx(auth, feeData, target, value, data, &nonce) + if err != nil { + return err + } + } + s.pendingTxs.Set(ID, &PendingTransaction{ + submitAt: number, + id: ID, + feeData: feeData, + signer: auth, + tx: tx, + }) + return nil +} + // SendTransaction send a signed L2tL1 transaction. func (s *Sender) SendTransaction(ID string, target *common.Address, value *big.Int, data []byte, minGasLimit uint64) (common.Address, *types.Transaction, error) { if s.IsFull() { diff --git a/database/orm/interface.go b/database/orm/interface.go index e95052a22..aba0731d1 100644 --- a/database/orm/interface.go +++ b/database/orm/interface.go @@ -117,7 +117,7 @@ type TxOrm interface { SaveTx(id, sender string, tx *etypes.Transaction) error UpdateTxMsgByID(hash string, txHash string) error GetTxByID(id string) (*types.TxMessage, error) - GetL1TxMessages(fields map[string]interface{}, args ...string) ([]*types.TxMessage, error) - GetL2TxMessages(fields map[string]interface{}, args ...string) ([]*types.TxMessage, error) - GetBlockBatchTxMessages(fields map[string]interface{}, args ...string) ([]*types.TxMessage, error) + GetL1TxMessages(fields map[string]interface{}, args ...string) (uint64, []*types.TxMessage, error) + GetL2TxMessages(fields map[string]interface{}, args ...string) (uint64, []*types.TxMessage, error) + GetBlockBatchTxMessages(fields map[string]interface{}, args ...string) (uint64, []*types.TxMessage, error) } diff --git a/database/orm/transaction.go b/database/orm/transaction.go index f1f811ac7..5172514e9 100644 --- a/database/orm/transaction.go +++ b/database/orm/transaction.go @@ -2,6 +2,7 @@ package orm import ( "fmt" + "modernc.org/mathutil" "strings" "github.com/jmoiron/sqlx" @@ -71,79 +72,100 @@ func (t *txOrm) GetTxByID(id string) (*stypes.TxMessage, error) { // where 1 = 1 AND status = :status AND queue_index > 0 // ORDER BY queue_index ASC // LIMIT 10) as l1 on tx.id = l1.msg_hash; -func (t *txOrm) GetL1TxMessages(fields map[string]interface{}, args ...string) ([]*stypes.TxMessage, error) { - query := "select msg_hash from l1_message where 1 = 1" +func (t *txOrm) GetL1TxMessages(fields map[string]interface{}, args ...string) (uint64, []*stypes.TxMessage, error) { + query := "select msg_hash, queue_index from l1_message where 1 = 1" for key := range fields { query = query + fmt.Sprintf(" AND %s = :%s", key, key) } query = strings.Join(append([]string{query}, args...), " ") - query = fmt.Sprintf("select l1.msg_hash as id, tx.tx_hash, tx.sender, tx.nonce, tx.target, tx.value, tx.data from transaction as tx right join (%s) as l1 on tx.id = l1.msg_hash;", query) + query = fmt.Sprintf("select l1.queue_index as index, l1.msg_hash as id, tx.tx_hash, tx.sender, tx.nonce, tx.target, tx.value, tx.data from transaction as tx right join (%s) as l1 on tx.id = l1.msg_hash;", query) db := t.db rows, err := db.NamedQuery(db.Rebind(query), fields) if err != nil { - return nil, err + return 0, nil, err } - var txMsgs []*stypes.TxMessage + var ( + index uint64 + txMsgs []*stypes.TxMessage + ) for rows.Next() { - msg := &stypes.TxMessage{} - if err = rows.StructScan(msg); err != nil { - return nil, err + warp := struct { + Index uint64 `db:"index"` + *stypes.TxMessage + }{} + if err = rows.StructScan(&warp); err != nil { + return 0, nil, err } - txMsgs = append(txMsgs, msg) + index = mathutil.MaxUint64(index, warp.Index) + txMsgs = append(txMsgs, warp.TxMessage) } - return txMsgs, nil + return index, txMsgs, nil } // GetL2TxMessages gets tx messages by transaction right join l2_message. -func (t *txOrm) GetL2TxMessages(fields map[string]interface{}, args ...string) ([]*stypes.TxMessage, error) { - query := "select msg_hash from l2_message where 1 = 1" +func (t *txOrm) GetL2TxMessages(fields map[string]interface{}, args ...string) (uint64, []*stypes.TxMessage, error) { + query := "select msg_hash, nonce from l2_message where 1 = 1" for key := range fields { query = query + fmt.Sprintf(" AND %s = :%s", key, key) } query = strings.Join(append([]string{query}, args...), " ") - query = fmt.Sprintf("select l2.msg_hash as id, tx.tx_hash, tx.sender, tx.nonce, tx.target, tx.value, tx.data from transaction as tx right join (%s) as l2 on tx.id = l2.msg_hash;", query) + query = fmt.Sprintf("select l2.nonce as l2_nonce, l2.msg_hash as id, tx.tx_hash, tx.sender, tx.nonce, tx.target, tx.value, tx.data from transaction as tx right join (%s) as l2 on tx.id = l2.msg_hash;", query) db := t.db rows, err := db.NamedQuery(db.Rebind(query), fields) if err != nil { - return nil, err + return 0, nil, err } - var txMsgs []*stypes.TxMessage + var ( + nonce uint64 + txMsgs []*stypes.TxMessage + ) for rows.Next() { - msg := &stypes.TxMessage{} - if err = rows.StructScan(msg); err != nil { - return nil, err + warp := struct { + Nonce uint64 `db:"l2_nonce"` + *stypes.TxMessage + }{} + if err = rows.StructScan(&warp); err != nil { + return 0, nil, err } - txMsgs = append(txMsgs, msg) + nonce = mathutil.MaxUint64(nonce, warp.Nonce) + txMsgs = append(txMsgs, warp.TxMessage) } - return txMsgs, nil + return nonce, txMsgs, nil } // GetBlockBatchTxMessages gets tx messages by transaction right join block_batch. -func (t *txOrm) GetBlockBatchTxMessages(fields map[string]interface{}, args ...string) ([]*stypes.TxMessage, error) { - query := "select hash from block_batch where 1 = 1" +func (t *txOrm) GetBlockBatchTxMessages(fields map[string]interface{}, args ...string) (uint64, []*stypes.TxMessage, error) { + query := "select hash, index from block_batch where 1 = 1" for key := range fields { query = query + fmt.Sprintf(" AND %s = :%s", key, key) } query = strings.Join(append([]string{query}, args...), " ") - query = fmt.Sprintf("select bt.hash as id, tx.tx_hash, tx.sender, tx.nonce, tx.target, tx.value, tx.data from transaction as tx right join (%s) as bt on tx.id = bt.hash;", query) + query = fmt.Sprintf("select bt.index as index, bt.hash as id, tx.tx_hash, tx.sender, tx.nonce, tx.target, tx.value, tx.data from transaction as tx right join (%s) as bt on tx.id = bt.hash;", query) db := t.db rows, err := db.NamedQuery(db.Rebind(query), fields) if err != nil { - return nil, err + return 0, nil, err } - var txMsgs []*stypes.TxMessage + var ( + index uint64 + txMsgs []*stypes.TxMessage + ) for rows.Next() { - msg := &stypes.TxMessage{} - if err = rows.StructScan(msg); err != nil { - return nil, err + warp := struct { + Index uint64 `db:"index"` + *stypes.TxMessage + }{} + if err = rows.StructScan(&warp); err != nil { + return 0, nil, err } - txMsgs = append(txMsgs, msg) + index = mathutil.MaxUint64(index, warp.Index) + txMsgs = append(txMsgs, warp.TxMessage) } - return txMsgs, nil + return index, txMsgs, nil } diff --git a/database/orm_test.go b/database/orm_test.go index 17d135a2b..a3dd72e9d 100644 --- a/database/orm_test.go +++ b/database/orm_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math/big" + "modernc.org/mathutil" "os" "testing" "time" @@ -93,8 +94,8 @@ var ( func setupEnv(t *testing.T) error { // Init db config and start db container. dbConfig = &database.DBConfig{DriverName: "postgres"} - base.RunImages(t) - dbConfig.DSN = base.DBEndpoint() + //base.RunImages(t) + dbConfig.DSN = "postgres://maskpp:123456@localhost:5432/postgres?sslmode=disable" //base.DBEndpoint() // Create db handler and reset db. factory, err := database.NewOrmFactory(dbConfig) @@ -518,12 +519,14 @@ func testTxOrmGetL1TxMessages(t *testing.T) { assert.NoError(t, err) } - txMsgs, err := ormTx.GetL1TxMessages( + index, txMsgs, err := ormTx.GetL1TxMessages( map[string]interface{}{"status": types.MsgSubmitted}, fmt.Sprintf("AND queue_index > %d", 0), fmt.Sprintf("ORDER BY queue_index ASC LIMIT %d", 10), ) assert.NoError(t, err) + // check index is the biggest one. + assert.Equal(t, templateL1Message[1].QueueIndex, index) assert.Equal(t, len(templateL1Message), len(txMsgs)) // The first field is full. assert.Equal(t, templateL1Message[0].MsgHash, txMsgs[0].ID) @@ -552,12 +555,13 @@ func testTxOrmGetL2TxMessages(t *testing.T) { assert.NoError(t, err) } - txMsgs, err := ormTx.GetL2TxMessages( + nonce, txMsgs, err := ormTx.GetL2TxMessages( map[string]interface{}{"status": types.MsgSubmitted}, fmt.Sprintf("AND nonce > %d", 0), fmt.Sprintf("ORDER BY nonce ASC LIMIT %d", 10), ) assert.NoError(t, err) + assert.Equal(t, templateL2Message[1].Nonce, nonce) assert.Equal(t, len(templateL2Message), len(txMsgs)) assert.Equal(t, templateL2Message[0].MsgHash, txMsgs[0].ID) assert.Equal(t, false, txMsgs[1].TxHash.Valid) @@ -583,12 +587,14 @@ func testTxOrmGetBlockBatchTxMessages(t *testing.T) { err = ormTx.SaveTx(batchData1.Hash().String(), auth.From.String(), signedTx) assert.Nil(t, err) - txMsgs, err := ormTx.GetBlockBatchTxMessages( + batchIndex, txMsgs, err := ormTx.GetBlockBatchTxMessages( map[string]interface{}{"rollup_status": types.RollupPending}, fmt.Sprintf("AND index > %d", 0), fmt.Sprintf("ORDER BY index ASC LIMIT %d", 10), ) assert.NoError(t, err) + // Check bath index is the biggest one. + assert.Equal(t, mathutil.MaxUint64(batchData1.Batch.BatchIndex, batchData2.Batch.BatchIndex), batchIndex) assert.Equal(t, 2, len(txMsgs)) assert.Equal(t, batchData1.Hash().String(), txMsgs[0].ID) assert.Equal(t, false, txMsgs[1].TxHash.Valid)