Files
scroll/roller/prover/prover.go
Lawliet-Chan f2c8b733c3 fix(libzkp):fix libzkp.so path (#486)
Co-authored-by: xinran chen <lawliet@xinran-m1x.local>
Co-authored-by: HAOYUatHZ <37070449+HAOYUatHZ@users.noreply.github.com>
2023-05-17 13:27:42 +08:00

105 lines
2.5 KiB
Go

//go:build !mock_prover
package prover
/*
#cgo LDFLAGS: -L${SRCDIR}/lib/ -lzkp -lm -ldl -lzktrie -L${SRCDIR}/lib/ -Wl,-rpath=${SRCDIR}/lib
#cgo gpu LDFLAGS: -L${SRCDIR}/lib/ -lzkp -lm -ldl -lgmp -lstdc++ -lprocps -lzktrie -L/usr/local/cuda/lib64/ -L${SRCDIR}/lib/ -lcudart -Wl,-rpath=${SRCDIR}/lib
#include <stdlib.h>
#include "./lib/libzkp.h"
*/
import "C" //nolint:typecheck
import (
"encoding/json"
"os"
"path/filepath"
"unsafe"
"github.com/scroll-tech/go-ethereum/log"
"scroll-tech/common/types/message"
"scroll-tech/roller/config"
)
// Prover sends block-traces to rust-prover through ffi and get back the zk-proof.
type Prover struct {
cfg *config.ProverConfig
}
// NewProver inits a Prover object.
func NewProver(cfg *config.ProverConfig) (*Prover, error) {
paramsPathStr := C.CString(cfg.ParamsPath)
seedPathStr := C.CString(cfg.SeedPath)
defer func() {
C.free(unsafe.Pointer(paramsPathStr))
C.free(unsafe.Pointer(seedPathStr))
}()
C.init_prover(paramsPathStr, seedPathStr)
if cfg.DumpDir != "" {
err := os.MkdirAll(cfg.DumpDir, os.ModePerm)
if err != nil {
return nil, err
}
log.Info("Enabled dump_proof", "dir", cfg.DumpDir)
}
return &Prover{cfg: cfg}, nil
}
// Prove call rust ffi to generate proof, if first failed, try again.
func (p *Prover) Prove(task *message.TaskMsg) (*message.AggProof, error) {
var proofByt []byte
if p.cfg.ProveType == message.BasicProve {
tracesByt, err := json.Marshal(task.Traces)
if err != nil {
return nil, err
}
proofByt = p.prove(tracesByt)
} else if p.cfg.ProveType == message.AggregatorProve {
// TODO: aggregator prove
}
// dump proof
err := p.dumpProof(task.ID, proofByt)
if err != nil {
log.Error("Dump proof failed", "task-id", task.ID, "error", err)
}
zkProof := &message.AggProof{}
return zkProof, json.Unmarshal(proofByt, zkProof)
}
// Call cgo to generate proof.
func (p *Prover) prove(tracesByt []byte) []byte {
tracesStr := C.CString(string(tracesByt))
defer func() {
C.free(unsafe.Pointer(tracesStr))
}()
log.Info("Start to create agg proof ...")
cProof := C.create_agg_proof_multi(tracesStr)
log.Info("Finish creating agg proof!")
proof := C.GoString(cProof)
return []byte(proof)
}
func (p *Prover) dumpProof(id string, proofByt []byte) error {
if p.cfg.DumpDir == "" {
return nil
}
path := filepath.Join(p.cfg.DumpDir, id)
f, err := os.Create(path)
if err != nil {
return err
}
log.Info("Saving proof", "task-id", id)
_, err = f.Write(proofByt)
return err
}