mirror of
https://github.com/darkrenaissance/darkfi.git
synced 2026-01-07 22:04:03 -05:00
reorganize research into script/research/
This commit is contained in:
@@ -1,131 +0,0 @@
|
||||
# Functions here are called from pism.py using getattr()
|
||||
# and the function name as a string.
|
||||
|
||||
def witness(line, out, point):
|
||||
return \
|
||||
r"""let %s = ecc::EdwardsPoint::witness(
|
||||
cs.namespace(|| "%s"),
|
||||
%s.map(jubjub::ExtendedPoint::from))?;""" % (out, line, point)
|
||||
|
||||
def assert_not_small_order(line, point):
|
||||
return '%s.assert_not_small_order(cs.namespace(|| "%s"))?;' % (point, line)
|
||||
|
||||
def u64_as_binary_le(line, out, val):
|
||||
return \
|
||||
r"""let %s = boolean::u64_into_boolean_vec_le(
|
||||
cs.namespace(|| "%s"),
|
||||
%s,
|
||||
)?;""" % (out, line, val)
|
||||
|
||||
def fr_as_binary_le(line, out, fr):
|
||||
return \
|
||||
r"""let %s = boolean::field_into_boolean_vec_le(
|
||||
cs.namespace(|| "%s"), %s)?;""" % (out, line, fr)
|
||||
|
||||
def ec_mul_const(line, out, fr, base):
|
||||
return \
|
||||
r"""let %s = ecc::fixed_base_multiplication(
|
||||
cs.namespace(|| "%s"),
|
||||
&%s,
|
||||
&%s,
|
||||
)?;""" % (out, line, base, fr)
|
||||
|
||||
def ec_mul(line, out, fr, base):
|
||||
return 'let %s = %s.mul(cs.namespace(|| "%s"), &%s)?;' % (
|
||||
out, base, line, fr)
|
||||
|
||||
def ec_add(line, out, a, b):
|
||||
return 'let %s = %s.add(cs.namespace(|| "%s"), &%s)?;' % (out, a, line, b)
|
||||
|
||||
def ec_repr(line, out, point):
|
||||
return 'let %s = %s.repr(cs.namespace(|| "%s"))?;' % (out, point, line)
|
||||
|
||||
def ec_get_u(line, out, point):
|
||||
return "let mut %s = %s.get_u().clone();" % (out, point)
|
||||
|
||||
def emit_ec(line, point):
|
||||
return '%s.inputize(cs.namespace(|| "%s"))?;' % (point, line)
|
||||
|
||||
def alloc_binary(line, out):
|
||||
return "let mut %s = vec![];" % out
|
||||
|
||||
def binary_clone(line, out, binary):
|
||||
return "let mut %s: Vec<_> = %s.iter().cloned().collect();" % (out, binary)
|
||||
|
||||
def binary_extend(line, binary, value):
|
||||
return "%s.extend(%s);" % (binary, value)
|
||||
|
||||
def binary_push(line, binary, bit):
|
||||
return "%s.push(%s);" % (binary, bit)
|
||||
|
||||
def binary_truncate(line, binary, size):
|
||||
return "%s.truncate(%s);" % (binary, size)
|
||||
|
||||
def static_assert_binary_size(line, binary, size):
|
||||
return "assert_eq!(%s.len(), %s);" % (binary, size)
|
||||
|
||||
def blake2s(line, out, input, personalization):
|
||||
return \
|
||||
r"""let mut %s = blake2s::blake2s(
|
||||
cs.namespace(|| "%s"),
|
||||
&%s,
|
||||
%s,
|
||||
)?;""" % (out, line, input, personalization)
|
||||
|
||||
def pedersen_hash(line, out, input, personalization):
|
||||
return \
|
||||
r"""let mut %s = pedersen_hash::pedersen_hash(
|
||||
cs.namespace(|| "%s"),
|
||||
%s,
|
||||
&%s,
|
||||
)?;""" % (out, line, personalization, input)
|
||||
|
||||
def emit_binary(line, binary):
|
||||
return 'multipack::pack_into_inputs(cs.namespace(|| "%s"), &%s)?;' % (
|
||||
line, binary)
|
||||
|
||||
def alloc_bit(line, out, value):
|
||||
return \
|
||||
r"""let %s = boolean::Boolean::from(boolean::AllocatedBit::alloc(
|
||||
cs.namespace(|| "%s"),
|
||||
%s
|
||||
)?);""" % (out, line, value)
|
||||
|
||||
def alloc_const_bit(line, out, value):
|
||||
return "let %s = Boolean::constant(%s);" % (out, value)
|
||||
|
||||
def clone_bit(line, out, value):
|
||||
return "let %s = %s.clone();" % (out, value)
|
||||
|
||||
def alloc_scalar(line, out, scalar):
|
||||
return \
|
||||
r"""let %s =
|
||||
num::AllocatedNum::alloc(cs.namespace(|| "%s"), || Ok(*%s.get()?))?;""" % (
|
||||
out, line, scalar)
|
||||
|
||||
def scalar_as_binary(line, out, scalar):
|
||||
return 'let %s = %s.to_bits_le(cs.namespace(|| "%s"))?;' % (out, scalar,
|
||||
line)
|
||||
|
||||
def emit_scalar(line, scalar):
|
||||
return '%s.inputize(cs.namespace(|| "%s"))?;' % (scalar, line)
|
||||
|
||||
def scalar_enforce_equal(line, scalar_left, scalar_right):
|
||||
return \
|
||||
r"""cs.enforce(
|
||||
|| "%s",
|
||||
|lc| lc + %s.get_variable(),
|
||||
|lc| lc + CS::one(),
|
||||
|lc| lc + %s.get_variable(),
|
||||
);""" % (line, scalar_left, scalar_right)
|
||||
|
||||
def conditionally_reverse(line, out_left, out_right, in_left, in_right,
|
||||
condition):
|
||||
return \
|
||||
r"""let (%s, %s) = num::AllocatedNum::conditionally_reverse(
|
||||
cs.namespace(|| "%s"),
|
||||
&%s,
|
||||
&%s,
|
||||
&%s,
|
||||
)?;""" % (out_left, out_right, line, in_left, in_right, condition)
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
alloc_commands = {
|
||||
"param": 1,
|
||||
"private": 1,
|
||||
"public": 1,
|
||||
}
|
||||
|
||||
op_commands = {
|
||||
"set": 2,
|
||||
"mul": 2,
|
||||
"add": 2,
|
||||
"sub": 2,
|
||||
"divide": 2,
|
||||
"double": 1,
|
||||
"square": 1,
|
||||
"invert": 1,
|
||||
"unpack_bits": 3,
|
||||
"load": 2,
|
||||
"local": 1,
|
||||
"debug": 1,
|
||||
"dump_alloc": 0,
|
||||
"dump_local": 0,
|
||||
}
|
||||
|
||||
constraint_commands = {
|
||||
"lc0_add": 1,
|
||||
"lc1_add": 1,
|
||||
"lc2_add": 1,
|
||||
"lc0_sub": 1,
|
||||
"lc1_sub": 1,
|
||||
"lc2_sub": 1,
|
||||
"lc0_add_one": 0,
|
||||
"lc1_add_one": 0,
|
||||
"lc2_add_one": 0,
|
||||
"lc0_sub_one": 0,
|
||||
"lc1_sub_one": 0,
|
||||
"lc2_sub_one": 0,
|
||||
"lc0_add_coeff": 2,
|
||||
"lc1_add_coeff": 2,
|
||||
"lc2_add_coeff": 2,
|
||||
"lc0_add_constant": 1,
|
||||
"lc1_add_constant": 1,
|
||||
"lc2_add_constant": 1,
|
||||
"enforce": 0,
|
||||
"lc_coeff_reset": 0,
|
||||
"lc_coeff_double": 0,
|
||||
}
|
||||
|
||||
def eprint(*args):
|
||||
print(*args, file=sys.stderr)
|
||||
|
||||
class Line:
|
||||
|
||||
def __init__(self, text, line_number):
|
||||
self.text = text
|
||||
self.orig = text
|
||||
self.lineno = line_number
|
||||
|
||||
self.clean()
|
||||
|
||||
def clean(self):
|
||||
# Remove the comments
|
||||
self.text = self.text.split("#", 1)[0]
|
||||
# Remove whitespace
|
||||
self.text = self.text.strip()
|
||||
|
||||
def is_empty(self):
|
||||
return bool(self.text)
|
||||
|
||||
def __repr__(self):
|
||||
return "Line %s: %s" % (self.lineno, self.orig.lstrip())
|
||||
|
||||
def command(self):
|
||||
if not self.is_empty():
|
||||
return None
|
||||
return self.text.split(" ")[0]
|
||||
|
||||
def args(self):
|
||||
if not self.is_empty():
|
||||
return None
|
||||
return self.text.split()[1:]
|
||||
|
||||
def clean(contents):
|
||||
# Split input into lines
|
||||
contents = contents.split("\n")
|
||||
contents = [Line(line, i + 1) for i, line in enumerate(contents)]
|
||||
# Remove empty blank lines
|
||||
contents = [line for line in contents if line.is_empty()]
|
||||
return contents
|
||||
|
||||
def divide_sections(contents):
|
||||
state = "NOSCOPE"
|
||||
segments = {}
|
||||
current_segment = []
|
||||
contract_name = None
|
||||
|
||||
for line in contents:
|
||||
if line.command() == "contract":
|
||||
if len(line.args()) != 1:
|
||||
eprint("error: missing contract name")
|
||||
eprint(line)
|
||||
return None
|
||||
contract_name = line.args()[0]
|
||||
|
||||
if state == "NOSCOPE":
|
||||
assert not current_segment
|
||||
state = "INSCOPE"
|
||||
continue
|
||||
else:
|
||||
assert state == "INSCOPE"
|
||||
eprint("error: double contract entry violation")
|
||||
eprint(line)
|
||||
return None
|
||||
elif line.command() == "end":
|
||||
if len(line.args()) != 0:
|
||||
eprint("error: end takes no args")
|
||||
eprint(line)
|
||||
return None
|
||||
|
||||
if state == "NOSCOPE":
|
||||
eprint("error: missing contract start for end")
|
||||
eprint(line)
|
||||
return None
|
||||
else:
|
||||
assert state == "INSCOPE"
|
||||
state = "NOSCOPE"
|
||||
segments[contract_name] = current_segment
|
||||
current_segment = []
|
||||
continue
|
||||
elif state == "NOSCOPE":
|
||||
# Ignore lines outside any contract
|
||||
continue
|
||||
|
||||
current_segment.append(line)
|
||||
|
||||
if state != "NOSCOPE":
|
||||
eprint("error: reached end of file with unclosed scope")
|
||||
return None
|
||||
|
||||
return segments
|
||||
|
||||
def extract_relevant_lines(contract, commands_table):
|
||||
relevant_lines = []
|
||||
|
||||
for line in contract:
|
||||
command = line.command()
|
||||
|
||||
if command not in commands_table.keys():
|
||||
continue
|
||||
|
||||
define = commands_table[command]
|
||||
|
||||
if len(line.args()) != define:
|
||||
eprint("error: wrong number of args")
|
||||
return None
|
||||
|
||||
relevant_lines.append(line)
|
||||
|
||||
return relevant_lines
|
||||
|
||||
class VariableType(Enum):
|
||||
PUBLIC = 1
|
||||
PRIVATE = 2
|
||||
|
||||
class Variable:
|
||||
|
||||
def __init__(self, symbol, index, type, is_param):
|
||||
self.symbol = symbol
|
||||
self.index = index
|
||||
self.type = type
|
||||
self.is_param = is_param
|
||||
|
||||
def __repr__(self):
|
||||
return "<Variable %s:%s>" % (self.symbol, self.index)
|
||||
|
||||
def generate_alloc_table(contract):
|
||||
relevant_lines = extract_relevant_lines(contract, alloc_commands)
|
||||
alloc_table = {}
|
||||
for i, line in enumerate(relevant_lines):
|
||||
assert len(line.args()) == 1
|
||||
symbol = line.args()[0]
|
||||
|
||||
command = line.command()
|
||||
|
||||
if command == "param":
|
||||
type = VariableType.PRIVATE
|
||||
is_param = True
|
||||
elif command == "private":
|
||||
type = VariableType.PRIVATE
|
||||
is_param = False
|
||||
elif command == "public":
|
||||
type = VariableType.PUBLIC
|
||||
is_param = False
|
||||
else:
|
||||
assert False
|
||||
|
||||
if symbol in alloc_table:
|
||||
eprint("error: duplicate symbol '%s'" % symbol)
|
||||
eprint(line)
|
||||
return None
|
||||
|
||||
alloc_table[symbol] = Variable(symbol, i, type, is_param)
|
||||
|
||||
return alloc_table
|
||||
|
||||
class Operation:
|
||||
|
||||
def __init__(self, line, indexes):
|
||||
self.command = line.command()
|
||||
self.args = indexes
|
||||
self.line = line
|
||||
|
||||
class VariableRefType(Enum):
|
||||
AUX = 1
|
||||
LOCAL = 2
|
||||
CONST = 3
|
||||
|
||||
class VariableRef:
|
||||
|
||||
def __init__(self, type, index):
|
||||
self.type = type
|
||||
self.index = index
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (self.type.name, self.index)
|
||||
|
||||
def symbols_list_to_refs(line, alloc, local_vars, constants):
|
||||
indexes = []
|
||||
for symbol in line.args():
|
||||
if symbol in alloc:
|
||||
# Lookup variable index
|
||||
index = alloc[symbol].index
|
||||
index = VariableRef(VariableRefType.AUX, index)
|
||||
elif symbol in local_vars:
|
||||
index = local_vars[symbol]
|
||||
index = VariableRef(VariableRefType.LOCAL, index)
|
||||
elif symbol in constants:
|
||||
index = constants[symbol][0]
|
||||
index = VariableRef(VariableRefType.CONST, index)
|
||||
else:
|
||||
eprint("error: missing unallocated symbol '%s'" % symbol)
|
||||
eprint(line)
|
||||
return None
|
||||
indexes.append(index)
|
||||
return indexes
|
||||
|
||||
def generate_ops_table(contract, alloc, constants):
|
||||
relevant_lines = extract_relevant_lines(contract, op_commands)
|
||||
ops = []
|
||||
local_vars = {}
|
||||
for line in relevant_lines:
|
||||
# This is a special case which creates a new local stack value
|
||||
if line.command() == "local":
|
||||
assert len(line.args()) == 1
|
||||
symbol = line.args()[0]
|
||||
local_vars[symbol] = len(local_vars)
|
||||
indexes = []
|
||||
else:
|
||||
if (indexes := symbols_list_to_refs(line, alloc,
|
||||
local_vars, constants)) is None:
|
||||
return None
|
||||
|
||||
# Handle this here directly since only the
|
||||
# load command deals with constants
|
||||
if line.command() == "load":
|
||||
assert len(indexes) == 2
|
||||
# This is the only command which uses consts
|
||||
if indexes[1].type != VariableRefType.CONST:
|
||||
eprint("error: load command takes a const argument")
|
||||
eprint(line)
|
||||
return None
|
||||
elif any(index.type == VariableRefType.CONST for index in indexes):
|
||||
eprint("error: invalid const arg")
|
||||
eprint(line)
|
||||
return None
|
||||
|
||||
ops.append(Operation(line, indexes))
|
||||
return ops
|
||||
|
||||
class Constraint:
|
||||
|
||||
def __init__(self, line, lcargs):
|
||||
self.command = line.command()
|
||||
self.args = lcargs
|
||||
self.line = line
|
||||
|
||||
def args_comment(self):
|
||||
return ", ".join("%s" % symbol for symbol in self.line.args())
|
||||
|
||||
def symbols_list_to_lcargs(line, alloc, constants):
|
||||
lcargs = []
|
||||
for symbol in line.args():
|
||||
if symbol in alloc:
|
||||
# Lookup variable index
|
||||
index = alloc[symbol].index
|
||||
lcargs.append(index)
|
||||
elif symbol in constants:
|
||||
value = constants[symbol]
|
||||
lcargs.append(value)
|
||||
else:
|
||||
eprint("error: missing unallocated symbol '%s'" % symbol)
|
||||
eprint(line)
|
||||
return None
|
||||
return lcargs
|
||||
|
||||
def generate_constraints_table(contract, alloc, constants):
|
||||
relevant_lines = extract_relevant_lines(contract, constraint_commands)
|
||||
constraints = []
|
||||
for line in relevant_lines:
|
||||
if (lcargs := symbols_list_to_lcargs(line, alloc, constants)) is None:
|
||||
return None
|
||||
constraints.append(Constraint(line, lcargs))
|
||||
return constraints
|
||||
|
||||
class Contract:
|
||||
|
||||
def __init__(self, constants, alloc, ops, constraints):
|
||||
self.constants = constants
|
||||
self.alloc = alloc
|
||||
self.ops = ops
|
||||
self.constraints = constraints
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = ""
|
||||
|
||||
repr_str += "Constants:\n"
|
||||
for symbol, value in self.constants.items():
|
||||
repr_str += " // %s\n" % symbol
|
||||
repr_str += " %s: %s\n" % value
|
||||
|
||||
repr_str += "Alloc table:\n"
|
||||
for symbol, variable in self.alloc.items():
|
||||
repr_str += " // %s\n" % symbol
|
||||
repr_str += " %s %s\n" % (variable.type, variable.index)
|
||||
|
||||
repr_str += "Operations:\n"
|
||||
for op in self.ops:
|
||||
repr_str += " // %s\n" % op.line
|
||||
repr_str += " %s %s\n" % (op.command, op.args)
|
||||
|
||||
repr_str += "Constraints:\n"
|
||||
for constraint in self.constraints:
|
||||
if constraint.args:
|
||||
repr_str += " // %s\n" % constraint.args_comment()
|
||||
repr_str += " %s %s\n" % (constraint.command, constraint.args)
|
||||
|
||||
repr_str += "Stats:\n"
|
||||
repr_str += " Constants: %s\n" % len(self.constants)
|
||||
repr_str += " Alloc: %s\n" % len(self.alloc)
|
||||
repr_str += " Operations: %s\n" % len(self.ops)
|
||||
repr_str += " Constraint Instructions: %s\n" % len(self.constraints)
|
||||
|
||||
return repr_str
|
||||
|
||||
def compile(contract, constants):
|
||||
# Allocation table
|
||||
# symbol: Private/Public, is_param, index
|
||||
if (alloc := generate_alloc_table(contract)) is None:
|
||||
return None
|
||||
# Operations lines list
|
||||
if (ops := generate_ops_table(contract, alloc, constants)) is None:
|
||||
return None
|
||||
# Constraint commands
|
||||
if (constraints := generate_constraints_table(
|
||||
contract, alloc, constants)) is None:
|
||||
return None
|
||||
return Contract(constants, alloc, ops, constraints)
|
||||
|
||||
def parse_constants(contents):
|
||||
relevant_lines = [line for line in contents if line.command() == "constant"]
|
||||
constants = {}
|
||||
for line in relevant_lines:
|
||||
assert line.command() == "constant"
|
||||
if len(line.args()) != 2:
|
||||
eprint("error: wrong number of args for constant")
|
||||
eprint(line)
|
||||
return None
|
||||
symbol, value = line.args()
|
||||
|
||||
try:
|
||||
int(value, 16)
|
||||
except ValueError:
|
||||
eprint("error: invalid constant value for '%s'" % symbol)
|
||||
eprint(line)
|
||||
return None
|
||||
|
||||
if len(value) != 32*2 + 2 or value[:2] != "0x":
|
||||
eprint("error: invalid hex value for constant")
|
||||
eprint(line)
|
||||
return None
|
||||
|
||||
# Remove 0x prefix
|
||||
value = value[2:]
|
||||
|
||||
constants[symbol] = (len(constants), value)
|
||||
return constants
|
||||
|
||||
def process(contents):
|
||||
# Remove left whitespace
|
||||
contents = clean(contents)
|
||||
# Parse all constants
|
||||
if (constants := parse_constants(contents)) is None:
|
||||
return None
|
||||
# Divide into contract sections
|
||||
if (pre_contracts := divide_sections(contents)) is None:
|
||||
return None
|
||||
# Process each contract
|
||||
contracts = {}
|
||||
for contract_name, pre_contract in pre_contracts.items():
|
||||
if (contract := compile(pre_contract, constants)) is None:
|
||||
return None
|
||||
contracts[contract_name] = contract
|
||||
return contracts
|
||||
|
||||
def main(argv):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("filename", help="VM PISM file: proofs/vm.pism")
|
||||
parser.add_argument("--output", type=argparse.FileType('wb', 0),
|
||||
default=sys.stdout.buffer, help="Output file")
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument('--display', action='store_true',
|
||||
help="show the compiled code in human readable format")
|
||||
group.add_argument('--rust', action='store_true',
|
||||
help="output compiled code to rust for testing")
|
||||
group.add_argument('--supervisor', action='store_true',
|
||||
help="output compiled code to zkvm supervisor")
|
||||
args = parser.parse_args()
|
||||
|
||||
src_filename = args.filename
|
||||
contents = open(src_filename).read()
|
||||
if (contracts := process(contents)) is None:
|
||||
return -2
|
||||
|
||||
def default_display():
|
||||
for contract_name, contract in contracts.items():
|
||||
print("Contract:", contract_name)
|
||||
print(contract)
|
||||
|
||||
if args.display:
|
||||
default_display()
|
||||
elif args.rust:
|
||||
import compile_export_rust
|
||||
for contract_name, contract in contracts.items():
|
||||
compile_export_rust.display(contract)
|
||||
elif args.supervisor:
|
||||
import compile_export_supervisor
|
||||
for contract_name, contract in contracts.items():
|
||||
compile_export_supervisor.export(args.output, contract_name,
|
||||
contract)
|
||||
else:
|
||||
default_display()
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv))
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
from compile import VariableType, VariableRefType
|
||||
|
||||
def to_initial_caps(snake_str):
|
||||
components = snake_str.split("_")
|
||||
return "".join(x.title() for x in components)
|
||||
|
||||
def display(contract):
|
||||
indent = " " * 4
|
||||
|
||||
print(r"""use super::vm::{ZkVirtualMachine, CryptoOperation, AllocType, ConstraintInstruction, VariableIndex, VariableRef};
|
||||
use bls12_381::Scalar;
|
||||
|
||||
pub fn load_params(params: Vec<Scalar>) -> Vec<(VariableIndex, Scalar)> {""")
|
||||
params = [(symbol, var) for symbol, var in contract.alloc.items() if var.is_param]
|
||||
print("%sassert_eq!(params.len(), %s);" % (indent, len(params)))
|
||||
print("%slet mut result = vec![(0, Scalar::zero()); %s];" % (
|
||||
indent, len(params)))
|
||||
for i, (symbol, variable) in enumerate(params):
|
||||
assert variable.is_param
|
||||
print("%s// %s" % (indent, symbol))
|
||||
print("%sresult[%s] = (%s, params[%s]);" % (
|
||||
indent, i, variable.index, i))
|
||||
print("%sresult" % indent)
|
||||
print("}\n")
|
||||
|
||||
print(r"""pub fn load_zkvm() -> ZkVirtualMachine {
|
||||
ZkVirtualMachine {
|
||||
constants: vec![""")
|
||||
|
||||
constants = list(contract.constants.items())
|
||||
constants.sort(key=lambda obj: obj[1][0])
|
||||
constants = [(obj[0], obj[1][1]) for obj in constants]
|
||||
for symbol, value in constants:
|
||||
print("%s// %s" % (indent * 3, symbol))
|
||||
assert len(value) == 32*2
|
||||
chunk_str = lambda line, n: \
|
||||
[line[i:i + n] for i in range(0, len(line), n)]
|
||||
chunks = chunk_str(value, 2)
|
||||
# Reverse the endianness
|
||||
# We allow literal numbers but rust wants little endian
|
||||
chunks = chunks[::-1]
|
||||
print("%sScalar::from_bytes(&[" % (indent * 3))
|
||||
for i in range(0, 32, 4):
|
||||
print("%s0x%s, 0x%s, 0x%s, 0x%s," % (indent * 4,
|
||||
chunks[i], chunks[i + 1], chunks[i + 2], chunks[i + 3]))
|
||||
print("%s]).unwrap()," % (indent * 3))
|
||||
|
||||
print("%s]," % (indent * 2))
|
||||
print("%salloc: vec![" % (indent * 2))
|
||||
|
||||
for symbol, variable in contract.alloc.items():
|
||||
print("%s// %s" % (indent * 3, symbol))
|
||||
|
||||
if variable.type.name == VariableType.PRIVATE.name:
|
||||
typestring = "Private"
|
||||
elif variable.type.name == VariableType.PUBLIC.name:
|
||||
typestring = "Public"
|
||||
else:
|
||||
assert False
|
||||
|
||||
print("%s(AllocType::%s, %s)," % (indent * 3, typestring,
|
||||
variable.index))
|
||||
|
||||
print("%s]," % (indent * 2))
|
||||
print("%sops: vec![" % (indent * 2))
|
||||
|
||||
def var_ref_str(var_ref):
|
||||
if var_ref.type.name == VariableRefType.AUX.name:
|
||||
return "VariableRef::Aux(%s)" % var_ref.index
|
||||
elif var_ref.type.name == VariableRefType.LOCAL.name:
|
||||
return "VariableRef::Local(%s)" % var_ref.index
|
||||
else:
|
||||
assert False
|
||||
|
||||
for op in contract.ops:
|
||||
print("%s// %s" % (indent * 3, op.line))
|
||||
args_part = ""
|
||||
if op.command == "load":
|
||||
assert len(op.args) == 2
|
||||
args_part = "(%s, %s)" % (var_ref_str(op.args[0]), op.args[1].index)
|
||||
elif op.command == "debug":
|
||||
assert len(op.args) == 1
|
||||
args_part = '(String::from("%s"), %s)' % (
|
||||
op.line, var_ref_str(op.args[0]))
|
||||
elif op.args:
|
||||
args_part = ", ".join(var_ref_str(var_ref) for var_ref in op.args)
|
||||
args_part = "(%s)" % args_part
|
||||
print("%sCryptoOperation::%s%s," % (
|
||||
indent * 3,
|
||||
to_initial_caps(op.command),
|
||||
args_part
|
||||
))
|
||||
|
||||
print("%s]," % (indent * 2))
|
||||
print("%sconstraints: vec![" % (indent * 2))
|
||||
|
||||
for constraint in contract.constraints:
|
||||
args_part = ""
|
||||
if constraint.args:
|
||||
print("%s// %s" % (indent *3, constraint.args_comment()))
|
||||
args = constraint.args[:]
|
||||
if (constraint.command == "lc0_add_coeff" or
|
||||
constraint.command == "lc1_add_coeff" or
|
||||
constraint.command == "lc2_add_coeff" or
|
||||
constraint.command == "lc0_add_one_coeff" or
|
||||
constraint.command == "lc1_add_one_coeff" or
|
||||
constraint.command == "lc2_add_one_coeff"):
|
||||
args[0] = args[0][0]
|
||||
args_part = ", ".join(str(index) for index in args)
|
||||
args_part = "(%s)" % args_part
|
||||
print("%sConstraintInstruction::%s%s," % (
|
||||
indent * 3,
|
||||
to_initial_caps(constraint.command),
|
||||
args_part
|
||||
))
|
||||
print(r""" ],
|
||||
aux: vec![],
|
||||
params: None,
|
||||
verifying_key: None,
|
||||
}
|
||||
}""")
|
||||
|
||||
@@ -1,193 +0,0 @@
|
||||
import struct
|
||||
from compile import VariableType, VariableRefType
|
||||
|
||||
class Operation:
|
||||
|
||||
def __init__(self, ident, args):
|
||||
self.ident = ident
|
||||
self.args = args
|
||||
|
||||
class ArgVarRef:
|
||||
|
||||
def __init__(self, type, index):
|
||||
self.type = type
|
||||
self.index = index
|
||||
|
||||
def bytes(self):
|
||||
return struct.pack("<BI", self.type, self.index)
|
||||
|
||||
class ArgVarIndex:
|
||||
|
||||
def __init__(self, _, index):
|
||||
self.index = index
|
||||
|
||||
def bytes(self):
|
||||
return struct.pack("<I", self.index)
|
||||
|
||||
class ArgString:
|
||||
|
||||
def __init__(self, description, index):
|
||||
self.description = description
|
||||
self.index = index
|
||||
|
||||
ops_table = {
|
||||
"set": Operation(0, [ArgVarRef, ArgVarRef]),
|
||||
"mul": Operation(1, [ArgVarRef, ArgVarRef]),
|
||||
"add": Operation(2, [ArgVarRef, ArgVarRef]),
|
||||
"sub": Operation(3, [ArgVarRef, ArgVarRef]),
|
||||
"divide": Operation(4, [ArgVarRef, ArgVarRef]),
|
||||
"double": Operation(5, [ArgVarRef]),
|
||||
"square": Operation(6, [ArgVarRef]),
|
||||
"invert": Operation(7, [ArgVarRef]),
|
||||
"unpack_bits": Operation(8, [ArgVarRef, ArgVarRef, ArgVarRef]),
|
||||
"local": Operation(9, []),
|
||||
"load": Operation(10, [ArgVarRef, ArgVarIndex]),
|
||||
"debug": Operation(11, [ArgString, ArgVarRef]),
|
||||
"dump_alloc": Operation(12, []),
|
||||
"dump_local": Operation(13, []),
|
||||
}
|
||||
|
||||
constraint_ident_map = {
|
||||
"lc0_add": 0,
|
||||
"lc1_add": 1,
|
||||
"lc2_add": 2,
|
||||
"lc0_sub": 3,
|
||||
"lc1_sub": 4,
|
||||
"lc2_sub": 5,
|
||||
"lc0_add_one": 6,
|
||||
"lc1_add_one": 7,
|
||||
"lc2_add_one": 8,
|
||||
"lc0_sub_one": 9,
|
||||
"lc1_sub_one": 10,
|
||||
"lc2_sub_one": 11,
|
||||
"lc0_add_coeff": 12,
|
||||
"lc1_add_coeff": 13,
|
||||
"lc2_add_coeff": 14,
|
||||
"lc0_add_constant": 15,
|
||||
"lc1_add_constant": 16,
|
||||
"lc2_add_constant": 17,
|
||||
"enforce": 18,
|
||||
"lc_coeff_reset": 19,
|
||||
"lc_coeff_double": 20,
|
||||
}
|
||||
|
||||
def varuint(value):
|
||||
if value <= 0xfc:
|
||||
return struct.pack("<B", value)
|
||||
elif value <= 0xffff:
|
||||
return struct.pack("<BH", 0xfd, value)
|
||||
elif value <= 0xffffffff:
|
||||
return struct.pack("<BI", 0xfe, value)
|
||||
else:
|
||||
return struct.pack("<BQ", 0xff, value)
|
||||
|
||||
def export(output, contract_name, contract):
|
||||
output.write(varuint(len(contract_name)))
|
||||
output.write(contract_name.encode())
|
||||
|
||||
constants = list(contract.constants.items())
|
||||
constants.sort(key=lambda obj: obj[1][0])
|
||||
constants = [(obj[0], obj[1][1]) for obj in constants]
|
||||
|
||||
# Constants
|
||||
output.write(varuint(len(constants)))
|
||||
for symbol, value in constants:
|
||||
print("Constant '%s' = %s" % (symbol, value))
|
||||
# Bellman uses little endian for Scalars from_bytes function
|
||||
const_bytes = bytearray.fromhex(value)[::-1]
|
||||
assert len(const_bytes) == 32
|
||||
output.write(const_bytes)
|
||||
|
||||
# Alloc
|
||||
output.write(varuint(len(contract.alloc)))
|
||||
for symbol, variable in contract.alloc.items():
|
||||
print("Alloc '%s' = (%s, %s)" % (symbol,
|
||||
variable.type.name, variable.index))
|
||||
if variable.type.name == VariableType.PRIVATE.name:
|
||||
typeval = 0
|
||||
elif variable.type.name == VariableType.PUBLIC.name:
|
||||
typeval = 1
|
||||
else:
|
||||
assert False
|
||||
alloc_bytes = struct.pack("<BI", typeval, variable.index)
|
||||
assert len(alloc_bytes) == 5
|
||||
output.write(alloc_bytes)
|
||||
|
||||
# Ops
|
||||
output.write(varuint(len(contract.ops)))
|
||||
for op in contract.ops:
|
||||
op_form = ops_table[op.command]
|
||||
output.write(struct.pack("B", op_form.ident))
|
||||
|
||||
if op.command == "debug":
|
||||
# Special case
|
||||
assert len(op.args) == 1
|
||||
line_str = str(op.line).encode()
|
||||
output.write(varuint(len(line_str)))
|
||||
output.write(line_str)
|
||||
|
||||
op_arg = op.args[0]
|
||||
if op_arg.type.name == VariableRefType.AUX.name:
|
||||
arg_type = 0
|
||||
elif op_arg.type.name == VariableRefType.LOCAL.name:
|
||||
arg_type = 1
|
||||
arg = ArgVarRef(arg_type, op_arg.index)
|
||||
output.write(arg.bytes())
|
||||
continue
|
||||
|
||||
assert len(op_form.args) == len(op.args)
|
||||
for arg_form, op_arg in zip(op_form.args, op.args):
|
||||
if op_arg.type.name == VariableRefType.AUX.name:
|
||||
arg_type = 0
|
||||
elif op_arg.type.name == VariableRefType.LOCAL.name:
|
||||
arg_type = 1
|
||||
arg = arg_form(arg_type, op_arg.index)
|
||||
output.write(arg.bytes())
|
||||
print("#", op.line)
|
||||
print("Operation", op.command,
|
||||
[(arg.type.name, arg.index) for arg in op.args])
|
||||
|
||||
# Constraints
|
||||
output.write(varuint(len(contract.constraints)))
|
||||
for constraint in contract.constraints:
|
||||
args = constraint.args[:]
|
||||
if (constraint.command == "lc0_add_coeff" or
|
||||
constraint.command == "lc1_add_coeff" or
|
||||
constraint.command == "lc2_add_coeff" or
|
||||
constraint.command == "lc0_add_constant" or
|
||||
constraint.command == "lc1_add_constant" or
|
||||
constraint.command == "lc2_add_constant"):
|
||||
args[0] = args[0][0]
|
||||
print("#", constraint.line)
|
||||
print("Constraint", constraint.command, args if args else "")
|
||||
enum_ident = constraint_ident_map[constraint.command]
|
||||
output.write(struct.pack("B", enum_ident))
|
||||
for arg in args:
|
||||
output.write(struct.pack("<I", arg))
|
||||
|
||||
# Params Map
|
||||
param_alloc = [(symbol, variable) for (symbol, variable)
|
||||
in contract.alloc.items() if variable.is_param]
|
||||
output.write(varuint(len(param_alloc)))
|
||||
for symbol, variable in param_alloc:
|
||||
assert variable.is_param
|
||||
print("Parameter '%s' = %s" % (symbol, variable.index))
|
||||
symbol = symbol.encode()
|
||||
output.write(varuint(len(symbol)))
|
||||
output.write(symbol)
|
||||
output.write(struct.pack("<I", variable.index))
|
||||
|
||||
# Public Map
|
||||
public_alloc = [(symbol, variable) for (symbol, variable)
|
||||
in contract.alloc.items()
|
||||
if variable.type.name == VariableType.PUBLIC.name]
|
||||
output.write(varuint(len(public_alloc)))
|
||||
for symbol, variable in public_alloc:
|
||||
assert not variable.is_param
|
||||
assert variable.type.name == VariableType.PUBLIC.name
|
||||
print("Public '%s' = %s" % (symbol, variable.index))
|
||||
symbol = symbol.encode()
|
||||
output.write(varuint(len(symbol)))
|
||||
output.write(symbol)
|
||||
output.write(struct.pack("<I", variable.index))
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
if exists('b:current_syntax')
|
||||
finish
|
||||
endif
|
||||
|
||||
syn keyword drkKeyword assert enforce for in def return const as let emit contract private proof
|
||||
syn keyword drkAttr mut
|
||||
syn keyword drkType BinaryNumber Point Fr SubgroupPoint EdwardsPoint Scalar EncryptedNum list Bool U64 Num Binary
|
||||
syn match drkFunction "\zs[a-zA-Z0-9_]*\ze("
|
||||
syn match drkComment "#.*$"
|
||||
syn match drkNumber '\d\+'
|
||||
syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
|
||||
|
||||
hi def link drkKeyword Statement
|
||||
hi def link drkAttr StorageClass
|
||||
hi def link drkType Type
|
||||
hi def link drkFunction Function
|
||||
hi def link drkComment Comment
|
||||
hi def link drkNumber Constant
|
||||
hi def link drkConst Constant
|
||||
|
||||
let b:current_syntax = "drk"
|
||||
@@ -1,2 +0,0 @@
|
||||
#!/bin/bash
|
||||
python scripts/parser.py proofs/sapling3.prf | rustfmt > proofs/sapling3.rs
|
||||
@@ -1,4 +0,0 @@
|
||||
#!/bin/bash -x
|
||||
python scripts/pism.py proofs/simple.pism | rustfmt > src/simple_circuit.rs
|
||||
cargo run --release --bin simple
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
# Dark Client
|
||||
|
||||
$ python3 -m venv env
|
||||
$ source env/bin/activate
|
||||
$ pip install -r requirements.txt
|
||||
@@ -1,89 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from util import arg_parser
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
class DarkClient:
|
||||
# TODO: generate random ID (4 byte unsigned int) (rand range 0 - max size
|
||||
# uint32
|
||||
def __init__(self, client_session):
|
||||
self.url = "http://localhost:8000/"
|
||||
self.client_session = client_session
|
||||
self.payload = {
|
||||
"method": [],
|
||||
"params": [],
|
||||
"jsonrpc": [],
|
||||
"id": [],
|
||||
}
|
||||
|
||||
async def key_gen(self, payload):
|
||||
payload['method'] = "key_gen"
|
||||
payload['jsonrpc'] = "2.0"
|
||||
payload['id'] = "0"
|
||||
key = await self.__request(payload)
|
||||
print(key)
|
||||
|
||||
async def get_info(self, payload):
|
||||
payload['method'] = "get_info"
|
||||
payload['jsonrpc'] = "2.0"
|
||||
payload['id'] = "0"
|
||||
info = await self.__request(payload)
|
||||
print(info)
|
||||
|
||||
async def stop(self, payload):
|
||||
payload['method'] = "stop"
|
||||
payload['jsonrpc'] = "2.0"
|
||||
payload['id'] = "0"
|
||||
stop = await self.__request(payload)
|
||||
print(stop)
|
||||
|
||||
async def say_hello(self, payload):
|
||||
payload['method'] = "say_hello"
|
||||
payload['jsonrpc'] = "2.0"
|
||||
payload['id'] = "0"
|
||||
hello = await self.__request(payload)
|
||||
print(hello)
|
||||
|
||||
async def create_wallet(self, payload):
|
||||
payload['method'] = "create_wallet"
|
||||
payload['jsonrpc'] = "2.0"
|
||||
payload['id'] = "0"
|
||||
wallet = await self.__request(payload)
|
||||
print(wallet)
|
||||
|
||||
async def create_cashier_wallet(self, payload):
|
||||
payload['method'] = "create_cashier_wallet"
|
||||
payload['jsonrpc'] = "2.0"
|
||||
payload['id'] = "0"
|
||||
wallet = await self.__request(payload)
|
||||
print(wallet)
|
||||
|
||||
async def test_wallet(self, payload):
|
||||
payload['method'] = "test_wallet"
|
||||
payload['jsonrpc'] = "2.0"
|
||||
payload['id'] = "0"
|
||||
test = await self.__request(payload)
|
||||
print(test)
|
||||
|
||||
async def __request(self, payload):
|
||||
async with self.client_session.post(self.url, json=payload) as response:
|
||||
resp = await response.text()
|
||||
print(resp)
|
||||
|
||||
async def main():
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
client = DarkClient(session)
|
||||
await arg_parser(client)
|
||||
except aiohttp.ClientConnectorError as err:
|
||||
print('CONNECTION ERROR:', str(err))
|
||||
except Exception as err:
|
||||
print("ERROR: ", str(err))
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
aiodns==3.0.0
|
||||
aiohttp==3.7.4.post0
|
||||
async-timeout==3.0.1
|
||||
attrs==21.2.0
|
||||
brotlipy==0.7.0
|
||||
cchardet==2.1.7
|
||||
cffi==1.14.5
|
||||
chardet==4.0.0
|
||||
idna==3.2
|
||||
multidict==5.1.0
|
||||
pycares==4.0.0
|
||||
pycparser==2.20
|
||||
typing-extensions==3.10.0.0
|
||||
yarl==1.6.3
|
||||
@@ -1,53 +0,0 @@
|
||||
|
||||
import argparse
|
||||
|
||||
async def arg_parser(client):
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='drk',
|
||||
usage='%(prog)s [commands]',
|
||||
description="""DarkFi wallet command-line tool"""
|
||||
)
|
||||
|
||||
parser.add_argument('-c', '--cashier', action='store_true', help='Create a cashier wallet')
|
||||
parser.add_argument('-w', '--wallet', action='store_true', help='Create a new wallet')
|
||||
parser.add_argument('-k', '--key', action='store_true', help='Test key')
|
||||
parser.add_argument('-i', '--info', action='store_true', help='Request info from daemon')
|
||||
parser.add_argument('-hi', '--hello', action='store_true', help='Test hello')
|
||||
parser.add_argument("-s", "--stop", action='store_true', help="Send a stop signal to the daemon")
|
||||
parser.add_argument("-t", "--test", action='store_true', help="Test writing to the wallet")
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.key:
|
||||
print("Attemping to generate a create key pair...")
|
||||
await client.key_gen(client.payload)
|
||||
|
||||
if args.wallet:
|
||||
print("Attemping to create a wallet...")
|
||||
await client.create_wallet(client.payload)
|
||||
|
||||
if args.info:
|
||||
print("Info was entered")
|
||||
await client.get_info(client.payload)
|
||||
print("Requesting daemon info...")
|
||||
|
||||
if args.stop:
|
||||
print("Stop was entered")
|
||||
await client.stop(client.payload)
|
||||
print("Sending a stop signal...")
|
||||
|
||||
if args.hello:
|
||||
print("Hello was entered")
|
||||
await client.say_hello(client.payload)
|
||||
|
||||
if args.cashier:
|
||||
print("Attempting to generate a cashier wallet...")
|
||||
await client.create_cashier_wallet(client.payload)
|
||||
|
||||
if args.test:
|
||||
print("Testing wallet write")
|
||||
await client.test_wallet(client.payload)
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash -x
|
||||
python scripts/preprocess.py proofs/bits.psm > /tmp/bits.psm
|
||||
python scripts/vm.py --rust /tmp/bits.psm > src/bits_contract.rs
|
||||
cargo run --release --bin bits
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash -x
|
||||
python scripts/preprocess.py proofs/mimc.psm > /tmp/mimc.psm
|
||||
python scripts/vm.py --rust /tmp/mimc.psm > src/zkmimc_contract.rs
|
||||
cargo run --release --bin zkmimc
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash -x
|
||||
python scripts/preprocess.py proofs/mint2.psm > /tmp/mint2.psm || exit $?
|
||||
python scripts/vm.py --rust /tmp/mint2.psm > src/mint2_contract.rs || exit $?
|
||||
cargo run --release --bin mint2
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
#!/bin/bash -x
|
||||
python scripts/preprocess.py proofs/mint.pism > /tmp/mint.pism
|
||||
python scripts/pism.py /tmp/mint.pism proofs/mint.aux | rustfmt > src/mint_contract.rs
|
||||
cargo run --release --bin mint
|
||||
@@ -1,4 +0,0 @@
|
||||
#!/bin/bash -x
|
||||
python scripts/preprocess.py proofs/spend.pism > /tmp/spend.pism
|
||||
python scripts/pism.py /tmp/spend.pism proofs/mint.aux | rustfmt > src/spend_contract.rs
|
||||
cargo run --release --bin spend
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash -x
|
||||
python scripts/preprocess.py proofs/jubjub.pism > /tmp/jubjub.pism
|
||||
python scripts/vm.py --rust /tmp/jubjub.pism > src/vm_load.rs
|
||||
cargo run --release --bin vmtest
|
||||
|
||||
@@ -1,835 +0,0 @@
|
||||
import lark
|
||||
import pprint
|
||||
import re
|
||||
import sys
|
||||
|
||||
class LineDesc:
|
||||
|
||||
def __init__(self, level, text, lineno):
|
||||
self.level = level
|
||||
self.text = text
|
||||
self.lineno = lineno
|
||||
assert self.text[0] != ' '
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s:'%s'>" % (self.level, self.text)
|
||||
|
||||
def clean_line(line):
|
||||
lead_spaces = len(line) - len(line.lstrip(" "))
|
||||
level = lead_spaces / 4
|
||||
# Remove leading spaces
|
||||
if line.strip(" ") == "":
|
||||
return None
|
||||
line = line.lstrip(" ")
|
||||
# Remove all comments
|
||||
line = re.sub('#.*$', '', line).strip()
|
||||
if not line:
|
||||
return None
|
||||
return level, line
|
||||
|
||||
def parse(text):
|
||||
lines = text.split("\n")
|
||||
|
||||
linedescs = []
|
||||
|
||||
# These are to join open parenthesis
|
||||
current_line = ""
|
||||
paren_level = 0
|
||||
|
||||
for lineno, line in enumerate(lines):
|
||||
if (lineinfo := clean_line(line)) is None:
|
||||
continue
|
||||
level, line = lineinfo
|
||||
|
||||
for c in line:
|
||||
if c == "(":
|
||||
paren_level += 1
|
||||
elif c == ")":
|
||||
paren_level -= 1
|
||||
|
||||
#print(level, paren_level, current_line)
|
||||
|
||||
if paren_level < 0:
|
||||
print("error: too many closing paren )", file=sys.stderr)
|
||||
print("line:", lineno)
|
||||
return
|
||||
|
||||
if current_line:
|
||||
current_line += " " + line
|
||||
else:
|
||||
current_line = line
|
||||
|
||||
if paren_level > 0:
|
||||
continue
|
||||
|
||||
#print(level, current_line)
|
||||
|
||||
ldesc = LineDesc(level, current_line, lineno)
|
||||
linedescs.append(ldesc)
|
||||
|
||||
current_line = ""
|
||||
|
||||
if paren_level > 0:
|
||||
print("error: missing closing paren )", file=sys.stderr)
|
||||
return None
|
||||
|
||||
return linedescs
|
||||
|
||||
def section(linedescs):
|
||||
sections = []
|
||||
|
||||
current_section = None
|
||||
for desc in linedescs:
|
||||
if desc.level == 0:
|
||||
if current_section:
|
||||
sections.append(current_section)
|
||||
current_section = [desc]
|
||||
continue
|
||||
|
||||
current_section.append(desc)
|
||||
sections.append(current_section)
|
||||
|
||||
return sections
|
||||
|
||||
def classify(sections):
|
||||
consts = []
|
||||
funcs = []
|
||||
contracts = []
|
||||
|
||||
for section in sections:
|
||||
assert len(section)
|
||||
|
||||
if section[0].text == "const:":
|
||||
consts.append(section)
|
||||
elif section[0].text.startswith("def"):
|
||||
funcs.append(section)
|
||||
elif section[0].text.startswith("contract"):
|
||||
contracts.append(section)
|
||||
|
||||
return consts, funcs, contracts
|
||||
|
||||
def tokenize_const(text):
|
||||
parser = lark.Lark(r"""
|
||||
value_map: name ":" type_def
|
||||
|
||||
name: NAME
|
||||
|
||||
?type_def: point
|
||||
| blake2s_personalization
|
||||
| pedersen_personalization
|
||||
| list
|
||||
|
||||
point: "Point"
|
||||
|
||||
blake2s_personalization: "Blake2sPersonalization"
|
||||
|
||||
pedersen_personalization: "PedersenPersonalization"
|
||||
|
||||
list: "list<" type_def ">"
|
||||
|
||||
%import common.CNAME -> NAME
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
""", start="value_map")
|
||||
return parser.parse(text)
|
||||
|
||||
class ConstTransformer(lark.Transformer):
|
||||
def name(self, name):
|
||||
return str(name[0])
|
||||
|
||||
def point(self, _):
|
||||
return "Point"
|
||||
def blake2s_personalization(self, _):
|
||||
return "Blake2sPersonalization"
|
||||
def pedersen_personalization(self, _):
|
||||
return "PedersenPersonalization"
|
||||
value_map = tuple
|
||||
list = list
|
||||
|
||||
def read_consts(consts):
|
||||
consts_map = {}
|
||||
|
||||
for subsection in consts:
|
||||
assert subsection[0].text == "const:"
|
||||
|
||||
for ldesc in subsection[1:]:
|
||||
tree = tokenize_const(ldesc.text)
|
||||
tokens = ConstTransformer().transform(tree)
|
||||
#print(tokens)
|
||||
name, typedesc = tokens
|
||||
consts_map[name] = typedesc
|
||||
|
||||
#pprint.pprint(consts_map)
|
||||
return consts_map
|
||||
|
||||
class FuncDefTransformer(lark.Transformer):
|
||||
def func_name(self, name):
|
||||
return str(name[0])
|
||||
|
||||
def param(self, obj):
|
||||
return tuple(obj)
|
||||
|
||||
def param_name(self, name):
|
||||
return str(name[0])
|
||||
|
||||
def u64(self, _):
|
||||
return "U64"
|
||||
def scalar(self, _):
|
||||
return "Scalar"
|
||||
def point(self, _):
|
||||
return "Point"
|
||||
def binary(self, _):
|
||||
return "Binary"
|
||||
|
||||
def type(self, obj):
|
||||
return obj[0]
|
||||
|
||||
func_def = list
|
||||
params = list
|
||||
type_list = list
|
||||
|
||||
def parse_func_def(text):
|
||||
parser = lark.Lark(r"""
|
||||
func_def: "def" func_name "(" params+ ")" "->" type_list ":"
|
||||
|
||||
func_name: NAME
|
||||
params: param ("," param)*
|
||||
|
||||
type_list: type
|
||||
| "(" type ("," type)* ")"
|
||||
|
||||
param: param_name ":" type
|
||||
param_name: NAME
|
||||
|
||||
type: u64 | scalar | point | binary
|
||||
|
||||
u64: "U64"
|
||||
scalar: "Scalar"
|
||||
point: "Point"
|
||||
binary: "Binary"
|
||||
|
||||
%import common.CNAME -> NAME
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
""", start="func_def")
|
||||
tree = parser.parse(text)
|
||||
tokens = FuncDefTransformer().transform(tree)
|
||||
assert len(tokens) == 3
|
||||
return tokens
|
||||
|
||||
def compile_func_header(func_def):
|
||||
func_name, params, retvals = func_def
|
||||
#print("Function:", func_name)
|
||||
#print("Params:", params)
|
||||
#print("Return values:", retvals)
|
||||
#print()
|
||||
|
||||
param_str = ""
|
||||
for param, type in params:
|
||||
if param_str:
|
||||
param_str += ", "
|
||||
param_str += param + ": Option<"
|
||||
if type == "U64":
|
||||
param_str += "u64"
|
||||
elif type == "Scalar":
|
||||
param_str += "jubjub::Fr"
|
||||
elif type == "Point":
|
||||
param_str += "jubjub::SubgroupPoint"
|
||||
else:
|
||||
print("error: unsupported param type", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
param_str += ">"
|
||||
|
||||
converted_retvals = []
|
||||
for type in retvals:
|
||||
if type == "Binary":
|
||||
converted_retvals.append("boolean::Boolean")
|
||||
else:
|
||||
print("error: unsupported return type", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
retvals = converted_retvals
|
||||
|
||||
if len(retvals) == 1:
|
||||
retstr = retvals[0]
|
||||
else:
|
||||
retstr = "(" + ", ".join(retvals) + ")"
|
||||
|
||||
header = r"""fn %s<CS>(
|
||||
mut cs: CS,
|
||||
%s
|
||||
) -> Result<%s, SynthesisError>
|
||||
where
|
||||
CS: ConstraintSystem<bls12_381::Scalar>,
|
||||
{
|
||||
""" % (func_name, param_str, retstr)
|
||||
return header
|
||||
|
||||
def as_expr(line, stack, consts, expr, code):
|
||||
var_from, type_to = expr.children
|
||||
if var_from not in stack:
|
||||
print("error: variable from not in stack frame:", var_from,
|
||||
file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
type_from = stack[var_from]
|
||||
|
||||
if type_from == "U64" and type_to == "Binary":
|
||||
code += "boolean::u64_into_boolean_vec_le(" + \
|
||||
"cs.namespace(|| \"" + line.text + "\"), " + var_from + \
|
||||
")?;"
|
||||
elif type_from == "Scalar" and type_to == "Binary":
|
||||
code += "boolean::field_into_boolean_vec_le(" + \
|
||||
"cs.namespace(|| \"" + line.text + "\"), " + var_from + \
|
||||
")?;"
|
||||
else:
|
||||
print("error: unknown type conversion!", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
#print(var_from, type_from, type_to)
|
||||
return code, type_to
|
||||
|
||||
def mul_expr(line, stack, consts, expr, code):
|
||||
var_a, var_b = expr.children
|
||||
#print("MUL", var_a, var_b)
|
||||
|
||||
if var_b not in consts:
|
||||
print("error: unknown base!", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
base_type = consts[var_b]
|
||||
if base_type != "Point":
|
||||
print("error: unknown base type!", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
code += "ecc::fixed_base_multiplication(" + \
|
||||
"cs.namespace(|| \"" + line.text + "\"), &" + var_b + \
|
||||
", &" + var_a + ")?;"
|
||||
|
||||
return code, base_type
|
||||
|
||||
def add_expr(line, stack, consts, expr, code):
|
||||
var_a, var_b = expr.children
|
||||
|
||||
if var_a not in stack or var_b not in stack:
|
||||
print("error: missing stack item!", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
result_type = stack[var_a]
|
||||
if stack[var_b] != result_type:
|
||||
print("error: non matching items for addition!", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
code += var_a + ".add(cs.namespace(|| \"" + line.text \
|
||||
+ "\"), &" + var_b + ")?;"
|
||||
|
||||
return (code, result_type)
|
||||
|
||||
def compile_let(line, stack, consts, statement):
|
||||
is_mutable = False
|
||||
if statement[0] == "mut":
|
||||
is_mutable = True
|
||||
statement = statement[1:]
|
||||
variable_name, variable_type = statement[0], statement[1]
|
||||
expr = statement[2]
|
||||
#print("LET", is_mutable, variable_name, variable_type)
|
||||
#print(" ", expr)
|
||||
code = "let " + ("mut " if is_mutable else "") + variable_name + " = "
|
||||
|
||||
if expr.data == "as_expr":
|
||||
ceval = as_expr(line, stack, consts, expr, code)
|
||||
elif expr.data == "mul_expr":
|
||||
ceval = mul_expr(line, stack, consts, expr, code)
|
||||
elif expr.data == "add_expr":
|
||||
ceval = add_expr(line, stack, consts, expr, code)
|
||||
|
||||
if ceval is None:
|
||||
return None
|
||||
|
||||
code, type_to = ceval
|
||||
|
||||
if variable_type != type_to:
|
||||
print("error: sub expr does not evaluate to correct type",
|
||||
file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
stack[variable_name] = variable_type
|
||||
|
||||
return code
|
||||
|
||||
def interpret_func(func, consts):
|
||||
func_def = parse_func_def(func[0].text)
|
||||
|
||||
header = compile_func_header(func_def)
|
||||
if header is None:
|
||||
return
|
||||
|
||||
subroutine = header
|
||||
indent = " " * 4
|
||||
|
||||
stack = dict(func_def[1])
|
||||
emitted_types = []
|
||||
for line in func[1:]:
|
||||
statement_type, statement = interpret_func_line(line.text, stack, consts)
|
||||
if statement_type == "let":
|
||||
code = compile_let(line, stack, consts, statement)
|
||||
if code is None:
|
||||
return
|
||||
subroutine += indent + code + "\n"
|
||||
|
||||
elif statement_type == "return":
|
||||
for var in statement:
|
||||
if var not in stack:
|
||||
print("error: missing variable in stack!", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
if len(statement) == 1:
|
||||
code = "Ok(" + statement[0] + ")"
|
||||
else:
|
||||
code = "Ok(" + ",".join(statement) + ")"
|
||||
subroutine += indent + code + "\n"
|
||||
|
||||
elif statement_type == "emit":
|
||||
assert len(statement) == 1
|
||||
variable = statement[0]
|
||||
if variable not in stack:
|
||||
print("error: missing variable in stack!", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
variable_type = stack[variable]
|
||||
|
||||
if variable_type == "Point":
|
||||
code = variable + ".inputize(cs.namespace(|| \"" + \
|
||||
line.text + "\"))?;"
|
||||
else:
|
||||
print("error: unable to inputize type!", file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
emitted_types.append(variable_type)
|
||||
subroutine += indent + code + "\n"
|
||||
|
||||
subroutine += "}\n\n"
|
||||
return subroutine, emitted_types, func_def
|
||||
|
||||
class CodeLineTransformer(lark.Transformer):
|
||||
def variable_name(self, name):
|
||||
return str(name[0])
|
||||
|
||||
def let_statement(self, obj):
|
||||
return ("let", obj)
|
||||
def return_statement(self, obj):
|
||||
return ("return", obj)
|
||||
def emit_statement(self, obj):
|
||||
return ("emit", obj)
|
||||
|
||||
def point(self, _):
|
||||
return "Point"
|
||||
def scalar(self, _):
|
||||
return "Scalar"
|
||||
def binary(self, _):
|
||||
return "Binary"
|
||||
def u64(self, _):
|
||||
return "U64"
|
||||
|
||||
def type(self, typename):
|
||||
return str(typename[0])
|
||||
|
||||
def mutable(self, _):
|
||||
return "mut"
|
||||
|
||||
statement = list
|
||||
|
||||
def interpret_func_line(text, stack, consts):
|
||||
parser = lark.Lark(r"""
|
||||
statement: let_statement
|
||||
| return_statement
|
||||
| emit_statement
|
||||
|
||||
let_statement: "let" [mutable] variable_name ":" type "=" expr
|
||||
mutable: "mut"
|
||||
|
||||
?expr: as_expr
|
||||
| mul_expr
|
||||
| add_expr
|
||||
|
||||
as_expr: variable_name "as" type
|
||||
mul_expr: variable_name "*" variable_name
|
||||
add_expr: variable_name "+" variable_name
|
||||
|
||||
return_statement: "return" variable_name
|
||||
| "return" variable_tuple
|
||||
variable_tuple: "(" variable_name ("," variable_name)* ")"
|
||||
|
||||
emit_statement: "emit" variable_name
|
||||
|
||||
variable_name: NAME
|
||||
type: u64 | scalar | point | binary
|
||||
|
||||
u64: "U64"
|
||||
scalar: "Scalar"
|
||||
point: "Point"
|
||||
binary: "Binary"
|
||||
|
||||
%import common.CNAME -> NAME
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
""", start="statement")
|
||||
tree = parser.parse(text)
|
||||
tokens = CodeLineTransformer().transform(tree)[0]
|
||||
return tokens
|
||||
|
||||
class ContractDefTransformer(lark.Transformer):
|
||||
def contract_name(self, name):
|
||||
return str(name[0])
|
||||
|
||||
def param(self, obj):
|
||||
return tuple(obj)
|
||||
|
||||
def param_name(self, name):
|
||||
return str(name[0])
|
||||
|
||||
def u64(self, _):
|
||||
return "U64"
|
||||
def scalar(self, _):
|
||||
return "Scalar"
|
||||
def point(self, _):
|
||||
return "Point"
|
||||
def binary(self, _):
|
||||
return "Binary"
|
||||
|
||||
def type(self, obj):
|
||||
return obj[0]
|
||||
|
||||
contract_def = list
|
||||
params = list
|
||||
type_list = list
|
||||
|
||||
def parse_contract_def(text):
|
||||
parser = lark.Lark(r"""
|
||||
contract_def: "contract" contract_name "(" params+ ")" "->" type_list ":"
|
||||
|
||||
contract_name: NAME
|
||||
params: param ("," param)*
|
||||
|
||||
type_list: type
|
||||
| "(" type ("," type)* ")"
|
||||
|
||||
param: param_name ":" type
|
||||
param_name: NAME
|
||||
|
||||
type: u64 | scalar | point | binary
|
||||
|
||||
u64: "U64"
|
||||
scalar: "Scalar"
|
||||
point: "Point"
|
||||
binary: "Binary"
|
||||
|
||||
%import common.CNAME -> NAME
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
""", start="contract_def")
|
||||
tree = parser.parse(text)
|
||||
tokens = ContractDefTransformer().transform(tree)
|
||||
assert len(tokens) == 3
|
||||
return tokens
|
||||
|
||||
class ContractCodeLineTransformer(lark.Transformer):
|
||||
def variable_name(self, name):
|
||||
return str(name[0])
|
||||
|
||||
def let_statement(self, obj):
|
||||
return ("let", obj)
|
||||
def return_statement(self, obj):
|
||||
return ("return", obj)
|
||||
def emit_statement(self, obj):
|
||||
return ("emit", obj)
|
||||
def method_statement(self, obj):
|
||||
return ("method", obj)
|
||||
|
||||
def point(self, _):
|
||||
return "Point"
|
||||
def scalar(self, _):
|
||||
return "Scalar"
|
||||
def binary(self, _):
|
||||
return "Binary"
|
||||
def u64(self, _):
|
||||
return "U64"
|
||||
|
||||
def type(self, typename):
|
||||
return str(typename[0])
|
||||
|
||||
def mutable(self, _):
|
||||
return "mut"
|
||||
|
||||
def function_name(self, name):
|
||||
return str(name[0])
|
||||
|
||||
statement = list
|
||||
variable_assign = list
|
||||
variable_decl = tuple
|
||||
|
||||
def interpret_contract_line(text, stack, consts):
|
||||
parser = lark.Lark(r"""
|
||||
statement: let_statement
|
||||
| return_statement
|
||||
| emit_statement
|
||||
| method_statement
|
||||
|
||||
let_statement: "let" variable_assign "=" expr
|
||||
mutable: "mut"
|
||||
|
||||
variable_assign: variable_decl
|
||||
| "(" variable_decl ("," variable_decl)* ")"
|
||||
variable_decl: [mutable] variable_name ":" type
|
||||
|
||||
?expr: as_expr
|
||||
| mul_expr
|
||||
| add_expr
|
||||
| funccall_expr
|
||||
| empty_list_expr
|
||||
|
||||
as_expr: variable_name "as" type
|
||||
mul_expr: variable_name "*" variable_name
|
||||
add_expr: variable_name "+" variable_name
|
||||
funccall_expr: function_name "(" [variable_name ("," variable_name)*] ")"
|
||||
empty_list_expr: "[]"
|
||||
|
||||
return_statement: "return" variable_name
|
||||
| "return" variable_tuple
|
||||
variable_tuple: "(" variable_name ("," variable_name)* ")"
|
||||
|
||||
emit_statement: "emit" variable_name
|
||||
|
||||
method_statement: variable_name "." funccall_expr
|
||||
|
||||
variable_name: NAME
|
||||
function_name: NAME
|
||||
type: u64 | scalar | point | binary
|
||||
|
||||
u64: "U64"
|
||||
scalar: "Scalar"
|
||||
point: "Point"
|
||||
binary: "Binary"
|
||||
|
||||
%import common.CNAME -> NAME
|
||||
%import common.WS
|
||||
%ignore WS
|
||||
""", start="statement")
|
||||
tree = parser.parse(text)
|
||||
tokens = ContractCodeLineTransformer().transform(tree)[0]
|
||||
return tokens
|
||||
|
||||
def to_initial_caps(snake_str):
|
||||
components = snake_str.split("_")
|
||||
return "".join(x.title() for x in components)
|
||||
|
||||
def create_contract_header(contract_def):
|
||||
contract_name, params, retvals = contract_def
|
||||
contract_name = to_initial_caps(contract_name)
|
||||
|
||||
header = "pub struct %s {\n" % contract_name
|
||||
|
||||
for param_name, param_type in params:
|
||||
if param_type == "U64":
|
||||
param_type = "u64"
|
||||
elif param_type == "Scalar":
|
||||
param_type = "jubjub::Fr"
|
||||
elif param_type == "Point":
|
||||
param_type = "jubjub::SubgroupPoint"
|
||||
header += " " * 4 + "pub %s: Option<%s>,\n" % (param_name, param_type)
|
||||
|
||||
header += "}\n\n"
|
||||
|
||||
header += r"""impl Circuit<bls12_381::Scalar> for %s {
|
||||
fn synthesize<CS: ConstraintSystem<bls12_381::Scalar>>(
|
||||
self,
|
||||
cs: &mut CS,
|
||||
) -> Result<(), SynthesisError> {
|
||||
""" % contract_name
|
||||
|
||||
return header
|
||||
|
||||
# Worst code ever
|
||||
def compile_let2(line, stack, consts, funcs, selfvars, statement):
|
||||
lhs = []
|
||||
for variable_decl in statement[0]:
|
||||
assert len(variable_decl) == 2 or \
|
||||
(len(variable_decl) == 3 and variable_decl[0] == "mut")
|
||||
|
||||
if len(variable_decl) == 2:
|
||||
mutable = False
|
||||
elif len(variable_decl) == 3:
|
||||
assert variable_decl[0] == "mut"
|
||||
mutable = True
|
||||
variable_decl = variable_decl[1:]
|
||||
#else:
|
||||
# Error!
|
||||
|
||||
lhs.append(list(variable_decl) + [mutable])
|
||||
|
||||
variable_types = []
|
||||
|
||||
code = "let "
|
||||
if len(lhs) == 1:
|
||||
name, type, is_mutable = lhs[0]
|
||||
variable_types.append(type)
|
||||
code += ("mut " if is_mutable else "") + name
|
||||
else:
|
||||
code += "("
|
||||
start = True
|
||||
for name, type, is_mutable in lhs:
|
||||
if not start:
|
||||
code += ", "
|
||||
start = False
|
||||
|
||||
code += name
|
||||
|
||||
variable_types.append(type)
|
||||
code += ")"
|
||||
code += " = "
|
||||
|
||||
expr = statement[1]
|
||||
expr_type = expr.data
|
||||
expr = expr.children
|
||||
if expr_type == "funccall_expr":
|
||||
ceval = funccall_expr(line, stack, consts, funcs, selfvars, expr, code)
|
||||
elif expr_type == "empty_list_expr":
|
||||
ceval = code + "vec![];", ["Binary"]
|
||||
#code = "let " + ("mut " if is_mutable else "") + variable_name + " = "
|
||||
|
||||
if ceval is None:
|
||||
return None
|
||||
|
||||
code, types_to = ceval
|
||||
|
||||
if variable_types != types_to:
|
||||
print("error: sub expr does not evaluate to correct type",
|
||||
file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
for name, type, _ in lhs:
|
||||
stack[name] = type
|
||||
|
||||
return code
|
||||
|
||||
def funccall_expr(line, stack, consts, funcs, selfvars, expr, code):
|
||||
func_name, arguments = expr[0], expr[1:]
|
||||
|
||||
if func_name not in funcs:
|
||||
print("error: non-existant function call",
|
||||
file=sys.stderr)
|
||||
print("line:", line.text, "line:", line.lineno)
|
||||
return None
|
||||
|
||||
arguments = [("self." + arg if arg in selfvars else arg)
|
||||
for arg in arguments]
|
||||
|
||||
code += "%s(cs.namespace(|| \"%s\"), %s)?;" % (
|
||||
func_name, line.text, ", ".join(arguments))
|
||||
|
||||
return_type = funcs[func_name][-1][-1]
|
||||
return code, return_type
|
||||
|
||||
def compile_method_call(line, stack, consts, funcs, selfvars, statement):
|
||||
variable = statement[0]
|
||||
method = statement[1].children[0]
|
||||
arguments = statement[1].children[1:]
|
||||
|
||||
arguments = [("self." + arg if arg in selfvars else arg)
|
||||
for arg in arguments]
|
||||
|
||||
return "%s.%s(%s);" % (variable, method, ", ".join(arguments))
|
||||
|
||||
def interpret_contract(contract, consts, funcs):
|
||||
contract_def = parse_contract_def(contract[0].text)
|
||||
contract_code = create_contract_header(contract_def)
|
||||
|
||||
selfvars = set(varname[0] for varname in contract_def[1])
|
||||
stack = dict(contract_def[1])
|
||||
for line in contract[1:10]:
|
||||
indent = " " * 4 * int(line.level + 1)
|
||||
statement_type, statement = interpret_contract_line(line.text, stack, consts)
|
||||
#pprint.pprint(statement_type)
|
||||
if statement_type == "let":
|
||||
code = compile_let2(line, stack, consts, funcs, selfvars, statement)
|
||||
if code is None:
|
||||
return
|
||||
elif statement_type == "method":
|
||||
code = compile_method_call(line, stack, consts, funcs,
|
||||
selfvars, statement)
|
||||
if code is None:
|
||||
return
|
||||
|
||||
contract_code += indent + code + "\n"
|
||||
|
||||
contract_code += " " * 8 + "Ok(())\n"
|
||||
contract_code += " " * 4 + "}\n"
|
||||
contract_code += "}\n\n"
|
||||
|
||||
#print("-------------------------------")
|
||||
#print(contract_code)
|
||||
return contract_code, contract_def
|
||||
|
||||
def main(argv):
|
||||
if len(argv) == 1:
|
||||
print("error: missing proof file", file=sys.stderr)
|
||||
return -1
|
||||
filename = sys.argv[1]
|
||||
text = open(filename, "r").read()
|
||||
|
||||
if (linedescs := parse(text)) is None:
|
||||
return -1
|
||||
|
||||
sections = section(linedescs)
|
||||
|
||||
consts, funcs, contracts = classify(sections)
|
||||
|
||||
consts = read_consts(consts)
|
||||
|
||||
compiled_funcs = {}
|
||||
for func in funcs:
|
||||
if (compiled := interpret_func(func, consts)) is None:
|
||||
return -1
|
||||
|
||||
_, _, func_def = compiled
|
||||
func_name, _, _ = func_def
|
||||
|
||||
compiled_funcs[func_name] = compiled
|
||||
funcs = compiled_funcs
|
||||
|
||||
compiled_contracts = {}
|
||||
for contract in contracts[1:]:
|
||||
if (compiled := interpret_contract(contract, consts, funcs)) is None:
|
||||
return -1
|
||||
#print(contract)
|
||||
|
||||
_, contract_def = compiled
|
||||
contract_name, _, _ = contract_def
|
||||
|
||||
compiled_contracts[contract_name] = compiled
|
||||
contracts = compiled_contracts
|
||||
|
||||
# Concat
|
||||
output = ""
|
||||
for _, func in funcs.items():
|
||||
output += func[0]
|
||||
for _, contract in contracts.items():
|
||||
output += contract[0]
|
||||
|
||||
#print(output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv)
|
||||
|
||||
528
scripts/pism.py
528
scripts/pism.py
@@ -1,528 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import codegen
|
||||
|
||||
symbol_table = {
|
||||
"contract": 1,
|
||||
"param": 2,
|
||||
"start": 0,
|
||||
"end": 0,
|
||||
}
|
||||
|
||||
types_map = {
|
||||
"U64": "u64",
|
||||
"Fr": "jubjub::Fr",
|
||||
"Point": "jubjub::SubgroupPoint",
|
||||
"Scalar": "bls12_381::Scalar",
|
||||
"Bool": "bool"
|
||||
}
|
||||
|
||||
feature_includes = {"G_SPEND": "use crate::crypto::merkle_node::SAPLING_COMMITMENT_TREE_DEPTH;\n"}
|
||||
|
||||
command_desc = {
|
||||
"witness": (
|
||||
("EdwardsPoint", True),
|
||||
("Point", False)
|
||||
),
|
||||
"assert_not_small_order": (
|
||||
("EdwardsPoint", False),
|
||||
),
|
||||
"u64_as_binary_le": (
|
||||
("Vec<Boolean>", True),
|
||||
("U64", False),
|
||||
),
|
||||
"fr_as_binary_le": (
|
||||
("Vec<Boolean>", True),
|
||||
("Fr", False)
|
||||
),
|
||||
"ec_mul_const": (
|
||||
("EdwardsPoint", True),
|
||||
("Vec<Boolean>", False),
|
||||
("FixedGenerator", False)
|
||||
),
|
||||
"ec_mul": (
|
||||
("EdwardsPoint", True),
|
||||
("Vec<Boolean>", False),
|
||||
("EdwardsPoint", False),
|
||||
),
|
||||
"ec_add": (
|
||||
("EdwardsPoint", True),
|
||||
("EdwardsPoint", False),
|
||||
("EdwardsPoint", False),
|
||||
),
|
||||
"ec_repr": (
|
||||
("Vec<Boolean>", True),
|
||||
("EdwardsPoint", False),
|
||||
),
|
||||
"ec_get_u": (
|
||||
("ScalarNum", True),
|
||||
("EdwardsPoint", False),
|
||||
),
|
||||
"emit_ec": (
|
||||
("EdwardsPoint", False),
|
||||
),
|
||||
"alloc_binary": (
|
||||
("Vec<Boolean>", True),
|
||||
),
|
||||
"binary_clone": (
|
||||
("Vec<Boolean>", True),
|
||||
("Vec<Boolean>", False),
|
||||
),
|
||||
"binary_extend": (
|
||||
("Vec<Boolean>", False),
|
||||
("Vec<Boolean>", False),
|
||||
),
|
||||
"binary_push": (
|
||||
("Vec<Boolean>", False),
|
||||
("Boolean", False),
|
||||
),
|
||||
"binary_truncate": (
|
||||
("Vec<Boolean>", False),
|
||||
("BinarySize", False),
|
||||
),
|
||||
"static_assert_binary_size": (
|
||||
("Vec<Boolean>", False),
|
||||
("INTEGER", False),
|
||||
),
|
||||
"blake2s": (
|
||||
("Vec<Boolean>", True),
|
||||
("Vec<Boolean>", False),
|
||||
("BlakePersonalization", False),
|
||||
),
|
||||
"pedersen_hash": (
|
||||
("EdwardsPoint", True),
|
||||
("Vec<Boolean>", False),
|
||||
("PedersenPersonalization", False),
|
||||
),
|
||||
"emit_binary": (
|
||||
("Vec<Boolean>", False),
|
||||
),
|
||||
"alloc_bit": (
|
||||
("Boolean", True),
|
||||
("Bool", False),
|
||||
),
|
||||
"alloc_const_bit": (
|
||||
("Boolean", True),
|
||||
("BOOL_CONST", False),
|
||||
),
|
||||
"clone_bit": (
|
||||
("Boolean", True),
|
||||
("Boolean", False),
|
||||
),
|
||||
"alloc_scalar": (
|
||||
("ScalarNum", True),
|
||||
("Scalar", False),
|
||||
),
|
||||
"scalar_as_binary": (
|
||||
("Vec<Boolean>", True),
|
||||
("ScalarNum", False),
|
||||
),
|
||||
"emit_scalar": (
|
||||
("ScalarNum", False),
|
||||
),
|
||||
"scalar_enforce_equal": (
|
||||
("ScalarNum", False),
|
||||
("ScalarNum", False),
|
||||
),
|
||||
"conditionally_reverse": (
|
||||
("ScalarNum", True),
|
||||
("ScalarNum", True),
|
||||
("ScalarNum", False),
|
||||
("ScalarNum", False),
|
||||
("Boolean", False),
|
||||
),
|
||||
}
|
||||
|
||||
def eprint(*args):
|
||||
print(*args, file=sys.stderr)
|
||||
|
||||
class Line:
|
||||
|
||||
def __init__(self, text, line_number):
|
||||
self.text = text
|
||||
self.orig = text
|
||||
self.lineno = line_number
|
||||
|
||||
self.clean()
|
||||
|
||||
def clean(self):
|
||||
# Remove the comments
|
||||
self.text = self.text.split("#", 1)[0]
|
||||
# Remove whitespace
|
||||
self.text = self.text.strip()
|
||||
|
||||
def is_empty(self):
|
||||
return bool(self.text)
|
||||
|
||||
def __repr__(self):
|
||||
return "Line %s: %s" % (self.lineno, self.orig.lstrip())
|
||||
|
||||
def command(self):
|
||||
if not self.is_empty():
|
||||
return None
|
||||
return self.text.split(" ")[0]
|
||||
|
||||
def args(self):
|
||||
if not self.is_empty():
|
||||
return None
|
||||
return self.text.split(" ")[1:]
|
||||
|
||||
def clean(contents):
|
||||
# Split input into lines
|
||||
contents = contents.split("\n")
|
||||
contents = [Line(line, i) for i, line in enumerate(contents)]
|
||||
# Remove empty blank lines
|
||||
contents = [line for line in contents if line.is_empty()]
|
||||
return contents
|
||||
|
||||
def make_segments(contents):
|
||||
constants = [line for line in contents if line.command() == "constant"]
|
||||
|
||||
segments = []
|
||||
current_segment = []
|
||||
for line in contents:
|
||||
if line.command() == "contract":
|
||||
current_segment = []
|
||||
|
||||
current_segment.append(line)
|
||||
|
||||
if line.command() == "end":
|
||||
segments.append(current_segment)
|
||||
current_segment = []
|
||||
|
||||
return constants, segments
|
||||
|
||||
def build_constants_table(constants):
|
||||
table = {}
|
||||
for line in constants:
|
||||
args = line.args()
|
||||
if len(args) != 2:
|
||||
eprint("error: wrong number of args")
|
||||
eprint(line)
|
||||
return None
|
||||
name, type = args
|
||||
table[name] = type
|
||||
return table
|
||||
|
||||
def extract(segment):
|
||||
assert segment
|
||||
# Does it have a declaration?
|
||||
if not segment[0].command() == "contract":
|
||||
eprint("error: missing contract declaration")
|
||||
eprint(segment[0])
|
||||
return None
|
||||
# Does it have an end?
|
||||
if not segment[-1].command() == "end":
|
||||
eprint("error: missing contract end")
|
||||
eprint(segment[-1])
|
||||
return None
|
||||
# Does it have a start?
|
||||
if not [line for line in segment if line.command() == "start"]:
|
||||
eprint("error: missing contract start")
|
||||
eprint(segment[0])
|
||||
return None
|
||||
|
||||
for line in segment:
|
||||
command, args = line.command(), line.args()
|
||||
|
||||
if command in symbol_table:
|
||||
if symbol_table[command] != len(args):
|
||||
eprint("error: wrong number of args for command '%s'" % command)
|
||||
eprint(line)
|
||||
return None
|
||||
elif command in command_desc:
|
||||
if len(command_desc[command]) != len(args):
|
||||
eprint("error: wrong number of args for command '%s'" % command)
|
||||
eprint(line)
|
||||
return None
|
||||
else:
|
||||
eprint("error: missing symbol for command '%s'" % command)
|
||||
eprint(line)
|
||||
return None
|
||||
|
||||
contract_name = segment[0].args()[0]
|
||||
|
||||
start_index = [index for index, line in enumerate(segment)
|
||||
if line.command() == "start"]
|
||||
if len(start_index) > 1:
|
||||
eprint("error: multiple start statements in contract '%s'" %
|
||||
contract_name)
|
||||
for index in start_index:
|
||||
eprint(segment[index])
|
||||
eprint("Aborting.")
|
||||
return None
|
||||
assert len(start_index) == 1
|
||||
start_index = start_index[0]
|
||||
|
||||
header = segment[1:start_index]
|
||||
code = segment[start_index + 1:-1]
|
||||
|
||||
params = {}
|
||||
for param_decl in header:
|
||||
args = param_decl.args()
|
||||
assert len(args) == 2
|
||||
name, type = args
|
||||
params[name] = type
|
||||
|
||||
program = []
|
||||
for line in code:
|
||||
command, args = line.command(), line.args()
|
||||
program.append((command, args, line))
|
||||
|
||||
return Contract(contract_name, params, program)
|
||||
|
||||
def to_initial_caps(snake_str):
|
||||
components = snake_str.split("_")
|
||||
return "".join(x.title() for x in components)
|
||||
|
||||
class Contract:
|
||||
|
||||
def __init__(self, name, params, program):
|
||||
self.name = name
|
||||
self.params = params
|
||||
self.program = program
|
||||
|
||||
def _includes(self):
|
||||
return \
|
||||
r"""#![allow(unused_imports)]
|
||||
#![allow(unused_mut)]
|
||||
use bellman::{
|
||||
gadgets::{
|
||||
boolean,
|
||||
boolean::{AllocatedBit, Boolean},
|
||||
multipack,
|
||||
blake2s,
|
||||
num,
|
||||
Assignment,
|
||||
},
|
||||
groth16, Circuit, ConstraintSystem, SynthesisError,
|
||||
};
|
||||
use bls12_381::Bls12;
|
||||
use ff::{PrimeField, Field};
|
||||
use group::Curve;
|
||||
use zcash_proofs::circuit::{ecc, pedersen_hash};
|
||||
"""
|
||||
|
||||
def _compile_header(self):
|
||||
code = "pub struct %s {\n" % to_initial_caps(self.name)
|
||||
for param_name, param_type in self.params.items():
|
||||
try:
|
||||
mapped_type = types_map[param_type]
|
||||
except KeyError:
|
||||
return None
|
||||
code += " pub %s: Option<%s>,\n" % (param_name, mapped_type)
|
||||
code += "}\n"
|
||||
return code
|
||||
|
||||
def _compile_body(self):
|
||||
self.stack = {}
|
||||
code = "\n"
|
||||
#indent = " " * 8
|
||||
for command, args, line in self.program:
|
||||
if (code_text := self._compile_line(command, args, line)) is None:
|
||||
return None
|
||||
code += "// %s\n" % str(line)
|
||||
code += code_text + "\n\n"
|
||||
return code
|
||||
|
||||
def _preprocess_args(self, args, line):
|
||||
nargs = []
|
||||
for arg in args:
|
||||
if not arg.startswith("param:"):
|
||||
nargs.append((arg, False))
|
||||
continue
|
||||
_, argname = arg.split(":", 1)
|
||||
if argname not in self.params:
|
||||
eprint("error: non-existant param referenced")
|
||||
eprint(line)
|
||||
return None
|
||||
nargs.append((argname, True))
|
||||
return nargs
|
||||
|
||||
def type_checking(self, command, args, line):
|
||||
assert command in command_desc
|
||||
type_list = command_desc[command]
|
||||
if len(type_list) != len(args):
|
||||
eprint("error: wrong number of arguments!")
|
||||
eprint(line)
|
||||
return False
|
||||
|
||||
for (expected_type, new_val), (argname, is_param) in \
|
||||
zip(type_list, args):
|
||||
# Only type check input arguments, not output values
|
||||
if new_val:
|
||||
continue
|
||||
|
||||
if expected_type == "INTEGER" or expected_type == "BOOL_CONST":
|
||||
continue
|
||||
|
||||
if is_param:
|
||||
actual_type = self.params[argname]
|
||||
elif argname in self.constants:
|
||||
actual_type = self.constants[argname]
|
||||
else:
|
||||
# Check the stack here
|
||||
if argname not in self.stack:
|
||||
eprint("error: cannot find value '%s' on the stack!" %
|
||||
argname)
|
||||
eprint(line)
|
||||
return False
|
||||
|
||||
actual_type = self.stack[argname]
|
||||
|
||||
if expected_type != actual_type:
|
||||
eprint("error: wrong type for arg '%s'!" % argname)
|
||||
eprint(line)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_args(self, command, args, line):
|
||||
assert command in command_desc
|
||||
type_list = command_desc[command]
|
||||
assert len(type_list) == len(args)
|
||||
|
||||
for (expected_type, is_new_val), (arg, is_param) in zip(type_list, args):
|
||||
if is_param:
|
||||
continue
|
||||
if is_new_val:
|
||||
continue
|
||||
if arg in self.stack:
|
||||
continue
|
||||
if arg in self.constants:
|
||||
continue
|
||||
|
||||
if expected_type == "INTEGER" or expected_type == "BOOL_CONST":
|
||||
continue
|
||||
|
||||
eprint("error: cannot find '%s' in the stack" % arg)
|
||||
eprint(line)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _compile_line(self, command, args, line):
|
||||
if (args := self._preprocess_args(args, line)) is None:
|
||||
return None
|
||||
if not self.type_checking(command, args, line):
|
||||
return None
|
||||
|
||||
if not self._check_args(command, args, line):
|
||||
return None
|
||||
|
||||
self.modify_stack(command, args)
|
||||
|
||||
args = [self.carg(arg) for arg in args]
|
||||
|
||||
try:
|
||||
codegen_method = getattr(codegen, command)
|
||||
except AttributeError:
|
||||
eprint("error: missing command '%s' does not exist" % command)
|
||||
eprint(line)
|
||||
return None
|
||||
|
||||
return codegen_method(line, *args)
|
||||
|
||||
def carg(self, arg):
|
||||
argname, is_param = arg
|
||||
if is_param:
|
||||
return "self.%s" % argname
|
||||
if argname in self.rename_consts:
|
||||
return self.rename_consts[argname]
|
||||
return argname
|
||||
|
||||
def modify_stack(self, command, args):
|
||||
type_list = command_desc[command]
|
||||
assert len(type_list) == len(args)
|
||||
for (expected_type, new_val), (argname, is_param) in \
|
||||
zip(type_list, args):
|
||||
if is_param:
|
||||
assert not new_val
|
||||
continue
|
||||
|
||||
# Now apply the new values to the stack
|
||||
if new_val:
|
||||
self.stack[argname] = expected_type
|
||||
|
||||
def compile(self, constants, aux):
|
||||
self.constants = constants
|
||||
code = ""
|
||||
|
||||
code += self._includes()
|
||||
|
||||
self.rename_consts = {}
|
||||
if "constants" in aux:
|
||||
for const_name, value in aux["constants"].items():
|
||||
if "maps_to" not in value:
|
||||
eprint("error: bad aux config '%s', missing maps_to" %
|
||||
const_name)
|
||||
return None
|
||||
|
||||
if const_name in feature_includes:
|
||||
code += feature_includes[const_name]
|
||||
|
||||
mapped_type = value["maps_to"]
|
||||
self.rename_consts[const_name] = mapped_type
|
||||
|
||||
code += "\n"
|
||||
|
||||
if (header := self._compile_header()) is None:
|
||||
return None
|
||||
code += header
|
||||
|
||||
code += \
|
||||
r"""impl Circuit<bls12_381::Scalar> for %s {
|
||||
fn synthesize<CS: ConstraintSystem<bls12_381::Scalar>>(
|
||||
self,
|
||||
cs: &mut CS,
|
||||
) -> Result<(), SynthesisError> {
|
||||
""" % to_initial_caps(self.name)
|
||||
|
||||
if (body := self._compile_body()) is None:
|
||||
return None
|
||||
code += body
|
||||
code += "Ok(())\n"
|
||||
|
||||
code += " }\n"
|
||||
code += "}\n"
|
||||
|
||||
return code
|
||||
|
||||
def process(contents, aux):
|
||||
contents = clean(contents)
|
||||
constants, segments = make_segments(contents)
|
||||
if (constants := build_constants_table(constants)) is None:
|
||||
return False
|
||||
|
||||
codes = []
|
||||
for segment in segments:
|
||||
if (contract := extract(segment)) is None:
|
||||
return False
|
||||
if (code := contract.compile(constants, aux)) is None:
|
||||
return False
|
||||
codes.append(code)
|
||||
|
||||
# Success! Output finished product.
|
||||
[print(code) for code in codes]
|
||||
|
||||
return True
|
||||
|
||||
def main(argv):
|
||||
if len(argv) != 3:
|
||||
eprint("pism FILENAME AUX_FILENAME")
|
||||
return -1
|
||||
|
||||
aux_filename = argv[2]
|
||||
aux = json.loads(open(aux_filename).read())
|
||||
|
||||
src_filename = argv[1]
|
||||
contents = open(src_filename).read()
|
||||
if not process(contents, aux):
|
||||
return -2
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv))
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
"For autoload, add this to your VIM config:
|
||||
" VIM: .vimrc
|
||||
" NeoVIM: .config/nvim/init.vim
|
||||
"
|
||||
"autocmd BufRead *.pism call SetPismOptions()
|
||||
"function SetPismOptions()
|
||||
" set syntax=pism
|
||||
" source /home/narodnik/src/drk/scripts/pism.vim
|
||||
"endfunction
|
||||
|
||||
if exists('b:current_syntax')
|
||||
finish
|
||||
endif
|
||||
|
||||
syn keyword drkKeyword constant contract start end constraint
|
||||
"syn keyword drkAttr
|
||||
syn keyword drkType FixedGenerator BlakePersonalization PedersenPersonalization ByteSize U64 Fr Point Bool Scalar BinarySize
|
||||
syn keyword drkFunctionKeyword enforce lc0_add_one lc1_add_one lc2_add_one lc_coeff_reset lc_coeff_double lc0_sub_one lc1_sub_one lc2_sub_one dump_alloc dump_local
|
||||
syn match drkFunction "^[ ]*[a-z_0-9]* "
|
||||
syn match drkComment "#.*$"
|
||||
syn match drkNumber ' \zs\d\+\ze'
|
||||
syn match drkHexNumber ' \zs0x[a-z0-9]\+\ze'
|
||||
syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
|
||||
syn keyword drkBoolVal true false
|
||||
syn match drkPreproc "{%.*%}"
|
||||
syn match drkPreproc2 "{{.*}}"
|
||||
|
||||
hi def link drkKeyword Statement
|
||||
"hi def link drkAttr StorageClass
|
||||
hi def link drkPreproc PreProc
|
||||
hi def link drkPreproc2 PreProc
|
||||
hi def link drkType Type
|
||||
hi def link drkFunction Function
|
||||
hi def link drkFunctionKeyword Function
|
||||
hi def link drkComment Comment
|
||||
hi def link drkNumber Constant
|
||||
hi def link drkHexNumber Constant
|
||||
hi def link drkConst Constant
|
||||
hi def link drkBoolVal Constant
|
||||
|
||||
let b:current_syntax = "pism"
|
||||
@@ -1,20 +0,0 @@
|
||||
import os.path
|
||||
import sys
|
||||
from jinja2 import Environment, FileSystemLoader, Template
|
||||
|
||||
def main(argv):
|
||||
if len(argv) != 2:
|
||||
print("error: missing arg", file=sys.stderr)
|
||||
return -1
|
||||
|
||||
path = argv[1]
|
||||
dirname, filename = os.path.dirname(path), os.path.basename(path)
|
||||
env = Environment(loader = FileSystemLoader([dirname]))
|
||||
template = env.get_template(filename)
|
||||
print(template.render())
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main(sys.argv))
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
#!/bin/bash
|
||||
#nvim -c ":TOhtml" $1
|
||||
sed -i "s/PreProc { color: #5fd7ff; }/PreProc { color: #8f2722; }/" $1
|
||||
sed -i "s/Comment { color: #00ffff; }/Comment { color: #0055ff; }/" $1
|
||||
|
||||
Reference in New Issue
Block a user