mirror of
https://github.com/qwang98/PyChiquito.git
synced 2026-04-22 03:00:16 -04:00
refactored everything
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import List
|
||||
from util import F
|
||||
from expr import Expr, Const, Neg, to_expr, ToExpr
|
||||
from query import StepTypeNext
|
||||
from chiquito_ast import StepType
|
||||
from chiquito_ast import ASTStepType
|
||||
|
||||
##########
|
||||
# dsl/cb #
|
||||
@@ -165,7 +165,7 @@ def isz(constraint: ToConstraint) -> Constraint:
|
||||
)
|
||||
|
||||
|
||||
def if_next_step(step_type: StepType, constraint: ToConstraint) -> Constraint:
|
||||
def if_next_step(step_type: ASTStepType, constraint: ToConstraint) -> Constraint:
|
||||
constraint = to_constraint(constraint)
|
||||
return Constraint(
|
||||
f"if(next step is {step_type.annotation})then({constraint.annotation})",
|
||||
@@ -174,7 +174,7 @@ def if_next_step(step_type: StepType, constraint: ToConstraint) -> Constraint:
|
||||
)
|
||||
|
||||
|
||||
def next_step_must_be(step_type: StepType) -> Constraint:
|
||||
def next_step_must_be(step_type: ASTStepType) -> Constraint:
|
||||
return Constraint(
|
||||
f"next step must be {step_type.annotation}",
|
||||
Constraint.cb_not(StepTypeNext(step_type)),
|
||||
@@ -182,7 +182,7 @@ def next_step_must_be(step_type: StepType) -> Constraint:
|
||||
)
|
||||
|
||||
|
||||
def next_step_must_not_be(step_type: StepType) -> Constraint:
|
||||
def next_step_must_not_be(step_type: ASTStepType) -> Constraint:
|
||||
return Constraint(
|
||||
f"next step must not be {step_type.annotation}",
|
||||
StepTypeNext(step_type),
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Callable, List, Dict, Optional, Any, Tuple
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
from wit_gen import TraceContext, FixedGenContext, StepInstance
|
||||
from wit_gen import TraceContext, FixedGenContext, StepInstance, TraceWitness
|
||||
from expr import Expr
|
||||
from util import uuid
|
||||
from query import Queriable
|
||||
@@ -11,8 +11,8 @@ from query import Queriable
|
||||
# ast #
|
||||
#######
|
||||
|
||||
# pub struct Circuit<F, TraceArgs> {
|
||||
# pub step_types: HashMap<u32, Rc<StepType<F>>>,
|
||||
# pub struct AST<F, TraceArgs> {
|
||||
# pub step_types: HashMap<u32, Rc<ASTStepType<F>>>,
|
||||
|
||||
# pub forward_signals: Vec<ForwardSignal>,
|
||||
# pub shared_signals: Vec<SharedSignal>,
|
||||
@@ -26,15 +26,15 @@ from query import Queriable
|
||||
# pub trace: Option<Rc<Trace<F, TraceArgs>>>,
|
||||
# pub fixed_gen: Option<Rc<FixedGen<F>>>,
|
||||
|
||||
# pub first_step: Option<StepTypeUUID>,
|
||||
# pub last_step: Option<StepTypeUUID>,
|
||||
# pub first_step: Option<ASTStepTypeUUID>,
|
||||
# pub last_step: Option<ASTStepTypeUUID>,
|
||||
# pub num_steps: usize,
|
||||
# }
|
||||
|
||||
|
||||
@dataclass
|
||||
class Circuit:
|
||||
step_types: Dict[int, StepType] = field(default_factory=dict)
|
||||
class AST:
|
||||
step_types: Dict[int, ASTStepType] = field(default_factory=dict)
|
||||
forward_signals: List[ForwardSignal] = field(default_factory=list)
|
||||
shared_signals: List[SharedSignal] = field(default_factory=list)
|
||||
fixed_signals: List[FixedSignal] = field(default_factory=list)
|
||||
@@ -48,7 +48,7 @@ class Circuit:
|
||||
q_enable: bool = True
|
||||
id: int = uuid()
|
||||
|
||||
def __str__(self: Circuit):
|
||||
def __str__(self: AST):
|
||||
step_types_str = (
|
||||
"\n\t\t"
|
||||
+ ",\n\t\t".join(f"{k}: {v}" for k, v in self.step_types.items())
|
||||
@@ -90,7 +90,7 @@ class Circuit:
|
||||
)
|
||||
|
||||
return (
|
||||
f"Circuit(\n"
|
||||
f"AST(\n"
|
||||
f"\tstep_types={{{step_types_str}}},\n"
|
||||
f"\tforward_signals=[{forward_signals_str}],\n"
|
||||
f"\tshared_signals=[{shared_signals_str}],\n"
|
||||
@@ -106,7 +106,7 @@ class Circuit:
|
||||
f")"
|
||||
)
|
||||
|
||||
def __json__(self: Circuit):
|
||||
def __json__(self: AST):
|
||||
return {
|
||||
"step_types": {k: v.__json__() for k, v in self.step_types.items()},
|
||||
"forward_signals": [x.__json__() for x in self.forward_signals],
|
||||
@@ -124,55 +124,55 @@ class Circuit:
|
||||
"id": self.id,
|
||||
}
|
||||
|
||||
def add_forward(self: Circuit, name: str, phase: int) -> ForwardSignal:
|
||||
def add_forward(self: AST, name: str, phase: int) -> ForwardSignal:
|
||||
signal = ForwardSignal(phase, name)
|
||||
self.forward_signals.append(signal)
|
||||
self.annotations[signal.id] = name
|
||||
return signal
|
||||
|
||||
def add_shared(self: Circuit, name: str, phase: int) -> SharedSignal:
|
||||
def add_shared(self: AST, name: str, phase: int) -> SharedSignal:
|
||||
signal = SharedSignal(phase, name)
|
||||
self.shared_signals.append(signal)
|
||||
self.annotations[signal.id] = name
|
||||
return signal
|
||||
|
||||
def add_fixed(self: Circuit, name: str) -> FixedSignal:
|
||||
def add_fixed(self: AST, name: str) -> FixedSignal:
|
||||
signal = FixedSignal(name)
|
||||
self.fixed_signals.append(signal)
|
||||
self.annotations[signal.id] = name
|
||||
return signal
|
||||
|
||||
def expose(self: Circuit, signal: Queriable, offset: ExposeOffset):
|
||||
def expose(self: AST, signal: Queriable, offset: ExposeOffset):
|
||||
self.exposed.append((signal, offset))
|
||||
|
||||
def add_step_type(self: Circuit, step_type: StepType, name: str):
|
||||
def add_step_type(self: AST, step_type: ASTStepType, name: str):
|
||||
self.annotations[step_type.id] = name
|
||||
self.step_types[step_type.id] = step_type
|
||||
|
||||
def set_trace(
|
||||
self: Circuit, trace_def: Callable[[TraceContext, Any], None]
|
||||
self: AST, trace_def: Callable[[TraceContext, Any], None]
|
||||
): # TraceArgs are Any.
|
||||
if self.trace is not None:
|
||||
raise Exception(
|
||||
"Circuit cannot have more than one trace generator.")
|
||||
"AST cannot have more than one trace generator.")
|
||||
else:
|
||||
self.trace = trace_def
|
||||
|
||||
def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]):
|
||||
if self.fixed_gen is not None:
|
||||
raise Exception(
|
||||
"Circuit cannot have more than one fixed generator.")
|
||||
"AST cannot have more than one fixed generator.")
|
||||
else:
|
||||
self.fixed_gen = fixed_gen_def
|
||||
|
||||
def get_step_type(self, uuid: int) -> StepType:
|
||||
def get_step_type(self, uuid: int) -> ASTStepType:
|
||||
if uuid in self.step_types.keys():
|
||||
return self.step_types[uuid]
|
||||
else:
|
||||
raise ValueError("StepType not found.")
|
||||
raise ValueError("ASTStepType not found.")
|
||||
|
||||
|
||||
# pub struct StepType<F> {
|
||||
# pub struct ASTStepType<F> {
|
||||
# id: StepTypeUUID,
|
||||
|
||||
# pub name: String,
|
||||
@@ -185,7 +185,7 @@ class Circuit:
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepType:
|
||||
class ASTStepType:
|
||||
id: int
|
||||
name: str
|
||||
signals: List[InternalSignal]
|
||||
@@ -196,8 +196,8 @@ class StepType:
|
||||
Callable[[StepInstance, Any], None]
|
||||
] # Args are Any. Not passed to Rust Chiquito.
|
||||
|
||||
def new(name: str) -> StepType:
|
||||
return StepType(uuid(), name, [], [], [], {}, None)
|
||||
def new(name: str) -> ASTStepType:
|
||||
return ASTStepType(uuid(), name, [], [], [], {}, None)
|
||||
|
||||
def __str__(self):
|
||||
signals_str = (
|
||||
@@ -232,7 +232,7 @@ class StepType:
|
||||
)
|
||||
|
||||
return (
|
||||
f"StepType(\n"
|
||||
f"ASTStepType(\n"
|
||||
f"\t\t\tid={self.id},\n"
|
||||
f"\t\t\tname='{self.name}',\n"
|
||||
f"\t\t\tsignals=[{signals_str}],\n"
|
||||
@@ -254,17 +254,17 @@ class StepType:
|
||||
"annotations": self.annotations,
|
||||
}
|
||||
|
||||
def add_signal(self: StepType, name: str) -> InternalSignal:
|
||||
def add_signal(self: ASTStepType, name: str) -> InternalSignal:
|
||||
signal = InternalSignal(name)
|
||||
self.signals.append(signal)
|
||||
self.annotations[signal.id] = name
|
||||
return signal
|
||||
|
||||
def add_constr(self: StepType, annotation: str, expr: Expr):
|
||||
def add_constr(self: ASTStepType, annotation: str, expr: Expr):
|
||||
condition = ASTConstraint(annotation, expr)
|
||||
self.constraints.append(condition)
|
||||
|
||||
def add_transition(self: StepType, annotation: str, expr: Expr):
|
||||
def add_transition(self: ASTStepType, annotation: str, expr: Expr):
|
||||
condition = TransitionConstraint(annotation, expr)
|
||||
self.transition_constraints.append(condition)
|
||||
|
||||
@@ -272,15 +272,15 @@ class StepType:
|
||||
def set_wg(self, wg_def: Callable[[StepInstance, Any], None]):
|
||||
self.wg = wg_def
|
||||
|
||||
def __eq__(self: StepType, other: StepType) -> bool:
|
||||
if isinstance(self, StepType) and isinstance(other, StepType):
|
||||
def __eq__(self: ASTStepType, other: ASTStepType) -> bool:
|
||||
if isinstance(self, ASTStepType) and isinstance(other, ASTStepType):
|
||||
return self.id == other.id
|
||||
return False
|
||||
|
||||
def __req__(other: StepType, self: StepType) -> bool:
|
||||
return StepType.__eq__(self, other)
|
||||
def __req__(other: ASTStepType, self: ASTStepType) -> bool:
|
||||
return ASTStepType.__eq__(self, other)
|
||||
|
||||
def __hash__(self: StepType):
|
||||
def __hash__(self: ASTStepType):
|
||||
return hash(self.id)
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ from enum import Enum
|
||||
from typing import Callable, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from chiquito_ast import Circuit, StepType, ExposeOffset, ForwardSignal, SharedSignal
|
||||
from chiquito_ast import AST, ASTStepType, ExposeOffset, ForwardSignal, SharedSignal
|
||||
from query import Internal, Forward, Queriable, Shared, Fixed
|
||||
from wit_gen import FixedGenContext, TraceContext
|
||||
from wit_gen import FixedGenContext, TraceContext, StepInstance, TraceWitness
|
||||
from cb import Constraint, Typing, ToConstraint, to_constraint
|
||||
|
||||
|
||||
@@ -13,27 +13,43 @@ from cb import Constraint, Typing, ToConstraint, to_constraint
|
||||
# dsl #
|
||||
#######
|
||||
|
||||
class CircuitMode(Enum):
|
||||
NoMode = 0
|
||||
SETUP = 1
|
||||
Trace = 2
|
||||
|
||||
class CircuitContext:
|
||||
def __init__(self):
|
||||
self.circuit = Circuit()
|
||||
class Circuit:
|
||||
def __init__(self: Circuit):
|
||||
self.circuit = AST()
|
||||
self.trace_context = TraceContext
|
||||
self.mode = CircuitMode.SETUP
|
||||
self.setup()
|
||||
self.mode = CircuitMode.Trace
|
||||
self.circuit.set_trace(self.trace)
|
||||
self.mode = CircuitMode.NoMode
|
||||
|
||||
def forward(self: CircuitContext, name: str) -> Forward:
|
||||
def forward(self: Circuit, name: str) -> Forward:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
return Forward(self.circuit.add_forward(name, 0), False)
|
||||
|
||||
def forward_with_phase(self: CircuitContext, name: str, phase: int) -> Forward:
|
||||
def forward_with_phase(self: Circuit, name: str, phase: int) -> Forward:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
return Forward(self.circuit.add_forward(name, phase), False)
|
||||
|
||||
def shared(self: CircuitContext, name: str) -> Shared:
|
||||
def shared(self: Circuit, name: str) -> Shared:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
return Shared(self.circuit.add_shared(name, 0), 0)
|
||||
|
||||
def shared_with_phase(self: CircuitContext, name: str, phase: int) -> Shared:
|
||||
def shared_with_phase(self: Circuit, name: str, phase: int) -> Shared:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
return Shared(self.circuit.add_shared(name, phase), 0)
|
||||
|
||||
def fixed(self: CircuitContext, name: str) -> Fixed:
|
||||
def fixed(self: Circuit, name: str) -> Fixed:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
return Fixed(self.circuit.add_fixed(name), 0)
|
||||
|
||||
def expose(self: CircuitContext, signal: Queriable, offset: ExposeOffset):
|
||||
def expose(self: Circuit, signal: Queriable, offset: ExposeOffset):
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
if isinstance(signal, (Forward, Shared)):
|
||||
self.circuit.expose(signal, offset)
|
||||
else:
|
||||
@@ -42,44 +58,54 @@ class CircuitContext:
|
||||
# import_halo2_advice and import_halo2_fixed are ignored.
|
||||
|
||||
def step_type(
|
||||
self: CircuitContext, step_type_context: StepTypeContext
|
||||
) -> StepTypeContext:
|
||||
self: Circuit, step_type_context: StepType
|
||||
) -> StepType:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
self.circuit.add_step_type(
|
||||
step_type_context.step_type, step_type_context.step_type.name
|
||||
)
|
||||
return step_type_context
|
||||
|
||||
def step_type_def(self: StepTypeContext) -> StepTypeContext:
|
||||
def step_type_def(self: StepType) -> StepType:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
self.circuit.add_step_type_def()
|
||||
|
||||
def trace(
|
||||
self: CircuitContext, trace_def: Callable[[TraceContext, Any], None]
|
||||
): # TraceArgs are Any.
|
||||
self.circuit.set_trace(trace_def)
|
||||
# def trace(
|
||||
# self: Circuit, trace_def: Callable[[TraceContext, Any], None]
|
||||
# ): # TraceArgs are Any.
|
||||
# self.circuit.set_trace(trace_def)
|
||||
|
||||
def fixed_gen(
|
||||
self: CircuitContext, fixed_gen_def: Callable[[FixedGenContext], None]
|
||||
self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]
|
||||
):
|
||||
self.circuit.set_fixed_gen(fixed_gen_def)
|
||||
|
||||
def pragma_first_step(
|
||||
self: CircuitContext, step_type_context: StepTypeContext
|
||||
self: Circuit, step_type_context: StepType
|
||||
) -> None:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
self.circuit.first_step = step_type_context.step_type.id
|
||||
print(f"first step id: {step_type_context.step_type.id}")
|
||||
|
||||
def pragma_last_step(
|
||||
self: CircuitContext, step_type_context: StepTypeContext
|
||||
self: Circuit, step_type_context: StepType
|
||||
) -> None:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
self.circuit.last_step = step_type_context.step_type.id
|
||||
print(f"last step id: {step_type_context.step_type.id}")
|
||||
|
||||
def pragma_num_steps(self: CircuitContext, num_steps: int) -> None:
|
||||
def pragma_num_steps(self: Circuit, num_steps: int) -> None:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
self.circuit.num_steps = num_steps
|
||||
|
||||
def pragma_disable_q_enable(self: CircuitContext) -> None:
|
||||
def pragma_disable_q_enable(self: Circuit) -> None:
|
||||
assert (self.mode == CircuitMode.SETUP)
|
||||
self.circuit.q_enable = False
|
||||
|
||||
def add(self: Circuit, step_type: StepType, args: Any):
|
||||
assert (self.mode == CircuitMode.Trace)
|
||||
self.trace_context.add(self, step_type, args)
|
||||
|
||||
|
||||
class StepTypeMode(Enum):
|
||||
NoMode = 0
|
||||
@@ -87,37 +113,40 @@ class StepTypeMode(Enum):
|
||||
WG = 2
|
||||
|
||||
|
||||
class StepTypeContext:
|
||||
class StepType:
|
||||
|
||||
def __init__(self: StepTypeContext, circuit, step_type_name: str, ):
|
||||
self.step_type = StepType.new(step_type_name)
|
||||
def __init__(self: StepType, circuit, step_type_name: str):
|
||||
self.step_type = ASTStepType.new(step_type_name)
|
||||
self.circuit = circuit
|
||||
self.step_instance = StepInstance.new(self.step_type.id)
|
||||
self.mode = StepTypeMode.SETUP
|
||||
self.setup()
|
||||
self.mode = StepTypeMode.WG
|
||||
self.step_type.set_wg(self.wg)
|
||||
self.mode = StepTypeMode.NoMode
|
||||
|
||||
def internal(self: StepTypeContext, name: str) -> Internal:
|
||||
def internal(self: StepType, name: str) -> Internal:
|
||||
assert (self.mode == StepTypeMode.SETUP)
|
||||
|
||||
return Internal(self.step_type.add_signal(name))
|
||||
|
||||
def wg(
|
||||
self: StepTypeContext, wg_def: Callable[[TraceContext, Any], None]
|
||||
): # Args are Any.
|
||||
self.step_type.set_wg(wg_def)
|
||||
# def wg(
|
||||
# self: StepType, wg_def: Callable[[TraceContext, Any], None]
|
||||
# ): # Args are Any.
|
||||
# self.step_type.set_wg(wg_def)
|
||||
|
||||
def constr(self: StepTypeContext, constraint: ToConstraint):
|
||||
def constr(self: StepType, constraint: ToConstraint):
|
||||
assert (self.mode == StepTypeMode.SETUP)
|
||||
|
||||
constraint = to_constraint(constraint)
|
||||
StepTypeContext.enforce_constraint_typing(constraint)
|
||||
StepType.enforce_constraint_typing(constraint)
|
||||
self.step_type.add_constr(constraint.annotation, constraint.expr)
|
||||
|
||||
def transition(self: StepTypeContext, constraint: ToConstraint):
|
||||
def transition(self: StepType, constraint: ToConstraint):
|
||||
assert (self.mode == StepTypeMode.SETUP)
|
||||
|
||||
constraint = to_constraint(constraint)
|
||||
StepTypeContext.enforce_constraint_typing(constraint)
|
||||
StepType.enforce_constraint_typing(constraint)
|
||||
self.step_type.add_transition(constraint.annotation, constraint.expr)
|
||||
|
||||
def enforce_constraint_typing(constraint: Constraint):
|
||||
@@ -125,13 +154,18 @@ class StepTypeContext:
|
||||
raise ValueError(
|
||||
f"Expected AntiBooly constraint, got {constraint.typing} (constraint: {constraint.annotation})"
|
||||
)
|
||||
|
||||
def assign(self: StepType, lhs: Queriable, rhs: F):
|
||||
assert (self.mode == StepTypeMode.WG)
|
||||
|
||||
self.step_instance.assign(lhs, rhs)
|
||||
|
||||
# TODO: Implement add_lookup after lookup abstraction PR is merged.
|
||||
|
||||
|
||||
def circuit(
|
||||
name: str, circuit_context_def: Callable[[CircuitContext], None]
|
||||
) -> Circuit:
|
||||
ctx = CircuitContext()
|
||||
name: str, circuit_context_def: Callable[[Circuit], None]
|
||||
) -> AST:
|
||||
ctx = Circuit()
|
||||
circuit_context_def(ctx)
|
||||
return ctx.circuit
|
||||
|
||||
@@ -2,10 +2,10 @@ from __future__ import annotations
|
||||
from typing import Any, Tuple
|
||||
from py_ecc import bn128
|
||||
import json
|
||||
import rust_chiquito # rust bindings
|
||||
# import rust_chiquito # rust bindings
|
||||
|
||||
from dsl import CircuitContext, StepTypeContext
|
||||
from chiquito_ast import StepType, First, Last, Step
|
||||
from dsl import Circuit, StepType
|
||||
from chiquito_ast import ASTStepType, First, Last, Step
|
||||
from cb import eq
|
||||
from query import Queriable
|
||||
from wit_gen import TraceContext, StepInstance, TraceGenerator
|
||||
@@ -13,12 +13,9 @@ from wit_gen import TraceContext, StepInstance, TraceGenerator
|
||||
F = bn128.FQ
|
||||
|
||||
|
||||
class Fibonacci(CircuitContext):
|
||||
def __init__(self: Fibonacci):
|
||||
super().__init__()
|
||||
self.a: Queriable = self.forward(
|
||||
"a"
|
||||
) # `self.a` is required instead of `a`, because steps need to access `circuit.a`.
|
||||
class Fibonacci(Circuit):
|
||||
def setup(self: Fibonacci):
|
||||
self.a: Queriable = self.forward("a")
|
||||
self.b: Queriable = self.forward("b")
|
||||
|
||||
self.fibo_step = self.step_type(
|
||||
@@ -37,21 +34,32 @@ class Fibonacci(CircuitContext):
|
||||
# self.expose(self.a, Step(1))
|
||||
|
||||
def trace(self: Fibonacci):
|
||||
def trace_def(ctx: TraceContext, _: Any): # Any instead of TraceArgs
|
||||
ctx.add(self, self.fibo_step, (1, 1))
|
||||
a = 1
|
||||
b = 2
|
||||
for i in range(1, 10):
|
||||
ctx.add(self, self.fibo_step, (a, b))
|
||||
prev_a = a
|
||||
a = b
|
||||
b += prev_a
|
||||
ctx.add(self, self.fibo_last_step, (a, b))
|
||||
self.add(self.fibo_step, (1, 1))
|
||||
a = 1
|
||||
b = 2
|
||||
for i in range(1, 10):
|
||||
self.add(self.fibo_step, (a, b))
|
||||
prev_a = a
|
||||
a = b
|
||||
b += prev_a
|
||||
self.add(self.fibo_last_step, (a, b))
|
||||
|
||||
super().trace(trace_def)
|
||||
# def trace(self: Fibonacci):
|
||||
# def trace_def(ctx: TraceContext, _: Any): # Any instead of TraceArgs
|
||||
# ctx.add(self, self.fibo_step, (1, 1))
|
||||
# a = 1
|
||||
# b = 2
|
||||
# for i in range(1, 10):
|
||||
# ctx.add(self, self.fibo_step, (a, b))
|
||||
# prev_a = a
|
||||
# a = b
|
||||
# b += prev_a
|
||||
# ctx.add(self, self.fibo_last_step, (a, b))
|
||||
|
||||
# super().trace(trace_def)
|
||||
|
||||
|
||||
class FiboStep(StepTypeContext):
|
||||
class FiboStep(StepType):
|
||||
def setup(self: FiboStep):
|
||||
self.c = self.internal(
|
||||
"c"
|
||||
@@ -60,75 +68,74 @@ class FiboStep(StepTypeContext):
|
||||
self.transition(eq(self.circuit.b, self.circuit.a.next()))
|
||||
self.transition(eq(self.c, self.circuit.b.next()))
|
||||
|
||||
def wg(self: FiboStep, circuit: Fibonacci):
|
||||
# Any instead of Args
|
||||
def wg_def(ctx: StepInstance, values: Tuple[int, int]):
|
||||
a_value, b_value = values
|
||||
# print(f"fib step wg: {a_value}, {b_value}, {a_value + b_value}")
|
||||
ctx.assign(circuit.a, F(a_value))
|
||||
ctx.assign(circuit.b, F(b_value))
|
||||
ctx.assign(self.c, F(a_value + b_value))
|
||||
def wg(self: FiboStep, values: Tuple[int, int]):
|
||||
a_value, b_value = values
|
||||
self.assign(self.circuit.a, F(a_value))
|
||||
self.assign(self.circuit.b, F(b_value))
|
||||
self.assign(self.c, F(a_value + b_value))
|
||||
|
||||
super().wg(wg_def)
|
||||
# def wg(self: FiboStep, circuit: Fibonacci):
|
||||
# # Any instead of Args
|
||||
# def wg_def(ctx: StepInstance, values: Tuple[int, int]):
|
||||
# a_value, b_value = values
|
||||
# # print(f"fib step wg: {a_value}, {b_value}, {a_value + b_value}")
|
||||
# ctx.assign(circuit.a, F(a_value))
|
||||
# ctx.assign(circuit.b, F(b_value))
|
||||
# ctx.assign(self.c, F(a_value + b_value))
|
||||
|
||||
# super().wg(wg_def)
|
||||
|
||||
|
||||
class FiboLastStep(StepTypeContext):
|
||||
class FiboLastStep(StepType):
|
||||
def setup(self: FiboLastStep):
|
||||
self.c = self.internal("c")
|
||||
self.constr(eq(self.circuit.a + self.circuit.b, self.c))
|
||||
|
||||
def wg(self: FiboLastStep, circuit: Fibonacci):
|
||||
# Any instead of Args
|
||||
def wg_def(ctx: StepInstance, values: Tuple[int, int]):
|
||||
a_value, b_value = values
|
||||
print(
|
||||
f"fib last step wg: {a_value}, {b_value}, {a_value + b_value}\n")
|
||||
ctx.assign(circuit.a, F(a_value))
|
||||
ctx.assign(circuit.b, F(b_value))
|
||||
ctx.assign(self.c, F(a_value + b_value))
|
||||
def wg(self: FiboLastStep, values = Tuple[int, int]):
|
||||
a_value, b_value = values
|
||||
self.assign(self.circuit.a, F(a_value))
|
||||
self.assign(self.circuit.b, F(b_value))
|
||||
self.assign(self.c, F(a_value + b_value))
|
||||
|
||||
super().wg(wg_def)
|
||||
|
||||
|
||||
# Print Circuit
|
||||
# Print AST
|
||||
fibo = Fibonacci()
|
||||
fibo.trace()
|
||||
print("Print Circuit using custom __str__ method in python:")
|
||||
print(fibo.circuit)
|
||||
print("Print Circuit using __json__ method in python:")
|
||||
# print("Print AST using custom __str__ method in python:")
|
||||
# print(fibo.circuit)
|
||||
# print("Print AST using __json__ method in python:")
|
||||
|
||||
|
||||
class CustomEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if hasattr(obj, "__json__"):
|
||||
return obj.__json__()
|
||||
return super().default(obj)
|
||||
# class CustomEncoder(json.JSONEncoder):
|
||||
# def default(self, obj):
|
||||
# if hasattr(obj, "__json__"):
|
||||
# return obj.__json__()
|
||||
# return super().default(obj)
|
||||
|
||||
|
||||
# Print Circuit
|
||||
print("Print Circuit using custom __str__ method in python:")
|
||||
print(fibo.circuit)
|
||||
print("Print Circuit using __json__ method in python:")
|
||||
circuit_json = json.dumps(fibo.circuit, cls=CustomEncoder, indent=4)
|
||||
print(circuit_json)
|
||||
# # Print AST
|
||||
# print("Print AST using custom __str__ method in python:")
|
||||
# print(fibo.circuit)
|
||||
# print("Print AST using __json__ method in python:")
|
||||
# circuit_json = json.dumps(fibo.circuit, cls=CustomEncoder, indent=4)
|
||||
# print(circuit_json)
|
||||
|
||||
# Print TraceWitness
|
||||
trace_generator = TraceGenerator(fibo.circuit.trace)
|
||||
trace_witness = trace_generator.generate(None)
|
||||
print("Print TraceWitness using custom __str__ method in python:")
|
||||
print(trace_witness)
|
||||
print("Print TraceWitness using __json__ method in python:")
|
||||
trace_witness_json = json.dumps(trace_witness, cls=CustomEncoder, indent=4)
|
||||
print(trace_witness_json)
|
||||
# # Print TraceWitness
|
||||
# trace_generator = TraceGenerator(fibo.circuit.trace)
|
||||
# trace_witness = trace_generator.generate(None)
|
||||
# print("Print TraceWitness using custom __str__ method in python:")
|
||||
# print(trace_witness)
|
||||
# print("Print TraceWitness using __json__ method in python:")
|
||||
# trace_witness_json = json.dumps(trace_witness, cls=CustomEncoder, indent=4)
|
||||
# print(trace_witness_json)
|
||||
|
||||
# Rust bindings for Circuit
|
||||
print("Call rust bindings, parse json to Chiquito ast, and print using Debug trait:")
|
||||
rust_chiquito.convert_and_print_ast(circuit_json)
|
||||
print(
|
||||
"Call rust bindings, parse json to Chiquito TraceWitness, and print using Debug trait:"
|
||||
)
|
||||
rust_chiquito.convert_and_print_trace_witness(trace_witness_json)
|
||||
print("Parse json to Chiquito Halo2, and obtain UUID:")
|
||||
ast_uuid: int = rust_chiquito.ast_to_halo2(circuit_json)
|
||||
print("Verify ciruit with ast uuid and trace witness json:")
|
||||
rust_chiquito.verify_proof(trace_witness_json, ast_uuid)
|
||||
# # Rust bindings for AST
|
||||
# print("Call rust bindings, parse json to Chiquito ast, and print using Debug trait:")
|
||||
# rust_chiquito.convert_and_print_ast(circuit_json)
|
||||
# print(
|
||||
# "Call rust bindings, parse json to Chiquito TraceWitness, and print using Debug trait:"
|
||||
# )
|
||||
# rust_chiquito.convert_and_print_trace_witness(trace_witness_json)
|
||||
# print("Parse json to Chiquito Halo2, and obtain UUID:")
|
||||
# ast_uuid: int = rust_chiquito.ast_to_halo2(circuit_json)
|
||||
# print("Verify ciruit with ast uuid and trace witness json:")
|
||||
# rust_chiquito.verify_proof(trace_witness_json, ast_uuid)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from expr import Expr
|
||||
|
||||
# Commented out to avoid circular reference
|
||||
# from chiquito_ast import InternalSignal, ForwardSignal, SharedSignal, FixedSignal, StepType
|
||||
# from chiquito_ast import InternalSignal, ForwardSignal, SharedSignal, FixedSignal, ASTStepType
|
||||
|
||||
|
||||
######################
|
||||
@@ -15,7 +15,7 @@ from expr import Expr
|
||||
# Forward(ForwardSignal, bool),
|
||||
# Shared(SharedSignal, i32),
|
||||
# Fixed(FixedSignal, i32),
|
||||
# StepTypeNext(StepTypeHandler),
|
||||
# StepTypeNext(ASTStepTypeHandler),
|
||||
# Halo2AdviceQuery(ImportedHalo2Advice, i32),
|
||||
# Halo2FixedQuery(ImportedHalo2Fixed, i32),
|
||||
# #[allow(non_camel_case_types)]
|
||||
@@ -127,13 +127,13 @@ class Fixed(Queriable):
|
||||
|
||||
|
||||
class StepTypeNext(Queriable):
|
||||
def __init__(self: StepTypeNext, step_type: StepType):
|
||||
def __init__(self: StepTypeNext, step_type: ASTStepType):
|
||||
self.step_type = step_type
|
||||
|
||||
def uuid(self: StepType) -> int:
|
||||
def uuid(self: ASTStepType) -> int:
|
||||
return self.id
|
||||
|
||||
def __str__(self: StepType) -> str:
|
||||
def __str__(self: ASTStepType) -> str:
|
||||
return self.name
|
||||
|
||||
def __json__(self):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from chiquito_ast import (
|
||||
StepType,
|
||||
ASTStepType,
|
||||
ASTConstraint,
|
||||
TransitionConstraint,
|
||||
InternalSignal,
|
||||
@@ -17,11 +17,11 @@ from expr import Const, Sum, Mul
|
||||
# print(Forward(ForwardSignal(1, "a"), True).__json__())
|
||||
# print(Shared(SharedSignal(0, "a"), 2).__json__())
|
||||
# print(Fixed(FixedSignal("a"), 2).__json__())
|
||||
# print(StepTypeNext(StepType.new("fibo")).__json__())
|
||||
# print(StepTypeNext(ASTStepType.new("fibo")).__json__())
|
||||
# print(ASTConstraint("constraint", Sum([Const(1), Mul([Internal(InternalSignal("a")), Const(3)])])).__json__())
|
||||
# print(TransitionConstraint("trans", Sum([Const(1), Mul([Internal(InternalSignal("a")), Const(3)])])).__json__())
|
||||
print(
|
||||
StepType(
|
||||
ASTStepType(
|
||||
1,
|
||||
"fibo",
|
||||
[InternalSignal("a"), InternalSignal("b")],
|
||||
|
||||
@@ -6,7 +6,7 @@ from query import Queriable, Fixed
|
||||
from util import F
|
||||
|
||||
# Commented out to avoid circular reference
|
||||
# from dsl import CircuitContext, StepTypeContext
|
||||
# from dsl import Circuit, StepType
|
||||
|
||||
###########
|
||||
# wit_gen #
|
||||
@@ -90,8 +90,8 @@ class TraceContext:
|
||||
witness: TraceWitness = field(default_factory=TraceWitness)
|
||||
|
||||
def add(
|
||||
self: TraceContext, circuit: CircuitContext, step: StepTypeContext, args: Any
|
||||
): # Use StepTypeContext instead of StepTypeWGHandler, because StepTypeContext contains step type id and `wg` method that returns witness generation function.
|
||||
self: TraceContext, circuit: Circuit, step: StepType, args: Any
|
||||
): # Use StepType instead of StepTypeWGHandler, because StepType contains step type id and `wg` method that returns witness generation function.
|
||||
witness = StepInstance.new(step.step_type.id)
|
||||
step.wg(circuit)
|
||||
if step.step_type.wg is None:
|
||||
|
||||
Reference in New Issue
Block a user