mirror of
https://github.com/0xPARC/plonkathon.git
synced 2026-01-13 07:37:57 -05:00
167 lines
5.8 KiB
Python
167 lines
5.8 KiB
Python
from utils import *
|
|
from .utils import *
|
|
from typing import Optional
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class GateWires:
|
|
"""Variable names for Left, Right, and Output wires."""
|
|
|
|
L: Optional[str]
|
|
R: Optional[str]
|
|
O: Optional[str]
|
|
|
|
def as_list(self) -> list[Optional[str]]:
|
|
return [self.L, self.R, self.O]
|
|
|
|
|
|
@dataclass
|
|
class Gate:
|
|
"""Gate polynomial"""
|
|
|
|
L: Scalar
|
|
R: Scalar
|
|
M: Scalar
|
|
O: Scalar
|
|
C: Scalar
|
|
|
|
|
|
@dataclass
|
|
class AssemblyEqn:
|
|
"""Assembly equation mapping wires to coefficients."""
|
|
|
|
wires: GateWires
|
|
coeffs: dict[Optional[str], int]
|
|
|
|
def L(self) -> Scalar:
|
|
return Scalar(-self.coeffs.get(self.wires.L, 0))
|
|
|
|
def R(self) -> Scalar:
|
|
if self.wires.R != self.wires.L:
|
|
return Scalar(-self.coeffs.get(self.wires.R, 0))
|
|
return Scalar(0)
|
|
|
|
def C(self) -> Scalar:
|
|
return Scalar(-self.coeffs.get("", 0))
|
|
|
|
def O(self) -> Scalar:
|
|
return Scalar(self.coeffs.get("$output_coeff", 1))
|
|
|
|
def M(self) -> Scalar:
|
|
if None not in self.wires.as_list():
|
|
return Scalar(
|
|
-self.coeffs.get(get_product_key(self.wires.L, self.wires.R), 0)
|
|
)
|
|
return Scalar(0)
|
|
|
|
def gate(self) -> Gate:
|
|
return Gate(self.L(), self.R(), self.M(), self.O(), self.C())
|
|
|
|
|
|
# Converts a arithmetic expression containing numbers, variables and {+, -, *}
|
|
# into a mapping of term to coefficient
|
|
#
|
|
# For example:
|
|
# ['a', '+', 'b', '*', 'c', '*', '5'] becomes {'a': 1, 'b*c': 5}
|
|
#
|
|
# Note that this is a recursive algo, so the input can be a mix of tokens and
|
|
# mapping expressions
|
|
#
|
|
def evaluate(exprs: list[str], first_is_negative=False) -> dict[Optional[str], int]:
|
|
# Splits by + and - first, then *, to follow order of operations
|
|
# The first_is_negative flag helps us correctly interpret expressions
|
|
# like 6000 - 700 - 80 + 9 (that's 5229)
|
|
if "+" in exprs:
|
|
L = evaluate(exprs[: exprs.index("+")], first_is_negative)
|
|
R = evaluate(exprs[exprs.index("+") + 1 :], False)
|
|
return {x: L.get(x, 0) + R.get(x, 0) for x in set(L.keys()).union(R.keys())}
|
|
elif "-" in exprs:
|
|
L = evaluate(exprs[: exprs.index("-")], first_is_negative)
|
|
R = evaluate(exprs[exprs.index("-") + 1 :], True)
|
|
return {x: L.get(x, 0) + R.get(x, 0) for x in set(L.keys()).union(R.keys())}
|
|
elif "*" in exprs:
|
|
L = evaluate(exprs[: exprs.index("*")], first_is_negative)
|
|
R = evaluate(exprs[exprs.index("*") + 1 :], first_is_negative)
|
|
o = {}
|
|
for k1 in L.keys():
|
|
for k2 in R.keys():
|
|
o[get_product_key(k1, k2)] = L[k1] * R[k2]
|
|
return o
|
|
elif len(exprs) > 1:
|
|
raise Exception("No ops, expected sub-expr to be a unit: {}".format(exprs[1]))
|
|
elif exprs[0][0] == "-":
|
|
return evaluate([exprs[0][1:]], not first_is_negative)
|
|
elif exprs[0].isnumeric():
|
|
return {"": int(exprs[0]) * (-1 if first_is_negative else 1)}
|
|
elif is_valid_variable_name(exprs[0]):
|
|
return {exprs[0]: -1 if first_is_negative else 1}
|
|
else:
|
|
raise Exception("ok wtf is {}".format(exprs[0]))
|
|
|
|
|
|
# Converts an equation to a mapping of term to coefficient, and verifies that
|
|
# the operations in the equation are valid.
|
|
#
|
|
# Also outputs a triple containing the L and R input variables and the output
|
|
# variable
|
|
#
|
|
# Think of the list of (variable triples, coeffs) pairs as this language's
|
|
# version of "assembly"
|
|
#
|
|
# Example valid equations, and output:
|
|
# a === 9 ([None, None, 'a'], {'': 9})
|
|
# b <== a * c (['a', 'c', 'b'], {'a*c': 1})
|
|
# d <== a * c - 45 * a + 987 (['a', 'c', 'd'], {'a*c': 1, 'a': -45, '': 987})
|
|
#
|
|
# Example invalid equations:
|
|
# 7 === 7 # Can't assign to non-variable
|
|
# a <== b * * c # Two times signs in a row
|
|
# e <== a + b * c * d # Multiplicative degree > 2
|
|
#
|
|
def eq_to_assembly(eq: str) -> AssemblyEqn:
|
|
tokens = eq.rstrip("\n").split(" ")
|
|
if tokens[1] in ("<==", "==="):
|
|
# First token is the output variable
|
|
out = tokens[0]
|
|
# Convert the expression to coefficient map form
|
|
coeffs = evaluate(tokens[2:])
|
|
# Handle the "-x === a * b" case
|
|
if out[0] == "-":
|
|
out = out[1:]
|
|
coeffs["$output_coeff"] = -1
|
|
# Check out variable name validity
|
|
if not is_valid_variable_name(out):
|
|
raise Exception("Invalid out variable name: {}".format(out))
|
|
# Gather list of variables used in the expression
|
|
variables = []
|
|
for t in tokens[2:]:
|
|
var = t.lstrip("-")
|
|
if is_valid_variable_name(var) and var not in variables:
|
|
variables.append(var)
|
|
# Construct the list of allowed coefficients
|
|
allowed_coeffs = variables + ["", "$output_coeff"]
|
|
if len(variables) == 0:
|
|
pass
|
|
elif len(variables) == 1:
|
|
variables.append(variables[0])
|
|
allowed_coeffs.append(get_product_key(*variables))
|
|
elif len(variables) == 2:
|
|
allowed_coeffs.append(get_product_key(*variables))
|
|
else:
|
|
raise Exception("Max 2 variables, found {}".format(variables))
|
|
# Check that only allowed coefficients are in the coefficient map
|
|
for key in coeffs.keys():
|
|
if key not in allowed_coeffs:
|
|
raise Exception("Disallowed multiplication: {}".format(key))
|
|
# Return output
|
|
wires = variables + [None] * (2 - len(variables)) + [out]
|
|
return AssemblyEqn(GateWires(wires[0], wires[1], wires[2]), coeffs)
|
|
elif tokens[1] == "public":
|
|
return AssemblyEqn(
|
|
GateWires(tokens[0], None, None),
|
|
{tokens[0]: -1, "$output_coeff": 0, "$public": True},
|
|
)
|
|
else:
|
|
raise Exception("Unsupported op: {}".format(tokens[1]))
|