trying to fix circular reference: collapsed files to only ast and dsl

This commit is contained in:
Steve Wang
2023-07-07 13:43:24 +08:00
parent a40a410cd5
commit 963465f370
8 changed files with 629 additions and 630 deletions

View File

@@ -1,201 +0,0 @@
from __future__ import annotations
from typing import Callable, List, Dict, Optional
from util import uuid
from expr import Expr
from dsl import StepTypeHandler
from dataclasses import dataclass
# pub struct Circuit<F, TraceArgs> {
# pub step_types: HashMap<u32, Rc<StepType<F>>>,
# pub forward_signals: Vec<ForwardSignal>,
# pub shared_signals: Vec<SharedSignal>,
# pub fixed_signals: Vec<FixedSignal>,
# pub halo2_advice: Vec<ImportedHalo2Advice>,
# pub halo2_fixed: Vec<ImportedHalo2Fixed>,
# pub exposed: Vec<ForwardSignal>,
# pub annotations: HashMap<u32, String>,
# pub trace: Option<Rc<Trace<F, TraceArgs>>>,
# pub fixed_gen: Option<Rc<FixedGen<F>>>,
# pub first_step: Option<StepTypeUUID>,
# pub last_step: Option<StepTypeUUID>,
# pub num_steps: usize,
# }
@dataclass
class Circuit:
step_types: Dict[int, StepType]
forward_signals: List[ForwardSignal]
shared_signals: List[SharedSignal]
fixed_signals: List[FixedSignal]
# halo2_advice: List[ImportedHalo2Advice]
# halo2_fixed: List[ImportedHalo2Fixed]
exposed: List[ForwardSignal]
annotations: Dict[int, str]
trace: Optional[callable]
fixed_gen: Optional[callable]
first_step: Optional[int]
last_step: Optional[int]
num_steps: int
# TODO: __str__ (low priority as it's only used for debugging)
def add_forward(self: Circuit, name: str, phase: int) -> ForwardSignal:
signal = ForwardSignal(phase, name)
self.forward_signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_shared(self: Circuit, name: str, phase: int) -> SharedSignal:
signal = SharedSignal(phase, name)
self.shared_signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_fixed(self: Circuit, name: str) -> FixedSignal:
signal = FixedSignal(name)
self.fixed_signals.append(signal)
self.annotations[signal.id] = name
return signal
def expose(self: Circuit, forward_signal: ForwardSignal):
self.exposed.append(forward_signal)
def add_step_type(self: Circuit, handler: StepTypeHandler, name: str):
self.annotations[handler.id] = name
def add_step_type_def(self, step: StepType) -> int:
self.step_types[step.id] = step
return step.id
def set_trace(self, trace_def: Callable[[TraceContext, TraceArgs], None]):
if self.trace is not None:
raise Exception("Circuit cannot have more than one trace generator.")
else:
self.trace = trace_def
def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]):
if self.fixed_gen is not None:
raise Exception("Circuit cannot have more than one fixed generator.")
else:
self.fixed_gen = fixed_gen_def
def get_step_type(self, uuid: int) -> StepType:
if uuid in self.step_types.keys():
return self.step_types[uuid]
else:
raise ValueError("StepType not found.")
# pub struct StepType<F> {
# id: StepTypeUUID,
# pub name: String,
# pub signals: Vec<InternalSignal>,
# pub constraints: Vec<Constraint<F>>,
# pub transition_constraints: Vec<TransitionConstraint<F>>,
# pub lookups: Vec<Lookup<F>>,
# pub annotations: HashMap<u32, String>,
# }
@dataclass
class StepType:
id: int
name: str
signals: List[InternalSignal]
constraints: List[Constraint]
transition_constraints: List[TransitionConstraint]
# lookups: List[Lookup]
annotations: Dict[int, str]
# TODO: __str__ (low priority as it's only used for debugging)
def __init__(self: StepType, id: int, name: str):
self.id = id
self.name = name
self.signals: List[InternalSignal] = []
self.constraints: List[Constraint] = []
self.transition_constraints: List[TransitionConstraint] = []
self.annotations: Dict[int, str] = {}
def add_signal(self: StepType, name: str) -> InternalSignal:
signal = InternalSignal(name)
self.signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_constr(self: StepType, annotation: str, expr: Expr):
condition = Constraint(annotation, expr)
self.constraints.append(condition)
def add_transition(self, annotation: str, expr: Expr):
condition = TransitionConstraint(annotation, expr)
self.transition_constraints.append(condition)
def __eq__(self: StepType, other: StepType) -> bool:
if isinstance(self, StepType) and isinstance(other, StepType):
return self.id == other.id
return False
def __req__(other: StepType, self: StepType) -> bool:
return StepType.__eq__(self, other)
def __hash__(self: StepType):
return hash(self.id)
@dataclass
class Constraint:
annotation: str
expr: Expr
@dataclass
class TransitionConstraint:
annotation: str
expr: Expr
@dataclass
class ForwardSignal:
id: int
phase: int
annotation: str
def __init__(self, phase: int, annotation: str):
self.id: int = uuid()
self.phase = phase
self.annotation = annotation
@dataclass
class SharedSignal:
id: int
phase: int
anotation: str
def __init__(self: SharedSignal, phase: int, annotation: str):
self.id: int = uuid()
self.phase = phase
self.annotation = annotation
@dataclass
class FixedSignal:
id: int
annotation: str
def __init__(self: FixedSignal, annotation: str):
self.id: int = uuid()
self.annotation = annotation
@dataclass
class InternalSignal:
id: int
annotation: str
def __init__(self, annotation: str):
self.id = uuid()
self.annotation = annotation

155
dsl.py
View File

@@ -1,155 +0,0 @@
from __future__ import annotations
from chiquito_ast import Circuit, StepType
from typing import Callable
from query import Queriable, Forward, Internal, Shared, Fixed, StepTypeNext
from util import uuid
from cb import Constraint, ToConstraint, to_constraint, Typing
from dataclasses import dataclass
class CircuitContext:
def __init__(self):
self.circuit = Circuit()
def forward(self: CircuitContext, name: str) -> Queriable:
return Queriable(Forward(self.circuit.add_forward(name, 0), False))
def forward_with_phase(self: CircuitContext, name: str, phase: int) -> Queriable:
return Queriable(Forward(self.circuit.add_forward(name, phase), False))
def shared(self: CircuitContext, name: str) -> Queriable:
return Queriable(Shared(self.circuit.add_shared(name, 0), 0))
def shared_with_phase(self: CircuitContext, name: str, phase: int) -> Queriable:
return Queriable(Shared(self.circuit.add_shared(name, phase), 0))
def fixed(self: CircuitContext, name: str) -> Queriable:
return Queriable(Fixed(self.circuit.add_fixed(name), 0))
def expose(self: CircuitContext, forward: Queriable):
match forward:
case Queriable(Forward(forward_signal, _)):
self.circuit.expose(forward_signal)
case _:
raise ValueError("Can only expose a forward signal.")
# TODO: Implement import_halo2_advice and import_halo2_fixed. Currently we ignore imported query types.
def step_type(self: CircuitContext, name: str) -> StepTypeHandler:
handler = StepTypeHandler.new(name)
self.circuit.add_step_type(handler, name)
return handler
def step_type_def(self: CircuitContext, step: StepTypeDefInput, step_type_func: Callable[[StepTypeContext], StepTypeWGHandler]) -> StepTypeWGHandler:
match step:
case StepTypeDefInput(Handler(h)):
handler = h
case StepTypeDefInput(String(name)):
handler = StepTypeHandler.new(name)
self.circuit.add_step_type(handler, name)
case _:
raise ValueError("Invalid StepTypeDefInput enum type.")
context = StepTypeContext.new(handler.id, handler.annotation)
result = step_type_func(context)
self.circuit.add_step_type_def(context.step_type)
return result
def trace(self: CircuitContext, trace_def: Callable[[TraceContext, TraceArgs], None]):
self.circuit.set_trace(trace_def)
def fixed_gen(self: CircuitContext, fixed_gen_def: Callable[[FixedGenContext], None]):
self.circuit.set_fixed_gen(fixed_gen_def)
def pragma_first_step(self: CircuitContext, step_type: StepTypeHandler) -> None: # TODO: define and use ToStepTypeHandler instead
self.circuit.first_step = step_type.id
def pragma_last_step(self, step_type: StepTypeHandler) -> None:
self.circuit.last_step = step_type.id
def pragma_num_steps(self, num_steps: int) -> None:
self.circuit.num_steps = num_steps
ToStepTypeDefInput = StepTypeHandler | str
@dataclass
class Handler:
handler: StepTypeHandler
@dataclass
class String:
string: str
@dataclass
class StepTypeDefInput:
enum: Handler | String
def from_handler(h: StepTypeHandler) -> StepTypeDefInput:
return StepTypeDefInput(Handler(h))
def from_string(s: str) -> StepTypeDefInput:
return StepTypeDefInput(String(s))
@dataclass
class StepTypeContext:
step_type: StepType
def new(id: int, name: str) -> StepTypeContext:
return StepTypeContext(StepType(id, name))
def internal(self: StepTypeContext, name: str) -> Queriable:
return Queriable(Internal(self.step_type.add_signal(name)))
def setup(self: StepTypeContext, setup_def: Callable[[StepTypeSetupContext], None]) -> None: # def is a keyword in python
ctx = StepTypeSetupContext(self.step_type)
setup_def(ctx)
def wg(self, wg_def: Callable[[StepInstance, Args], None]) -> StepTypeWGHandler:
return StepTypeWGHandler(
self.step_type.id,
self.step_type.name,
wg_def
)
@dataclass
class StepTypeSetupContext:
step_type: StepType
def constr(self: StepTypeSetupContext, constraint: ToConstraint):
constraint = to_constraint(constraint)
StepTypeSetupContext.enforce_constraint_typing(constraint)
self.step_type.add_constr(constraint.annotation, constraint.expr)
def transition(self: StepTypeSetupContext, constraint: ToConstraint):
constraint = to_constraint(constraint)
StepTypeSetupContext.enforce_constraint_typing(constraint)
self.step_type.add_transition(constraint.annotation, constraint.expr)
def enforce_constraint_typing(constraint: Constraint):
if constraint.typing != Typing.AntiBooly:
raise ValueError(f"Expected AntiBooly constraint, got {constraint.typing} (constraint: {constraint.annotation})")
# TODO: Implement add_lookup after lookup abstraction PR is merged.
@dataclass
class StepTypeHandler:
id: int
annotation: str
def new(annotation: str) -> StepTypeHandler:
return StepTypeHandler(uuid(), annotation)
def next(self: StepTypeHandler) -> Queriable:
return Queriable(StepTypeNext(self))
def from_wg_handler(wg_handler: StepTypeWGHandler) -> StepTypeHandler:
return StepTypeHandler(wg_handler.id, wg_handler.annotation)
@dataclass
class StepTypeWGHandler:
id: int
annotation: str
wg: Callable[[StepInstance, Args], None]
def circuit(name: str, circuit_context_def: Callable[[CircuitContext], None]) -> Circuit:
ctx = CircuitContext()
circuit_context_def(ctx)
return ctx.circuit

145
expr.py
View File

@@ -1,145 +0,0 @@
from __future__ import annotations
from typing import List
from cb import Constraint
from py_ecc import bn128
from dataclasses import dataclass
from query import Queriable
F = bn128.FQ
# pub enum Expr<F> {
# Const(F),
# Sum(Vec<Expr<F>>),
# Mul(Vec<Expr<F>>),
# Neg(Box<Expr<F>>),
# Pow(Box<Expr<F>>, u32),
# Query(Queriable<F>),
# Halo2Expr(Expression<F>),
# }
@dataclass
class Const:
value: F
@dataclass
class Sum:
exprs: List[Expr]
@dataclass
class Mul:
exprs: List[Expr]
@dataclass
class Neg:
expr: Expr
@dataclass
class Pow:
expr: Expr
pow: int
@dataclass
class Query:
queriable: Queriable
# Ignored Expr::Halo2Expr.
@dataclass
class Expr:
enum: Const | Sum | Mul | Neg | Pow | Query
def __str__(self: Expr) -> str:
match expr.enum:
case Const(value):
return str(value)
case Sum(exprs):
result = "("
for i, expr in enumerate(exprs):
match expr.enum:
case Neg(expr):
if i == 0:
result += "-"
else:
result += " - "
case _:
result += " + "
result += str(expr)
result += ")"
return result
case Mul(exprs):
return "*".join([str(expr) for expr in exprs])
case Neg(expr):
return "(-" + str(expr) + ")"
case Pow(expr, pow):
return str(expr) + "^" + str(pow)
case Query(queriable):
return str(queriable)
case _:
raise ValueError("Invalid Expr enum type.")
def __neg__(self: Expr) -> Expr:
match self.enum:
case Neg(e):
return e
case _:
return Expr(Neg(self))
def __add__(lhs: ToExpr, rhs: ToExpr) -> Expr:
(lhs, rhs) = (to_expr(lhs), to_expr(rhs))
match (lhs.enum, rhs.enum):
case (Sum(exprs), _):
return Expr(Sum(exprs + [rhs]))
case (_, Sum(exprs)):
return Expr(Sum([lhs] + exprs))
case _:
return Expr(Sum([lhs, rhs]))
def __radd__(rhs: ToExpr, lhs: ToExpr) -> Expr:
return Expr.__add__(lhs, rhs)
def __sub__(lhs: ToExpr, rhs: ToExpr) -> Expr:
(lhs, rhs) = (to_expr(lhs), to_expr(rhs))
match (lhs.enum, rhs.enum):
case (Sum(exprs), _):
return Expr(Sum(exprs + [Expr(Neg(rhs))]))
case _:
return Expr(Sum([lhs, Expr(Neg(rhs))]))
def __rsub__(rhs: ToExpr, lhs: ToExpr) -> Expr:
return Expr.__sub__(lhs, rhs)
def __mul__(lhs: ToExpr, rhs: ToExpr) -> Expr:
(lhs, rhs) = (to_expr(lhs), to_expr(rhs))
match (lhs.enum, rhs.enum):
case (Mul(exprs), _):
return Expr(Mul(exprs + [rhs]))
case (_, Mul(exprs)):
return Expr(Mul([lhs] + exprs))
case _:
return Expr(Mul([lhs, rhs]))
def __rmul__(rhs: ToExpr, lhs: ToExpr) -> Expr:
return Expr.__mul__(lhs, rhs)
def __pow__(lhs: ToExpr, rhs: int) -> Expr:
lhs = to_expr(lhs)
return Expr(Pow(lhs, rhs))
ToExpr = Expr | str | int | F | Queriable | Constraint
def to_expr(v: ToExpr) -> Expr:
if isinstance(v, Expr):
return v
elif isinstance(v, F):
return Expr(Const(v))
elif isinstance(v, int):
if v >= 0:
return Expr(Const(F(v)))
else:
return Expr(Neg(Expr(Const(F(-v)))))
elif isinstance(v, Queriable):
return Expr(Query(v))
elif isinstance(v, Constraint):
return v.expr
else:
raise TypeError(f"Type `{type(v)}` is not ToExpr (one of Expr, str, int, F, Queriable, or Constraint).")

119
query.py
View File

@@ -1,119 +0,0 @@
from __future__ import annotations
from chiquito_ast import InternalSignal, ForwardSignal, SharedSignal, FixedSignal
from dataclasses import dataclass
from dsl import StepTypeHandler
from py_ecc import bn128
F = bn128.FQ
# pub enum Queriable<F> {
# Internal(InternalSignal),
# Forward(ForwardSignal, bool),
# Shared(SharedSignal, i32),
# Fixed(FixedSignal, i32),
# StepTypeNext(StepTypeHandler),
# Halo2AdviceQuery(ImportedHalo2Advice, i32),
# Halo2FixedQuery(ImportedHalo2Fixed, i32),
# #[allow(non_camel_case_types)]
# _unaccessible(PhantomData<F>),
# }
@dataclass
class Internal:
signal: InternalSignal
@dataclass
class Forward:
signal: ForwardSignal
rotation: bool
@dataclass
class Shared:
signal: SharedSignal
rotation: int
@dataclass
class Fixed:
signal: FixedSignal
rotation: int
@dataclass
class StepTypeNext:
handler: StepTypeHandler
# Ignored Queriable::Halo2AdviceQuery and Queriable::Halo2FixedQuery
@dataclass
class Queriable:
enum: Internal | Forward | Shared | Fixed | StepTypeNext
def next(self: Queriable) -> Queriable:
match self.enum:
case Forward(signal, rotation):
if rotation:
raise ValueError("Cannot rotate Forward twice.")
else:
return Queriable(Forward(signal, True))
case Shared(signal, rotation):
return Queriable(Shared(signal, rotation + 1))
case Fixed(signal, rotation):
return Queriable(Fixed(signal, rotation + 1))
case _:
raise ValueError("Can only call `next` on Forward, Shared, or Fixed.")
def prev(self: Queriable) -> Queriable:
match self.enum:
case Shared(signal, rotation):
return Queriable(Shared(signal, rotation - 1))
case Fixed(signal, rotation):
return Queriable(Fixed(signal, rotation - 1))
case _:
raise ValueError("Can only call `prev` on Shared or Fixed.")
def rot(self: Queriable, rotation: int) -> Queriable:
match self.enum:
case Shared(signal, rot):
return Queriable(Shared(signal, rot + rotation))
case Fixed(signal, rot):
return Queriable(Fixed(signal, rot + rotation))
case _:
raise ValueError("Can only call `rot` on Shared or Fixed.")
def uuid(self: Queriable) -> int:
match self.enum:
case Internal(signal):
return signal.id
case Forward(signal, _):
return signal.id
case Shared(signal, _):
return signal.id
case Fixed(signal, _):
return signal.id
case StepTypeNext(handler):
return handler.id
case _:
raise ValueError("Invalid Queriable enum type.")
def annotation(self: Queriable) -> str:
match self.enum:
case Internal(signal):
return signal.annotation
case Forward(signal, rotation):
if not rotation:
return signal.annotation
else:
return f"next({signal.annotation})"
case Shared(signal, rotation):
if rotation == 0:
return signal.annotation
else:
return f"{signal.annotation}(rot {rotation})"
case Fixed(signal, rotation):
if rotation == 0:
return signal.annotation
else:
return f"{signal.annotation}(rot {rotation})"
case StepTypeNext(handler):
return handler.annotation
case _:
raise ValueError("Invalid Queriable enum type.")

465
src/chiquito_ast.py Normal file
View File

@@ -0,0 +1,465 @@
from __future__ import annotations
from typing import Callable, List, Dict, Optional
from dataclasses import dataclass
from py_ecc import bn128
from util import uuid
from dsl import StepTypeHandler, Constraint
F = bn128.FQ
#######
# ast #
#######
# pub struct Circuit<F, TraceArgs> {
# pub step_types: HashMap<u32, Rc<StepType<F>>>,
# pub forward_signals: Vec<ForwardSignal>,
# pub shared_signals: Vec<SharedSignal>,
# pub fixed_signals: Vec<FixedSignal>,
# pub halo2_advice: Vec<ImportedHalo2Advice>,
# pub halo2_fixed: Vec<ImportedHalo2Fixed>,
# pub exposed: Vec<ForwardSignal>,
# pub annotations: HashMap<u32, String>,
# pub trace: Option<Rc<Trace<F, TraceArgs>>>,
# pub fixed_gen: Option<Rc<FixedGen<F>>>,
# pub first_step: Option<StepTypeUUID>,
# pub last_step: Option<StepTypeUUID>,
# pub num_steps: usize,
# }
@dataclass
class Circuit:
step_types: Dict[int, StepType]
forward_signals: List[ForwardSignal]
shared_signals: List[SharedSignal]
fixed_signals: List[FixedSignal]
# halo2_advice: List[ImportedHalo2Advice]
# halo2_fixed: List[ImportedHalo2Fixed]
exposed: List[ForwardSignal]
annotations: Dict[int, str]
trace: Optional[callable]
fixed_gen: Optional[callable]
first_step: Optional[int]
last_step: Optional[int]
num_steps: int
# TODO: __str__ (low priority as it's only used for debugging)
def add_forward(self: Circuit, name: str, phase: int) -> ForwardSignal:
signal = ForwardSignal(phase, name)
self.forward_signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_shared(self: Circuit, name: str, phase: int) -> SharedSignal:
signal = SharedSignal(phase, name)
self.shared_signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_fixed(self: Circuit, name: str) -> FixedSignal:
signal = FixedSignal(name)
self.fixed_signals.append(signal)
self.annotations[signal.id] = name
return signal
def expose(self: Circuit, forward_signal: ForwardSignal):
self.exposed.append(forward_signal)
def add_step_type(self: Circuit, handler: StepTypeHandler, name: str):
self.annotations[handler.id] = name
def add_step_type_def(self, step: StepType) -> int:
self.step_types[step.id] = step
return step.id
def set_trace(self, trace_def: Callable[[TraceContext, TraceArgs], None]):
if self.trace is not None:
raise Exception("Circuit cannot have more than one trace generator.")
else:
self.trace = trace_def
def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]):
if self.fixed_gen is not None:
raise Exception("Circuit cannot have more than one fixed generator.")
else:
self.fixed_gen = fixed_gen_def
def get_step_type(self, uuid: int) -> StepType:
if uuid in self.step_types.keys():
return self.step_types[uuid]
else:
raise ValueError("StepType not found.")
# pub struct StepType<F> {
# id: StepTypeUUID,
# pub name: String,
# pub signals: Vec<InternalSignal>,
# pub constraints: Vec<Constraint<F>>,
# pub transition_constraints: Vec<TransitionConstraint<F>>,
# pub lookups: Vec<Lookup<F>>,
# pub annotations: HashMap<u32, String>,
# }
@dataclass
class StepType:
id: int
name: str
signals: List[InternalSignal]
constraints: List[Constraint]
transition_constraints: List[TransitionConstraint]
# lookups: List[Lookup]
annotations: Dict[int, str]
# TODO: __str__ (low priority as it's only used for debugging)
def __init__(self: StepType, id: int, name: str):
self.id = id
self.name = name
self.signals: List[InternalSignal] = []
self.constraints: List[Constraint] = []
self.transition_constraints: List[TransitionConstraint] = []
self.annotations: Dict[int, str] = {}
def add_signal(self: StepType, name: str) -> InternalSignal:
signal = InternalSignal(name)
self.signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_constr(self: StepType, annotation: str, expr: Expr):
condition = Constraint(annotation, expr)
self.constraints.append(condition)
def add_transition(self, annotation: str, expr: Expr):
condition = TransitionConstraint(annotation, expr)
self.transition_constraints.append(condition)
def __eq__(self: StepType, other: StepType) -> bool:
if isinstance(self, StepType) and isinstance(other, StepType):
return self.id == other.id
return False
def __req__(other: StepType, self: StepType) -> bool:
return StepType.__eq__(self, other)
def __hash__(self: StepType):
return hash(self.id)
@dataclass
class Constraint:
annotation: str
expr: Expr
@dataclass
class TransitionConstraint:
annotation: str
expr: Expr
@dataclass
class ForwardSignal:
id: int
phase: int
annotation: str
def __init__(self, phase: int, annotation: str):
self.id: int = uuid()
self.phase = phase
self.annotation = annotation
@dataclass
class SharedSignal:
id: int
phase: int
anotation: str
def __init__(self: SharedSignal, phase: int, annotation: str):
self.id: int = uuid()
self.phase = phase
self.annotation = annotation
@dataclass
class FixedSignal:
id: int
annotation: str
def __init__(self: FixedSignal, annotation: str):
self.id: int = uuid()
self.annotation = annotation
@dataclass
class InternalSignal:
id: int
annotation: str
def __init__(self, annotation: str):
self.id = uuid()
self.annotation = annotation
############
# ast/expr #
############
# pub enum Expr<F> {
# Const(F),
# Sum(Vec<Expr<F>>),
# Mul(Vec<Expr<F>>),
# Neg(Box<Expr<F>>),
# Pow(Box<Expr<F>>, u32),
# Query(Queriable<F>),
# Halo2Expr(Expression<F>),
# }
@dataclass
class Const:
value: F
@dataclass
class Sum:
exprs: List[Expr]
@dataclass
class Mul:
exprs: List[Expr]
@dataclass
class Neg:
expr: Expr
@dataclass
class Pow:
expr: Expr
pow: int
@dataclass
class Query:
queriable: Queriable
# Ignored Expr::Halo2Expr.
@dataclass
class Expr:
enum: Const | Sum | Mul | Neg | Pow | Query
def __str__(self: Expr) -> str:
match expr.enum:
case Const(value):
return str(value)
case Sum(exprs):
result = "("
for i, expr in enumerate(exprs):
match expr.enum:
case Neg(expr):
if i == 0:
result += "-"
else:
result += " - "
case _:
result += " + "
result += str(expr)
result += ")"
return result
case Mul(exprs):
return "*".join([str(expr) for expr in exprs])
case Neg(expr):
return "(-" + str(expr) + ")"
case Pow(expr, pow):
return str(expr) + "^" + str(pow)
case Query(queriable):
return str(queriable)
case _:
raise ValueError("Invalid Expr enum type.")
def __neg__(self: Expr) -> Expr:
match self.enum:
case Neg(e):
return e
case _:
return Expr(Neg(self))
def __add__(lhs: ToExpr, rhs: ToExpr) -> Expr:
(lhs, rhs) = (to_expr(lhs), to_expr(rhs))
match (lhs.enum, rhs.enum):
case (Sum(exprs), _):
return Expr(Sum(exprs + [rhs]))
case (_, Sum(exprs)):
return Expr(Sum([lhs] + exprs))
case _:
return Expr(Sum([lhs, rhs]))
def __radd__(rhs: ToExpr, lhs: ToExpr) -> Expr:
return Expr.__add__(lhs, rhs)
def __sub__(lhs: ToExpr, rhs: ToExpr) -> Expr:
(lhs, rhs) = (to_expr(lhs), to_expr(rhs))
match (lhs.enum, rhs.enum):
case (Sum(exprs), _):
return Expr(Sum(exprs + [Expr(Neg(rhs))]))
case _:
return Expr(Sum([lhs, Expr(Neg(rhs))]))
def __rsub__(rhs: ToExpr, lhs: ToExpr) -> Expr:
return Expr.__sub__(lhs, rhs)
def __mul__(lhs: ToExpr, rhs: ToExpr) -> Expr:
(lhs, rhs) = (to_expr(lhs), to_expr(rhs))
match (lhs.enum, rhs.enum):
case (Mul(exprs), _):
return Expr(Mul(exprs + [rhs]))
case (_, Mul(exprs)):
return Expr(Mul([lhs] + exprs))
case _:
return Expr(Mul([lhs, rhs]))
def __rmul__(rhs: ToExpr, lhs: ToExpr) -> Expr:
return Expr.__mul__(lhs, rhs)
def __pow__(lhs: ToExpr, rhs: int) -> Expr:
lhs = to_expr(lhs)
return Expr(Pow(lhs, rhs))
######################
# ast/expr/queriable #
######################
# pub enum Queriable<F> {
# Internal(InternalSignal),
# Forward(ForwardSignal, bool),
# Shared(SharedSignal, i32),
# Fixed(FixedSignal, i32),
# StepTypeNext(StepTypeHandler),
# Halo2AdviceQuery(ImportedHalo2Advice, i32),
# Halo2FixedQuery(ImportedHalo2Fixed, i32),
# #[allow(non_camel_case_types)]
# _unaccessible(PhantomData<F>),
# }
@dataclass
class Internal:
signal: InternalSignal
@dataclass
class Forward:
signal: ForwardSignal
rotation: bool
@dataclass
class Shared:
signal: SharedSignal
rotation: int
@dataclass
class Fixed:
signal: FixedSignal
rotation: int
@dataclass
class StepTypeNext:
handler: StepTypeHandler
# Ignored Queriable::Halo2AdviceQuery and Queriable::Halo2FixedQuery
@dataclass
class Queriable:
enum: Internal | Forward | Shared | Fixed | StepTypeNext
def next(self: Queriable) -> Queriable:
match self.enum:
case Forward(signal, rotation):
if rotation:
raise ValueError("Cannot rotate Forward twice.")
else:
return Queriable(Forward(signal, True))
case Shared(signal, rotation):
return Queriable(Shared(signal, rotation + 1))
case Fixed(signal, rotation):
return Queriable(Fixed(signal, rotation + 1))
case _:
raise ValueError("Can only call `next` on Forward, Shared, or Fixed.")
def prev(self: Queriable) -> Queriable:
match self.enum:
case Shared(signal, rotation):
return Queriable(Shared(signal, rotation - 1))
case Fixed(signal, rotation):
return Queriable(Fixed(signal, rotation - 1))
case _:
raise ValueError("Can only call `prev` on Shared or Fixed.")
def rot(self: Queriable, rotation: int) -> Queriable:
match self.enum:
case Shared(signal, rot):
return Queriable(Shared(signal, rot + rotation))
case Fixed(signal, rot):
return Queriable(Fixed(signal, rot + rotation))
case _:
raise ValueError("Can only call `rot` on Shared or Fixed.")
def uuid(self: Queriable) -> int:
match self.enum:
case Internal(signal):
return signal.id
case Forward(signal, _):
return signal.id
case Shared(signal, _):
return signal.id
case Fixed(signal, _):
return signal.id
case StepTypeNext(handler):
return handler.id
case _:
raise ValueError("Invalid Queriable enum type.")
def annotation(self: Queriable) -> str:
match self.enum:
case Internal(signal):
return signal.annotation
case Forward(signal, rotation):
if not rotation:
return signal.annotation
else:
return f"next({signal.annotation})"
case Shared(signal, rotation):
if rotation == 0:
return signal.annotation
else:
return f"{signal.annotation}(rot {rotation})"
case Fixed(signal, rotation):
if rotation == 0:
return signal.annotation
else:
return f"{signal.annotation}(rot {rotation})"
case StepTypeNext(handler):
return handler.annotation
case _:
raise ValueError("Invalid Queriable enum type.")
ToExpr = Expr | str | int | F | Queriable | Constraint
def to_expr(v: ToExpr) -> Expr:
if isinstance(v, Expr):
return v
elif isinstance(v, F):
return Expr(Const(v))
elif isinstance(v, int):
if v >= 0:
return Expr(Const(F(v)))
else:
return Expr(Neg(Expr(Const(F(-v)))))
elif isinstance(v, Queriable):
return Expr(Query(v))
elif isinstance(v, Constraint):
return v.expr
else:
raise TypeError(f"Type `{type(v)}` is not ToExpr (one of Expr, str, int, F, Queriable, or Constraint).")

View File

@@ -1,16 +1,170 @@
from __future__ import annotations
from typing import Any, TypeVar, Generic, List, Dict, List
from chiquito_ast import InternalSignal, ForwardSignal, SharedSignal, FixedSignal
from expr import Expr, Query, StepTypeNext, Const, Neg
from query import Queriable
from enum import Enum, auto
from dsl import StepTypeHandler
from typing import Callable, List
from dataclasses import dataclass
from enum import Enum, auto
from py_ecc import bn128
from expr import ToExpr, to_expr
from chiquito_ast import Circuit, StepType, Queriable, Forward, Internal, Shared, Fixed, StepTypeNext, Const, Neg, ToExpr, to_expr, Expr, Query
from util import uuid
F = bn128.FQ
#######
# dsl #
#######
class CircuitContext:
def __init__(self):
self.circuit = Circuit()
def forward(self: CircuitContext, name: str) -> Queriable:
return Queriable(Forward(self.circuit.add_forward(name, 0), False))
def forward_with_phase(self: CircuitContext, name: str, phase: int) -> Queriable:
return Queriable(Forward(self.circuit.add_forward(name, phase), False))
def shared(self: CircuitContext, name: str) -> Queriable:
return Queriable(Shared(self.circuit.add_shared(name, 0), 0))
def shared_with_phase(self: CircuitContext, name: str, phase: int) -> Queriable:
return Queriable(Shared(self.circuit.add_shared(name, phase), 0))
def fixed(self: CircuitContext, name: str) -> Queriable:
return Queriable(Fixed(self.circuit.add_fixed(name), 0))
def expose(self: CircuitContext, forward: Queriable):
match forward:
case Queriable(Forward(forward_signal, _)):
self.circuit.expose(forward_signal)
case _:
raise ValueError("Can only expose a forward signal.")
# TODO: Implement import_halo2_advice and import_halo2_fixed. Currently we ignore imported query types.
def step_type(self: CircuitContext, name: str) -> StepTypeHandler:
handler = StepTypeHandler.new(name)
self.circuit.add_step_type(handler, name)
return handler
def step_type_def(self: CircuitContext, step: StepTypeDefInput, step_type_func: Callable[[StepTypeContext], StepTypeWGHandler]) -> StepTypeWGHandler:
match step:
case StepTypeDefInput(Handler(h)):
handler = h
case StepTypeDefInput(String(name)):
handler = StepTypeHandler.new(name)
self.circuit.add_step_type(handler, name)
case _:
raise ValueError("Invalid StepTypeDefInput enum type.")
context = StepTypeContext.new(handler.id, handler.annotation)
result = step_type_func(context)
self.circuit.add_step_type_def(context.step_type)
return result
def trace(self: CircuitContext, trace_def: Callable[[TraceContext, TraceArgs], None]):
self.circuit.set_trace(trace_def)
def fixed_gen(self: CircuitContext, fixed_gen_def: Callable[[FixedGenContext], None]):
self.circuit.set_fixed_gen(fixed_gen_def)
def pragma_first_step(self: CircuitContext, step_type: StepTypeHandler) -> None: # TODO: define and use ToStepTypeHandler instead
self.circuit.first_step = step_type.id
def pragma_last_step(self, step_type: StepTypeHandler) -> None:
self.circuit.last_step = step_type.id
def pragma_num_steps(self, num_steps: int) -> None:
self.circuit.num_steps = num_steps
@dataclass
class StepTypeDefInput:
enum: Handler | String
def from_handler(h: StepTypeHandler) -> StepTypeDefInput:
return StepTypeDefInput(Handler(h))
def from_string(s: str) -> StepTypeDefInput:
return StepTypeDefInput(String(s))
@dataclass
class StepTypeContext:
step_type: StepType
def new(id: int, name: str) -> StepTypeContext:
return StepTypeContext(StepType(id, name))
def internal(self: StepTypeContext, name: str) -> Queriable:
return Queriable(Internal(self.step_type.add_signal(name)))
def setup(self: StepTypeContext, setup_def: Callable[[StepTypeSetupContext], None]) -> None: # def is a keyword in python
ctx = StepTypeSetupContext(self.step_type)
setup_def(ctx)
def wg(self, wg_def: Callable[[StepInstance, Args], None]) -> StepTypeWGHandler:
return StepTypeWGHandler(
self.step_type.id,
self.step_type.name,
wg_def
)
@dataclass
class StepTypeSetupContext:
step_type: StepType
def constr(self: StepTypeSetupContext, constraint: ToConstraint):
constraint = to_constraint(constraint)
StepTypeSetupContext.enforce_constraint_typing(constraint)
self.step_type.add_constr(constraint.annotation, constraint.expr)
def transition(self: StepTypeSetupContext, constraint: ToConstraint):
constraint = to_constraint(constraint)
StepTypeSetupContext.enforce_constraint_typing(constraint)
self.step_type.add_transition(constraint.annotation, constraint.expr)
def enforce_constraint_typing(constraint: Constraint):
if constraint.typing != Typing.AntiBooly:
raise ValueError(f"Expected AntiBooly constraint, got {constraint.typing} (constraint: {constraint.annotation})")
# TODO: Implement add_lookup after lookup abstraction PR is merged.
@dataclass
class StepTypeHandler:
id: int
annotation: str
def new(annotation: str) -> StepTypeHandler:
return StepTypeHandler(uuid(), annotation)
def next(self: StepTypeHandler) -> Queriable:
return Queriable(StepTypeNext(self))
def from_wg_handler(wg_handler: StepTypeWGHandler) -> StepTypeHandler:
return StepTypeHandler(wg_handler.id, wg_handler.annotation)
@dataclass
class StepTypeWGHandler:
id: int
annotation: str
wg: Callable[[StepInstance, Args], None]
def circuit(name: str, circuit_context_def: Callable[[CircuitContext], None]) -> Circuit:
ctx = CircuitContext()
circuit_context_def(ctx)
return ctx.circuit
ToStepTypeDefInput = StepTypeHandler | str
@dataclass
class Handler:
handler: StepTypeHandler
@dataclass
class String:
string: str
##########
# dsl/cb #
##########
class Typing(Enum):
Unknown = auto()
Boolean = auto()

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dsl import CircuitContext, StepTypeContext, StepTypeSetupContext, StepTypeWGHandler, StepTypeHandler
from cb import Constraint
from query import Queriable
from dsl import CircuitContext, StepTypeContext, StepTypeSetupContext, StepTypeWGHandler, StepTypeHandler, Constraint
from chiquito_ast import Queriable
class Fibonacci(CircuitContext):
def __init__(self: Fibonacci):