diff --git a/.gitignore b/.gitignore index 76823da..1d5f27f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ lib/ include/ target/ maturin -python + python3 python3.11 pip @@ -21,4 +21,5 @@ packages.png .vscode .devcontainer .ipynb_checkpoints -dsl-ipynb_checkpoints \ No newline at end of file +dsl-ipynb_checkpoints +*.so \ No newline at end of file diff --git a/python/chiquito/__init__.py b/python/chiquito/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/chiquito/cb.py b/python/chiquito/cb.py new file mode 100644 index 0000000..f6058b6 --- /dev/null +++ b/python/chiquito/cb.py @@ -0,0 +1,224 @@ +from __future__ import annotations +from dataclasses import dataclass +from enum import Enum, auto +from typing import List + +from chiquito.util import F +from chiquito.expr import Expr, Const, Neg, to_expr, ToExpr +from chiquito.query import StepTypeNext +from chiquito.chiquito_ast import ASTStepType + + +class Typing(Enum): + Unknown = auto() + Boolean = auto() + AntiBooly = auto() + + +@dataclass +class Constraint: + annotation: str + expr: Expr + typing: Typing + + def from_expr( + expr: Expr, + ) -> Constraint: # Cannot call function `from`, a reserved keyword in Python. + annotation: str = str(expr) + if isinstance(expr, StepTypeNext): + return Constraint(annotation, expr, Typing.Boolean) + else: + return Constraint(annotation, expr, Typing.Unknown) + + def __str__(self: Constraint) -> str: + return self.annotation + + +def cb_and( + inputs: List[ToConstraint], +) -> Constraint: # Cannot call function `and`, a reserved keyword in Python + inputs = [to_constraint(input) for input in inputs] + annotations: List[str] = [] + expr = Const(F(1)) + for constraint in inputs: + if constraint.typing == Typing.Boolean or constraint.typing == Typing.Unknown: + annotations.append(constraint.annotation) + expr = expr * constraint.expr + else: + raise ValueError( + f"Expected Boolean or Unknown constraint, got AntiBooly (constraint: {constraint.annotation})" + ) + return Constraint(f"({' AND '.join(annotations)})", expr, Typing.Boolean) + + +def cb_or( + inputs: List[ToConstraint], +) -> Constraint: # Cannot call function `or`, a reserved keyword in Python + inputs = [to_constraint(input) for input in inputs] + annotations: List[str] = [] + exprs: List[Expr] = [] + for constraint in inputs: + if constraint.typing == Typing.Boolean or constraint.typing == Typing.Unknown: + annotations.append(constraint.annotation) + exprs.append(constraint.expr) + else: + raise ValueError( + f"Expected Boolean or Unknown constraint, got AntiBooly (constraint: {constraint.annotation})" + ) + result: Constraint = Constraint.cb_not( + Constraint.cb_and([Constraint.cb_not(expr) for expr in exprs]) + ) + return Constraint(f"({' OR '.join(annotations)})", result.expr, Typing.Boolean) + + +def xor(lhs: ToConstraint, rhs: ToConstraint) -> Constraint: + (lhs, rhs) = (to_constraint(lhs), to_constraint(rhs)) + if (lhs.typing == Typing.Boolean or lhs.typing == Typing.Unknown) and ( + rhs.typing == Typing.Boolean or rhs.typing == Typing.Unknown + ): + return Constraint( + f"({lhs.annotation} XOR {rhs.annotation})", + lhs.expr + rhs.expr - F(2) * lhs.expr * rhs.expr, + Typing.Boolean, + ) + else: + raise ValueError( + f"Expected Boolean or Unknown constraints, got AntiBooly in one of lhs or rhs constraints (lhs constraint: {lhs.annotation}) (rhs constraint: {rhs.annotation})" + ) + + +def eq(lhs: ToConstraint, rhs: ToConstraint) -> Constraint: + (lhs, rhs) = (to_constraint(lhs), to_constraint(rhs)) + return Constraint( + f"({lhs.annotation} == {rhs.annotation})", + lhs.expr - rhs.expr, + Typing.AntiBooly, + ) + + +def select( + selector: ToConstraint, when_true: ToConstraint, when_false: ToConstraint +) -> Constraint: + (selector, when_true, when_false) = ( + to_constraint(selector), + to_constraint(when_true), + to_constraint(when_false), + ) + if selector.typing == Typing.AntiBooly: + raise ValueError( + f"Expected Boolean or Unknown selector, got AntiBooly (selector: {selector.annotation})" + ) + return Constraint( + f"if({selector.annotation})then({when_true.annotation})else({when_false.annotation})", + selector.expr * when_true.expr + (F(1) - selector.expr) * when_false.expr, + when_true.typing if when_true.typing == when_false.typing else Typing.Unknown, + ) + + +def when(selector: ToConstraint, when_true: ToConstraint) -> Constraint: + (selector, when_true) = (to_constraint(selector), to_constraint(when_true)) + if selector.typing == Typing.AntiBooly: + raise ValueError( + f"Expected Boolean or Unknown selector, got AntiBooly (selector: {selector.annotation})" + ) + return Constraint( + f"if({selector.annotation})then({when_true.annotation})", + selector.expr * when_true.expr, + when_true.typing, + ) + + +def unless(selector: ToConstraint, when_false: ToConstraint) -> Constraint: + (selector, when_false) = (to_constraint(selector), to_constraint(when_false)) + if selector.typing == Typing.AntiBooly: + raise ValueError( + f"Expected Boolean or Unknown selector, got AntiBooly (selector: {selector.annotation})" + ) + return Constraint( + f"unless({selector.annotation})then({when_false.annotation})", + (F(1) - selector.expr) * when_false.expr, + when_false.typing, + ) + + +def cb_not( + constraint: ToConstraint, +) -> Constraint: # Cannot call function `not`, a reserved keyword in Python + constraint = to_constraint(constraint) + if constraint.typing == Typing.AntiBooly: + raise ValueError( + f"Expected Boolean or Unknown constraint, got AntiBooly (constraint: {constraint.annotation})" + ) + return Constraint( + f"NOT({constraint.annotation})", F(1) - constraint.expr, Typing.Boolean + ) + + +def isz(constraint: ToConstraint) -> Constraint: + constraint = to_constraint(constraint) + return Constraint( + f"0 == {constraint.annotation}", constraint.expr, Typing.AntiBooly + ) + + +def if_next_step(step_type: ASTStepType, constraint: ToConstraint) -> Constraint: + constraint = to_constraint(constraint) + return Constraint( + f"if(next step is {step_type.annotation})then({constraint.annotation})", + StepTypeNext(step_type) * constraint.expr, + constraint.typing, + ) + + +def next_step_must_be(step_type: ASTStepType) -> Constraint: + return Constraint( + f"next step must be {step_type.annotation}", + Constraint.cb_not(StepTypeNext(step_type)), + Typing.AntiBooly, + ) + + +def next_step_must_not_be(step_type: ASTStepType) -> Constraint: + return Constraint( + f"next step must not be {step_type.annotation}", + StepTypeNext(step_type), + Typing.AntiBooly, + ) + + +def rlc(exprs: List[ToExpr], randomness: Expr) -> Expr: + if len(exprs) > 0: + exprs: List[Expr] = [to_expr(expr) for expr in exprs].reverse() + init: Expr = exprs[0] + for expr in exprs[1:]: + init = init * randomness + expr + return init + else: + return Expr(Const(F(0))) + + +# TODO: Implement lookup table after the lookup abstraction PR is merged. + + +ToConstraint = Constraint | Expr | int | F + + +def to_constraint(v: ToConstraint) -> Constraint: + if isinstance(v, Constraint): + return v + elif isinstance(v, Expr): + if isinstance(v, StepTypeNext): + return Constraint(str(v), v, Typing.Boolean) + else: + return Constraint(str(v), v, Typing.Unknown) + elif isinstance(v, int): + if v >= 0: + return to_constraint(Const(F(v))) + else: + return to_constraint(Neg(Const(F(-v)))) + elif isinstance(v, F): + return to_constraint(Const(v)) + else: + raise TypeError( + f"Type `{type(v)}` is not ToConstraint (one of Constraint, Expr, int, or F)." + ) diff --git a/python/chiquito/chiquito_ast.py b/python/chiquito/chiquito_ast.py new file mode 100644 index 0000000..1f4b28c --- /dev/null +++ b/python/chiquito/chiquito_ast.py @@ -0,0 +1,386 @@ +from __future__ import annotations +from typing import Callable, List, Dict, Optional, Any, Tuple +from dataclasses import dataclass, field, asdict +# from chiquito import wit_gen, expr, query, util + +from chiquito.wit_gen import FixedGenContext, StepInstance +from chiquito.expr import Expr +from chiquito.util import uuid +from chiquito.query import Queriable + + +# pub struct Circuit { +# pub step_types: HashMap>>, + +# pub forward_signals: Vec, +# pub shared_signals: Vec, +# pub fixed_signals: Vec, +# pub halo2_advice: Vec, +# pub halo2_fixed: Vec, +# pub exposed: Vec, + +# pub annotations: HashMap, + +# pub trace: Option>>, +# pub fixed_gen: Option>>, + +# pub first_step: Option, +# pub last_step: Option, +# pub num_steps: usize, +# } + + +@dataclass +class ASTCircuit: + step_types: Dict[int, ASTStepType] = field(default_factory=dict) + forward_signals: List[ForwardSignal] = field(default_factory=list) + shared_signals: List[SharedSignal] = field(default_factory=list) + fixed_signals: List[FixedSignal] = field(default_factory=list) + exposed: List[Tuple[Queriable, ExposeOffset]] = field(default_factory=list) + annotations: Dict[int, str] = field(default_factory=dict) + fixed_gen: Optional[Callable] = None + first_step: Optional[int] = None + last_step: Optional[int] = None + num_steps: int = 0 + q_enable: bool = True + id: int = uuid() + + def __str__(self: ASTCircuit): + step_types_str = ( + "\n\t\t" + + ",\n\t\t".join(f"{k}: {v}" for k, v in self.step_types.items()) + + "\n\t" + if self.step_types + else "" + ) + forward_signals_str = ( + "\n\t\t" + ",\n\t\t".join(str(fs) for fs in self.forward_signals) + "\n\t" + if self.forward_signals + else "" + ) + shared_signals_str = ( + "\n\t\t" + ",\n\t\t".join(str(ss) for ss in self.shared_signals) + "\n\t" + if self.shared_signals + else "" + ) + fixed_signals_str = ( + "\n\t\t" + ",\n\t\t".join(str(fs) for fs in self.fixed_signals) + "\n\t" + if self.fixed_signals + else "" + ) + exposed_str = ( + "\n\t\t" + + ",\n\t\t".join(f"({str(lhs)}, {str(rhs)})" for (lhs, rhs) in self.exposed) + + "\n\t" + if self.exposed + else "" + ) + annotations_str = ( + "\n\t\t" + + ",\n\t\t".join(f"{k}: {v}" for k, v in self.annotations.items()) + + "\n\t" + if self.annotations + else "" + ) + + return ( + f"ASTCircuit(\n" + f"\tstep_types={{{step_types_str}}},\n" + f"\tforward_signals=[{forward_signals_str}],\n" + f"\tshared_signals=[{shared_signals_str}],\n" + f"\tfixed_signals=[{fixed_signals_str}],\n" + f"\texposed=[{exposed_str}],\n" + f"\tannotations={{{annotations_str}}},\n" + f"\tfixed_gen={self.fixed_gen},\n" + f"\tfirst_step={self.first_step},\n" + f"\tlast_step={self.last_step},\n" + f"\tnum_steps={self.num_steps}\n" + f"\tq_enable={self.q_enable}\n" + f")" + ) + + def __json__(self: ASTCircuit): + return { + "step_types": {k: v.__json__() for k, v in self.step_types.items()}, + "forward_signals": [x.__json__() for x in self.forward_signals], + "shared_signals": [x.__json__() for x in self.shared_signals], + "fixed_signals": [x.__json__() for x in self.fixed_signals], + "exposed": [ + [queriable.__json__(), offset.__json__()] + for (queriable, offset) in self.exposed + ], + "annotations": self.annotations, + "first_step": self.first_step, + "last_step": self.last_step, + "num_steps": self.num_steps, + "q_enable": self.q_enable, + "id": self.id, + } + + def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal: + signal = ForwardSignal(phase, name) + self.forward_signals.append(signal) + self.annotations[signal.id] = name + return signal + + def add_shared(self: ASTCircuit, name: str, phase: int) -> SharedSignal: + signal = SharedSignal(phase, name) + self.shared_signals.append(signal) + self.annotations[signal.id] = name + return signal + + def add_fixed(self: ASTCircuit, name: str) -> FixedSignal: + signal = FixedSignal(name) + self.fixed_signals.append(signal) + self.annotations[signal.id] = name + return signal + + def expose(self: ASTCircuit, signal: Queriable, offset: ExposeOffset): + self.exposed.append((signal, offset)) + + def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str): + self.annotations[step_type.id] = name + self.step_types[step_type.id] = step_type + + def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]): + if self.fixed_gen is not None: + raise Exception("ASTCircuit cannot have more than one fixed generator.") + else: + self.fixed_gen = fixed_gen_def + + def get_step_type(self, uuid: int) -> ASTStepType: + if uuid in self.step_types.keys(): + return self.step_types[uuid] + else: + raise ValueError("ASTStepType not found.") + + +# pub struct StepType { +# id: StepTypeUUID, + +# pub name: String, +# pub signals: Vec, +# pub constraints: Vec>, +# pub transition_constraints: Vec>, +# pub lookups: Vec>, +# pub annotations: HashMap, +# } + + +@dataclass +class ASTStepType: + id: int + name: str + signals: List[InternalSignal] + constraints: List[ASTConstraint] + transition_constraints: List[TransitionConstraint] + annotations: Dict[int, str] + + def new(name: str) -> ASTStepType: + return ASTStepType(uuid(), name, [], [], [], {}) + + def __str__(self): + signals_str = ( + "\n\t\t\t\t" + + ",\n\t\t\t\t".join(str(signal) for signal in self.signals) + + "\n\t\t\t" + if self.signals + else "" + ) + constraints_str = ( + "\n\t\t\t\t" + + ",\n\t\t\t\t".join(str(constraint) for constraint in self.constraints) + + "\n\t\t\t" + if self.constraints + else "" + ) + transition_constraints_str = ( + "\n\t\t\t\t" + + ",\n\t\t\t\t".join(str(tc) for tc in self.transition_constraints) + + "\n\t\t\t" + if self.transition_constraints + else "" + ) + annotations_str = ( + "\n\t\t\t\t" + + ",\n\t\t\t\t".join(f"{k}: {v}" for k, v in self.annotations.items()) + + "\n\t\t\t" + if self.annotations + else "" + ) + + return ( + f"ASTStepType(\n" + f"\t\t\tid={self.id},\n" + f"\t\t\tname='{self.name}',\n" + f"\t\t\tsignals=[{signals_str}],\n" + f"\t\t\tconstraints=[{constraints_str}],\n" + f"\t\t\ttransition_constraints=[{transition_constraints_str}],\n" + f"\t\t\tannotations={{{annotations_str}}}\n" + f"\t\t)" + ) + + def __json__(self): + return { + "id": self.id, + "name": self.name, + "signals": [x.__json__() for x in self.signals], + "constraints": [x.__json__() for x in self.constraints], + "transition_constraints": [ + x.__json__() for x in self.transition_constraints + ], + "annotations": self.annotations, + } + + def add_signal(self: ASTStepType, name: str) -> InternalSignal: + signal = InternalSignal(name) + self.signals.append(signal) + self.annotations[signal.id] = name + return signal + + def add_constr(self: ASTStepType, annotation: str, expr: Expr): + condition = ASTConstraint(annotation, expr) + self.constraints.append(condition) + + def add_transition(self: ASTStepType, annotation: str, expr: Expr): + condition = TransitionConstraint(annotation, expr) + self.transition_constraints.append(condition) + + def __eq__(self: ASTStepType, other: ASTStepType) -> bool: + if isinstance(self, ASTStepType) and isinstance(other, ASTStepType): + return self.id == other.id + return False + + def __req__(other: ASTStepType, self: ASTStepType) -> bool: + return ASTStepType.__eq__(self, other) + + def __hash__(self: ASTStepType): + return hash(self.id) + + +@dataclass +class ASTConstraint: + annotation: str + expr: Expr + + def __str__(self: ASTConstraint): + return ( + f"Constraint(\n" + f"\t\t\t\t\tannotation='{self.annotation}',\n" + f"\t\t\t\t\texpr={self.expr}\n" + f"\t\t\t\t)" + ) + + def __json__(self: ASTConstraint): + return {"annotation": self.annotation, "expr": self.expr.__json__()} + + +@dataclass +class TransitionConstraint: + annotation: str + expr: Expr + + def __str__(self: TransitionConstraint): + return f"TransitionConstraint({self.annotation})" + + def __json__(self: TransitionConstraint): + return {"annotation": self.annotation, "expr": self.expr.__json__()} + + +@dataclass +class ForwardSignal: + id: int + phase: int + annotation: str + + def __init__(self: ForwardSignal, phase: int, annotation: str): + self.id: int = uuid() + self.phase = phase + self.annotation = annotation + + def __str__(self: ForwardSignal): + return f"ForwardSignal(id={self.id}, phase={self.phase}, annotation='{self.annotation}')" + + def __json__(self: ForwardSignal): + return asdict(self) + + +@dataclass +class SharedSignal: + id: int + phase: int + annotation: str + + def __init__(self: SharedSignal, phase: int, annotation: str): + self.id: int = uuid() + self.phase = phase + self.annotation = annotation + + def __str__(self: SharedSignal): + return f"SharedSignal(id={self.id}, phase={self.phase}, annotation='{self.annotation}')" + + def __json__(self: SharedSignal): + return asdict(self) + + +class ExposeOffset: + pass + + +class First(ExposeOffset): + def __str__(self: First): + return "First" + + def __json__(self: First): + return {"First": 0} + + +class Last(ExposeOffset): + def __str__(self: Last): + return "Last" + + def __json__(self: Last): + return {"Last": -1} + + +@dataclass +class Step(ExposeOffset): + offset: int + + def __str__(self: Step): + return f"Step({self.offset})" + + def __json__(self: Step): + return {"Step": self.offset} + + +@dataclass +class FixedSignal: + id: int + annotation: str + + def __init__(self: FixedSignal, annotation: str): + self.id: int = uuid() + self.annotation = annotation + + def __str__(self: FixedSignal): + return f"FixedSignal(id={self.id}, annotation='{self.annotation}')" + + def __json__(self: FixedSignal): + return asdict(self) + + +@dataclass +class InternalSignal: + id: int + annotation: str + + def __init__(self: InternalSignal, annotation: str): + self.id = uuid() + self.annotation = annotation + + def __str__(self: InternalSignal): + return f"InternalSignal(id={self.id}, annotation='{self.annotation}')" + + def __json__(self: InternalSignal): + return asdict(self) diff --git a/python/chiquito/dsl.py b/python/chiquito/dsl.py new file mode 100644 index 0000000..2361d66 --- /dev/null +++ b/python/chiquito/dsl.py @@ -0,0 +1,164 @@ +from __future__ import annotations +from enum import Enum +from typing import Callable, Any +import rust_chiquito # rust bindings +import json +from chiquito import (chiquito_ast, wit_gen) + +from chiquito.chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset +from chiquito.query import Internal, Forward, Queriable, Shared, Fixed +from chiquito.wit_gen import FixedGenContext, StepInstance, TraceWitness +from chiquito.cb import Constraint, Typing, ToConstraint, to_constraint +from chiquito.util import CustomEncoder, F + + +class CircuitMode(Enum): + NoMode = 0 + SETUP = 1 + Trace = 2 + + +class Circuit: + def __init__(self: Circuit): + self.ast = ASTCircuit() + self.witness = TraceWitness() + self.rust_ast_id = 0 + self.mode = CircuitMode.SETUP + self.setup() + + def forward(self: Circuit, name: str) -> Forward: + assert self.mode == CircuitMode.SETUP + return Forward(self.ast.add_forward(name, 0), False) + + def forward_with_phase(self: Circuit, name: str, phase: int) -> Forward: + assert self.mode == CircuitMode.SETUP + return Forward(self.ast.add_forward(name, phase), False) + + def shared(self: Circuit, name: str) -> Shared: + assert self.mode == CircuitMode.SETUP + return Shared(self.ast.add_shared(name, 0), 0) + + def shared_with_phase(self: Circuit, name: str, phase: int) -> Shared: + assert self.mode == CircuitMode.SETUP + return Shared(self.ast.add_shared(name, phase), 0) + + def fixed(self: Circuit, name: str) -> Fixed: + assert self.mode == CircuitMode.SETUP + return Fixed(self.ast.add_fixed(name), 0) + + def expose(self: Circuit, signal: Queriable, offset: ExposeOffset): + assert self.mode == CircuitMode.SETUP + if isinstance(signal, (Forward, Shared)): + self.ast.expose(signal, offset) + else: + raise TypeError(f"Can only expose ForwardSignal or SharedSignal.") + + def step_type(self: Circuit, step_type: StepType) -> StepType: + assert self.mode == CircuitMode.SETUP + self.ast.add_step_type(step_type.step_type, step_type.step_type.name) + return step_type + + def step_type_def(self: StepType) -> StepType: + assert self.mode == CircuitMode.SETUP + self.ast.add_step_type_def() + + def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]): + self.ast.set_fixed_gen(fixed_gen_def) + + def pragma_first_step(self: Circuit, step_type: StepType) -> None: + assert self.mode == CircuitMode.SETUP + self.ast.first_step = step_type.step_type.id + + def pragma_last_step(self: Circuit, step_type: StepType) -> None: + assert self.mode == CircuitMode.SETUP + self.ast.last_step = step_type.step_type.id + + def pragma_num_steps(self: Circuit, num_steps: int) -> None: + assert self.mode == CircuitMode.SETUP + self.ast.num_steps = num_steps + + def pragma_disable_q_enable(self: Circuit) -> None: + assert self.mode == CircuitMode.SETUP + self.ast.q_enable = False + + def add(self: Circuit, step_type: StepType, args: Any): + assert self.mode == CircuitMode.Trace + step_instance: StepInstance = step_type.gen_step_instance(args) + self.witness.step_instances.append(step_instance) + + def gen_witness(self: Circuit, args: Any) -> TraceWitness: + self.mode = CircuitMode.Trace + self.witness = TraceWitness() + self.trace(args) + self.mode = CircuitMode.NoMode + witness = self.witness + del self.witness + return witness + + def get_ast_json(self: Circuit) -> str: + return json.dumps(self.ast, cls=CustomEncoder, indent=4) + + def halo2_mock_prover(self: Circuit, witness: TraceWitness): + if self.rust_ast_id == 0: + ast_json: str = self.get_ast_json() + self.rust_ast_id: int = rust_chiquito.ast_to_halo2(ast_json) + witness_json: str = witness.get_witness_json() + rust_chiquito.halo2_mock_prover(witness_json, self.rust_ast_id) + + def __str__(self: Circuit) -> str: + return self.ast.__str__() + + +class StepTypeMode(Enum): + NoMode = 0 + SETUP = 1 + WG = 2 + + +class StepType: + def __init__(self: StepType, circuit: Circuit, step_type_name: str): + self.step_type = ASTStepType.new(step_type_name) + self.circuit = circuit + self.mode = StepTypeMode.SETUP + self.setup() + + def gen_step_instance(self: StepType, args: Any) -> StepInstance: + self.mode = StepTypeMode.WG + self.step_instance = StepInstance.new(self.step_type.id) + self.wg(args) + self.mode = StepTypeMode.NoMode + step_instance = self.step_instance + del self.step_instance + return step_instance + + def internal(self: StepType, name: str) -> Internal: + assert self.mode == StepTypeMode.SETUP + + return Internal(self.step_type.add_signal(name)) + + def constr(self: StepType, constraint: ToConstraint): + assert self.mode == StepTypeMode.SETUP + + constraint = to_constraint(constraint) + StepType.enforce_constraint_typing(constraint) + self.step_type.add_constr(constraint.annotation, constraint.expr) + + def transition(self: StepType, constraint: ToConstraint): + assert self.mode == StepTypeMode.SETUP + + constraint = to_constraint(constraint) + StepType.enforce_constraint_typing(constraint) + self.step_type.add_transition(constraint.annotation, constraint.expr) + + def enforce_constraint_typing(constraint: Constraint): + if constraint.typing != Typing.AntiBooly: + raise ValueError( + f"Expected AntiBooly constraint, got {constraint.typing} (constraint: {constraint.annotation})" + ) + + def assign(self: StepType, lhs: Queriable, rhs: F): + assert self.mode == StepTypeMode.WG + + self.step_instance.assign(lhs, rhs) + + # TODO: Implement add_lookup after lookup abstraction PR is merged. diff --git a/python/chiquito/expr.py b/python/chiquito/expr.py new file mode 100644 index 0000000..52f9936 --- /dev/null +++ b/python/chiquito/expr.py @@ -0,0 +1,157 @@ +from __future__ import annotations +from typing import List +from dataclasses import dataclass + +from chiquito.util import F + + +# pub enum Expr { +# Const(F), +# Sum(Vec>), +# Mul(Vec>), +# Neg(Box>), +# Pow(Box>, u32), +# Query(Queriable), +# Halo2Expr(Expression), +# } + + +@dataclass +class Expr: + def __neg__(self: Expr) -> Neg: + return Neg(self) + + def __add__(self: Expr, rhs: ToExpr) -> Sum: + rhs = to_expr(rhs) + return Sum([self, rhs]) + + def __radd__(self: Expr, lhs: ToExpr) -> Sum: + return Expr.__add__(lhs, self) + + def __sub__(self: Expr, rhs: ToExpr) -> Sum: + rhs = to_expr(rhs) + return Sum([self, Neg(rhs)]) + + def __rsub__(self: Expr, lhs: ToExpr) -> Sum: + return Expr.__sub__(lhs, self) + + def __mul__(self: Expr, rhs: ToExpr) -> Mul: + rhs = to_expr(rhs) + return Mul([self, rhs]) + + def __rmul__(self: Expr, lhs: ToExpr) -> Mul: + return Expr.__mul__(lhs, self) + + def __pow__(self: Expr, rhs: int) -> Pow: + return Pow(self, rhs) + + +@dataclass +class Const(Expr): + value: F + + def __str__(self: Const) -> str: + return str(self.value) + + def __json__(self): + return {"Const": self.value} + + +@dataclass +class Sum(Expr): + exprs: List[Expr] + + def __str__(self: Sum) -> str: + result = "(" + for i, expr in enumerate(self.exprs): + if type(expr) is Neg: + if i == 0: + result += "-" + else: + result += " - " + else: + if i > 0: + result += " + " + result += str(expr) + result += ")" + return result + + def __json__(self): + return {"Sum": [expr.__json__() for expr in self.exprs]} + + def __add__(self: Sum, rhs: ToExpr) -> Sum: + rhs = to_expr(rhs) + return Sum(self.exprs + [rhs]) + + def __radd__(self: Sum, lhs: ToExpr) -> Sum: + return Sum.__add__(lhs, self) + + def __sub__(self: Sum, rhs: ToExpr) -> Sum: + rhs = to_expr(rhs) + return Sum(self.exprs + [Neg(rhs)]) + + def __rsub__(self: Sum, lhs: ToExpr) -> Sum: + return Sum.__sub__(lhs, self) + + +@dataclass +class Mul(Expr): + exprs: List[Expr] + + def __str__(self: Mul) -> str: + return "*".join([str(expr) for expr in self.exprs]) + + def __json__(self): + return {"Mul": [expr.__json__() for expr in self.exprs]} + + def __mul__(self: Mul, rhs: ToExpr) -> Mul: + rhs = to_expr(rhs) + return Mul(self.exprs + [rhs]) + + def __rmul__(self: Mul, lhs: ToExpr) -> Mul: + return Mul.__mul__(lhs, self) + + +@dataclass +class Neg(Expr): + expr: Expr + + def __str__(self: Neg) -> str: + return "(-" + str(self.expr) + ")" + + def __json__(self): + return {"Neg": self.expr.__json__()} + + def __neg__(self: Neg) -> Expr: + return self.expr + + +@dataclass +class Pow(Expr): + expr: Expr + pow: int + + def __str__(self: Pow) -> str: + return str(self.expr) + "^" + str(self.pow) + + def __json__(self): + return {"Pow": [self.expr.__json__(), self.pow]} + + +ToExpr = Expr | int | F + + +def to_expr(v: ToExpr) -> Expr: + if isinstance(v, Expr): + return v + elif isinstance(v, int): + if v >= 0: + return Const(F(v)) + else: + return Neg(Const(F(-v))) + elif isinstance(v, F): + return Const(v) + else: + raise TypeError( + f"Type {type(v)} is not ToExpr (one of Expr, int, F, or Constraint)." + ) diff --git a/python/chiquito/query.py b/python/chiquito/query.py new file mode 100644 index 0000000..9dafb6c --- /dev/null +++ b/python/chiquito/query.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from chiquito.expr import Expr + +# Commented out to avoid circular reference +# from chiquito_ast import InternalSignal, ForwardSignal, SharedSignal, FixedSignal, ASTStepType + + +# pub enum Queriable { +# Internal(InternalSignal), +# Forward(ForwardSignal, bool), +# Shared(SharedSignal, i32), +# Fixed(FixedSignal, i32), +# StepTypeNext(ASTStepTypeHandler), +# Halo2AdviceQuery(ImportedHalo2Advice, i32), +# Halo2FixedQuery(ImportedHalo2Fixed, i32), +# #[allow(non_camel_case_types)] +# _unaccessible(PhantomData), +# } + + +class Queriable(Expr): + # __hash__ method is required, because Queriable is used as a key in the assignment dictionary. + def __hash__(self: Queriable): + return hash(self.uuid()) + + # Implemented in all children classes, and only children instances will ever be created for Queriable. + def uuid(self: Queriable) -> int: + pass + + +# Not defined as @dataclass, because inherited __hash__ will be set to None. +class Internal(Queriable): + def __init__(self: Internal, signal: InternalSignal): + self.signal = signal + + def uuid(self: Internal) -> int: + return self.signal.id + + def __str__(self: Internal) -> str: + return self.signal.annotation + + def __json__(self): + return {"Internal": self.signal.__json__()} + + +class Forward(Queriable): + def __init__(self: Forward, signal: ForwardSignal, rotation: bool): + self.signal = signal + self.rotation = rotation + + def next(self: Forward) -> Forward: + if self.rotation: + raise ValueError("Cannot rotate Forward twice.") + else: + return Forward(self.signal, True) + + def uuid(self: Forward) -> int: + return self.signal.id + + def __str__(self: Forward) -> str: + if not self.rotation: + return self.signal.annotation + else: + return f"next({self.signal.annotation})" + + def __json__(self): + return {"Forward": [self.signal.__json__(), self.rotation]} + + +class Shared(Queriable): + def __init__(self: Shared, signal: SharedSignal, rotation: int): + self.signal = signal + self.rotation = rotation + + def next(self: Shared) -> Shared: + return Shared(self.signal, self.rotation + 1) + + def prev(self: Shared) -> Shared: + return Shared(self.signal, self.rotation - 1) + + def rot(self: Shared, rotation: int) -> Shared: + return Shared(self.signal, self.rotation + rotation) + + def uuid(self: Shared) -> int: + return self.signal.id + + def __str__(self: Shared) -> str: + if self.rotation == 0: + return self.signal.annotation + else: + return f"{self.signal.annotation}(rot {self.rotation})" + + def __json__(self): + return {"Shared": [self.signal.__json__(), self.rotation]} + + +class Fixed(Queriable): + def __init__(self: Fixed, signal: FixedSignal, rotation: int): + self.signal = signal + self.rotation = rotation + + def next(self: Fixed) -> Fixed: + return Fixed(self.signal, self.rotation + 1) + + def prev(self: Fixed) -> Fixed: + return Fixed(self.signal, self.rotation - 1) + + def rot(self: Fixed, rotation: int) -> Fixed: + return Fixed(self.signal, self.rotation + rotation) + + def uuid(self: Fixed) -> int: + return self.signal.id + + def __str__(self: Fixed) -> str: + if self.rotation == 0: + return self.signal.annotation + else: + return f"{self.signal.annotation}(rot {self.rotation})" + + def __json__(self): + return {"Fixed": [self.signal.__json__(), self.rotation]} + + +class StepTypeNext(Queriable): + def __init__(self: StepTypeNext, step_type: ASTStepType): + self.step_type = step_type + + def uuid(self: ASTStepType) -> int: + return self.id + + def __str__(self: ASTStepType) -> str: + return self.name + + def __json__(self): + return { + "StepTypeNext": {"id": self.step_type.id, "annotation": self.step_type.name} + } diff --git a/python/chiquito/util.py b/python/chiquito/util.py new file mode 100644 index 0000000..d838aa1 --- /dev/null +++ b/python/chiquito/util.py @@ -0,0 +1,31 @@ +from __future__ import annotations +from py_ecc import bn128 +from uuid import uuid1 +import json + +F = bn128.FQ + + +def json_method(self: F): + # Convert the integer to a byte array + byte_array = self.n.to_bytes(32, "little") + + # Split into four 64-bit integers + ints = [int.from_bytes(byte_array[i * 8 : i * 8 + 8], "little") for i in range(4)] + + return ints + + +F.__json__ = json_method + + +class CustomEncoder(json.JSONEncoder): + def default(self, obj): + if hasattr(obj, "__json__"): + return obj.__json__() + return super().default(obj) + + +# int field is the u128 version of uuid. +def uuid() -> int: + return uuid1(node=int.from_bytes([10, 10, 10, 10, 10, 10], byteorder="little")).int diff --git a/python/chiquito/wit_gen.py b/python/chiquito/wit_gen.py new file mode 100644 index 0000000..b740875 --- /dev/null +++ b/python/chiquito/wit_gen.py @@ -0,0 +1,122 @@ +from __future__ import annotations +from dataclasses import dataclass, field +from typing import Dict, List, Callable, Any +import json + +from chiquito.query import Queriable, Fixed +from chiquito.util import F, CustomEncoder + +# Commented out to avoid circular reference +# from dsl import Circuit, StepType + + +@dataclass +class StepInstance: + step_type_uuid: int = 0 + assignments: Dict[Queriable, F] = field(default_factory=dict) + + def new(step_type_uuid: int) -> StepInstance: + return StepInstance(step_type_uuid, {}) + + def assign(self: StepInstance, lhs: Queriable, rhs: F): + self.assignments[lhs] = rhs + + def __str__(self: StepInstance): + assignments_str = ( + "\n\t\t\t\t" + + ",\n\t\t\t\t".join( + f"{str(lhs)} = {rhs}" for (lhs, rhs) in self.assignments.items() + ) + + "\n\t\t\t" + if self.assignments + else "" + ) + return ( + f"StepInstance(\n" + f"\t\t\tstep_type_uuid={self.step_type_uuid},\n" + f"\t\t\tassignments={{{assignments_str}}},\n" + f"\t\t)" + ) + + # For assignments, return "uuid: (Queriable, F)" rather than "Queriable: F", because JSON doesn't accept Dict as key. + def __json__(self: StepInstance): + return { + "step_type_uuid": self.step_type_uuid, + "assignments": { + lhs.uuid(): [lhs, rhs] for (lhs, rhs) in self.assignments.items() + }, + } + + +Witness = List[StepInstance] + + +@dataclass +class TraceWitness: + step_instances: Witness = field(default_factory=list) + + def __str__(self: TraceWitness): + step_instances_str = ( + "\n\t\t" + + ",\n\t\t".join( + str(step_instance) for step_instance in self.step_instances + ) + + "\n\t" + if self.step_instances + else "" + ) + return f"TraceWitness(\n" f"\tstep_instances={{{step_instances_str}}},\n" f")" + + def __json__(self: TraceWitness): + return { + "step_instances": [ + step_instance.__json__() for step_instance in self.step_instances + ] + } + + def get_witness_json(self: TraceWitness) -> str: + return json.dumps(self, cls=CustomEncoder, indent=4) + + def evil_witness_test( + self: TraceWitness, + step_instance_indices: List[int], + assignment_indices: List[int], + rhs: List[F], + ) -> TraceWitness: + if not len(step_instance_indices) == len(assignment_indices) == len(rhs): + raise ValueError(f"`evil_witness_test` inputs have different lengths.") + new_step_instances = self.step_instances.copy() + for i in range(len(step_instance_indices)): + keys = list(new_step_instances[step_instance_indices[i]].assignments.keys()) + new_step_instances[step_instance_indices[i]].assignments[ + keys[assignment_indices[i]] + ] = rhs[i] + return TraceWitness(new_step_instances) + + +FixedAssigment = Dict[Queriable, List[F]] + + +@dataclass +class FixedGenContext: + assignments: FixedAssigment = field(default_factory=dict) + num_steps: int = 0 + + def new(num_steps: int) -> FixedGenContext: + return FixedGenContext({}, num_steps) + + def assign(self: FixedGenContext, offset: int, lhs: Queriable, rhs: F): + if not FixedGenContext.is_fixed_queriable(lhs): + raise ValueError(f"Cannot assign to non-fixed signal.") + if lhs in self.assignments.keys(): + self.assignments[lhs][offset] = rhs + else: + self.assignments[lhs] = [F.zero()] * self.num_steps + self.assignments[lhs][offset] = rhs + + def is_fixed_queriable(q: Queriable) -> bool: + match q.enum: + case Fixed(_, _): + return True + case _: + return False