finished wit_gen and obtained TraceWitness object from fibo example

This commit is contained in:
Steve Wang
2023-07-07 19:43:06 +08:00
parent df3284624d
commit 3bf005c7ff
2 changed files with 39 additions and 22 deletions

View File

@@ -1,7 +1,11 @@
from __future__ import annotations
import pprint
from typing import Any, Tuple
from py_ecc import bn128
from pychiquito import CircuitContext, StepTypeContext, StepTypeSetupContext, StepTypeWGHandler, StepTypeHandler, Constraint, Queriable
from pychiquito import CircuitContext, StepTypeContext, StepTypeSetupContext, StepTypeWGHandler, StepTypeHandler, Constraint, Queriable, TraceContext, StepInstance, TraceGenerator
F = bn128.FQ
class Fibonacci(CircuitContext):
def __init__(self: Fibonacci):
@@ -19,7 +23,7 @@ class Fibonacci(CircuitContext):
self.fibo_last_step: StepTypeWGHandler = self.step_type_def(FiboLastStep(self, fibo_last_step))
def trace(self: Fibonacci):
def trace_def(ctx: TraceContext, values: TraceArgs): # TODO: Complete wit_gen.py and update.
def trace_def(ctx: TraceContext, _: Any): # Any instead of TraceArgs
ctx.add(self.fibo_step, (1, 1))
a = 1
b = 2
@@ -34,35 +38,44 @@ class Fibonacci(CircuitContext):
class FiboStep(StepTypeContext):
def __init__(self: FiboStep, circuit: Fibonacci, handler: StepTypeHandler):
super().__init__(handler) # Pass the id and annotation of handler to a new StepTypeContext instance.
c = self.internal("c")
self.c = self.internal("c") # `self.c` is required instead of `c`, because wg needs to access `self.c`.
def setup_def(ctx: StepTypeSetupContext):
ctx.constr(Constraint.eq(circuit.a + circuit.b, c))
ctx.constr(Constraint.eq(circuit.a + circuit.b, self.c))
ctx.transition(Constraint.eq(circuit.b, circuit.a.next()))
ctx.transition(Constraint.eq(c, circuit.b.next()))
ctx.transition(Constraint.eq(self.c, circuit.b.next()))
self.setup(setup_def)
def wg(self: FiboStep) -> StepTypeWGHandler:
def wg_def(ctx: StepInstance, values: Args):
# TODO: Implement after wit_gen.py is completed.
pass
def wg(self: FiboStep, circuit: Fibonacci) -> StepTypeWGHandler:
def wg_def(ctx: StepInstance, values: Tuple[int, int]): # Any instead of Args
a_value, b_value = values
print(f"fib step wg: {a_value}, {b_value}, {a_value + b_value}")
ctx.assign(circuit.a, F(a_value))
ctx.assign(circuit.b, F(b_value))
ctx.assign(self.c, F(a_value + b_value))
return super().wg(wg_def)
class FiboLastStep(StepTypeContext):
def __init__(self: FiboStep, circuit: Fibonacci, handler: StepTypeHandler):
super().__init__(handler)
c = self.internal("c")
self.c = self.internal("c")
def setup_def(ctx: StepTypeSetupContext):
ctx.constr(Constraint.eq(circuit.a + circuit.b, c))
ctx.constr(Constraint.eq(circuit.a + circuit.b, self.c))
self.setup(setup_def)
def wg(self: FiboLastStep) -> StepTypeWGHandler:
def wg_def(ctx: StepInstance, values: Args):
# TODO: Implement after wit_gen.py is completed.
pass
return super().wg(wg_def)
def wg(self: FiboLastStep, circuit: Fibonacci) -> StepTypeWGHandler:
def wg_def(ctx: StepInstance, values: Tuple[int, int]): # Any instead of Args
a_value, b_value = values
print(f"fib last step wg: {a_value}, {b_value}, {a_value + b_value}\n")
ctx.assign(circuit.a, F(a_value))
ctx.assign(circuit.b, F(b_value))
ctx.assign(self.c, F(a_value + b_value))
return super().wg(wg_def)
fibo = Fibonacci()
# pprint.pprint(fibo.circuit)
print(fibo.circuit) # Print ast::Circuit.
fibo.trace()
trace_generator = TraceGenerator(fibo.circuit.trace)
print(trace_generator.generate(None)) # Print TraceWitness

View File

@@ -49,10 +49,10 @@ class CircuitContext:
# StepTypeContext is generated by initialising a custom-defined step type class.
# The step type class should have a wg function that returns a StepTypeWGHandler.
def step_type_def(
self: CircuitContext, context: StepTypeContext
self: CircuitContext, step_type_context: StepTypeContext
) -> StepTypeWGHandler:
self.circuit.add_step_type_def(context.step_type)
return context.wg()
self.circuit.add_step_type_def(step_type_context.step_type)
return step_type_context.wg(self)
# def step_type_def(self: CircuitContext, step: StepTypeDefInput, step_type_func: Callable[[StepTypeContext], StepTypeWGHandler]) -> StepTypeWGHandler:
# match step:
@@ -423,7 +423,7 @@ class Circuit:
# halo2_fixed: List[ImportedHalo2Fixed] = field(default_factory=list)
exposed: List[ForwardSignal] = field(default_factory=list)
annotations: Dict[int, str] = field(default_factory=dict)
trace: Optional[Callable] = None
trace: Optional[Callable[[TraceContext, Any], None]] = None
fixed_gen: Optional[Callable] = None
first_step: Optional[int] = None
last_step: Optional[int] = None
@@ -999,6 +999,10 @@ class Queriable:
def __pow__(self: Queriable, rhs: int) -> Expr:
lhs = to_expr(self)
return Expr(Pow(lhs, rhs))
# __hash__ method is required, because Queriable is used as a key in the assignment dictionary.
def __hash__(self: Queriable):
return hash(self.uuid())
################
@@ -1087,10 +1091,10 @@ class TraceWitness:
@dataclass
class TraceContext:
witness: TraceWitness = TraceWitness()
witness: TraceWitness = field(default_factory=TraceWitness)
def add(self: TraceContext, step: StepTypeWGHandler, args: Any):
witness = StepInstance.new(step.uuid)
witness = StepInstance.new(step.id)
step.wg(witness, args)
self.witness.step_instances.append(witness)