diff --git a/pychiquito/chiquito_ast.py b/pychiquito/chiquito_ast.py index b882cd8..a81cc10 100644 --- a/pychiquito/chiquito_ast.py +++ b/pychiquito/chiquito_ast.py @@ -11,7 +11,7 @@ from query import Queriable # ast # ####### -# pub struct AST { +# pub struct ASTCircuit { # pub step_types: HashMap>>, # pub forward_signals: Vec, @@ -33,7 +33,7 @@ from query import Queriable @dataclass -class AST: +class ASTCircuit: step_types: Dict[int, ASTStepType] = field(default_factory=dict) forward_signals: List[ForwardSignal] = field(default_factory=list) shared_signals: List[SharedSignal] = field(default_factory=list) @@ -48,7 +48,7 @@ class AST: q_enable: bool = True id: int = uuid() - def __str__(self: AST): + def __str__(self: ASTCircuit): step_types_str = ( "\n\t\t" + ",\n\t\t".join(f"{k}: {v}" for k, v in self.step_types.items()) @@ -87,7 +87,7 @@ class AST: ) return ( - f"AST(\n" + f"ASTCircuit(\n" f"\tstep_types={{{step_types_str}}},\n" f"\tforward_signals=[{forward_signals_str}],\n" f"\tshared_signals=[{shared_signals_str}],\n" @@ -103,7 +103,7 @@ class AST: f")" ) - def __json__(self: AST): + def __json__(self: ASTCircuit): return { "step_types": {k: v.__json__() for k, v in self.step_types.items()}, "forward_signals": [x.__json__() for x in self.forward_signals], @@ -121,42 +121,42 @@ class AST: "id": self.id, } - def add_forward(self: AST, name: str, phase: int) -> ForwardSignal: + def add_forward(self: ASTCircuit, name: str, phase: int) -> ForwardSignal: signal = ForwardSignal(phase, name) self.forward_signals.append(signal) self.annotations[signal.id] = name return signal - def add_shared(self: AST, name: str, phase: int) -> SharedSignal: + def add_shared(self: ASTCircuit, name: str, phase: int) -> SharedSignal: signal = SharedSignal(phase, name) self.shared_signals.append(signal) self.annotations[signal.id] = name return signal - def add_fixed(self: AST, name: str) -> FixedSignal: + def add_fixed(self: ASTCircuit, name: str) -> FixedSignal: signal = FixedSignal(name) self.fixed_signals.append(signal) self.annotations[signal.id] = name return signal - def expose(self: AST, signal: Queriable, offset: ExposeOffset): + def expose(self: ASTCircuit, signal: Queriable, offset: ExposeOffset): self.exposed.append((signal, offset)) - def add_step_type(self: AST, step_type: ASTStepType, name: str): + def add_step_type(self: ASTCircuit, step_type: ASTStepType, name: str): self.annotations[step_type.id] = name self.step_types[step_type.id] = step_type def set_trace( - self: AST, trace_def: Callable[[TraceContext, Any], None] + self: ASTCircuit, trace_def: Callable[[TraceContext, Any], None] ): # TraceArgs are Any. if self.trace is not None: - raise Exception("AST cannot have more than one trace generator.") + raise Exception("ASTCircuit cannot have more than one trace generator.") else: self.trace = trace_def def set_fixed_gen(self, fixed_gen_def: Callable[[FixedGenContext], None]): if self.fixed_gen is not None: - raise Exception("AST cannot have more than one fixed generator.") + raise Exception("ASTCircuit cannot have more than one fixed generator.") else: self.fixed_gen = fixed_gen_def diff --git a/pychiquito/dsl.py b/pychiquito/dsl.py index 3b7130c..00dd4ca 100644 --- a/pychiquito/dsl.py +++ b/pychiquito/dsl.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import rust_chiquito # rust bindings import json -from chiquito_ast import AST, ASTStepType, ExposeOffset +from chiquito_ast import ASTCircuit, ASTStepType, ExposeOffset from query import Internal, Forward, Queriable, Shared, Fixed from wit_gen import FixedGenContext, TraceContext, StepInstance from cb import Constraint, Typing, ToConstraint, to_constraint @@ -25,40 +25,40 @@ class CircuitMode(Enum): class Circuit: def __init__(self: Circuit): - self.circuit = AST() + self.ast = ASTCircuit() self.trace_context = TraceContext() self.rust_ast_id = 0 self.mode = CircuitMode.SETUP self.setup() self.mode = CircuitMode.Trace - self.circuit.set_trace(self.trace) + self.ast.set_trace(self.trace) self.trace() # self.mode = CircuitMode.NoMode def forward(self: Circuit, name: str) -> Forward: assert self.mode == CircuitMode.SETUP - return Forward(self.circuit.add_forward(name, 0), False) + return Forward(self.ast.add_forward(name, 0), False) def forward_with_phase(self: Circuit, name: str, phase: int) -> Forward: assert self.mode == CircuitMode.SETUP - return Forward(self.circuit.add_forward(name, phase), False) + return Forward(self.ast.add_forward(name, phase), False) def shared(self: Circuit, name: str) -> Shared: assert self.mode == CircuitMode.SETUP - return Shared(self.circuit.add_shared(name, 0), 0) + return Shared(self.ast.add_shared(name, 0), 0) def shared_with_phase(self: Circuit, name: str, phase: int) -> Shared: assert self.mode == CircuitMode.SETUP - return Shared(self.circuit.add_shared(name, phase), 0) + return Shared(self.ast.add_shared(name, phase), 0) def fixed(self: Circuit, name: str) -> Fixed: assert self.mode == CircuitMode.SETUP - return Fixed(self.circuit.add_fixed(name), 0) + return Fixed(self.ast.add_fixed(name), 0) def expose(self: Circuit, signal: Queriable, offset: ExposeOffset): assert self.mode == CircuitMode.SETUP if isinstance(signal, (Forward, Shared)): - self.circuit.expose(signal, offset) + self.ast.expose(signal, offset) else: raise TypeError(f"Can only expose ForwardSignal or SharedSignal.") @@ -66,38 +66,38 @@ class Circuit: def step_type(self: Circuit, step_type: StepType) -> StepType: assert self.mode == CircuitMode.SETUP - self.circuit.add_step_type(step_type.step_type, step_type.step_type.name) + self.ast.add_step_type(step_type.step_type, step_type.step_type.name) return step_type def step_type_def(self: StepType) -> StepType: assert self.mode == CircuitMode.SETUP - self.circuit.add_step_type_def() + self.ast.add_step_type_def() # def trace( # self: Circuit, trace_def: Callable[[TraceContext, Any], None] # ): # TraceArgs are Any. - # self.circuit.set_trace(trace_def) + # self.ast.set_trace(trace_def) def fixed_gen(self: Circuit, fixed_gen_def: Callable[[FixedGenContext], None]): - self.circuit.set_fixed_gen(fixed_gen_def) + self.ast.set_fixed_gen(fixed_gen_def) def pragma_first_step(self: Circuit, step_type: StepType) -> None: assert self.mode == CircuitMode.SETUP - self.circuit.first_step = step_type.step_type.id + self.ast.first_step = step_type.step_type.id print(f"first step id: {step_type.step_type.id}") def pragma_last_step(self: Circuit, step_type: StepType) -> None: assert self.mode == CircuitMode.SETUP - self.circuit.last_step = step_type.step_type.id + self.ast.last_step = step_type.step_type.id print(f"last step id: {step_type.step_type.id}") def pragma_num_steps(self: Circuit, num_steps: int) -> None: assert self.mode == CircuitMode.SETUP - self.circuit.num_steps = num_steps + self.ast.num_steps = num_steps def pragma_disable_q_enable(self: Circuit) -> None: assert self.mode == CircuitMode.SETUP - self.circuit.q_enable = False + self.ast.q_enable = False def add(self: Circuit, step_type: StepType, args: Any): print(self) @@ -107,13 +107,13 @@ class Circuit: self.trace_context.add(self, step_type, args) def print_ast(self: Circuit): - print("Print AST using custom __str__ method in python:") - print(self.circuit) + print("Print ASTCircuit using custom __str__ method in python:") + print(self.ast) def get_ast_json(self: Circuit, print_json=False) -> str: - ast_json: str = json.dumps(self.circuit, cls=CustomEncoder, indent=4) + ast_json: str = json.dumps(self.ast, cls=CustomEncoder, indent=4) if print_json: - print("Print AST using __json__ method in python:") + print("Print ASTCircuit using __json__ method in python:") print(ast_json) return ast_json @@ -134,7 +134,7 @@ class Circuit: ast_json: str = self.get_ast_json() if print_ast: print( - "Call rust bindings, parse json to Chiquito AST, and print using Debug trait:" + "Call rust bindings, parse json to Chiquito ASTCircuit, and print using Debug trait:" ) print(rust_chiquito.convert_and_print_ast(ast_json)) @@ -173,7 +173,7 @@ class StepTypeMode(Enum): class StepType: - def __init__(self: StepType, circuit, step_type_name: str): + def __init__(self: StepType, circuit: Circuit, step_type_name: str): self.step_type = ASTStepType.new(step_type_name) self.circuit = circuit self.step_instance = StepInstance.new(self.step_type.id)