first tests did not pass without python dir, adding it

This commit is contained in:
trangnv
2023-08-11 10:45:25 +07:00
parent cfda2ce461
commit 5693ada9d8
9 changed files with 1225 additions and 2 deletions

5
.gitignore vendored
View File

@@ -6,7 +6,7 @@ lib/
include/
target/
maturin
python
python3
python3.11
pip
@@ -21,4 +21,5 @@ packages.png
.vscode
.devcontainer
.ipynb_checkpoints
dsl-ipynb_checkpoints
dsl-ipynb_checkpoints
*.so

View File

224
python/chiquito/cb.py Normal file
View File

@@ -0,0 +1,224 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum, auto
from typing import List
from chiquito.util import F
from chiquito.expr import Expr, Const, Neg, to_expr, ToExpr
from chiquito.query import StepTypeNext
from chiquito.chiquito_ast import ASTStepType
class Typing(Enum):
Unknown = auto()
Boolean = auto()
AntiBooly = auto()
@dataclass
class Constraint:
annotation: str
expr: Expr
typing: Typing
def from_expr(
expr: Expr,
) -> Constraint: # Cannot call function `from`, a reserved keyword in Python.
annotation: str = str(expr)
if isinstance(expr, StepTypeNext):
return Constraint(annotation, expr, Typing.Boolean)
else:
return Constraint(annotation, expr, Typing.Unknown)
def __str__(self: Constraint) -> str:
return self.annotation
def cb_and(
inputs: List[ToConstraint],
) -> Constraint: # Cannot call function `and`, a reserved keyword in Python
inputs = [to_constraint(input) for input in inputs]
annotations: List[str] = []
expr = Const(F(1))
for constraint in inputs:
if constraint.typing == Typing.Boolean or constraint.typing == Typing.Unknown:
annotations.append(constraint.annotation)
expr = expr * constraint.expr
else:
raise ValueError(
f"Expected Boolean or Unknown constraint, got AntiBooly (constraint: {constraint.annotation})"
)
return Constraint(f"({' AND '.join(annotations)})", expr, Typing.Boolean)
def cb_or(
inputs: List[ToConstraint],
) -> Constraint: # Cannot call function `or`, a reserved keyword in Python
inputs = [to_constraint(input) for input in inputs]
annotations: List[str] = []
exprs: List[Expr] = []
for constraint in inputs:
if constraint.typing == Typing.Boolean or constraint.typing == Typing.Unknown:
annotations.append(constraint.annotation)
exprs.append(constraint.expr)
else:
raise ValueError(
f"Expected Boolean or Unknown constraint, got AntiBooly (constraint: {constraint.annotation})"
)
result: Constraint = Constraint.cb_not(
Constraint.cb_and([Constraint.cb_not(expr) for expr in exprs])
)
return Constraint(f"({' OR '.join(annotations)})", result.expr, Typing.Boolean)
def xor(lhs: ToConstraint, rhs: ToConstraint) -> Constraint:
(lhs, rhs) = (to_constraint(lhs), to_constraint(rhs))
if (lhs.typing == Typing.Boolean or lhs.typing == Typing.Unknown) and (
rhs.typing == Typing.Boolean or rhs.typing == Typing.Unknown
):
return Constraint(
f"({lhs.annotation} XOR {rhs.annotation})",
lhs.expr + rhs.expr - F(2) * lhs.expr * rhs.expr,
Typing.Boolean,
)
else:
raise ValueError(
f"Expected Boolean or Unknown constraints, got AntiBooly in one of lhs or rhs constraints (lhs constraint: {lhs.annotation}) (rhs constraint: {rhs.annotation})"
)
def eq(lhs: ToConstraint, rhs: ToConstraint) -> Constraint:
(lhs, rhs) = (to_constraint(lhs), to_constraint(rhs))
return Constraint(
f"({lhs.annotation} == {rhs.annotation})",
lhs.expr - rhs.expr,
Typing.AntiBooly,
)
def select(
selector: ToConstraint, when_true: ToConstraint, when_false: ToConstraint
) -> Constraint:
(selector, when_true, when_false) = (
to_constraint(selector),
to_constraint(when_true),
to_constraint(when_false),
)
if selector.typing == Typing.AntiBooly:
raise ValueError(
f"Expected Boolean or Unknown selector, got AntiBooly (selector: {selector.annotation})"
)
return Constraint(
f"if({selector.annotation})then({when_true.annotation})else({when_false.annotation})",
selector.expr * when_true.expr + (F(1) - selector.expr) * when_false.expr,
when_true.typing if when_true.typing == when_false.typing else Typing.Unknown,
)
def when(selector: ToConstraint, when_true: ToConstraint) -> Constraint:
(selector, when_true) = (to_constraint(selector), to_constraint(when_true))
if selector.typing == Typing.AntiBooly:
raise ValueError(
f"Expected Boolean or Unknown selector, got AntiBooly (selector: {selector.annotation})"
)
return Constraint(
f"if({selector.annotation})then({when_true.annotation})",
selector.expr * when_true.expr,
when_true.typing,
)
def unless(selector: ToConstraint, when_false: ToConstraint) -> Constraint:
(selector, when_false) = (to_constraint(selector), to_constraint(when_false))
if selector.typing == Typing.AntiBooly:
raise ValueError(
f"Expected Boolean or Unknown selector, got AntiBooly (selector: {selector.annotation})"
)
return Constraint(
f"unless({selector.annotation})then({when_false.annotation})",
(F(1) - selector.expr) * when_false.expr,
when_false.typing,
)
def cb_not(
constraint: ToConstraint,
) -> Constraint: # Cannot call function `not`, a reserved keyword in Python
constraint = to_constraint(constraint)
if constraint.typing == Typing.AntiBooly:
raise ValueError(
f"Expected Boolean or Unknown constraint, got AntiBooly (constraint: {constraint.annotation})"
)
return Constraint(
f"NOT({constraint.annotation})", F(1) - constraint.expr, Typing.Boolean
)
def isz(constraint: ToConstraint) -> Constraint:
constraint = to_constraint(constraint)
return Constraint(
f"0 == {constraint.annotation}", constraint.expr, Typing.AntiBooly
)
def if_next_step(step_type: ASTStepType, constraint: ToConstraint) -> Constraint:
constraint = to_constraint(constraint)
return Constraint(
f"if(next step is {step_type.annotation})then({constraint.annotation})",
StepTypeNext(step_type) * constraint.expr,
constraint.typing,
)
def next_step_must_be(step_type: ASTStepType) -> Constraint:
return Constraint(
f"next step must be {step_type.annotation}",
Constraint.cb_not(StepTypeNext(step_type)),
Typing.AntiBooly,
)
def next_step_must_not_be(step_type: ASTStepType) -> Constraint:
return Constraint(
f"next step must not be {step_type.annotation}",
StepTypeNext(step_type),
Typing.AntiBooly,
)
def rlc(exprs: List[ToExpr], randomness: Expr) -> Expr:
if len(exprs) > 0:
exprs: List[Expr] = [to_expr(expr) for expr in exprs].reverse()
init: Expr = exprs[0]
for expr in exprs[1:]:
init = init * randomness + expr
return init
else:
return Expr(Const(F(0)))
# TODO: Implement lookup table after the lookup abstraction PR is merged.
ToConstraint = Constraint | Expr | int | F
def to_constraint(v: ToConstraint) -> Constraint:
if isinstance(v, Constraint):
return v
elif isinstance(v, Expr):
if isinstance(v, StepTypeNext):
return Constraint(str(v), v, Typing.Boolean)
else:
return Constraint(str(v), v, Typing.Unknown)
elif isinstance(v, int):
if v >= 0:
return to_constraint(Const(F(v)))
else:
return to_constraint(Neg(Const(F(-v))))
elif isinstance(v, F):
return to_constraint(Const(v))
else:
raise TypeError(
f"Type `{type(v)}` is not ToConstraint (one of Constraint, Expr, int, or F)."
)

View File

@@ -0,0 +1,386 @@
from __future__ import annotations
from typing import Callable, List, Dict, Optional, Any, Tuple
from dataclasses import dataclass, field, asdict
# from chiquito import wit_gen, expr, query, util
from chiquito.wit_gen import FixedGenContext, StepInstance
from chiquito.expr import Expr
from chiquito.util import uuid
from chiquito.query import Queriable
# pub struct Circuit<F, TraceArgs> {
# pub step_types: HashMap<u32, Rc<ASTStepType<F>>>,
# pub forward_signals: Vec<ForwardSignal>,
# pub shared_signals: Vec<SharedSignal>,
# pub fixed_signals: Vec<FixedSignal>,
# pub halo2_advice: Vec<ImportedHalo2Advice>,
# pub halo2_fixed: Vec<ImportedHalo2Fixed>,
# pub exposed: Vec<ForwardSignal>,
# pub annotations: HashMap<u32, String>,
# pub trace: Option<Rc<Trace<F, TraceArgs>>>,
# pub fixed_gen: Option<Rc<FixedGen<F>>>,
# pub first_step: Option<ASTStepTypeUUID>,
# pub last_step: Option<ASTStepTypeUUID>,
# pub num_steps: usize,
# }
@dataclass
class ASTCircuit:
step_types: Dict[int, ASTStepType] = field(default_factory=dict)
forward_signals: List[ForwardSignal] = field(default_factory=list)
shared_signals: List[SharedSignal] = field(default_factory=list)
fixed_signals: List[FixedSignal] = field(default_factory=list)
exposed: List[Tuple[Queriable, ExposeOffset]] = field(default_factory=list)
annotations: Dict[int, str] = field(default_factory=dict)
fixed_gen: Optional[Callable] = None
first_step: Optional[int] = None
last_step: Optional[int] = None
num_steps: int = 0
q_enable: bool = True
id: int = uuid()
def __str__(self: ASTCircuit):
step_types_str = (
"\n\t\t"
+ ",\n\t\t".join(f"{k}: {v}" for k, v in self.step_types.items())
+ "\n\t"
if self.step_types
else ""
)
forward_signals_str = (
"\n\t\t" + ",\n\t\t".join(str(fs) for fs in self.forward_signals) + "\n\t"
if self.forward_signals
else ""
)
shared_signals_str = (
"\n\t\t" + ",\n\t\t".join(str(ss) for ss in self.shared_signals) + "\n\t"
if self.shared_signals
else ""
)
fixed_signals_str = (
"\n\t\t" + ",\n\t\t".join(str(fs) for fs in self.fixed_signals) + "\n\t"
if self.fixed_signals
else ""
)
exposed_str = (
"\n\t\t"
+ ",\n\t\t".join(f"({str(lhs)}, {str(rhs)})" for (lhs, rhs) in self.exposed)
+ "\n\t"
if self.exposed
else ""
)
annotations_str = (
"\n\t\t"
+ ",\n\t\t".join(f"{k}: {v}" for k, v in self.annotations.items())
+ "\n\t"
if self.annotations
else ""
)
return (
f"ASTCircuit(\n"
f"\tstep_types={{{step_types_str}}},\n"
f"\tforward_signals=[{forward_signals_str}],\n"
f"\tshared_signals=[{shared_signals_str}],\n"
f"\tfixed_signals=[{fixed_signals_str}],\n"
f"\texposed=[{exposed_str}],\n"
f"\tannotations={{{annotations_str}}},\n"
f"\tfixed_gen={self.fixed_gen},\n"
f"\tfirst_step={self.first_step},\n"
f"\tlast_step={self.last_step},\n"
f"\tnum_steps={self.num_steps}\n"
f"\tq_enable={self.q_enable}\n"
f")"
)
def __json__(self: ASTCircuit):
return {
"step_types": {k: v.__json__() for k, v in self.step_types.items()},
"forward_signals": [x.__json__() for x in self.forward_signals],
"shared_signals": [x.__json__() for x in self.shared_signals],
"fixed_signals": [x.__json__() for x in self.fixed_signals],
"exposed": [
[queriable.__json__(), offset.__json__()]
for (queriable, offset) in self.exposed
],
"annotations": self.annotations,
"first_step": self.first_step,
"last_step": self.last_step,
"num_steps": self.num_steps,
"q_enable": self.q_enable,
"id": self.id,
}
def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal:
signal = ForwardSignal(phase, name)
self.forward_signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_shared(self: ASTCircuit, name: str, phase: int) -> SharedSignal:
signal = SharedSignal(phase, name)
self.shared_signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_fixed(self: ASTCircuit, name: str) -> FixedSignal:
signal = FixedSignal(name)
self.fixed_signals.append(signal)
self.annotations[signal.id] = name
return signal
def expose(self: ASTCircuit, signal: Queriable, offset: ExposeOffset):
self.exposed.append((signal, offset))
def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str):
self.annotations[step_type.id] = name
self.step_types[step_type.id] = step_type
def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]):
if self.fixed_gen is not None:
raise Exception("ASTCircuit cannot have more than one fixed generator.")
else:
self.fixed_gen = fixed_gen_def
def get_step_type(self, uuid: int) -> ASTStepType:
if uuid in self.step_types.keys():
return self.step_types[uuid]
else:
raise ValueError("ASTStepType not found.")
# pub struct StepType<F> {
# id: StepTypeUUID,
# pub name: String,
# pub signals: Vec<InternalSignal>,
# pub constraints: Vec<Constraint<F>>,
# pub transition_constraints: Vec<TransitionConstraint<F>>,
# pub lookups: Vec<Lookup<F>>,
# pub annotations: HashMap<u32, String>,
# }
@dataclass
class ASTStepType:
id: int
name: str
signals: List[InternalSignal]
constraints: List[ASTConstraint]
transition_constraints: List[TransitionConstraint]
annotations: Dict[int, str]
def new(name: str) -> ASTStepType:
return ASTStepType(uuid(), name, [], [], [], {})
def __str__(self):
signals_str = (
"\n\t\t\t\t"
+ ",\n\t\t\t\t".join(str(signal) for signal in self.signals)
+ "\n\t\t\t"
if self.signals
else ""
)
constraints_str = (
"\n\t\t\t\t"
+ ",\n\t\t\t\t".join(str(constraint) for constraint in self.constraints)
+ "\n\t\t\t"
if self.constraints
else ""
)
transition_constraints_str = (
"\n\t\t\t\t"
+ ",\n\t\t\t\t".join(str(tc) for tc in self.transition_constraints)
+ "\n\t\t\t"
if self.transition_constraints
else ""
)
annotations_str = (
"\n\t\t\t\t"
+ ",\n\t\t\t\t".join(f"{k}: {v}" for k, v in self.annotations.items())
+ "\n\t\t\t"
if self.annotations
else ""
)
return (
f"ASTStepType(\n"
f"\t\t\tid={self.id},\n"
f"\t\t\tname='{self.name}',\n"
f"\t\t\tsignals=[{signals_str}],\n"
f"\t\t\tconstraints=[{constraints_str}],\n"
f"\t\t\ttransition_constraints=[{transition_constraints_str}],\n"
f"\t\t\tannotations={{{annotations_str}}}\n"
f"\t\t)"
)
def __json__(self):
return {
"id": self.id,
"name": self.name,
"signals": [x.__json__() for x in self.signals],
"constraints": [x.__json__() for x in self.constraints],
"transition_constraints": [
x.__json__() for x in self.transition_constraints
],
"annotations": self.annotations,
}
def add_signal(self: ASTStepType, name: str) -> InternalSignal:
signal = InternalSignal(name)
self.signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_constr(self: ASTStepType, annotation: str, expr: Expr):
condition = ASTConstraint(annotation, expr)
self.constraints.append(condition)
def add_transition(self: ASTStepType, annotation: str, expr: Expr):
condition = TransitionConstraint(annotation, expr)
self.transition_constraints.append(condition)
def __eq__(self: ASTStepType, other: ASTStepType) -> bool:
if isinstance(self, ASTStepType) and isinstance(other, ASTStepType):
return self.id == other.id
return False
def __req__(other: ASTStepType, self: ASTStepType) -> bool:
return ASTStepType.__eq__(self, other)
def __hash__(self: ASTStepType):
return hash(self.id)
@dataclass
class ASTConstraint:
annotation: str
expr: Expr
def __str__(self: ASTConstraint):
return (
f"Constraint(\n"
f"\t\t\t\t\tannotation='{self.annotation}',\n"
f"\t\t\t\t\texpr={self.expr}\n"
f"\t\t\t\t)"
)
def __json__(self: ASTConstraint):
return {"annotation": self.annotation, "expr": self.expr.__json__()}
@dataclass
class TransitionConstraint:
annotation: str
expr: Expr
def __str__(self: TransitionConstraint):
return f"TransitionConstraint({self.annotation})"
def __json__(self: TransitionConstraint):
return {"annotation": self.annotation, "expr": self.expr.__json__()}
@dataclass
class ForwardSignal:
id: int
phase: int
annotation: str
def __init__(self: ForwardSignal, phase: int, annotation: str):
self.id: int = uuid()
self.phase = phase
self.annotation = annotation
def __str__(self: ForwardSignal):
return f"ForwardSignal(id={self.id}, phase={self.phase}, annotation='{self.annotation}')"
def __json__(self: ForwardSignal):
return asdict(self)
@dataclass
class SharedSignal:
id: int
phase: int
annotation: str
def __init__(self: SharedSignal, phase: int, annotation: str):
self.id: int = uuid()
self.phase = phase
self.annotation = annotation
def __str__(self: SharedSignal):
return f"SharedSignal(id={self.id}, phase={self.phase}, annotation='{self.annotation}')"
def __json__(self: SharedSignal):
return asdict(self)
class ExposeOffset:
pass
class First(ExposeOffset):
def __str__(self: First):
return "First"
def __json__(self: First):
return {"First": 0}
class Last(ExposeOffset):
def __str__(self: Last):
return "Last"
def __json__(self: Last):
return {"Last": -1}
@dataclass
class Step(ExposeOffset):
offset: int
def __str__(self: Step):
return f"Step({self.offset})"
def __json__(self: Step):
return {"Step": self.offset}
@dataclass
class FixedSignal:
id: int
annotation: str
def __init__(self: FixedSignal, annotation: str):
self.id: int = uuid()
self.annotation = annotation
def __str__(self: FixedSignal):
return f"FixedSignal(id={self.id}, annotation='{self.annotation}')"
def __json__(self: FixedSignal):
return asdict(self)
@dataclass
class InternalSignal:
id: int
annotation: str
def __init__(self: InternalSignal, annotation: str):
self.id = uuid()
self.annotation = annotation
def __str__(self: InternalSignal):
return f"InternalSignal(id={self.id}, annotation='{self.annotation}')"
def __json__(self: InternalSignal):
return asdict(self)

164
python/chiquito/dsl.py Normal file
View File

@@ -0,0 +1,164 @@
from __future__ import annotations
from enum import Enum
from typing import Callable, Any
import rust_chiquito # rust bindings
import json
from chiquito import (chiquito_ast, wit_gen)
from chiquito.chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset
from chiquito.query import Internal, Forward, Queriable, Shared, Fixed
from chiquito.wit_gen import FixedGenContext, StepInstance, TraceWitness
from chiquito.cb import Constraint, Typing, ToConstraint, to_constraint
from chiquito.util import CustomEncoder, F
class CircuitMode(Enum):
NoMode = 0
SETUP = 1
Trace = 2
class Circuit:
def __init__(self: Circuit):
self.ast = ASTCircuit()
self.witness = TraceWitness()
self.rust_ast_id = 0
self.mode = CircuitMode.SETUP
self.setup()
def forward(self: Circuit, name: str) -> Forward:
assert self.mode == CircuitMode.SETUP
return Forward(self.ast.add_forward(name, 0), False)
def forward_with_phase(self: Circuit, name: str, phase: int) -> Forward:
assert self.mode == CircuitMode.SETUP
return Forward(self.ast.add_forward(name, phase), False)
def shared(self: Circuit, name: str) -> Shared:
assert self.mode == CircuitMode.SETUP
return Shared(self.ast.add_shared(name, 0), 0)
def shared_with_phase(self: Circuit, name: str, phase: int) -> Shared:
assert self.mode == CircuitMode.SETUP
return Shared(self.ast.add_shared(name, phase), 0)
def fixed(self: Circuit, name: str) -> Fixed:
assert self.mode == CircuitMode.SETUP
return Fixed(self.ast.add_fixed(name), 0)
def expose(self: Circuit, signal: Queriable, offset: ExposeOffset):
assert self.mode == CircuitMode.SETUP
if isinstance(signal, (Forward, Shared)):
self.ast.expose(signal, offset)
else:
raise TypeError(f"Can only expose ForwardSignal or SharedSignal.")
def step_type(self: Circuit, step_type: StepType) -> StepType:
assert self.mode == CircuitMode.SETUP
self.ast.add_step_type(step_type.step_type, step_type.step_type.name)
return step_type
def step_type_def(self: StepType) -> StepType:
assert self.mode == CircuitMode.SETUP
self.ast.add_step_type_def()
def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]):
self.ast.set_fixed_gen(fixed_gen_def)
def pragma_first_step(self: Circuit, step_type: StepType) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.first_step = step_type.step_type.id
def pragma_last_step(self: Circuit, step_type: StepType) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.last_step = step_type.step_type.id
def pragma_num_steps(self: Circuit, num_steps: int) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.num_steps = num_steps
def pragma_disable_q_enable(self: Circuit) -> None:
assert self.mode == CircuitMode.SETUP
self.ast.q_enable = False
def add(self: Circuit, step_type: StepType, args: Any):
assert self.mode == CircuitMode.Trace
step_instance: StepInstance = step_type.gen_step_instance(args)
self.witness.step_instances.append(step_instance)
def gen_witness(self: Circuit, args: Any) -> TraceWitness:
self.mode = CircuitMode.Trace
self.witness = TraceWitness()
self.trace(args)
self.mode = CircuitMode.NoMode
witness = self.witness
del self.witness
return witness
def get_ast_json(self: Circuit) -> str:
return json.dumps(self.ast, cls=CustomEncoder, indent=4)
def halo2_mock_prover(self: Circuit, witness: TraceWitness):
if self.rust_ast_id == 0:
ast_json: str = self.get_ast_json()
self.rust_ast_id: int = rust_chiquito.ast_to_halo2(ast_json)
witness_json: str = witness.get_witness_json()
rust_chiquito.halo2_mock_prover(witness_json, self.rust_ast_id)
def __str__(self: Circuit) -> str:
return self.ast.__str__()
class StepTypeMode(Enum):
NoMode = 0
SETUP = 1
WG = 2
class StepType:
def __init__(self: StepType, circuit: Circuit, step_type_name: str):
self.step_type = ASTStepType.new(step_type_name)
self.circuit = circuit
self.mode = StepTypeMode.SETUP
self.setup()
def gen_step_instance(self: StepType, args: Any) -> StepInstance:
self.mode = StepTypeMode.WG
self.step_instance = StepInstance.new(self.step_type.id)
self.wg(args)
self.mode = StepTypeMode.NoMode
step_instance = self.step_instance
del self.step_instance
return step_instance
def internal(self: StepType, name: str) -> Internal:
assert self.mode == StepTypeMode.SETUP
return Internal(self.step_type.add_signal(name))
def constr(self: StepType, constraint: ToConstraint):
assert self.mode == StepTypeMode.SETUP
constraint = to_constraint(constraint)
StepType.enforce_constraint_typing(constraint)
self.step_type.add_constr(constraint.annotation, constraint.expr)
def transition(self: StepType, constraint: ToConstraint):
assert self.mode == StepTypeMode.SETUP
constraint = to_constraint(constraint)
StepType.enforce_constraint_typing(constraint)
self.step_type.add_transition(constraint.annotation, constraint.expr)
def enforce_constraint_typing(constraint: Constraint):
if constraint.typing != Typing.AntiBooly:
raise ValueError(
f"Expected AntiBooly constraint, got {constraint.typing} (constraint: {constraint.annotation})"
)
def assign(self: StepType, lhs: Queriable, rhs: F):
assert self.mode == StepTypeMode.WG
self.step_instance.assign(lhs, rhs)
# TODO: Implement add_lookup after lookup abstraction PR is merged.

157
python/chiquito/expr.py Normal file
View File

@@ -0,0 +1,157 @@
from __future__ import annotations
from typing import List
from dataclasses import dataclass
from chiquito.util import F
# pub enum Expr<F> {
# Const(F),
# Sum(Vec<Expr<F>>),
# Mul(Vec<Expr<F>>),
# Neg(Box<Expr<F>>),
# Pow(Box<Expr<F>>, u32),
# Query(Queriable<F>),
# Halo2Expr(Expression<F>),
# }
@dataclass
class Expr:
def __neg__(self: Expr) -> Neg:
return Neg(self)
def __add__(self: Expr, rhs: ToExpr) -> Sum:
rhs = to_expr(rhs)
return Sum([self, rhs])
def __radd__(self: Expr, lhs: ToExpr) -> Sum:
return Expr.__add__(lhs, self)
def __sub__(self: Expr, rhs: ToExpr) -> Sum:
rhs = to_expr(rhs)
return Sum([self, Neg(rhs)])
def __rsub__(self: Expr, lhs: ToExpr) -> Sum:
return Expr.__sub__(lhs, self)
def __mul__(self: Expr, rhs: ToExpr) -> Mul:
rhs = to_expr(rhs)
return Mul([self, rhs])
def __rmul__(self: Expr, lhs: ToExpr) -> Mul:
return Expr.__mul__(lhs, self)
def __pow__(self: Expr, rhs: int) -> Pow:
return Pow(self, rhs)
@dataclass
class Const(Expr):
value: F
def __str__(self: Const) -> str:
return str(self.value)
def __json__(self):
return {"Const": self.value}
@dataclass
class Sum(Expr):
exprs: List[Expr]
def __str__(self: Sum) -> str:
result = "("
for i, expr in enumerate(self.exprs):
if type(expr) is Neg:
if i == 0:
result += "-"
else:
result += " - "
else:
if i > 0:
result += " + "
result += str(expr)
result += ")"
return result
def __json__(self):
return {"Sum": [expr.__json__() for expr in self.exprs]}
def __add__(self: Sum, rhs: ToExpr) -> Sum:
rhs = to_expr(rhs)
return Sum(self.exprs + [rhs])
def __radd__(self: Sum, lhs: ToExpr) -> Sum:
return Sum.__add__(lhs, self)
def __sub__(self: Sum, rhs: ToExpr) -> Sum:
rhs = to_expr(rhs)
return Sum(self.exprs + [Neg(rhs)])
def __rsub__(self: Sum, lhs: ToExpr) -> Sum:
return Sum.__sub__(lhs, self)
@dataclass
class Mul(Expr):
exprs: List[Expr]
def __str__(self: Mul) -> str:
return "*".join([str(expr) for expr in self.exprs])
def __json__(self):
return {"Mul": [expr.__json__() for expr in self.exprs]}
def __mul__(self: Mul, rhs: ToExpr) -> Mul:
rhs = to_expr(rhs)
return Mul(self.exprs + [rhs])
def __rmul__(self: Mul, lhs: ToExpr) -> Mul:
return Mul.__mul__(lhs, self)
@dataclass
class Neg(Expr):
expr: Expr
def __str__(self: Neg) -> str:
return "(-" + str(self.expr) + ")"
def __json__(self):
return {"Neg": self.expr.__json__()}
def __neg__(self: Neg) -> Expr:
return self.expr
@dataclass
class Pow(Expr):
expr: Expr
pow: int
def __str__(self: Pow) -> str:
return str(self.expr) + "^" + str(self.pow)
def __json__(self):
return {"Pow": [self.expr.__json__(), self.pow]}
ToExpr = Expr | int | F
def to_expr(v: ToExpr) -> Expr:
if isinstance(v, Expr):
return v
elif isinstance(v, int):
if v >= 0:
return Const(F(v))
else:
return Neg(Const(F(-v)))
elif isinstance(v, F):
return Const(v)
else:
raise TypeError(
f"Type {type(v)} is not ToExpr (one of Expr, int, F, or Constraint)."
)

138
python/chiquito/query.py Normal file
View File

@@ -0,0 +1,138 @@
from __future__ import annotations
from chiquito.expr import Expr
# Commented out to avoid circular reference
# from chiquito_ast import InternalSignal, ForwardSignal, SharedSignal, FixedSignal, ASTStepType
# pub enum Queriable<F> {
# Internal(InternalSignal),
# Forward(ForwardSignal, bool),
# Shared(SharedSignal, i32),
# Fixed(FixedSignal, i32),
# StepTypeNext(ASTStepTypeHandler),
# Halo2AdviceQuery(ImportedHalo2Advice, i32),
# Halo2FixedQuery(ImportedHalo2Fixed, i32),
# #[allow(non_camel_case_types)]
# _unaccessible(PhantomData<F>),
# }
class Queriable(Expr):
# __hash__ method is required, because Queriable is used as a key in the assignment dictionary.
def __hash__(self: Queriable):
return hash(self.uuid())
# Implemented in all children classes, and only children instances will ever be created for Queriable.
def uuid(self: Queriable) -> int:
pass
# Not defined as @dataclass, because inherited __hash__ will be set to None.
class Internal(Queriable):
def __init__(self: Internal, signal: InternalSignal):
self.signal = signal
def uuid(self: Internal) -> int:
return self.signal.id
def __str__(self: Internal) -> str:
return self.signal.annotation
def __json__(self):
return {"Internal": self.signal.__json__()}
class Forward(Queriable):
def __init__(self: Forward, signal: ForwardSignal, rotation: bool):
self.signal = signal
self.rotation = rotation
def next(self: Forward) -> Forward:
if self.rotation:
raise ValueError("Cannot rotate Forward twice.")
else:
return Forward(self.signal, True)
def uuid(self: Forward) -> int:
return self.signal.id
def __str__(self: Forward) -> str:
if not self.rotation:
return self.signal.annotation
else:
return f"next({self.signal.annotation})"
def __json__(self):
return {"Forward": [self.signal.__json__(), self.rotation]}
class Shared(Queriable):
def __init__(self: Shared, signal: SharedSignal, rotation: int):
self.signal = signal
self.rotation = rotation
def next(self: Shared) -> Shared:
return Shared(self.signal, self.rotation + 1)
def prev(self: Shared) -> Shared:
return Shared(self.signal, self.rotation - 1)
def rot(self: Shared, rotation: int) -> Shared:
return Shared(self.signal, self.rotation + rotation)
def uuid(self: Shared) -> int:
return self.signal.id
def __str__(self: Shared) -> str:
if self.rotation == 0:
return self.signal.annotation
else:
return f"{self.signal.annotation}(rot {self.rotation})"
def __json__(self):
return {"Shared": [self.signal.__json__(), self.rotation]}
class Fixed(Queriable):
def __init__(self: Fixed, signal: FixedSignal, rotation: int):
self.signal = signal
self.rotation = rotation
def next(self: Fixed) -> Fixed:
return Fixed(self.signal, self.rotation + 1)
def prev(self: Fixed) -> Fixed:
return Fixed(self.signal, self.rotation - 1)
def rot(self: Fixed, rotation: int) -> Fixed:
return Fixed(self.signal, self.rotation + rotation)
def uuid(self: Fixed) -> int:
return self.signal.id
def __str__(self: Fixed) -> str:
if self.rotation == 0:
return self.signal.annotation
else:
return f"{self.signal.annotation}(rot {self.rotation})"
def __json__(self):
return {"Fixed": [self.signal.__json__(), self.rotation]}
class StepTypeNext(Queriable):
def __init__(self: StepTypeNext, step_type: ASTStepType):
self.step_type = step_type
def uuid(self: ASTStepType) -> int:
return self.id
def __str__(self: ASTStepType) -> str:
return self.name
def __json__(self):
return {
"StepTypeNext": {"id": self.step_type.id, "annotation": self.step_type.name}
}

31
python/chiquito/util.py Normal file
View File

@@ -0,0 +1,31 @@
from __future__ import annotations
from py_ecc import bn128
from uuid import uuid1
import json
F = bn128.FQ
def json_method(self: F):
# Convert the integer to a byte array
byte_array = self.n.to_bytes(32, "little")
# Split into four 64-bit integers
ints = [int.from_bytes(byte_array[i * 8 : i * 8 + 8], "little") for i in range(4)]
return ints
F.__json__ = json_method
class CustomEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, "__json__"):
return obj.__json__()
return super().default(obj)
# int field is the u128 version of uuid.
def uuid() -> int:
return uuid1(node=int.from_bytes([10, 10, 10, 10, 10, 10], byteorder="little")).int

122
python/chiquito/wit_gen.py Normal file
View File

@@ -0,0 +1,122 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Callable, Any
import json
from chiquito.query import Queriable, Fixed
from chiquito.util import F, CustomEncoder
# Commented out to avoid circular reference
# from dsl import Circuit, StepType
@dataclass
class StepInstance:
step_type_uuid: int = 0
assignments: Dict[Queriable, F] = field(default_factory=dict)
def new(step_type_uuid: int) -> StepInstance:
return StepInstance(step_type_uuid, {})
def assign(self: StepInstance, lhs: Queriable, rhs: F):
self.assignments[lhs] = rhs
def __str__(self: StepInstance):
assignments_str = (
"\n\t\t\t\t"
+ ",\n\t\t\t\t".join(
f"{str(lhs)} = {rhs}" for (lhs, rhs) in self.assignments.items()
)
+ "\n\t\t\t"
if self.assignments
else ""
)
return (
f"StepInstance(\n"
f"\t\t\tstep_type_uuid={self.step_type_uuid},\n"
f"\t\t\tassignments={{{assignments_str}}},\n"
f"\t\t)"
)
# For assignments, return "uuid: (Queriable, F)" rather than "Queriable: F", because JSON doesn't accept Dict as key.
def __json__(self: StepInstance):
return {
"step_type_uuid": self.step_type_uuid,
"assignments": {
lhs.uuid(): [lhs, rhs] for (lhs, rhs) in self.assignments.items()
},
}
Witness = List[StepInstance]
@dataclass
class TraceWitness:
step_instances: Witness = field(default_factory=list)
def __str__(self: TraceWitness):
step_instances_str = (
"\n\t\t"
+ ",\n\t\t".join(
str(step_instance) for step_instance in self.step_instances
)
+ "\n\t"
if self.step_instances
else ""
)
return f"TraceWitness(\n" f"\tstep_instances={{{step_instances_str}}},\n" f")"
def __json__(self: TraceWitness):
return {
"step_instances": [
step_instance.__json__() for step_instance in self.step_instances
]
}
def get_witness_json(self: TraceWitness) -> str:
return json.dumps(self, cls=CustomEncoder, indent=4)
def evil_witness_test(
self: TraceWitness,
step_instance_indices: List[int],
assignment_indices: List[int],
rhs: List[F],
) -> TraceWitness:
if not len(step_instance_indices) == len(assignment_indices) == len(rhs):
raise ValueError(f"`evil_witness_test` inputs have different lengths.")
new_step_instances = self.step_instances.copy()
for i in range(len(step_instance_indices)):
keys = list(new_step_instances[step_instance_indices[i]].assignments.keys())
new_step_instances[step_instance_indices[i]].assignments[
keys[assignment_indices[i]]
] = rhs[i]
return TraceWitness(new_step_instances)
FixedAssigment = Dict[Queriable, List[F]]
@dataclass
class FixedGenContext:
assignments: FixedAssigment = field(default_factory=dict)
num_steps: int = 0
def new(num_steps: int) -> FixedGenContext:
return FixedGenContext({}, num_steps)
def assign(self: FixedGenContext, offset: int, lhs: Queriable, rhs: F):
if not FixedGenContext.is_fixed_queriable(lhs):
raise ValueError(f"Cannot assign to non-fixed signal.")
if lhs in self.assignments.keys():
self.assignments[lhs][offset] = rhs
else:
self.assignments[lhs] = [F.zero()] * self.num_steps
self.assignments[lhs][offset] = rhs
def is_fixed_queriable(q: Queriable) -> bool:
match q.enum:
case Fixed(_, _):
return True
case _:
return False