Compare commits

...

5 Commits

Author SHA1 Message Date
Mengran Lan
4270d34aa9 tmp changes for local test 2024-06-03 19:58:03 +08:00
Mengran Lan
89614c8dcb upgrade legacy_vk vkeys to v0.10.3 2024-06-03 16:32:26 +08:00
Mengran Lan
f9842621bc fix test issue 2024-06-03 16:32:26 +08:00
Mengran Lan
e8643b5206 no longer send acceptversions when login 2024-06-03 16:32:26 +08:00
Mengran Lan
2407db9f71 coordinator support multi-circuits prover 2024-06-03 16:32:26 +08:00
15 changed files with 584 additions and 116 deletions

View File

@@ -154,10 +154,11 @@ func (t *TestcontainerApps) GetL2GethEndPoint() (string, error) {
// GetGormDBClient returns a gorm.DB by connecting to the running postgres container
func (t *TestcontainerApps) GetGormDBClient() (*gorm.DB, error) {
endpoint, err := t.GetDBEndPoint()
if err != nil {
return nil, err
}
// endpoint, err := t.GetDBEndPoint()
// if err != nil {
// return nil, err
// }
endpoint := "postgres://lmr:@localhost:5432/unittest?sslmode=disable"
dbCfg := &database.Config{
DSN: endpoint,
DriverName: "postgres",

View File

@@ -18,6 +18,7 @@ import (
"scroll-tech/coordinator/internal/logic/provertask"
"scroll-tech/coordinator/internal/logic/verifier"
coordinatorType "scroll-tech/coordinator/internal/types"
itypes "scroll-tech/coordinator/internal/types"
)
// GetTaskController the get prover task api controller
@@ -70,6 +71,10 @@ func (ptc *GetTaskController) incGetTaskAccessCounter(ctx *gin.Context) error {
// GetTasks get assigned chunk/batch task
func (ptc *GetTaskController) GetTasks(ctx *gin.Context) {
ctx.Set(itypes.PublicKey, "fake_public_key2")
ctx.Set(itypes.ProverName, "test")
ctx.Set(itypes.ProverVersion, "v4.4.9-000000-000000-000000")
var getTaskParameter coordinatorType.GetTaskParameter
if err := ctx.ShouldBind(&getTaskParameter); err != nil {
nerr := fmt.Errorf("prover task parameter invalid, err:%w", err)

View File

@@ -42,6 +42,7 @@ func NewBatchProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go
bp := &BatchProverTask{
BaseProverTask: BaseProverTask{
vkMap: vkMap,
reverseVkMap: reverseMap(vkMap),
db: db,
cfg: cfg,
nameForkMap: nameForkMap,
@@ -64,48 +65,31 @@ func NewBatchProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go
return bp
}
// Assign load and assign batch tasks
func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := bp.checkParameter(ctx, getTaskParameter)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
}
type chunkIndexRange struct {
start uint64
end uint64
}
hardForkNumber, err := bp.getHardForkNumberByName(taskCtx.HardForkName)
if err != nil {
log.Error("batch assign failure because of the hard fork name don't exist", "fork name", taskCtx.HardForkName)
return nil, err
func (r *chunkIndexRange) merge(o chunkIndexRange) *chunkIndexRange {
var start, end = r.start, r.end
if o.start < r.start {
start = o.start
}
if o.end > r.end {
end = o.end
}
return &chunkIndexRange{start, end}
}
// if the hard fork number set, rollup relayer must generate the chunk from hard fork number,
// so the hard fork chunk's start_block_number must be ForkBlockNumber
var startChunkIndex uint64 = 0
var endChunkIndex uint64 = math.MaxInt64
fromBlockNum, toBlockNum := forks.BlockRange(hardForkNumber, bp.forkHeights)
if fromBlockNum != 0 {
startChunk, chunkErr := bp.chunkOrm.GetChunkByStartBlockNumber(ctx.Copy(), fromBlockNum)
if chunkErr != nil {
log.Error("failed to get fork start chunk index", "forkName", taskCtx.HardForkName, "fromBlockNumber", fromBlockNum, "err", chunkErr)
return nil, ErrCoordinatorInternalFailure
}
if startChunk == nil {
return nil, nil
}
startChunkIndex = startChunk.Index
}
if toBlockNum != math.MaxInt64 {
toChunk, chunkErr := bp.chunkOrm.GetChunkByStartBlockNumber(ctx.Copy(), toBlockNum)
if chunkErr != nil {
log.Error("failed to get fork end chunk index", "forkName", taskCtx.HardForkName, "toBlockNumber", toBlockNum, "err", chunkErr)
return nil, ErrCoordinatorInternalFailure
}
if toChunk != nil {
// toChunk being nil only indicates that we haven't yet reached the fork boundary
// don't need change the endChunkIndex of math.MaxInt64
endChunkIndex = toChunk.Index
}
}
func (r *chunkIndexRange) contains(start, end uint64) bool {
return r.start <= start && r.end >= end+1
}
type getHardForkNameByBatchFunc func(*orm.Batch) (string, error)
func (bp *BatchProverTask) doAssignTaskWithinChunkRange(ctx *gin.Context, taskCtx *proverTaskContext,
chunkRange *chunkIndexRange, getTaskParameter *coordinatorType.GetTaskParameter, getHardForkName getHardForkNameByBatchFunc) (*coordinatorType.GetTaskSchema, error) {
startChunkIndex, endChunkIndex := chunkRange.start, chunkRange.end
maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession
maxTotalAttempts := bp.cfg.ProverManager.SessionAttempts
var batchTask *orm.Batch
@@ -154,13 +138,25 @@ func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato
}
log.Info("start batch proof generation session", "id", batchTask.Hash, "public key", taskCtx.PublicKey, "prover name", taskCtx.ProverName)
var (
proverVersion = taskCtx.ProverVersion
hardForkName = taskCtx.HardForkName
)
var err error
if getHardForkName != nil {
hardForkName, err = getHardForkName(batchTask)
if err != nil {
log.Error("failed to get version by chunk", "error", err.Error())
return nil, ErrCoordinatorInternalFailure
}
}
proverTask := orm.ProverTask{
TaskID: batchTask.Hash,
ProverPublicKey: taskCtx.PublicKey,
TaskType: int16(message.ProofTypeBatch),
ProverName: taskCtx.ProverName,
ProverVersion: taskCtx.ProverVersion,
ProverVersion: proverVersion,
ProvingStatus: int16(types.ProverAssigned),
FailureType: int16(types.ProverTaskFailureTypeUndefined),
// here why need use UTC time. see scroll/common/databased/db.go
@@ -181,7 +177,7 @@ func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato
return nil, ErrCoordinatorInternalFailure
}
bp.batchTaskGetTaskTotal.WithLabelValues(taskCtx.HardForkName).Inc()
bp.batchTaskGetTaskTotal.WithLabelValues(hardForkName).Inc()
bp.batchTaskGetTaskProver.With(prometheus.Labels{
coordinatorType.LabelProverName: proverTask.ProverName,
coordinatorType.LabelProverPublicKey: proverTask.ProverPublicKey,
@@ -191,6 +187,108 @@ func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato
return taskMsg, nil
}
func (bp *BatchProverTask) getChunkRangeByName(ctx *gin.Context, hardForkName string) (*chunkIndexRange, error) {
hardForkNumber, err := bp.getHardForkNumberByName(hardForkName)
if err != nil {
log.Error("batch assign failure because of the hard fork name don't exist", "fork name", hardForkName)
return nil, err
}
// if the hard fork number set, rollup relayer must generate the chunk from hard fork number,
// so the hard fork chunk's start_block_number must be ForkBlockNumber
var startChunkIndex uint64 = 0
var endChunkIndex uint64 = math.MaxInt64
fromBlockNum, toBlockNum := forks.BlockRange(hardForkNumber, bp.forkHeights)
if fromBlockNum != 0 {
startChunk, chunkErr := bp.chunkOrm.GetChunkByStartBlockNumber(ctx.Copy(), fromBlockNum)
if chunkErr != nil {
log.Error("failed to get fork start chunk index", "forkName", hardForkName, "fromBlockNumber", fromBlockNum, "err", chunkErr)
return nil, ErrCoordinatorInternalFailure
}
if startChunk == nil {
return nil, nil
}
startChunkIndex = startChunk.Index
}
if toBlockNum != math.MaxInt64 {
toChunk, chunkErr := bp.chunkOrm.GetChunkByStartBlockNumber(ctx.Copy(), toBlockNum)
if chunkErr != nil {
log.Error("failed to get fork end chunk index", "forkName", hardForkName, "toBlockNumber", toBlockNum, "err", chunkErr)
return nil, ErrCoordinatorInternalFailure
}
if toChunk != nil {
// toChunk being nil only indicates that we haven't yet reached the fork boundary
// don't need change the endChunkIndex of math.MaxInt64
endChunkIndex = toChunk.Index
}
}
fmt.Printf("%s index range %+v\n", hardForkName, &chunkIndexRange{startChunkIndex, endChunkIndex})
return &chunkIndexRange{startChunkIndex, endChunkIndex}, nil
}
func (bp *BatchProverTask) assignWithSingleCircuit(ctx *gin.Context, taskCtx *proverTaskContext, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
chunkRange, err := bp.getChunkRangeByName(ctx, taskCtx.HardForkName)
if err != nil {
return nil, err
}
if chunkRange == nil {
return nil, nil
}
return bp.doAssignTaskWithinChunkRange(ctx, taskCtx, chunkRange, getTaskParameter, nil)
}
func (bp *BatchProverTask) assignWithTwoCircuits(ctx *gin.Context, taskCtx *proverTaskContext, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
var (
hardForkNames [2]string
chunkRanges [2]*chunkIndexRange
err error
)
for i := 0; i < 2; i++ {
hardForkNames[i] = bp.reverseVkMap[getTaskParameter.VKs[i]]
chunkRanges[i], err = bp.getChunkRangeByName(ctx, hardForkNames[i])
if err != nil {
return nil, err
}
if chunkRanges[i] == nil {
return nil, nil
}
}
chunkRange := chunkRanges[0].merge(*chunkRanges[1])
var hardForkName string
getHardForkName := func(batch *orm.Batch) (string, error) {
for i := 0; i < 2; i++ {
if chunkRanges[i].contains(batch.StartChunkIndex, batch.EndChunkIndex) {
hardForkName = hardForkNames[i]
break
}
}
if hardForkName == "" {
log.Warn("get batch not belongs to any hard fork name", "batch id", batch.Index)
return "", fmt.Errorf("get batch not belongs to any hard fork name, batch id: %d", batch.Index)
}
return hardForkName, nil
}
schema, err := bp.doAssignTaskWithinChunkRange(ctx, taskCtx, chunkRange, getTaskParameter, getHardForkName)
if schema != nil && err == nil {
schema.HardForkName = hardForkName
return schema, nil
}
return schema, err
}
// Assign load and assign batch tasks
func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := bp.checkParameter(ctx, getTaskParameter)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
}
if len(getTaskParameter.VKs) > 0 {
return bp.assignWithTwoCircuits(ctx, taskCtx, getTaskParameter)
}
return bp.assignWithSingleCircuit(ctx, taskCtx, getTaskParameter)
}
func (bp *BatchProverTask) formatProverTask(ctx context.Context, task *orm.ProverTask) (*coordinatorType.GetTaskSchema, error) {
// get chunk from db
chunks, err := bp.chunkOrm.GetChunksByBatchHash(ctx, task.TaskID)

View File

@@ -39,6 +39,7 @@ func NewChunkProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go
cp := &ChunkProverTask{
BaseProverTask: BaseProverTask{
vkMap: vkMap,
reverseVkMap: reverseMap(vkMap),
db: db,
cfg: cfg,
nameForkMap: nameForkMap,
@@ -61,20 +62,11 @@ func NewChunkProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go
return cp
}
// Assign the chunk proof which need to prove
func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := cp.checkParameter(ctx, getTaskParameter)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
}
type getHardForkNameByChunkFunc func(*orm.Chunk) (string, error)
hardForkNumber, err := cp.getHardForkNumberByName(taskCtx.HardForkName)
if err != nil {
log.Error("chunk assign failure because of the hard fork name don't exist", "fork name", taskCtx.HardForkName)
return nil, err
}
fromBlockNum, toBlockNum := forks.BlockRange(hardForkNumber, cp.forkHeights)
func (cp *ChunkProverTask) doAssignTaskWithinBlockRange(ctx *gin.Context, taskCtx *proverTaskContext,
blockRange *blockRange, getTaskParameter *coordinatorType.GetTaskParameter, getHardForkName getHardForkNameByChunkFunc) (*coordinatorType.GetTaskSchema, error) {
fromBlockNum, toBlockNum := blockRange.from, blockRange.to
if toBlockNum > getTaskParameter.ProverHeight {
toBlockNum = getTaskParameter.ProverHeight + 1
}
@@ -127,13 +119,25 @@ func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato
}
log.Info("start chunk generation session", "id", chunkTask.Hash, "public key", taskCtx.PublicKey, "prover name", taskCtx.ProverName)
var (
proverVersion = taskCtx.ProverVersion
hardForkName = taskCtx.HardForkName
)
var err error
if getHardForkName != nil {
hardForkName, err = getHardForkName(chunkTask)
if err != nil {
log.Error("failed to get version by chunk", "error", err.Error())
return nil, ErrCoordinatorInternalFailure
}
}
proverTask := orm.ProverTask{
TaskID: chunkTask.Hash,
ProverPublicKey: taskCtx.PublicKey,
TaskType: int16(message.ProofTypeChunk),
ProverName: taskCtx.ProverName,
ProverVersion: taskCtx.ProverVersion,
ProverVersion: proverVersion,
ProvingStatus: int16(types.ProverAssigned),
FailureType: int16(types.ProverTaskFailureTypeUndefined),
// here why need use UTC time. see scroll/common/databased/db.go
@@ -153,7 +157,7 @@ func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato
return nil, ErrCoordinatorInternalFailure
}
cp.chunkTaskGetTaskTotal.WithLabelValues(taskCtx.HardForkName).Inc()
cp.chunkTaskGetTaskTotal.WithLabelValues(hardForkName).Inc()
cp.chunkTaskGetTaskProver.With(prometheus.Labels{
coordinatorType.LabelProverName: proverTask.ProverName,
coordinatorType.LabelProverPublicKey: proverTask.ProverPublicKey,
@@ -163,6 +167,96 @@ func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinato
return taskMsg, nil
}
func (cp *ChunkProverTask) assignWithSingleCircuit(ctx *gin.Context, taskCtx *proverTaskContext, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
blockRange, err := cp.getBlockRangeByName(taskCtx.HardForkName)
if err != nil {
return nil, err
}
return cp.doAssignTaskWithinBlockRange(ctx, taskCtx, blockRange, getTaskParameter, nil)
}
func (cp *ChunkProverTask) assignWithTwoCircuits(ctx *gin.Context, taskCtx *proverTaskContext, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
var (
hardForkNames [2]string
blockRanges [2]*blockRange
err error
)
for i := 0; i < 2; i++ {
hardForkNames[i] = cp.reverseVkMap[getTaskParameter.VKs[i]]
blockRanges[i], err = cp.getBlockRangeByName(hardForkNames[i])
if err != nil {
return nil, err
}
}
blockRange, err := blockRanges[0].merge(*blockRanges[1])
if err != nil {
return nil, err
}
var hardForkName string
getHardForkName := func(chunk *orm.Chunk) (string, error) {
for i := 0; i < 2; i++ {
if blockRanges[i].contains(chunk.StartBlockNumber, chunk.EndBlockNumber) {
hardForkName = hardForkNames[i]
break
}
}
if hardForkName == "" {
log.Warn("get chunk not belongs to any hard fork name", "chunk id", chunk.Index)
return "", fmt.Errorf("get chunk not belongs to any hard fork name, chunk id: %d", chunk.Index)
}
return hardForkName, nil
}
schema, err := cp.doAssignTaskWithinBlockRange(ctx, taskCtx, blockRange, getTaskParameter, getHardForkName)
if schema != nil && err == nil {
schema.HardForkName = hardForkName
return schema, nil
}
return schema, err
}
type blockRange struct {
from uint64
to uint64
}
func (r *blockRange) merge(o blockRange) (*blockRange, error) {
if r.from == o.to {
return &blockRange{o.from, r.to}, nil
} else if r.to == o.from {
return &blockRange{r.from, o.to}, nil
}
return nil, fmt.Errorf("two ranges are not adjacent")
}
func (r *blockRange) contains(start, end uint64) bool {
return r.from <= start && r.to >= end+1
}
func (cp *ChunkProverTask) getBlockRangeByName(hardForkName string) (*blockRange, error) {
hardForkNumber, err := cp.getHardForkNumberByName(hardForkName)
if err != nil {
log.Error("chunk assign failure because of the hard fork name don't exist", "fork name", hardForkName)
return nil, err
}
fromBlockNum, toBlockNum := forks.BlockRange(hardForkNumber, cp.forkHeights)
return &blockRange{fromBlockNum, toBlockNum}, nil
}
// Assign the chunk proof which need to prove
func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) {
taskCtx, err := cp.checkParameter(ctx, getTaskParameter)
if err != nil || taskCtx == nil {
return nil, fmt.Errorf("check prover task parameter failed, error:%w", err)
}
if len(getTaskParameter.VKs) > 0 {
return cp.assignWithTwoCircuits(ctx, taskCtx, getTaskParameter)
}
return cp.assignWithSingleCircuit(ctx, taskCtx, getTaskParameter)
}
func (cp *ChunkProverTask) formatProverTask(ctx context.Context, task *orm.ProverTask) (*coordinatorType.GetTaskSchema, error) {
// Get block hashes.
blockHashes, dbErr := cp.blockOrm.GetL2BlockHashesByChunkHash(ctx, task.TaskID)

View File

@@ -29,14 +29,25 @@ type ProverTask interface {
Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error)
}
func reverseMap(input map[string]string) map[string]string {
output := make(map[string]string, len(input))
for k, v := range input {
output[v] = k
}
return output
}
// BaseProverTask a base prover task which contain series functions
type BaseProverTask struct {
cfg *config.Config
db *gorm.DB
vkMap map[string]string
nameForkMap map[string]uint64
forkHeights []uint64
// key is hardForkName, value is vk
vkMap map[string]string
// key is vk, value is hardForkName
reverseVkMap map[string]string
nameForkMap map[string]uint64
forkHeights []uint64
batchOrm *orm.Batch
chunkOrm *orm.Chunk
@@ -74,38 +85,56 @@ func (b *BaseProverTask) checkParameter(ctx *gin.Context, getTaskParameter *coor
}
ptc.ProverVersion = proverVersion.(string)
hardForkName, hardForkNameExist := ctx.Get(coordinatorType.HardForkName)
if !hardForkNameExist {
return nil, fmt.Errorf("get hard fork name from context failed")
}
ptc.HardForkName = hardForkName.(string)
if !version.CheckScrollRepoVersion(proverVersion.(string), b.cfg.ProverManager.MinProverVersion) {
return nil, fmt.Errorf("incompatible prover version. please upgrade your prover, minimum allowed version: %s, actual version: %s", b.cfg.ProverManager.MinProverVersion, proverVersion.(string))
}
vk, vkExist := b.vkMap[ptc.HardForkName]
if !vkExist {
return nil, fmt.Errorf("can't get vk for hard fork:%s, vkMap:%v", ptc.HardForkName, b.vkMap)
}
// if the prover has a different vk
if getTaskParameter.VK != vk {
log.Error("vk inconsistency", "prover vk", getTaskParameter.VK, "vk", vk, "hardForkName", ptc.HardForkName)
// if the prover reports a different prover version
if !version.CheckScrollProverVersion(proverVersion.(string)) {
return nil, fmt.Errorf("incompatible prover version. please upgrade your prover, expect version: %s, actual version: %s", version.Version, proverVersion.(string))
// signals that the prover is multi-circuits version
if len(getTaskParameter.VKs) > 0 {
if len(getTaskParameter.VKs) != 2 {
return nil, fmt.Errorf("parameter vks length must be 2")
}
// min prover version supporting multi circuits, maybe put it to config file?
var minMultiCircuitsProverVersion = "v4.4.7"
if !version.CheckScrollRepoVersion(ptc.ProverVersion, minMultiCircuitsProverVersion) {
return nil, fmt.Errorf("incompatible prover version. please upgrade your prover, minimum allowed version: %s, actual version: %s", minMultiCircuitsProverVersion, ptc.ProverVersion)
}
for _, vk := range getTaskParameter.VKs {
fmt.Printf("%+v\n", b.reverseVkMap)
if _, exists := b.reverseVkMap[vk]; !exists {
return nil, fmt.Errorf("incompatible vk. vk %s is invalid", vk)
}
}
} else {
hardForkName, hardForkNameExist := ctx.Get(coordinatorType.HardForkName)
if !hardForkNameExist {
return nil, fmt.Errorf("get hard fork name from context failed")
}
ptc.HardForkName = hardForkName.(string)
vk, vkExist := b.vkMap[ptc.HardForkName]
if !vkExist {
return nil, fmt.Errorf("can't get vk for hard fork:%s, vkMap:%v", ptc.HardForkName, b.vkMap)
}
// if the prover has a different vk
if getTaskParameter.VK != vk {
log.Error("vk inconsistency", "prover vk", getTaskParameter.VK, "vk", vk, "hardForkName", ptc.HardForkName)
// if the prover reports a different prover version
if !version.CheckScrollProverVersion(proverVersion.(string)) {
return nil, fmt.Errorf("incompatible prover version. please upgrade your prover, expect version: %s, actual version: %s", version.Version, proverVersion.(string))
}
// if the prover reports a same prover version
return nil, fmt.Errorf("incompatible vk. please check your params files or config files")
}
// if the prover reports a same prover version
return nil, fmt.Errorf("incompatible vk. please check your params files or config files")
}
isBlocked, err := b.proverBlockListOrm.IsPublicKeyBlocked(ctx.Copy(), publicKey.(string))
if err != nil {
return nil, fmt.Errorf("failed to check whether the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s, proverVersion: %s", publicKey, err, proverName, proverVersion)
return nil, fmt.Errorf("failed to check whether the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s", publicKey, err, proverName)
}
if isBlocked {
return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion)
return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, ptc.ProverVersion)
}
isAssigned, err := b.proverTaskOrm.IsProverAssigned(ctx.Copy(), publicKey.(string))
@@ -114,7 +143,7 @@ func (b *BaseProverTask) checkParameter(ctx *gin.Context, getTaskParameter *coor
}
if isAssigned {
return nil, fmt.Errorf("prover with publicKey %s is already assigned a task. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion)
return nil, fmt.Errorf("prover with publicKey %s is already assigned a task. ProverName: %s", publicKey, proverName)
}
return &ptc, nil
}

View File

@@ -134,7 +134,12 @@ func (m *ProofReceiverLogic) HandleZkProof(ctx *gin.Context, proofMsg *message.P
if len(pv) == 0 {
return fmt.Errorf("get ProverVersion from context failed")
}
hardForkName := ctx.GetString(coordinatorType.HardForkName)
// use hard_fork_name from parameter first
// if prover support multi hard_forks, the real hard_fork_name is not set to the gin context
hardForkName := proofParameter.HardForkName
if hardForkName == "" {
hardForkName = ctx.GetString(coordinatorType.HardForkName)
}
var proverTask *orm.ProverTask
var err error

View File

@@ -12,17 +12,17 @@ import (
func NewVerifier(cfg *config.VerifierConfig) (*Verifier, error) {
batchVKMap := map[string]string{
"shanghai": "",
"bernoulli": "",
"london": "",
"istanbul": "",
"bernoulli": "bernoulli",
"london": "london",
"istanbul": "istanbul",
"homestead": "",
"eip155": "",
}
chunkVKMap := map[string]string{
"shanghai": "",
"bernoulli": "",
"london": "",
"istanbul": "",
"bernoulli": "bernoulli",
"london": "london",
"istanbul": "istanbul",
"homestead": "",
"eip155": "",
}

View File

@@ -164,8 +164,8 @@ func (v *Verifier) loadEmbedVK() error {
return err
}
v.BatchVKMap["shanghai"] = base64.StdEncoding.EncodeToString(batchVKBytes)
v.ChunkVKMap["shanghai"] = base64.StdEncoding.EncodeToString(chunkVkBytes)
v.BatchVKMap["bernoulli"] = base64.StdEncoding.EncodeToString(batchVKBytes)
v.ChunkVKMap["bernoulli"] = base64.StdEncoding.EncodeToString(chunkVkBytes)
v.BatchVKMap[""] = base64.StdEncoding.EncodeToString(batchVKBytes)
v.ChunkVKMap[""] = base64.StdEncoding.EncodeToString(chunkVkBytes)
return nil

View File

@@ -32,7 +32,7 @@ func v1(router *gin.RouterGroup, conf *config.Config) {
r.POST("/login", challengeMiddleware.MiddlewareFunc(), loginMiddleware.LoginHandler)
// need jwt token api
r.Use(loginMiddleware.MiddlewareFunc())
// r.Use(loginMiddleware.MiddlewareFunc())
{
r.POST("/get_task", api.GetTask.GetTasks)
r.POST("/submit_proof", api.SubmitProof.SubmitProof)

View File

@@ -2,15 +2,17 @@ package types
// GetTaskParameter for ProverTasks request parameter
type GetTaskParameter struct {
ProverHeight uint64 `form:"prover_height" json:"prover_height"`
TaskType int `form:"task_type" json:"task_type"`
VK string `form:"vk" json:"vk"`
ProverHeight uint64 `form:"prover_height" json:"prover_height"`
TaskType int `form:"task_type" json:"task_type"`
VK string `form:"vk" json:"vk"`
VKs []string `form:"vks" json:"vks"`
}
// GetTaskSchema the schema data return to prover for get prover task
type GetTaskSchema struct {
UUID string `json:"uuid"`
TaskID string `json:"task_id"`
TaskType int `json:"task_type"`
TaskData string `json:"task_data"`
UUID string `json:"uuid"`
TaskID string `json:"task_id"`
TaskType int `json:"task_type"`
TaskData string `json:"task_data"`
HardForkName string `json:"hard_fork_name"`
}

View File

@@ -3,11 +3,12 @@ package types
// SubmitProofParameter the SubmitProof api request parameter
type SubmitProofParameter struct {
// TODO when prover have upgrade, need change this field to required
UUID string `form:"uuid" json:"uuid"`
TaskID string `form:"task_id" json:"task_id" binding:"required"`
TaskType int `form:"task_type" json:"task_type" binding:"required"`
Status int `form:"status" json:"status"`
Proof string `form:"proof" json:"proof"`
FailureType int `form:"failure_type" json:"failure_type"`
FailureMsg string `form:"failure_msg" json:"failure_msg"`
UUID string `form:"uuid" json:"uuid"`
TaskID string `form:"task_id" json:"task_id" binding:"required"`
TaskType int `form:"task_type" json:"task_type" binding:"required"`
Status int `form:"status" json:"status"`
Proof string `form:"proof" json:"proof"`
FailureType int `form:"failure_type" json:"failure_type"`
FailureMsg string `form:"failure_msg" json:"failure_msg"`
HardForkName string `form:"hard_fork_name" json:"hard_fork_name"`
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/scroll-tech/go-ethereum/log"
"github.com/scroll-tech/go-ethereum/params"
"github.com/stretchr/testify/assert"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"scroll-tech/common/testcontainers"
@@ -26,6 +27,7 @@ import (
"scroll-tech/common/version"
"scroll-tech/database/migrate"
cutils "scroll-tech/common/utils"
"scroll-tech/coordinator/internal/config"
"scroll-tech/coordinator/internal/controller/api"
"scroll-tech/coordinator/internal/controller/cron"
@@ -80,6 +82,25 @@ func randomURL() string {
return fmt.Sprintf("localhost:%d", 10000+2000+id.Int64())
}
func useLocalDB(dsn string) *gorm.DB {
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
// Logger: &tmpGormLogger,
NowFunc: func() time.Time {
// why set time to UTC.
// if now set this, the inserted data time will use local timezone. like 2023-07-18 18:24:00 CST+8
// but when inserted, store to postgres is 2023-07-18 18:24:00 UTC+0 the timezone is incorrect.
// As mysql dsn user:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local, we cant set
// the timezone by loc=Local. but postgres's dsn don't have loc option to set timezone, so just need set the gorm option like that.
return cutils.NowUTC()
},
})
if err != nil {
fmt.Println("failed to init db", err.Error())
panic(err.Error())
}
return db
}
func setupCoordinator(t *testing.T, proversPerSession uint8, coordinatorURL string, nameForkMap map[string]int64) (*cron.Collector, *http.Server) {
var err error
db, err = testApps.GetGormDBClient()
@@ -200,14 +221,14 @@ func TestApis(t *testing.T) {
// Set up the test environment.
setEnv(t)
t.Run("TestHandshake", testHandshake)
t.Run("TestFailedHandshake", testFailedHandshake)
t.Run("TestGetTaskBlocked", testGetTaskBlocked)
t.Run("TestOutdatedProverVersion", testOutdatedProverVersion)
t.Run("TestValidProof", testValidProof)
t.Run("TestInvalidProof", testInvalidProof)
t.Run("TestProofGeneratedFailed", testProofGeneratedFailed)
t.Run("TestTimeoutProof", testTimeoutProof)
// t.Run("TestHandshake", testHandshake)
// t.Run("TestFailedHandshake", testFailedHandshake)
// t.Run("TestGetTaskBlocked", testGetTaskBlocked)
// t.Run("TestOutdatedProverVersion", testOutdatedProverVersion)
// t.Run("TestValidProof", testValidProof)
// t.Run("TestInvalidProof", testInvalidProof)
// t.Run("TestProofGeneratedFailed", testProofGeneratedFailed)
// t.Run("TestTimeoutProof", testTimeoutProof)
t.Run("TestHardFork", testHardForkAssignTask)
}
@@ -477,6 +498,106 @@ func testHardForkAssignTask(t *testing.T) {
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
coordinatorURL := randomURL()
collector, httpHandler := setupCoordinator(t, 3, coordinatorURL, tt.forkNumbers)
defer func() {
collector.Stop()
assert.NoError(t, httpHandler.Shutdown(context.Background()))
}()
chunkProof := &message.ChunkProof{
StorageTrace: []byte("testStorageTrace"),
Protocol: []byte("testProtocol"),
Proof: []byte("testProof"),
Instances: []byte("testInstance"),
Vk: []byte("testVk"),
ChunkInfo: nil,
}
// the insert block number is 2 and 3
// chunk1 batch1 contains block number 2
// chunk2 batch2 contains block number 3
err := l2BlockOrm.InsertL2Blocks(context.Background(), []*encoding.Block{block1, block2})
assert.NoError(t, err)
dbHardForkChunk1, err := chunkOrm.InsertChunk(context.Background(), hardForkChunk1)
assert.NoError(t, err)
err = l2BlockOrm.UpdateChunkHashInRange(context.Background(), 0, 2, dbHardForkChunk1.Hash)
assert.NoError(t, err)
err = chunkOrm.UpdateProofAndProvingStatusByHash(context.Background(), dbHardForkChunk1.Hash, chunkProof, types.ProvingTaskUnassigned, 1)
assert.NoError(t, err)
dbHardForkBatch1, err := batchOrm.InsertBatch(context.Background(), hardForkBatch1)
assert.NoError(t, err)
err = chunkOrm.UpdateBatchHashInRange(context.Background(), 0, 0, dbHardForkBatch1.Hash)
assert.NoError(t, err)
err = batchOrm.UpdateChunkProofsStatusByBatchHash(context.Background(), dbHardForkBatch1.Hash, types.ChunkProofsStatusReady)
assert.NoError(t, err)
dbHardForkChunk2, err := chunkOrm.InsertChunk(context.Background(), hardForkChunk2)
assert.NoError(t, err)
err = l2BlockOrm.UpdateChunkHashInRange(context.Background(), 3, 100, dbHardForkChunk2.Hash)
assert.NoError(t, err)
err = chunkOrm.UpdateProofAndProvingStatusByHash(context.Background(), dbHardForkChunk2.Hash, chunkProof, types.ProvingTaskUnassigned, 1)
assert.NoError(t, err)
dbHardForkBatch2, err := batchOrm.InsertBatch(context.Background(), hardForkBatch2)
assert.NoError(t, err)
err = chunkOrm.UpdateBatchHashInRange(context.Background(), 1, 1, dbHardForkBatch2.Hash)
assert.NoError(t, err)
err = batchOrm.UpdateChunkProofsStatusByBatchHash(context.Background(), dbHardForkBatch2.Hash, types.ChunkProofsStatusReady)
assert.NoError(t, err)
fmt.Println("data inserted")
time.Sleep(50 * time.Second)
getTaskNumber := 0
for i := 0; i < 2; i++ {
mockProver := newMockProver(t, fmt.Sprintf("mock_prover_%d", i), coordinatorURL, tt.proofType, version.Version)
proverTask, errCode, errMsg := mockProver.getProverTask(t, tt.proofType, tt.proverForkNames[i])
assert.Equal(t, tt.exceptGetTaskErrCodes[i], errCode)
assert.Equal(t, tt.exceptGetTaskErrMsgs[i], errMsg)
if errCode != types.Success {
continue
}
getTaskNumber++
mockProver.submitProof(t, proverTask, verifiedSuccess, types.Success, tt.proverForkNames[i])
}
assert.Equal(t, getTaskNumber, tt.exceptTaskNumber)
})
}
}
func testHardForkAssignTaskMultiCircuits(t *testing.T) {
tests := []struct {
name string
proofType message.ProofType
forkNumbers map[string]int64
proverForkNames []string
exceptTaskNumber int
exceptGetTaskErrCodes []int
exceptGetTaskErrMsgs []string
}{
{ // hard fork 4, prover 4 block [2-3]
name: "noTaskForkChunkProverVersionLargeOrEqualThanHardFork",
proofType: message.ProofTypeChunk,
forkNumbers: map[string]int64{"bernoulli": forkNumberFour},
exceptTaskNumber: 0,
proverForkNames: []string{"bernoulli", "bernoulli"},
exceptGetTaskErrCodes: []int{types.ErrCoordinatorEmptyProofData, types.ErrCoordinatorEmptyProofData},
exceptGetTaskErrMsgs: []string{"get empty prover task", "get empty prover task"},
},
{
name: "noTaskForkBatchProverVersionLargeOrEqualThanHardFork",
proofType: message.ProofTypeBatch,
forkNumbers: map[string]int64{"bernoulli": forkNumberFour},
exceptTaskNumber: 0,
proverForkNames: []string{"bernoulli", "bernoulli"},
exceptGetTaskErrCodes: []int{types.ErrCoordinatorEmptyProofData, types.ErrCoordinatorEmptyProofData},
exceptGetTaskErrMsgs: []string{"get empty prover task", "get empty prover task"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
coordinatorURL := randomURL()

View File

@@ -0,0 +1,112 @@
package main
import (
"context"
"errors"
"math/big"
"net/http"
"scroll-tech/common/database"
"scroll-tech/common/version"
"scroll-tech/coordinator/internal/config"
"scroll-tech/coordinator/internal/controller/api"
"scroll-tech/coordinator/internal/controller/cron"
"scroll-tech/coordinator/internal/route"
"time"
"github.com/gin-gonic/gin"
"github.com/scroll-tech/go-ethereum/params"
"gorm.io/gorm"
)
// GetGormDBClient returns a gorm.DB by connecting to the running postgres container
func GetGormDBClient() (*gorm.DB, error) {
// endpoint, err := t.GetDBEndPoint()
// if err != nil {
// return nil, err
// }
endpoint := "postgres://lmr:@localhost:5432/unittest?sslmode=disable"
dbCfg := &database.Config{
DSN: endpoint,
DriverName: "postgres",
MaxOpenNum: 200,
MaxIdleNum: 20,
}
return database.InitDB(dbCfg)
}
func setupCoordinator(proversPerSession uint8, coordinatorURL string, nameForkMap map[string]int64) (*cron.Collector, *http.Server) {
db, err := GetGormDBClient()
if err != nil {
panic(err.Error())
}
tokenTimeout := 6
conf := &config.Config{
L2: &config.L2{
ChainID: 111,
},
ProverManager: &config.ProverManager{
ProversPerSession: proversPerSession,
Verifier: &config.VerifierConfig{
MockMode: true,
},
BatchCollectionTimeSec: 10,
ChunkCollectionTimeSec: 10,
MaxVerifierWorkers: 10,
SessionAttempts: 5,
MinProverVersion: version.Version,
},
Auth: &config.Auth{
ChallengeExpireDurationSec: tokenTimeout,
LoginExpireDurationSec: tokenTimeout,
},
}
var chainConf params.ChainConfig
for forkName, forkNumber := range nameForkMap {
switch forkName {
case "shanghai":
chainConf.ShanghaiBlock = big.NewInt(forkNumber)
case "bernoulli":
chainConf.BernoulliBlock = big.NewInt(forkNumber)
case "london":
chainConf.LondonBlock = big.NewInt(forkNumber)
case "istanbul":
chainConf.IstanbulBlock = big.NewInt(forkNumber)
case "homestead":
chainConf.HomesteadBlock = big.NewInt(forkNumber)
case "eip155":
chainConf.EIP155Block = big.NewInt(forkNumber)
}
}
proofCollector := cron.NewCollector(context.Background(), db, conf, nil)
router := gin.New()
api.InitController(conf, &chainConf, db, nil)
route.Route(router, conf, nil)
srv := &http.Server{
Addr: coordinatorURL,
Handler: router,
}
go func() {
runErr := srv.ListenAndServe()
if runErr != nil && !errors.Is(runErr, http.ErrServerClosed) {
panic(runErr.Error())
}
}()
time.Sleep(time.Second * 2)
return proofCollector, srv
}
func main() {
coordinatorURL := ":9091"
nameForkMap := map[string]int64{"london": 2,
"istanbul": 3,
"bernoulli": 4}
setupCoordinator(1, coordinatorURL, nameForkMap)
var c = make(chan struct{}, 1)
_ = <-c
}