Files
plonkathon/compiler/program.py
2023-01-27 18:17:53 -05:00

193 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# A simple zk language, reverse-engineered to match https://zkrepl.dev/ output
from utils import *
from .assembly import *
from .utils import *
from typing import Optional, Set
from poly import Polynomial, Basis
@dataclass
class CommonPreprocessedInput:
"""Common preprocessed input"""
group_order: int
# q_M(X) multiplication selector polynomial
QM: Polynomial
# q_L(X) left selector polynomial
QL: Polynomial
# q_R(X) right selector polynomial
QR: Polynomial
# q_O(X) output selector polynomial
QO: Polynomial
# q_C(X) constants selector polynomial
QC: Polynomial
# S_σ1(X) first permutation polynomial S_σ1(X)
S1: Polynomial
# S_σ2(X) second permutation polynomial S_σ2(X)
S2: Polynomial
# S_σ3(X) third permutation polynomial S_σ3(X)
S3: Polynomial
class Program:
constraints: list[AssemblyEqn]
group_order: int
def __init__(self, constraints: list[str], group_order: int):
if len(constraints) > group_order:
raise Exception("Group order too small")
assembly = [eq_to_assembly(constraint) for constraint in constraints]
self.constraints = assembly
self.group_order = group_order
def common_preprocessed_input(self) -> CommonPreprocessedInput:
L, R, M, O, C = self.make_gate_polynomials()
S = self.make_s_polynomials()
return CommonPreprocessedInput(
self.group_order,
M,
L,
R,
O,
C,
S[Column.LEFT],
S[Column.RIGHT],
S[Column.OUTPUT],
)
@classmethod
def from_str(cls, constraints: str, group_order: int):
lines = [line.strip() for line in constraints.split("\n")]
return cls(lines, group_order)
def coeffs(self) -> list[dict[Optional[str], int]]:
return [constraint.coeffs for constraint in self.constraints]
def wires(self) -> list[GateWires]:
return [constraint.wires for constraint in self.constraints]
def make_s_polynomials(self) -> dict[Column, Polynomial]:
# For each variable, extract the list of (column, row) positions
# where that variable is used
variable_uses: dict[Optional[str], Set[Cell]] = {None: set()}
for row, constraint in enumerate(self.constraints):
for column, value in zip(Column.variants(), constraint.wires.as_list()):
if value not in variable_uses:
variable_uses[value] = set()
variable_uses[value].add(Cell(column, row))
# Mark unused cells
for row in range(len(self.constraints), self.group_order):
for column in Column.variants():
variable_uses[None].add(Cell(column, row))
# For each list of positions, rotate by one.
#
# For example, if some variable is used in positions
# (LEFT, 4), (LEFT, 7) and (OUTPUT, 2), then we store:
#
# at S[LEFT][7] the field element representing (LEFT, 4)
# at S[OUTPUT][2] the field element representing (LEFT, 7)
# at S[LEFT][4] the field element representing (OUTPUT, 2)
S_values = {
Column.LEFT: [Scalar(0)] * self.group_order,
Column.RIGHT: [Scalar(0)] * self.group_order,
Column.OUTPUT: [Scalar(0)] * self.group_order,
}
for _, uses in variable_uses.items():
sorted_uses = sorted(uses)
for i, cell in enumerate(sorted_uses):
next_i = (i + 1) % len(sorted_uses)
next_column = sorted_uses[next_i].column
next_row = sorted_uses[next_i].row
S_values[next_column][next_row] = cell.label(self.group_order)
S = {}
S[Column.LEFT] = Polynomial(S_values[Column.LEFT], Basis.LAGRANGE)
S[Column.RIGHT] = Polynomial(S_values[Column.RIGHT], Basis.LAGRANGE)
S[Column.OUTPUT] = Polynomial(S_values[Column.OUTPUT], Basis.LAGRANGE)
return S
# Get the list of public variable assignments, in order
def get_public_assignments(self) -> list[Optional[str]]:
coeffs = self.coeffs()
o = []
no_more_allowed = False
for coeff in coeffs:
if coeff.get("$public", False) is True:
if no_more_allowed:
raise Exception("Public var declarations must be at the top")
var_name = [x for x in list(coeff.keys()) if "$" not in str(x)][0]
if coeff != {"$public": True, "$output_coeff": 0, var_name: -1}:
raise Exception("Malformatted coeffs: {}", format(coeffs))
o.append(var_name)
else:
no_more_allowed = True
return o
# Generate the gate polynomials: L, R, M, O, C,
# each a list of length `group_order`
def make_gate_polynomials(
self,
) -> tuple[Polynomial, Polynomial, Polynomial, Polynomial, Polynomial]:
L = [Scalar(0) for _ in range(self.group_order)]
R = [Scalar(0) for _ in range(self.group_order)]
M = [Scalar(0) for _ in range(self.group_order)]
O = [Scalar(0) for _ in range(self.group_order)]
C = [Scalar(0) for _ in range(self.group_order)]
for i, constraint in enumerate(self.constraints):
gate = constraint.gate()
L[i] = gate.L
R[i] = gate.R
M[i] = gate.M
O[i] = gate.O
C[i] = gate.C
return (
Polynomial(L, Basis.LAGRANGE),
Polynomial(R, Basis.LAGRANGE),
Polynomial(M, Basis.LAGRANGE),
Polynomial(O, Basis.LAGRANGE),
Polynomial(C, Basis.LAGRANGE),
)
# Attempts to "run" the program to fill in any intermediate variable
# assignments, starting from the given assignments. Eg. if
# `starting_assignments` contains {'a': 3, 'b': 5}, and the first line
# says `c <== a * b`, then it fills in `c: 15`.
def fill_variable_assignments(
self, starting_assignments: dict[Optional[str], int]
) -> dict[Optional[str], int]:
out = {k: Scalar(v) for k, v in starting_assignments.items()}
out[None] = Scalar(0)
for constraint in self.constraints:
wires = constraint.wires
coeffs = constraint.coeffs
in_L = wires.L
in_R = wires.R
output = wires.O
out_coeff = coeffs.get("$output_coeff", 1)
product_key = get_product_key(in_L, in_R)
if output is not None and out_coeff in (-1, 1):
new_value = (
Scalar(
coeffs.get("", 0)
+ out[in_L] * coeffs.get(in_L, 0)
+ out[in_R] * coeffs.get(in_R, 0) * (1 if in_R != in_L else 0)
+ out[in_L] * out[in_R] * coeffs.get(product_key, 0)
)
* out_coeff
) # should be / but equivalent for (1, -1)
if output in out:
if out[output] != new_value:
raise Exception(
"Failed assertion: {} = {}".format(out[output], new_value)
)
else:
out[output] = new_value
# print('filled in:', output, out[output])
return {k: v.n for k, v in out.items()}