mirror of
https://github.com/qwang98/PyChiquito.git
synced 2026-04-22 03:00:16 -04:00
updated tracecontext and cleaned up debug methods
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Callable, List, Dict, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
from wit_gen import TraceContext, FixedGenContext, StepInstance
|
||||
from wit_gen import FixedGenContext, StepInstance
|
||||
from expr import Expr
|
||||
from util import uuid
|
||||
from query import Queriable
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from enum import Enum
|
||||
from typing import Callable, Any
|
||||
from dataclasses import dataclass
|
||||
import rust_chiquito # rust bindings
|
||||
import json
|
||||
|
||||
from chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset
|
||||
from query import Internal, Forward, Queriable, Shared, Fixed
|
||||
from wit_gen import FixedGenContext, TraceContext, StepInstance, TraceWitness
|
||||
from wit_gen import FixedGenContext, StepInstance, TraceWitness
|
||||
from cb import Constraint, Typing, ToConstraint, to_constraint
|
||||
from util import CustomEncoder, F
|
||||
|
||||
@@ -21,7 +20,7 @@ class CircuitMode(Enum):
|
||||
class Circuit:
|
||||
def __init__(self: Circuit):
|
||||
self.ast = ASTCircuit()
|
||||
self.trace_context = TraceContext()
|
||||
self.witness = TraceWitness()
|
||||
self.rust_ast_id = 0
|
||||
self.mode = CircuitMode.SETUP
|
||||
self.setup()
|
||||
@@ -68,12 +67,10 @@ class Circuit:
|
||||
def pragma_first_step(self: Circuit, step_type: StepType) -> None:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
self.ast.first_step = step_type.step_type.id
|
||||
# print(f"first step id: {step_type.step_type.id}")
|
||||
|
||||
def pragma_last_step(self: Circuit, step_type: StepType) -> None:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
self.ast.last_step = step_type.step_type.id
|
||||
# print(f"last step id: {step_type.step_type.id}")
|
||||
|
||||
def pragma_num_steps(self: Circuit, num_steps: int) -> None:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
@@ -85,27 +82,21 @@ class Circuit:
|
||||
|
||||
def add(self: Circuit, step_type: StepType, args: Any):
|
||||
assert self.mode == CircuitMode.Trace
|
||||
self.trace_context.add(self, step_type, args)
|
||||
step_instance: StepInstance = step_type.gen_step_instance(args)
|
||||
self.witness.step_instances.append(step_instance)
|
||||
|
||||
def gen_witness(self: Circuit, args: Any) -> TraceWitness:
|
||||
self.mode = CircuitMode.Trace
|
||||
self.trace_context = TraceContext()
|
||||
self.witness = TraceWitness()
|
||||
self.trace(args)
|
||||
self.mode = CircuitMode.NoMode
|
||||
witness = self.trace_context.witness
|
||||
del self.trace_context
|
||||
witness = self.witness
|
||||
del self.witness
|
||||
return witness
|
||||
|
||||
def get_ast_json(self: Circuit) -> str:
|
||||
return json.dumps(self.ast, cls=CustomEncoder, indent=4)
|
||||
|
||||
def convert_and_print_ast(self: Circuit):
|
||||
ast_json: str = self.get_ast_json()
|
||||
print(
|
||||
"Call rust bindings, parse json to Chiquito ASTCircuit, and print using Debug trait:"
|
||||
)
|
||||
rust_chiquito.convert_and_print_ast(ast_json)
|
||||
|
||||
def ast_to_halo2(self: Circuit):
|
||||
ast_json: str = self.get_ast_json()
|
||||
self.rust_ast_id: int = rust_chiquito.ast_to_halo2(ast_json)
|
||||
@@ -117,24 +108,6 @@ class Circuit:
|
||||
rust_chiquito.verify_proof(witness_json, self.rust_ast_id)
|
||||
|
||||
|
||||
# Debug method
|
||||
def convert_and_print_witness(witness: TraceWitness):
|
||||
witness_json: str = witness.get_witness_json()
|
||||
rust_chiquito.convert_and_print_trace_witness(witness_json)
|
||||
|
||||
|
||||
# Debug method
|
||||
def print_ast(ast: ASTCircuit):
|
||||
print("Print ASTCircuit using custom __str__ method in python:")
|
||||
print(ast)
|
||||
|
||||
|
||||
# Debug method
|
||||
def print_witness(witness: TraceWitness):
|
||||
print("Print TraceWitness using custom __str__ method in python:")
|
||||
print(witness)
|
||||
|
||||
|
||||
class StepTypeMode(Enum):
|
||||
NoMode = 0
|
||||
SETUP = 1
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple
|
||||
|
||||
from dsl import Circuit, StepType, print_ast, print_witness, convert_and_print_witness
|
||||
from dsl import Circuit, StepType
|
||||
from cb import eq
|
||||
from query import Queriable
|
||||
from util import F
|
||||
@@ -64,11 +64,5 @@ class FiboLastStep(StepType):
|
||||
|
||||
fibo = Fibonacci()
|
||||
fibo_witness = fibo.gen_witness(None)
|
||||
fibo.convert_and_print_ast()
|
||||
fibo.ast_to_halo2()
|
||||
fibo.verify_proof(fibo_witness)
|
||||
|
||||
# Debug methods
|
||||
# print_ast(fibo.ast)
|
||||
# print_witness(fibo_witness)
|
||||
# convert_and_print_witness(fibo_witness)
|
||||
|
||||
@@ -89,31 +89,6 @@ class TraceWitness:
|
||||
return json.dumps(self, cls=CustomEncoder, indent=4)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceContext:
|
||||
witness: TraceWitness = field(default_factory=TraceWitness)
|
||||
|
||||
def add(self: TraceContext, circuit: Circuit, step: StepType, args: Any):
|
||||
step_instance: StepInstance = step.gen_step_instance(args)
|
||||
self.witness.step_instances.append(step_instance)
|
||||
|
||||
def set_height(self: TraceContext, height: int):
|
||||
self.witness.height = height
|
||||
|
||||
|
||||
Trace = Callable[[TraceContext, Any], None] # TraceArgs are Any.
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceGenerator:
|
||||
trace: Trace
|
||||
|
||||
def generate(self: TraceGenerator, args: Any) -> TraceWitness: # Args are Any.
|
||||
ctx = TraceContext()
|
||||
self.trace(ctx, args)
|
||||
return ctx.witness
|
||||
|
||||
|
||||
FixedAssigment = Dict[Queriable, List[F]]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user