Compare commits

...

2 Commits

Author SHA1 Message Date
Ho
0e6f9d14f5 fix unittest 2025-11-04 09:36:44 +09:00
Ho
96d1a247a1 refactor for early checking of assigned task 2025-11-03 22:13:32 +09:00
6 changed files with 140 additions and 98 deletions

View File

@@ -23,7 +23,8 @@ import (
// GetTaskController the get prover task api controller
type GetTaskController struct {
proverTasks map[message.ProofType]provertask.ProverTask
proverTasks map[message.ProofType]provertask.ProverTask
proverTaskManager *provertask.ProverTaskManager
getTaskAccessCounter *prometheus.CounterVec
@@ -32,12 +33,15 @@ type GetTaskController struct {
// NewGetTaskController create a get prover task controller
func NewGetTaskController(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, verifier *verifier.Verifier, reg prometheus.Registerer) *GetTaskController {
proverTaskManager := provertask.NewProverTaskManager(db)
chunkProverTask := provertask.NewChunkProverTask(cfg, chainCfg, db, verifier.ChunkVk, reg)
batchProverTask := provertask.NewBatchProverTask(cfg, chainCfg, db, verifier.BatchVk, reg)
bundleProverTask := provertask.NewBundleProverTask(cfg, chainCfg, db, verifier.BundleVk, reg)
ptc := &GetTaskController{
proverTasks: make(map[message.ProofType]provertask.ProverTask),
proverTasks: make(map[message.ProofType]provertask.ProverTask),
proverTaskManager: proverTaskManager,
getTaskAccessCounter: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Name: "coordinator_get_task_access_count",
Help: "Multi dimensions get task counter.",
@@ -99,7 +103,19 @@ func (ptc *GetTaskController) GetTasks(ctx *gin.Context) {
}
}
proofType := ptc.proofType(&getTaskParameter)
assigned, err := ptc.proverTaskManager.CheckParameter(ctx)
if err != nil {
nerr := fmt.Errorf("check prover task parameter failed, error:%w", err)
types.RenderFailure(ctx, types.ErrCoordinatorGetTaskFailure, nerr)
return
}
var proofType message.ProofType
if assigned != nil {
proofType = message.ProofType(assigned.TaskType)
} else {
proofType = ptc.proofType(&getTaskParameter)
}
proverTask, isExist := ptc.proverTasks[proofType]
if !isExist {
nerr := fmt.Errorf("parameter wrong proof type:%v", proofType)

View File

@@ -39,15 +39,14 @@ type BatchProverTask struct {
func NewBatchProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *BatchProverTask {
bp := &BatchProverTask{
BaseProverTask: BaseProverTask{
db: db,
cfg: cfg,
chainCfg: chainCfg,
expectedVk: expectedVk,
blockOrm: orm.NewL2Block(db),
chunkOrm: orm.NewChunk(db),
batchOrm: orm.NewBatch(db),
proverTaskOrm: orm.NewProverTask(db),
proverBlockListOrm: orm.NewProverBlockList(db),
db: db,
cfg: cfg,
chainCfg: chainCfg,
expectedVk: expectedVk,
blockOrm: orm.NewL2Block(db),
chunkOrm: orm.NewChunk(db),
batchOrm: orm.NewBatch(db),
proverTaskOrm: orm.NewProverTask(db),
},
batchTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Name: "coordinator_batch_get_task_total",
@@ -60,9 +59,9 @@ func NewBatchProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go
// Assign load and assign batch tasks
func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := bp.checkParameter(ctx)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
taskCtx := bp.checkParameter(ctx)
if taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter missed")
}
maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession

View File

@@ -36,16 +36,15 @@ type BundleProverTask struct {
func NewBundleProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *BundleProverTask {
bp := &BundleProverTask{
BaseProverTask: BaseProverTask{
db: db,
chainCfg: chainCfg,
cfg: cfg,
expectedVk: expectedVk,
blockOrm: orm.NewL2Block(db),
chunkOrm: orm.NewChunk(db),
batchOrm: orm.NewBatch(db),
bundleOrm: orm.NewBundle(db),
proverTaskOrm: orm.NewProverTask(db),
proverBlockListOrm: orm.NewProverBlockList(db),
db: db,
chainCfg: chainCfg,
cfg: cfg,
expectedVk: expectedVk,
blockOrm: orm.NewL2Block(db),
chunkOrm: orm.NewChunk(db),
batchOrm: orm.NewBatch(db),
bundleOrm: orm.NewBundle(db),
proverTaskOrm: orm.NewProverTask(db),
},
bundleTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Name: "coordinator_bundle_get_task_total",
@@ -58,9 +57,9 @@ func NewBundleProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *g
// Assign load and assign batch tasks
func (bp *BundleProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := bp.checkParameter(ctx)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
taskCtx := bp.checkParameter(ctx)
if taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter missed")
}
maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession

View File

@@ -36,14 +36,13 @@ type ChunkProverTask struct {
func NewChunkProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *ChunkProverTask {
cp := &ChunkProverTask{
BaseProverTask: BaseProverTask{
db: db,
cfg: cfg,
chainCfg: chainCfg,
expectedVk: expectedVk,
chunkOrm: orm.NewChunk(db),
blockOrm: orm.NewL2Block(db),
proverTaskOrm: orm.NewProverTask(db),
proverBlockListOrm: orm.NewProverBlockList(db),
db: db,
cfg: cfg,
chainCfg: chainCfg,
expectedVk: expectedVk,
chunkOrm: orm.NewChunk(db),
blockOrm: orm.NewL2Block(db),
proverTaskOrm: orm.NewProverTask(db),
},
chunkTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{
Name: "coordinator_chunk_get_task_total",
@@ -56,9 +55,9 @@ func NewChunkProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go
// Assign the chunk proof which need to prove
func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := cp.checkParameter(ctx)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
taskCtx := cp.checkParameter(ctx)
if taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter missed")
}
maxActiveAttempts := cp.cfg.ProverManager.ProversPerSession

View File

@@ -37,6 +37,81 @@ type ProverTask interface {
Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error)
}
// ProverTaskManager manage task which has been assigned
type ProverTaskManager struct {
proverTaskOrm *orm.ProverTask
proverBlockListOrm *orm.ProverBlockList
}
const proverTaskCtxKey = "prover_task_context_key"
// NewProverTaskManager new a prover task manager
func NewProverTaskManager(db *gorm.DB) *ProverTaskManager {
return &ProverTaskManager{
proverTaskOrm: orm.NewProverTask(db),
proverBlockListOrm: orm.NewProverBlockList(db),
}
}
// checkParameter check the prover task parameter illegal
func (b *ProverTaskManager) CheckParameter(ctx *gin.Context) (*orm.ProverTask, error) {
var ptc proverTaskContext
ptc.HardForkNames = make(map[string]struct{})
publicKey, publicKeyExist := ctx.Get(coordinatorType.PublicKey)
if !publicKeyExist {
return nil, errors.New("get public key from context failed")
}
ptc.PublicKey = publicKey.(string)
proverName, proverNameExist := ctx.Get(coordinatorType.ProverName)
if !proverNameExist {
return nil, errors.New("get prover name from context failed")
}
ptc.ProverName = proverName.(string)
proverVersion, proverVersionExist := ctx.Get(coordinatorType.ProverVersion)
if !proverVersionExist {
return nil, errors.New("get prover version from context failed")
}
ptc.ProverVersion = proverVersion.(string)
ProverProviderType, ProverProviderTypeExist := ctx.Get(coordinatorType.ProverProviderTypeKey)
if !ProverProviderTypeExist {
// for backward compatibility, set ProverProviderType as internal
ProverProviderType = float64(coordinatorType.ProverProviderTypeInternal)
}
ptc.ProverProviderType = uint8(ProverProviderType.(float64))
hardForkNamesStr, hardForkNameExist := ctx.Get(coordinatorType.HardForkName)
if !hardForkNameExist {
return nil, errors.New("get hard fork name from context failed")
}
hardForkNames := strings.Split(hardForkNamesStr.(string), ",")
for _, hardForkName := range hardForkNames {
ptc.HardForkNames[hardForkName] = struct{}{}
}
isBlocked, err := b.proverBlockListOrm.IsPublicKeyBlocked(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check whether the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s, proverVersion: %s", publicKey, err, proverName, proverVersion)
}
if isBlocked {
return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion)
}
assigned, err := b.proverTaskOrm.IsProverAssigned(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check if prover %s is assigned a task, err: %w", publicKey.(string), err)
}
ptc.hasAssignedTask = assigned
ctx.Set(proverTaskCtxKey, &ptc)
return assigned, nil
}
// BaseProverTask a base prover task which contain series functions
type BaseProverTask struct {
cfg *config.Config
@@ -44,12 +119,12 @@ type BaseProverTask struct {
db *gorm.DB
expectedVk map[string][]byte
batchOrm *orm.Batch
chunkOrm *orm.Chunk
bundleOrm *orm.Bundle
blockOrm *orm.L2Block
proverTaskOrm *orm.ProverTask
proverBlockListOrm *orm.ProverBlockList
batchOrm *orm.Batch
chunkOrm *orm.Chunk
bundleOrm *orm.Bundle
blockOrm *orm.L2Block
proverTaskOrm *orm.ProverTask
}
type proverTaskContext struct {
@@ -132,59 +207,13 @@ func (b *BaseProverTask) hardForkSanityCheck(ctx *gin.Context, taskCtx *proverTa
}
// checkParameter check the prover task parameter illegal
func (b *BaseProverTask) checkParameter(ctx *gin.Context) (*proverTaskContext, error) {
var ptc proverTaskContext
ptc.HardForkNames = make(map[string]struct{})
publicKey, publicKeyExist := ctx.Get(coordinatorType.PublicKey)
if !publicKeyExist {
return nil, errors.New("get public key from context failed")
}
ptc.PublicKey = publicKey.(string)
proverName, proverNameExist := ctx.Get(coordinatorType.ProverName)
if !proverNameExist {
return nil, errors.New("get prover name from context failed")
}
ptc.ProverName = proverName.(string)
proverVersion, proverVersionExist := ctx.Get(coordinatorType.ProverVersion)
if !proverVersionExist {
return nil, errors.New("get prover version from context failed")
}
ptc.ProverVersion = proverVersion.(string)
ProverProviderType, ProverProviderTypeExist := ctx.Get(coordinatorType.ProverProviderTypeKey)
if !ProverProviderTypeExist {
// for backward compatibility, set ProverProviderType as internal
ProverProviderType = float64(coordinatorType.ProverProviderTypeInternal)
}
ptc.ProverProviderType = uint8(ProverProviderType.(float64))
hardForkNamesStr, hardForkNameExist := ctx.Get(coordinatorType.HardForkName)
if !hardForkNameExist {
return nil, errors.New("get hard fork name from context failed")
}
hardForkNames := strings.Split(hardForkNamesStr.(string), ",")
for _, hardForkName := range hardForkNames {
ptc.HardForkNames[hardForkName] = struct{}{}
func (b *BaseProverTask) checkParameter(ctx *gin.Context) *proverTaskContext {
pctx, exist := ctx.Get(proverTaskCtxKey)
if !exist {
return nil
}
isBlocked, err := b.proverBlockListOrm.IsPublicKeyBlocked(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check whether the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s, proverVersion: %s", publicKey, err, proverName, proverVersion)
}
if isBlocked {
return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion)
}
assigned, err := b.proverTaskOrm.IsProverAssigned(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check if prover %s is assigned a task, err: %w", publicKey.(string), err)
}
ptc.hasAssignedTask = assigned
return &ptc, nil
return pctx.(*proverTaskContext)
}
func (b *BaseProverTask) applyUniversal(schema *coordinatorType.GetTaskSchema) (*coordinatorType.GetTaskSchema, []byte, error) {

View File

@@ -234,7 +234,7 @@ func testGetTaskBlocked(t *testing.T) {
err := proverBlockListOrm.InsertProverPublicKey(context.Background(), chunkProver.proverName, chunkProver.publicKey())
assert.NoError(t, err)
expectedErr := fmt.Errorf("return prover task err:check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", chunkProver.publicKey(), chunkProver.proverName, chunkProver.proverVersion)
expectedErr := fmt.Errorf("check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", chunkProver.publicKey(), chunkProver.proverName, chunkProver.proverVersion)
code, errMsg := chunkProver.tryGetProverTask(t, message.ProofTypeChunk)
assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code)
assert.Equal(t, expectedErr, errors.New(errMsg))
@@ -255,7 +255,7 @@ func testGetTaskBlocked(t *testing.T) {
assert.Equal(t, types.ErrCoordinatorEmptyProofData, code)
assert.Equal(t, expectedErr, errors.New(errMsg))
expectedErr = fmt.Errorf("return prover task err:check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", batchProver.publicKey(), batchProver.proverName, batchProver.proverVersion)
expectedErr = fmt.Errorf("check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", batchProver.publicKey(), batchProver.proverName, batchProver.proverVersion)
code, errMsg = batchProver.tryGetProverTask(t, message.ProofTypeBatch)
assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code)
assert.Equal(t, expectedErr, errors.New(errMsg))