updated tracecontext and cleaned up debug methods

This commit is contained in:
Steve Wang
2023-07-28 14:25:35 +08:00
parent daf1201f2f
commit 59ecf65daa
4 changed files with 9 additions and 67 deletions

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Callable, List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, field, asdict
from wit_gen import TraceContext, FixedGenContext, StepInstance
from wit_gen import FixedGenContext, StepInstance
from expr import Expr
from util import uuid
from query import Queriable

View File

@@ -1,13 +1,12 @@
from __future__ import annotations
from enum import Enum
from typing import Callable, Any
from dataclasses import dataclass
import rust_chiquito # rust bindings
import json
from chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset
from query import Internal, Forward, Queriable, Shared, Fixed
from wit_gen import FixedGenContext, TraceContext, StepInstance, TraceWitness
from wit_gen import FixedGenContext, StepInstance, TraceWitness
from cb import Constraint, Typing, ToConstraint, to_constraint
from util import CustomEncoder, F
@@ -21,7 +20,7 @@ class CircuitMode(Enum):
class Circuit:
def __init__(self: Circuit):
self.ast = ASTCircuit()
self.trace_context = TraceContext()
self.witness = TraceWitness()
self.rust_ast_id = 0
self.mode = CircuitMode.SETUP
self.setup()
@@ -68,12 +67,10 @@ class Circuit:
def pragma_first_step(self: Circuit, step_type: StepType) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.first_step = step_type.step_type.id
# print(f"first step id: {step_type.step_type.id}")
def pragma_last_step(self: Circuit, step_type: StepType) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.last_step = step_type.step_type.id
# print(f"last step id: {step_type.step_type.id}")
def pragma_num_steps(self: Circuit, num_steps: int) -> None:
assert self.mode == CircuitMode.SETUP
@@ -85,27 +82,21 @@ class Circuit:
def add(self: Circuit, step_type: StepType, args: Any):
assert self.mode == CircuitMode.Trace
self.trace_context.add(self, step_type, args)
step_instance: StepInstance = step_type.gen_step_instance(args)
self.witness.step_instances.append(step_instance)
def gen_witness(self: Circuit, args: Any) -> TraceWitness:
self.mode = CircuitMode.Trace
self.trace_context = TraceContext()
self.witness = TraceWitness()
self.trace(args)
self.mode = CircuitMode.NoMode
witness = self.trace_context.witness
del self.trace_context
witness = self.witness
del self.witness
return witness
def get_ast_json(self: Circuit) -> str:
return json.dumps(self.ast, cls=CustomEncoder, indent=4)
def convert_and_print_ast(self: Circuit):
ast_json: str = self.get_ast_json()
print(
"Call rust bindings, parse json to Chiquito ASTCircuit, and print using Debug trait:"
)
rust_chiquito.convert_and_print_ast(ast_json)
def ast_to_halo2(self: Circuit):
ast_json: str = self.get_ast_json()
self.rust_ast_id: int = rust_chiquito.ast_to_halo2(ast_json)
@@ -117,24 +108,6 @@ class Circuit:
rust_chiquito.verify_proof(witness_json, self.rust_ast_id)
# Debug method
def convert_and_print_witness(witness: TraceWitness):
witness_json: str = witness.get_witness_json()
rust_chiquito.convert_and_print_trace_witness(witness_json)
# Debug method
def print_ast(ast: ASTCircuit):
print("Print ASTCircuit using custom __str__ method in python:")
print(ast)
# Debug method
def print_witness(witness: TraceWitness):
print("Print TraceWitness using custom __str__ method in python:")
print(witness)
class StepTypeMode(Enum):
NoMode = 0
SETUP = 1

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Tuple
from dsl import Circuit, StepType, print_ast, print_witness, convert_and_print_witness
from dsl import Circuit, StepType
from cb import eq
from query import Queriable
from util import F
@@ -64,11 +64,5 @@ class FiboLastStep(StepType):
fibo = Fibonacci()
fibo_witness = fibo.gen_witness(None)
fibo.convert_and_print_ast()
fibo.ast_to_halo2()
fibo.verify_proof(fibo_witness)
# Debug methods
# print_ast(fibo.ast)
# print_witness(fibo_witness)
# convert_and_print_witness(fibo_witness)

View File

@@ -89,31 +89,6 @@ class TraceWitness:
return json.dumps(self, cls=CustomEncoder, indent=4)
@dataclass
class TraceContext:
witness: TraceWitness = field(default_factory=TraceWitness)
def add(self: TraceContext, circuit: Circuit, step: StepType, args: Any):
step_instance: StepInstance = step.gen_step_instance(args)
self.witness.step_instances.append(step_instance)
def set_height(self: TraceContext, height: int):
self.witness.height = height
Trace = Callable[[TraceContext, Any], None] # TraceArgs are Any.
@dataclass
class TraceGenerator:
trace: Trace
def generate(self: TraceGenerator, args: Any) -> TraceWitness: # Args are Any.
ctx = TraceContext()
self.trace(ctx, args)
return ctx.witness
FixedAssigment = Dict[Queriable, List[F]]