mirror of
https://github.com/qwang98/PyChiquito.git
synced 2026-04-22 03:00:16 -04:00
fixed type and field namings
This commit is contained in:
@@ -11,7 +11,7 @@ from query import Queriable
|
||||
# ast #
|
||||
#######
|
||||
|
||||
# pub struct AST<F, TraceArgs> {
|
||||
# pub struct ASTCircuit<F, TraceArgs> {
|
||||
# pub step_types: HashMap<u32, Rc<ASTStepType<F>>>,
|
||||
|
||||
# pub forward_signals: Vec<ForwardSignal>,
|
||||
@@ -33,7 +33,7 @@ from query import Queriable
|
||||
|
||||
|
||||
@dataclass
|
||||
class AST:
|
||||
class ASTCircuit:
|
||||
step_types: Dict[int, ASTStepType] = field(default_factory=dict)
|
||||
forward_signals: List[ForwardSignal] = field(default_factory=list)
|
||||
shared_signals: List[SharedSignal] = field(default_factory=list)
|
||||
@@ -48,7 +48,7 @@ class AST:
|
||||
q_enable: bool = True
|
||||
id: int = uuid()
|
||||
|
||||
def __str__(self: AST):
|
||||
def __str__(self: ASTCircuit):
|
||||
step_types_str = (
|
||||
"\n\t\t"
|
||||
+ ",\n\t\t".join(f"{k}: {v}" for k, v in self.step_types.items())
|
||||
@@ -87,7 +87,7 @@ class AST:
|
||||
)
|
||||
|
||||
return (
|
||||
f"AST(\n"
|
||||
f"ASTCircuit(\n"
|
||||
f"\tstep_types={{{step_types_str}}},\n"
|
||||
f"\tforward_signals=[{forward_signals_str}],\n"
|
||||
f"\tshared_signals=[{shared_signals_str}],\n"
|
||||
@@ -103,7 +103,7 @@ class AST:
|
||||
f")"
|
||||
)
|
||||
|
||||
def __json__(self: AST):
|
||||
def __json__(self: ASTCircuit):
|
||||
return {
|
||||
"step_types": {k: v.__json__() for k, v in self.step_types.items()},
|
||||
"forward_signals": [x.__json__() for x in self.forward_signals],
|
||||
@@ -121,42 +121,42 @@ class AST:
|
||||
"id": self.id,
|
||||
}
|
||||
|
||||
def add_forward(self: AST, name: str, phase: int) -> ForwardSignal:
|
||||
def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal:
|
||||
signal = ForwardSignal(phase, name)
|
||||
self.forward_signals.append(signal)
|
||||
self.annotations[signal.id] = name
|
||||
return signal
|
||||
|
||||
def add_shared(self: AST, name: str, phase: int) -> SharedSignal:
|
||||
def add_shared(self: ASTCircuit, name: str, phase: int) -> SharedSignal:
|
||||
signal = SharedSignal(phase, name)
|
||||
self.shared_signals.append(signal)
|
||||
self.annotations[signal.id] = name
|
||||
return signal
|
||||
|
||||
def add_fixed(self: AST, name: str) -> FixedSignal:
|
||||
def add_fixed(self: ASTCircuit, name: str) -> FixedSignal:
|
||||
signal = FixedSignal(name)
|
||||
self.fixed_signals.append(signal)
|
||||
self.annotations[signal.id] = name
|
||||
return signal
|
||||
|
||||
def expose(self: AST, signal: Queriable, offset: ExposeOffset):
|
||||
def expose(self: ASTCircuit, signal: Queriable, offset: ExposeOffset):
|
||||
self.exposed.append((signal, offset))
|
||||
|
||||
def add_step_type(self: AST, step_type: ASTStepType, name: str):
|
||||
def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str):
|
||||
self.annotations[step_type.id] = name
|
||||
self.step_types[step_type.id] = step_type
|
||||
|
||||
def set_trace(
|
||||
self: AST, trace_def: Callable[[TraceContext, Any], None]
|
||||
self: ASTCircuit, trace_def: Callable[[TraceContext, Any], None]
|
||||
): # TraceArgs are Any.
|
||||
if self.trace is not None:
|
||||
raise Exception("AST cannot have more than one trace generator.")
|
||||
raise Exception("ASTCircuit cannot have more than one trace generator.")
|
||||
else:
|
||||
self.trace = trace_def
|
||||
|
||||
def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]):
|
||||
if self.fixed_gen is not None:
|
||||
raise Exception("AST cannot have more than one fixed generator.")
|
||||
raise Exception("ASTCircuit cannot have more than one fixed generator.")
|
||||
else:
|
||||
self.fixed_gen = fixed_gen_def
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
||||
import rust_chiquito # rust bindings
|
||||
import json
|
||||
|
||||
from chiquito_ast import AST, ASTStepType, ExposeOffset
|
||||
from chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset
|
||||
from query import Internal, Forward, Queriable, Shared, Fixed
|
||||
from wit_gen import FixedGenContext, TraceContext, StepInstance
|
||||
from cb import Constraint, Typing, ToConstraint, to_constraint
|
||||
@@ -25,40 +25,40 @@ class CircuitMode(Enum):
|
||||
|
||||
class Circuit:
|
||||
def __init__(self: Circuit):
|
||||
self.circuit = AST()
|
||||
self.ast = ASTCircuit()
|
||||
self.trace_context = TraceContext()
|
||||
self.rust_ast_id = 0
|
||||
self.mode = CircuitMode.SETUP
|
||||
self.setup()
|
||||
self.mode = CircuitMode.Trace
|
||||
self.circuit.set_trace(self.trace)
|
||||
self.ast.set_trace(self.trace)
|
||||
self.trace()
|
||||
# self.mode = CircuitMode.NoMode
|
||||
|
||||
def forward(self: Circuit, name: str) -> Forward:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
return Forward(self.circuit.add_forward(name, 0), False)
|
||||
return Forward(self.ast.add_forward(name, 0), False)
|
||||
|
||||
def forward_with_phase(self: Circuit, name: str, phase: int) -> Forward:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
return Forward(self.circuit.add_forward(name, phase), False)
|
||||
return Forward(self.ast.add_forward(name, phase), False)
|
||||
|
||||
def shared(self: Circuit, name: str) -> Shared:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
return Shared(self.circuit.add_shared(name, 0), 0)
|
||||
return Shared(self.ast.add_shared(name, 0), 0)
|
||||
|
||||
def shared_with_phase(self: Circuit, name: str, phase: int) -> Shared:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
return Shared(self.circuit.add_shared(name, phase), 0)
|
||||
return Shared(self.ast.add_shared(name, phase), 0)
|
||||
|
||||
def fixed(self: Circuit, name: str) -> Fixed:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
return Fixed(self.circuit.add_fixed(name), 0)
|
||||
return Fixed(self.ast.add_fixed(name), 0)
|
||||
|
||||
def expose(self: Circuit, signal: Queriable, offset: ExposeOffset):
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
if isinstance(signal, (Forward, Shared)):
|
||||
self.circuit.expose(signal, offset)
|
||||
self.ast.expose(signal, offset)
|
||||
else:
|
||||
raise TypeError(f"Can only expose ForwardSignal or SharedSignal.")
|
||||
|
||||
@@ -66,38 +66,38 @@ class Circuit:
|
||||
|
||||
def step_type(self: Circuit, step_type: StepType) -> StepType:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
self.circuit.add_step_type(step_type.step_type, step_type.step_type.name)
|
||||
self.ast.add_step_type(step_type.step_type, step_type.step_type.name)
|
||||
return step_type
|
||||
|
||||
def step_type_def(self: StepType) -> StepType:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
self.circuit.add_step_type_def()
|
||||
self.ast.add_step_type_def()
|
||||
|
||||
# def trace(
|
||||
# self: Circuit, trace_def: Callable[[TraceContext, Any], None]
|
||||
# ): # TraceArgs are Any.
|
||||
# self.circuit.set_trace(trace_def)
|
||||
# self.ast.set_trace(trace_def)
|
||||
|
||||
def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]):
|
||||
self.circuit.set_fixed_gen(fixed_gen_def)
|
||||
self.ast.set_fixed_gen(fixed_gen_def)
|
||||
|
||||
def pragma_first_step(self: Circuit, step_type: StepType) -> None:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
self.circuit.first_step = step_type.step_type.id
|
||||
self.ast.first_step = step_type.step_type.id
|
||||
print(f"first step id: {step_type.step_type.id}")
|
||||
|
||||
def pragma_last_step(self: Circuit, step_type: StepType) -> None:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
self.circuit.last_step = step_type.step_type.id
|
||||
self.ast.last_step = step_type.step_type.id
|
||||
print(f"last step id: {step_type.step_type.id}")
|
||||
|
||||
def pragma_num_steps(self: Circuit, num_steps: int) -> None:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
self.circuit.num_steps = num_steps
|
||||
self.ast.num_steps = num_steps
|
||||
|
||||
def pragma_disable_q_enable(self: Circuit) -> None:
|
||||
assert self.mode == CircuitMode.SETUP
|
||||
self.circuit.q_enable = False
|
||||
self.ast.q_enable = False
|
||||
|
||||
def add(self: Circuit, step_type: StepType, args: Any):
|
||||
print(self)
|
||||
@@ -107,13 +107,13 @@ class Circuit:
|
||||
self.trace_context.add(self, step_type, args)
|
||||
|
||||
def print_ast(self: Circuit):
|
||||
print("Print AST using custom __str__ method in python:")
|
||||
print(self.circuit)
|
||||
print("Print ASTCircuit using custom __str__ method in python:")
|
||||
print(self.ast)
|
||||
|
||||
def get_ast_json(self: Circuit, print_json=False) -> str:
|
||||
ast_json: str = json.dumps(self.circuit, cls=CustomEncoder, indent=4)
|
||||
ast_json: str = json.dumps(self.ast, cls=CustomEncoder, indent=4)
|
||||
if print_json:
|
||||
print("Print AST using __json__ method in python:")
|
||||
print("Print ASTCircuit using __json__ method in python:")
|
||||
print(ast_json)
|
||||
return ast_json
|
||||
|
||||
@@ -134,7 +134,7 @@ class Circuit:
|
||||
ast_json: str = self.get_ast_json()
|
||||
if print_ast:
|
||||
print(
|
||||
"Call rust bindings, parse json to Chiquito AST, and print using Debug trait:"
|
||||
"Call rust bindings, parse json to Chiquito ASTCircuit, and print using Debug trait:"
|
||||
)
|
||||
print(rust_chiquito.convert_and_print_ast(ast_json))
|
||||
|
||||
@@ -173,7 +173,7 @@ class StepTypeMode(Enum):
|
||||
|
||||
|
||||
class StepType:
|
||||
def __init__(self: StepType, circuit, step_type_name: str):
|
||||
def __init__(self: StepType, circuit: Circuit, step_type_name: str):
|
||||
self.step_type = ASTStepType.new(step_type_name)
|
||||
self.circuit = circuit
|
||||
self.step_instance = StepInstance.new(self.step_type.id)
|
||||
|
||||
Reference in New Issue
Block a user