fixed type and field namings

This commit is contained in:
Steve Wang
2023-07-28 10:03:16 +08:00
parent f84a85a88d
commit 85afef7ce1
2 changed files with 36 additions and 36 deletions

View File

@@ -11,7 +11,7 @@ from query import Queriable
# ast #
#######
# pub struct AST<F, TraceArgs> {
# pub struct ASTCircuit<F, TraceArgs> {
# pub step_types: HashMap<u32, Rc<ASTStepType<F>>>,
# pub forward_signals: Vec<ForwardSignal>,
@@ -33,7 +33,7 @@ from query import Queriable
@dataclass
class AST:
class ASTCircuit:
step_types: Dict[int, ASTStepType] = field(default_factory=dict)
forward_signals: List[ForwardSignal] = field(default_factory=list)
shared_signals: List[SharedSignal] = field(default_factory=list)
@@ -48,7 +48,7 @@ class AST:
q_enable: bool = True
id: int = uuid()
def __str__(self: AST):
def __str__(self: ASTCircuit):
step_types_str = (
"\n\t\t"
+ ",\n\t\t".join(f"{k}: {v}" for k, v in self.step_types.items())
@@ -87,7 +87,7 @@ class AST:
)
return (
f"AST(\n"
f"ASTCircuit(\n"
f"\tstep_types={{{step_types_str}}},\n"
f"\tforward_signals=[{forward_signals_str}],\n"
f"\tshared_signals=[{shared_signals_str}],\n"
@@ -103,7 +103,7 @@ class AST:
f")"
)
def __json__(self: AST):
def __json__(self: ASTCircuit):
return {
"step_types": {k: v.__json__() for k, v in self.step_types.items()},
"forward_signals": [x.__json__() for x in self.forward_signals],
@@ -121,42 +121,42 @@ class AST:
"id": self.id,
}
def add_forward(self: AST, name: str, phase: int) -> ForwardSignal:
def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal:
signal = ForwardSignal(phase, name)
self.forward_signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_shared(self: AST, name: str, phase: int) -> SharedSignal:
def add_shared(self: ASTCircuit, name: str, phase: int) -> SharedSignal:
signal = SharedSignal(phase, name)
self.shared_signals.append(signal)
self.annotations[signal.id] = name
return signal
def add_fixed(self: AST, name: str) -> FixedSignal:
def add_fixed(self: ASTCircuit, name: str) -> FixedSignal:
signal = FixedSignal(name)
self.fixed_signals.append(signal)
self.annotations[signal.id] = name
return signal
def expose(self: AST, signal: Queriable, offset: ExposeOffset):
def expose(self: ASTCircuit, signal: Queriable, offset: ExposeOffset):
self.exposed.append((signal, offset))
def add_step_type(self: AST, step_type: ASTStepType, name: str):
def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str):
self.annotations[step_type.id] = name
self.step_types[step_type.id] = step_type
def set_trace(
self: AST, trace_def: Callable[[TraceContext, Any], None]
self: ASTCircuit, trace_def: Callable[[TraceContext, Any], None]
): # TraceArgs are Any.
if self.trace is not None:
raise Exception("AST cannot have more than one trace generator.")
raise Exception("ASTCircuit cannot have more than one trace generator.")
else:
self.trace = trace_def
def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]):
if self.fixed_gen is not None:
raise Exception("AST cannot have more than one fixed generator.")
raise Exception("ASTCircuit cannot have more than one fixed generator.")
else:
self.fixed_gen = fixed_gen_def

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass
import rust_chiquito # rust bindings
import json
from chiquito_ast import AST, ASTStepType, ExposeOffset
from chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset
from query import Internal, Forward, Queriable, Shared, Fixed
from wit_gen import FixedGenContext, TraceContext, StepInstance
from cb import Constraint, Typing, ToConstraint, to_constraint
@@ -25,40 +25,40 @@ class CircuitMode(Enum):
class Circuit:
def __init__(self: Circuit):
self.circuit = AST()
self.ast = ASTCircuit()
self.trace_context = TraceContext()
self.rust_ast_id = 0
self.mode = CircuitMode.SETUP
self.setup()
self.mode = CircuitMode.Trace
self.circuit.set_trace(self.trace)
self.ast.set_trace(self.trace)
self.trace()
# self.mode = CircuitMode.NoMode
def forward(self: Circuit, name: str) -> Forward:
assert self.mode == CircuitMode.SETUP
return Forward(self.circuit.add_forward(name, 0), False)
return Forward(self.ast.add_forward(name, 0), False)
def forward_with_phase(self: Circuit, name: str, phase: int) -> Forward:
assert self.mode == CircuitMode.SETUP
return Forward(self.circuit.add_forward(name, phase), False)
return Forward(self.ast.add_forward(name, phase), False)
def shared(self: Circuit, name: str) -> Shared:
assert self.mode == CircuitMode.SETUP
return Shared(self.circuit.add_shared(name, 0), 0)
return Shared(self.ast.add_shared(name, 0), 0)
def shared_with_phase(self: Circuit, name: str, phase: int) -> Shared:
assert self.mode == CircuitMode.SETUP
return Shared(self.circuit.add_shared(name, phase), 0)
return Shared(self.ast.add_shared(name, phase), 0)
def fixed(self: Circuit, name: str) -> Fixed:
assert self.mode == CircuitMode.SETUP
return Fixed(self.circuit.add_fixed(name), 0)
return Fixed(self.ast.add_fixed(name), 0)
def expose(self: Circuit, signal: Queriable, offset: ExposeOffset):
assert self.mode == CircuitMode.SETUP
if isinstance(signal, (Forward, Shared)):
self.circuit.expose(signal, offset)
self.ast.expose(signal, offset)
else:
raise TypeError(f"Can only expose ForwardSignal or SharedSignal.")
@@ -66,38 +66,38 @@ class Circuit:
def step_type(self: Circuit, step_type: StepType) -> StepType:
assert self.mode == CircuitMode.SETUP
self.circuit.add_step_type(step_type.step_type, step_type.step_type.name)
self.ast.add_step_type(step_type.step_type, step_type.step_type.name)
return step_type
def step_type_def(self: StepType) -> StepType:
assert self.mode == CircuitMode.SETUP
self.circuit.add_step_type_def()
self.ast.add_step_type_def()
# def trace(
# self: Circuit, trace_def: Callable[[TraceContext, Any], None]
# ): # TraceArgs are Any.
# self.circuit.set_trace(trace_def)
# self.ast.set_trace(trace_def)
def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]):
self.circuit.set_fixed_gen(fixed_gen_def)
self.ast.set_fixed_gen(fixed_gen_def)
def pragma_first_step(self: Circuit, step_type: StepType) -> None:
assert self.mode == CircuitMode.SETUP
self.circuit.first_step = step_type.step_type.id
self.ast.first_step = step_type.step_type.id
print(f"first step id: {step_type.step_type.id}")
def pragma_last_step(self: Circuit, step_type: StepType) -> None:
assert self.mode == CircuitMode.SETUP
self.circuit.last_step = step_type.step_type.id
self.ast.last_step = step_type.step_type.id
print(f"last step id: {step_type.step_type.id}")
def pragma_num_steps(self: Circuit, num_steps: int) -> None:
assert self.mode == CircuitMode.SETUP
self.circuit.num_steps = num_steps
self.ast.num_steps = num_steps
def pragma_disable_q_enable(self: Circuit) -> None:
assert self.mode == CircuitMode.SETUP
self.circuit.q_enable = False
self.ast.q_enable = False
def add(self: Circuit, step_type: StepType, args: Any):
print(self)
@@ -107,13 +107,13 @@ class Circuit:
self.trace_context.add(self, step_type, args)
def print_ast(self: Circuit):
print("Print AST using custom __str__ method in python:")
print(self.circuit)
print("Print ASTCircuit using custom __str__ method in python:")
print(self.ast)
def get_ast_json(self: Circuit, print_json=False) -> str:
ast_json: str = json.dumps(self.circuit, cls=CustomEncoder, indent=4)
ast_json: str = json.dumps(self.ast, cls=CustomEncoder, indent=4)
if print_json:
print("Print AST using __json__ method in python:")
print("Print ASTCircuit using __json__ method in python:")
print(ast_json)
return ast_json
@@ -134,7 +134,7 @@ class Circuit:
ast_json: str = self.get_ast_json()
if print_ast:
print(
"Call rust bindings, parse json to Chiquito AST, and print using Debug trait:"
"Call rust bindings, parse json to Chiquito ASTCircuit, and print using Debug trait:"
)
print(rust_chiquito.convert_and_print_ast(ast_json))
@@ -173,7 +173,7 @@ class StepTypeMode(Enum):
class StepType:
def __init__(self: StepType, circuit, step_type_name: str):
def __init__(self: StepType, circuit: Circuit, step_type_name: str):
self.step_type = ASTStepType.new(step_type_name)
self.circuit = circuit
self.step_instance = StepInstance.new(self.step_type.id)