reorganize research into script/research/

This commit is contained in:
narodnik
2021-09-16 12:07:51 +02:00
parent 743366b673
commit 74f919d3cb
71 changed files with 0 additions and 2551 deletions

View File

@@ -1,131 +0,0 @@
# Functions here are called from pism.py using getattr()
# and the function name as a string.
def witness(line, out, point):
return \
r"""let %s = ecc::EdwardsPoint::witness(
cs.namespace(|| "%s"),
%s.map(jubjub::ExtendedPoint::from))?;""" % (out, line, point)
def assert_not_small_order(line, point):
return '%s.assert_not_small_order(cs.namespace(|| "%s"))?;' % (point, line)
def u64_as_binary_le(line, out, val):
return \
r"""let %s = boolean::u64_into_boolean_vec_le(
cs.namespace(|| "%s"),
%s,
)?;""" % (out, line, val)
def fr_as_binary_le(line, out, fr):
return \
r"""let %s = boolean::field_into_boolean_vec_le(
cs.namespace(|| "%s"), %s)?;""" % (out, line, fr)
def ec_mul_const(line, out, fr, base):
return \
r"""let %s = ecc::fixed_base_multiplication(
cs.namespace(|| "%s"),
&%s,
&%s,
)?;""" % (out, line, base, fr)
def ec_mul(line, out, fr, base):
return 'let %s = %s.mul(cs.namespace(|| "%s"), &%s)?;' % (
out, base, line, fr)
def ec_add(line, out, a, b):
return 'let %s = %s.add(cs.namespace(|| "%s"), &%s)?;' % (out, a, line, b)
def ec_repr(line, out, point):
return 'let %s = %s.repr(cs.namespace(|| "%s"))?;' % (out, point, line)
def ec_get_u(line, out, point):
return "let mut %s = %s.get_u().clone();" % (out, point)
def emit_ec(line, point):
return '%s.inputize(cs.namespace(|| "%s"))?;' % (point, line)
def alloc_binary(line, out):
return "let mut %s = vec![];" % out
def binary_clone(line, out, binary):
return "let mut %s: Vec<_> = %s.iter().cloned().collect();" % (out, binary)
def binary_extend(line, binary, value):
return "%s.extend(%s);" % (binary, value)
def binary_push(line, binary, bit):
return "%s.push(%s);" % (binary, bit)
def binary_truncate(line, binary, size):
return "%s.truncate(%s);" % (binary, size)
def static_assert_binary_size(line, binary, size):
return "assert_eq!(%s.len(), %s);" % (binary, size)
def blake2s(line, out, input, personalization):
return \
r"""let mut %s = blake2s::blake2s(
cs.namespace(|| "%s"),
&%s,
%s,
)?;""" % (out, line, input, personalization)
def pedersen_hash(line, out, input, personalization):
return \
r"""let mut %s = pedersen_hash::pedersen_hash(
cs.namespace(|| "%s"),
%s,
&%s,
)?;""" % (out, line, personalization, input)
def emit_binary(line, binary):
return 'multipack::pack_into_inputs(cs.namespace(|| "%s"), &%s)?;' % (
line, binary)
def alloc_bit(line, out, value):
return \
r"""let %s = boolean::Boolean::from(boolean::AllocatedBit::alloc(
cs.namespace(|| "%s"),
%s
)?);""" % (out, line, value)
def alloc_const_bit(line, out, value):
return "let %s = Boolean::constant(%s);" % (out, value)
def clone_bit(line, out, value):
return "let %s = %s.clone();" % (out, value)
def alloc_scalar(line, out, scalar):
return \
r"""let %s =
num::AllocatedNum::alloc(cs.namespace(|| "%s"), || Ok(*%s.get()?))?;""" % (
out, line, scalar)
def scalar_as_binary(line, out, scalar):
return 'let %s = %s.to_bits_le(cs.namespace(|| "%s"))?;' % (out, scalar,
line)
def emit_scalar(line, scalar):
return '%s.inputize(cs.namespace(|| "%s"))?;' % (scalar, line)
def scalar_enforce_equal(line, scalar_left, scalar_right):
return \
r"""cs.enforce(
|| "%s",
|lc| lc + %s.get_variable(),
|lc| lc + CS::one(),
|lc| lc + %s.get_variable(),
);""" % (line, scalar_left, scalar_right)
def conditionally_reverse(line, out_left, out_right, in_left, in_right,
condition):
return \
r"""let (%s, %s) = num::AllocatedNum::conditionally_reverse(
cs.namespace(|| "%s"),
&%s,
&%s,
&%s,
)?;""" % (out_left, out_right, line, in_left, in_right, condition)

View File

@@ -1,460 +0,0 @@
import argparse
import sys
from enum import Enum
alloc_commands = {
"param": 1,
"private": 1,
"public": 1,
}
op_commands = {
"set": 2,
"mul": 2,
"add": 2,
"sub": 2,
"divide": 2,
"double": 1,
"square": 1,
"invert": 1,
"unpack_bits": 3,
"load": 2,
"local": 1,
"debug": 1,
"dump_alloc": 0,
"dump_local": 0,
}
constraint_commands = {
"lc0_add": 1,
"lc1_add": 1,
"lc2_add": 1,
"lc0_sub": 1,
"lc1_sub": 1,
"lc2_sub": 1,
"lc0_add_one": 0,
"lc1_add_one": 0,
"lc2_add_one": 0,
"lc0_sub_one": 0,
"lc1_sub_one": 0,
"lc2_sub_one": 0,
"lc0_add_coeff": 2,
"lc1_add_coeff": 2,
"lc2_add_coeff": 2,
"lc0_add_constant": 1,
"lc1_add_constant": 1,
"lc2_add_constant": 1,
"enforce": 0,
"lc_coeff_reset": 0,
"lc_coeff_double": 0,
}
def eprint(*args):
print(*args, file=sys.stderr)
class Line:
def __init__(self, text, line_number):
self.text = text
self.orig = text
self.lineno = line_number
self.clean()
def clean(self):
# Remove the comments
self.text = self.text.split("#", 1)[0]
# Remove whitespace
self.text = self.text.strip()
def is_empty(self):
return bool(self.text)
def __repr__(self):
return "Line %s: %s" % (self.lineno, self.orig.lstrip())
def command(self):
if not self.is_empty():
return None
return self.text.split(" ")[0]
def args(self):
if not self.is_empty():
return None
return self.text.split()[1:]
def clean(contents):
# Split input into lines
contents = contents.split("\n")
contents = [Line(line, i + 1) for i, line in enumerate(contents)]
# Remove empty blank lines
contents = [line for line in contents if line.is_empty()]
return contents
def divide_sections(contents):
state = "NOSCOPE"
segments = {}
current_segment = []
contract_name = None
for line in contents:
if line.command() == "contract":
if len(line.args()) != 1:
eprint("error: missing contract name")
eprint(line)
return None
contract_name = line.args()[0]
if state == "NOSCOPE":
assert not current_segment
state = "INSCOPE"
continue
else:
assert state == "INSCOPE"
eprint("error: double contract entry violation")
eprint(line)
return None
elif line.command() == "end":
if len(line.args()) != 0:
eprint("error: end takes no args")
eprint(line)
return None
if state == "NOSCOPE":
eprint("error: missing contract start for end")
eprint(line)
return None
else:
assert state == "INSCOPE"
state = "NOSCOPE"
segments[contract_name] = current_segment
current_segment = []
continue
elif state == "NOSCOPE":
# Ignore lines outside any contract
continue
current_segment.append(line)
if state != "NOSCOPE":
eprint("error: reached end of file with unclosed scope")
return None
return segments
def extract_relevant_lines(contract, commands_table):
relevant_lines = []
for line in contract:
command = line.command()
if command not in commands_table.keys():
continue
define = commands_table[command]
if len(line.args()) != define:
eprint("error: wrong number of args")
return None
relevant_lines.append(line)
return relevant_lines
class VariableType(Enum):
PUBLIC = 1
PRIVATE = 2
class Variable:
def __init__(self, symbol, index, type, is_param):
self.symbol = symbol
self.index = index
self.type = type
self.is_param = is_param
def __repr__(self):
return "<Variable %s:%s>" % (self.symbol, self.index)
def generate_alloc_table(contract):
relevant_lines = extract_relevant_lines(contract, alloc_commands)
alloc_table = {}
for i, line in enumerate(relevant_lines):
assert len(line.args()) == 1
symbol = line.args()[0]
command = line.command()
if command == "param":
type = VariableType.PRIVATE
is_param = True
elif command == "private":
type = VariableType.PRIVATE
is_param = False
elif command == "public":
type = VariableType.PUBLIC
is_param = False
else:
assert False
if symbol in alloc_table:
eprint("error: duplicate symbol '%s'" % symbol)
eprint(line)
return None
alloc_table[symbol] = Variable(symbol, i, type, is_param)
return alloc_table
class Operation:
def __init__(self, line, indexes):
self.command = line.command()
self.args = indexes
self.line = line
class VariableRefType(Enum):
AUX = 1
LOCAL = 2
CONST = 3
class VariableRef:
def __init__(self, type, index):
self.type = type
self.index = index
def __repr__(self):
return "%s(%s)" % (self.type.name, self.index)
def symbols_list_to_refs(line, alloc, local_vars, constants):
indexes = []
for symbol in line.args():
if symbol in alloc:
# Lookup variable index
index = alloc[symbol].index
index = VariableRef(VariableRefType.AUX, index)
elif symbol in local_vars:
index = local_vars[symbol]
index = VariableRef(VariableRefType.LOCAL, index)
elif symbol in constants:
index = constants[symbol][0]
index = VariableRef(VariableRefType.CONST, index)
else:
eprint("error: missing unallocated symbol '%s'" % symbol)
eprint(line)
return None
indexes.append(index)
return indexes
def generate_ops_table(contract, alloc, constants):
relevant_lines = extract_relevant_lines(contract, op_commands)
ops = []
local_vars = {}
for line in relevant_lines:
# This is a special case which creates a new local stack value
if line.command() == "local":
assert len(line.args()) == 1
symbol = line.args()[0]
local_vars[symbol] = len(local_vars)
indexes = []
else:
if (indexes := symbols_list_to_refs(line, alloc,
local_vars, constants)) is None:
return None
# Handle this here directly since only the
# load command deals with constants
if line.command() == "load":
assert len(indexes) == 2
# This is the only command which uses consts
if indexes[1].type != VariableRefType.CONST:
eprint("error: load command takes a const argument")
eprint(line)
return None
elif any(index.type == VariableRefType.CONST for index in indexes):
eprint("error: invalid const arg")
eprint(line)
return None
ops.append(Operation(line, indexes))
return ops
class Constraint:
def __init__(self, line, lcargs):
self.command = line.command()
self.args = lcargs
self.line = line
def args_comment(self):
return ", ".join("%s" % symbol for symbol in self.line.args())
def symbols_list_to_lcargs(line, alloc, constants):
lcargs = []
for symbol in line.args():
if symbol in alloc:
# Lookup variable index
index = alloc[symbol].index
lcargs.append(index)
elif symbol in constants:
value = constants[symbol]
lcargs.append(value)
else:
eprint("error: missing unallocated symbol '%s'" % symbol)
eprint(line)
return None
return lcargs
def generate_constraints_table(contract, alloc, constants):
relevant_lines = extract_relevant_lines(contract, constraint_commands)
constraints = []
for line in relevant_lines:
if (lcargs := symbols_list_to_lcargs(line, alloc, constants)) is None:
return None
constraints.append(Constraint(line, lcargs))
return constraints
class Contract:
def __init__(self, constants, alloc, ops, constraints):
self.constants = constants
self.alloc = alloc
self.ops = ops
self.constraints = constraints
def __repr__(self):
repr_str = ""
repr_str += "Constants:\n"
for symbol, value in self.constants.items():
repr_str += " // %s\n" % symbol
repr_str += " %s: %s\n" % value
repr_str += "Alloc table:\n"
for symbol, variable in self.alloc.items():
repr_str += " // %s\n" % symbol
repr_str += " %s %s\n" % (variable.type, variable.index)
repr_str += "Operations:\n"
for op in self.ops:
repr_str += " // %s\n" % op.line
repr_str += " %s %s\n" % (op.command, op.args)
repr_str += "Constraints:\n"
for constraint in self.constraints:
if constraint.args:
repr_str += " // %s\n" % constraint.args_comment()
repr_str += " %s %s\n" % (constraint.command, constraint.args)
repr_str += "Stats:\n"
repr_str += " Constants: %s\n" % len(self.constants)
repr_str += " Alloc: %s\n" % len(self.alloc)
repr_str += " Operations: %s\n" % len(self.ops)
repr_str += " Constraint Instructions: %s\n" % len(self.constraints)
return repr_str
def compile(contract, constants):
# Allocation table
# symbol: Private/Public, is_param, index
if (alloc := generate_alloc_table(contract)) is None:
return None
# Operations lines list
if (ops := generate_ops_table(contract, alloc, constants)) is None:
return None
# Constraint commands
if (constraints := generate_constraints_table(
contract, alloc, constants)) is None:
return None
return Contract(constants, alloc, ops, constraints)
def parse_constants(contents):
relevant_lines = [line for line in contents if line.command() == "constant"]
constants = {}
for line in relevant_lines:
assert line.command() == "constant"
if len(line.args()) != 2:
eprint("error: wrong number of args for constant")
eprint(line)
return None
symbol, value = line.args()
try:
int(value, 16)
except ValueError:
eprint("error: invalid constant value for '%s'" % symbol)
eprint(line)
return None
if len(value) != 32*2 + 2 or value[:2] != "0x":
eprint("error: invalid hex value for constant")
eprint(line)
return None
# Remove 0x prefix
value = value[2:]
constants[symbol] = (len(constants), value)
return constants
def process(contents):
# Remove left whitespace
contents = clean(contents)
# Parse all constants
if (constants := parse_constants(contents)) is None:
return None
# Divide into contract sections
if (pre_contracts := divide_sections(contents)) is None:
return None
# Process each contract
contracts = {}
for contract_name, pre_contract in pre_contracts.items():
if (contract := compile(pre_contract, constants)) is None:
return None
contracts[contract_name] = contract
return contracts
def main(argv):
parser = argparse.ArgumentParser()
parser.add_argument("filename", help="VM PISM file: proofs/vm.pism")
parser.add_argument("--output", type=argparse.FileType('wb', 0),
default=sys.stdout.buffer, help="Output file")
group = parser.add_mutually_exclusive_group()
group.add_argument('--display', action='store_true',
help="show the compiled code in human readable format")
group.add_argument('--rust', action='store_true',
help="output compiled code to rust for testing")
group.add_argument('--supervisor', action='store_true',
help="output compiled code to zkvm supervisor")
args = parser.parse_args()
src_filename = args.filename
contents = open(src_filename).read()
if (contracts := process(contents)) is None:
return -2
def default_display():
for contract_name, contract in contracts.items():
print("Contract:", contract_name)
print(contract)
if args.display:
default_display()
elif args.rust:
import compile_export_rust
for contract_name, contract in contracts.items():
compile_export_rust.display(contract)
elif args.supervisor:
import compile_export_supervisor
for contract_name, contract in contracts.items():
compile_export_supervisor.export(args.output, contract_name,
contract)
else:
default_display()
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv))

View File

@@ -1,122 +0,0 @@
from compile import VariableType, VariableRefType
def to_initial_caps(snake_str):
components = snake_str.split("_")
return "".join(x.title() for x in components)
def display(contract):
indent = " " * 4
print(r"""use super::vm::{ZkVirtualMachine, CryptoOperation, AllocType, ConstraintInstruction, VariableIndex, VariableRef};
use bls12_381::Scalar;
pub fn load_params(params: Vec<Scalar>) -> Vec<(VariableIndex, Scalar)> {""")
params = [(symbol, var) for symbol, var in contract.alloc.items() if var.is_param]
print("%sassert_eq!(params.len(), %s);" % (indent, len(params)))
print("%slet mut result = vec![(0, Scalar::zero()); %s];" % (
indent, len(params)))
for i, (symbol, variable) in enumerate(params):
assert variable.is_param
print("%s// %s" % (indent, symbol))
print("%sresult[%s] = (%s, params[%s]);" % (
indent, i, variable.index, i))
print("%sresult" % indent)
print("}\n")
print(r"""pub fn load_zkvm() -> ZkVirtualMachine {
ZkVirtualMachine {
constants: vec![""")
constants = list(contract.constants.items())
constants.sort(key=lambda obj: obj[1][0])
constants = [(obj[0], obj[1][1]) for obj in constants]
for symbol, value in constants:
print("%s// %s" % (indent * 3, symbol))
assert len(value) == 32*2
chunk_str = lambda line, n: \
[line[i:i + n] for i in range(0, len(line), n)]
chunks = chunk_str(value, 2)
# Reverse the endianness
# We allow literal numbers but rust wants little endian
chunks = chunks[::-1]
print("%sScalar::from_bytes(&[" % (indent * 3))
for i in range(0, 32, 4):
print("%s0x%s, 0x%s, 0x%s, 0x%s," % (indent * 4,
chunks[i], chunks[i + 1], chunks[i + 2], chunks[i + 3]))
print("%s]).unwrap()," % (indent * 3))
print("%s]," % (indent * 2))
print("%salloc: vec![" % (indent * 2))
for symbol, variable in contract.alloc.items():
print("%s// %s" % (indent * 3, symbol))
if variable.type.name == VariableType.PRIVATE.name:
typestring = "Private"
elif variable.type.name == VariableType.PUBLIC.name:
typestring = "Public"
else:
assert False
print("%s(AllocType::%s, %s)," % (indent * 3, typestring,
variable.index))
print("%s]," % (indent * 2))
print("%sops: vec![" % (indent * 2))
def var_ref_str(var_ref):
if var_ref.type.name == VariableRefType.AUX.name:
return "VariableRef::Aux(%s)" % var_ref.index
elif var_ref.type.name == VariableRefType.LOCAL.name:
return "VariableRef::Local(%s)" % var_ref.index
else:
assert False
for op in contract.ops:
print("%s// %s" % (indent * 3, op.line))
args_part = ""
if op.command == "load":
assert len(op.args) == 2
args_part = "(%s, %s)" % (var_ref_str(op.args[0]), op.args[1].index)
elif op.command == "debug":
assert len(op.args) == 1
args_part = '(String::from("%s"), %s)' % (
op.line, var_ref_str(op.args[0]))
elif op.args:
args_part = ", ".join(var_ref_str(var_ref) for var_ref in op.args)
args_part = "(%s)" % args_part
print("%sCryptoOperation::%s%s," % (
indent * 3,
to_initial_caps(op.command),
args_part
))
print("%s]," % (indent * 2))
print("%sconstraints: vec![" % (indent * 2))
for constraint in contract.constraints:
args_part = ""
if constraint.args:
print("%s// %s" % (indent *3, constraint.args_comment()))
args = constraint.args[:]
if (constraint.command == "lc0_add_coeff" or
constraint.command == "lc1_add_coeff" or
constraint.command == "lc2_add_coeff" or
constraint.command == "lc0_add_one_coeff" or
constraint.command == "lc1_add_one_coeff" or
constraint.command == "lc2_add_one_coeff"):
args[0] = args[0][0]
args_part = ", ".join(str(index) for index in args)
args_part = "(%s)" % args_part
print("%sConstraintInstruction::%s%s," % (
indent * 3,
to_initial_caps(constraint.command),
args_part
))
print(r""" ],
aux: vec![],
params: None,
verifying_key: None,
}
}""")

View File

@@ -1,193 +0,0 @@
import struct
from compile import VariableType, VariableRefType
class Operation:
def __init__(self, ident, args):
self.ident = ident
self.args = args
class ArgVarRef:
def __init__(self, type, index):
self.type = type
self.index = index
def bytes(self):
return struct.pack("<BI", self.type, self.index)
class ArgVarIndex:
def __init__(self, _, index):
self.index = index
def bytes(self):
return struct.pack("<I", self.index)
class ArgString:
def __init__(self, description, index):
self.description = description
self.index = index
ops_table = {
"set": Operation(0, [ArgVarRef, ArgVarRef]),
"mul": Operation(1, [ArgVarRef, ArgVarRef]),
"add": Operation(2, [ArgVarRef, ArgVarRef]),
"sub": Operation(3, [ArgVarRef, ArgVarRef]),
"divide": Operation(4, [ArgVarRef, ArgVarRef]),
"double": Operation(5, [ArgVarRef]),
"square": Operation(6, [ArgVarRef]),
"invert": Operation(7, [ArgVarRef]),
"unpack_bits": Operation(8, [ArgVarRef, ArgVarRef, ArgVarRef]),
"local": Operation(9, []),
"load": Operation(10, [ArgVarRef, ArgVarIndex]),
"debug": Operation(11, [ArgString, ArgVarRef]),
"dump_alloc": Operation(12, []),
"dump_local": Operation(13, []),
}
constraint_ident_map = {
"lc0_add": 0,
"lc1_add": 1,
"lc2_add": 2,
"lc0_sub": 3,
"lc1_sub": 4,
"lc2_sub": 5,
"lc0_add_one": 6,
"lc1_add_one": 7,
"lc2_add_one": 8,
"lc0_sub_one": 9,
"lc1_sub_one": 10,
"lc2_sub_one": 11,
"lc0_add_coeff": 12,
"lc1_add_coeff": 13,
"lc2_add_coeff": 14,
"lc0_add_constant": 15,
"lc1_add_constant": 16,
"lc2_add_constant": 17,
"enforce": 18,
"lc_coeff_reset": 19,
"lc_coeff_double": 20,
}
def varuint(value):
if value <= 0xfc:
return struct.pack("<B", value)
elif value <= 0xffff:
return struct.pack("<BH", 0xfd, value)
elif value <= 0xffffffff:
return struct.pack("<BI", 0xfe, value)
else:
return struct.pack("<BQ", 0xff, value)
def export(output, contract_name, contract):
output.write(varuint(len(contract_name)))
output.write(contract_name.encode())
constants = list(contract.constants.items())
constants.sort(key=lambda obj: obj[1][0])
constants = [(obj[0], obj[1][1]) for obj in constants]
# Constants
output.write(varuint(len(constants)))
for symbol, value in constants:
print("Constant '%s' = %s" % (symbol, value))
# Bellman uses little endian for Scalars from_bytes function
const_bytes = bytearray.fromhex(value)[::-1]
assert len(const_bytes) == 32
output.write(const_bytes)
# Alloc
output.write(varuint(len(contract.alloc)))
for symbol, variable in contract.alloc.items():
print("Alloc '%s' = (%s, %s)" % (symbol,
variable.type.name, variable.index))
if variable.type.name == VariableType.PRIVATE.name:
typeval = 0
elif variable.type.name == VariableType.PUBLIC.name:
typeval = 1
else:
assert False
alloc_bytes = struct.pack("<BI", typeval, variable.index)
assert len(alloc_bytes) == 5
output.write(alloc_bytes)
# Ops
output.write(varuint(len(contract.ops)))
for op in contract.ops:
op_form = ops_table[op.command]
output.write(struct.pack("B", op_form.ident))
if op.command == "debug":
# Special case
assert len(op.args) == 1
line_str = str(op.line).encode()
output.write(varuint(len(line_str)))
output.write(line_str)
op_arg = op.args[0]
if op_arg.type.name == VariableRefType.AUX.name:
arg_type = 0
elif op_arg.type.name == VariableRefType.LOCAL.name:
arg_type = 1
arg = ArgVarRef(arg_type, op_arg.index)
output.write(arg.bytes())
continue
assert len(op_form.args) == len(op.args)
for arg_form, op_arg in zip(op_form.args, op.args):
if op_arg.type.name == VariableRefType.AUX.name:
arg_type = 0
elif op_arg.type.name == VariableRefType.LOCAL.name:
arg_type = 1
arg = arg_form(arg_type, op_arg.index)
output.write(arg.bytes())
print("#", op.line)
print("Operation", op.command,
[(arg.type.name, arg.index) for arg in op.args])
# Constraints
output.write(varuint(len(contract.constraints)))
for constraint in contract.constraints:
args = constraint.args[:]
if (constraint.command == "lc0_add_coeff" or
constraint.command == "lc1_add_coeff" or
constraint.command == "lc2_add_coeff" or
constraint.command == "lc0_add_constant" or
constraint.command == "lc1_add_constant" or
constraint.command == "lc2_add_constant"):
args[0] = args[0][0]
print("#", constraint.line)
print("Constraint", constraint.command, args if args else "")
enum_ident = constraint_ident_map[constraint.command]
output.write(struct.pack("B", enum_ident))
for arg in args:
output.write(struct.pack("<I", arg))
# Params Map
param_alloc = [(symbol, variable) for (symbol, variable)
in contract.alloc.items() if variable.is_param]
output.write(varuint(len(param_alloc)))
for symbol, variable in param_alloc:
assert variable.is_param
print("Parameter '%s' = %s" % (symbol, variable.index))
symbol = symbol.encode()
output.write(varuint(len(symbol)))
output.write(symbol)
output.write(struct.pack("<I", variable.index))
# Public Map
public_alloc = [(symbol, variable) for (symbol, variable)
in contract.alloc.items()
if variable.type.name == VariableType.PUBLIC.name]
output.write(varuint(len(public_alloc)))
for symbol, variable in public_alloc:
assert not variable.is_param
assert variable.type.name == VariableType.PUBLIC.name
print("Public '%s' = %s" % (symbol, variable.index))
symbol = symbol.encode()
output.write(varuint(len(symbol)))
output.write(symbol)
output.write(struct.pack("<I", variable.index))

View File

@@ -1,21 +0,0 @@
if exists('b:current_syntax')
finish
endif
syn keyword drkKeyword assert enforce for in def return const as let emit contract private proof
syn keyword drkAttr mut
syn keyword drkType BinaryNumber Point Fr SubgroupPoint EdwardsPoint Scalar EncryptedNum list Bool U64 Num Binary
syn match drkFunction "\zs[a-zA-Z0-9_]*\ze("
syn match drkComment "#.*$"
syn match drkNumber '\d\+'
syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
hi def link drkKeyword Statement
hi def link drkAttr StorageClass
hi def link drkType Type
hi def link drkFunction Function
hi def link drkComment Comment
hi def link drkNumber Constant
hi def link drkConst Constant
let b:current_syntax = "drk"

View File

@@ -1,2 +0,0 @@
#!/bin/bash
python scripts/parser.py proofs/sapling3.prf | rustfmt > proofs/sapling3.rs

View File

@@ -1,4 +0,0 @@
#!/bin/bash -x
python scripts/pism.py proofs/simple.pism | rustfmt > src/simple_circuit.rs
cargo run --release --bin simple

View File

@@ -1,5 +0,0 @@
# Dark Client
$ python3 -m venv env
$ source env/bin/activate
$ pip install -r requirements.txt

View File

@@ -1,89 +0,0 @@
#!/usr/bin/env python
from util import arg_parser
import aiohttp
import asyncio
class DarkClient:
# TODO: generate random ID (4 byte unsigned int) (rand range 0 - max size
# uint32
def __init__(self, client_session):
self.url = "http://localhost:8000/"
self.client_session = client_session
self.payload = {
"method": [],
"params": [],
"jsonrpc": [],
"id": [],
}
async def key_gen(self, payload):
payload['method'] = "key_gen"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
key = await self.__request(payload)
print(key)
async def get_info(self, payload):
payload['method'] = "get_info"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
info = await self.__request(payload)
print(info)
async def stop(self, payload):
payload['method'] = "stop"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
stop = await self.__request(payload)
print(stop)
async def say_hello(self, payload):
payload['method'] = "say_hello"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
hello = await self.__request(payload)
print(hello)
async def create_wallet(self, payload):
payload['method'] = "create_wallet"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
wallet = await self.__request(payload)
print(wallet)
async def create_cashier_wallet(self, payload):
payload['method'] = "create_cashier_wallet"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
wallet = await self.__request(payload)
print(wallet)
async def test_wallet(self, payload):
payload['method'] = "test_wallet"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
test = await self.__request(payload)
print(test)
async def __request(self, payload):
async with self.client_session.post(self.url, json=payload) as response:
resp = await response.text()
print(resp)
async def main():
try:
async with aiohttp.ClientSession() as session:
client = DarkClient(session)
await arg_parser(client)
except aiohttp.ClientConnectorError as err:
print('CONNECTION ERROR:', str(err))
except Exception as err:
print("ERROR: ", str(err))
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main())

View File

@@ -1,14 +0,0 @@
aiodns==3.0.0
aiohttp==3.7.4.post0
async-timeout==3.0.1
attrs==21.2.0
brotlipy==0.7.0
cchardet==2.1.7
cffi==1.14.5
chardet==4.0.0
idna==3.2
multidict==5.1.0
pycares==4.0.0
pycparser==2.20
typing-extensions==3.10.0.0
yarl==1.6.3

View File

@@ -1,53 +0,0 @@
import argparse
async def arg_parser(client):
parser = argparse.ArgumentParser(
prog='drk',
usage='%(prog)s [commands]',
description="""DarkFi wallet command-line tool"""
)
parser.add_argument('-c', '--cashier', action='store_true', help='Create a cashier wallet')
parser.add_argument('-w', '--wallet', action='store_true', help='Create a new wallet')
parser.add_argument('-k', '--key', action='store_true', help='Test key')
parser.add_argument('-i', '--info', action='store_true', help='Request info from daemon')
parser.add_argument('-hi', '--hello', action='store_true', help='Test hello')
parser.add_argument("-s", "--stop", action='store_true', help="Send a stop signal to the daemon")
parser.add_argument("-t", "--test", action='store_true', help="Test writing to the wallet")
try:
args = parser.parse_args()
if args.key:
print("Attemping to generate a create key pair...")
await client.key_gen(client.payload)
if args.wallet:
print("Attemping to create a wallet...")
await client.create_wallet(client.payload)
if args.info:
print("Info was entered")
await client.get_info(client.payload)
print("Requesting daemon info...")
if args.stop:
print("Stop was entered")
await client.stop(client.payload)
print("Sending a stop signal...")
if args.hello:
print("Hello was entered")
await client.say_hello(client.payload)
if args.cashier:
print("Attempting to generate a cashier wallet...")
await client.create_cashier_wallet(client.payload)
if args.test:
print("Testing wallet write")
await client.test_wallet(client.payload)
except Exception:
raise

View File

@@ -1,5 +0,0 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/bits.psm > /tmp/bits.psm
python scripts/vm.py --rust /tmp/bits.psm > src/bits_contract.rs
cargo run --release --bin bits

View File

@@ -1,5 +0,0 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/mimc.psm > /tmp/mimc.psm
python scripts/vm.py --rust /tmp/mimc.psm > src/zkmimc_contract.rs
cargo run --release --bin zkmimc

View File

@@ -1,5 +0,0 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/mint2.psm > /tmp/mint2.psm || exit $?
python scripts/vm.py --rust /tmp/mint2.psm > src/mint2_contract.rs || exit $?
cargo run --release --bin mint2

View File

@@ -1,4 +0,0 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/mint.pism > /tmp/mint.pism
python scripts/pism.py /tmp/mint.pism proofs/mint.aux | rustfmt > src/mint_contract.rs
cargo run --release --bin mint

View File

@@ -1,4 +0,0 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/spend.pism > /tmp/spend.pism
python scripts/pism.py /tmp/spend.pism proofs/mint.aux | rustfmt > src/spend_contract.rs
cargo run --release --bin spend

View File

@@ -1,5 +0,0 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/jubjub.pism > /tmp/jubjub.pism
python scripts/vm.py --rust /tmp/jubjub.pism > src/vm_load.rs
cargo run --release --bin vmtest

View File

@@ -1,835 +0,0 @@
import lark
import pprint
import re
import sys
class LineDesc:
def __init__(self, level, text, lineno):
self.level = level
self.text = text
self.lineno = lineno
assert self.text[0] != ' '
def __repr__(self):
return "<%s:'%s'>" % (self.level, self.text)
def clean_line(line):
lead_spaces = len(line) - len(line.lstrip(" "))
level = lead_spaces / 4
# Remove leading spaces
if line.strip(" ") == "":
return None
line = line.lstrip(" ")
# Remove all comments
line = re.sub('#.*$', '', line).strip()
if not line:
return None
return level, line
def parse(text):
lines = text.split("\n")
linedescs = []
# These are to join open parenthesis
current_line = ""
paren_level = 0
for lineno, line in enumerate(lines):
if (lineinfo := clean_line(line)) is None:
continue
level, line = lineinfo
for c in line:
if c == "(":
paren_level += 1
elif c == ")":
paren_level -= 1
#print(level, paren_level, current_line)
if paren_level < 0:
print("error: too many closing paren )", file=sys.stderr)
print("line:", lineno)
return
if current_line:
current_line += " " + line
else:
current_line = line
if paren_level > 0:
continue
#print(level, current_line)
ldesc = LineDesc(level, current_line, lineno)
linedescs.append(ldesc)
current_line = ""
if paren_level > 0:
print("error: missing closing paren )", file=sys.stderr)
return None
return linedescs
def section(linedescs):
sections = []
current_section = None
for desc in linedescs:
if desc.level == 0:
if current_section:
sections.append(current_section)
current_section = [desc]
continue
current_section.append(desc)
sections.append(current_section)
return sections
def classify(sections):
consts = []
funcs = []
contracts = []
for section in sections:
assert len(section)
if section[0].text == "const:":
consts.append(section)
elif section[0].text.startswith("def"):
funcs.append(section)
elif section[0].text.startswith("contract"):
contracts.append(section)
return consts, funcs, contracts
def tokenize_const(text):
parser = lark.Lark(r"""
value_map: name ":" type_def
name: NAME
?type_def: point
| blake2s_personalization
| pedersen_personalization
| list
point: "Point"
blake2s_personalization: "Blake2sPersonalization"
pedersen_personalization: "PedersenPersonalization"
list: "list<" type_def ">"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="value_map")
return parser.parse(text)
class ConstTransformer(lark.Transformer):
def name(self, name):
return str(name[0])
def point(self, _):
return "Point"
def blake2s_personalization(self, _):
return "Blake2sPersonalization"
def pedersen_personalization(self, _):
return "PedersenPersonalization"
value_map = tuple
list = list
def read_consts(consts):
consts_map = {}
for subsection in consts:
assert subsection[0].text == "const:"
for ldesc in subsection[1:]:
tree = tokenize_const(ldesc.text)
tokens = ConstTransformer().transform(tree)
#print(tokens)
name, typedesc = tokens
consts_map[name] = typedesc
#pprint.pprint(consts_map)
return consts_map
class FuncDefTransformer(lark.Transformer):
def func_name(self, name):
return str(name[0])
def param(self, obj):
return tuple(obj)
def param_name(self, name):
return str(name[0])
def u64(self, _):
return "U64"
def scalar(self, _):
return "Scalar"
def point(self, _):
return "Point"
def binary(self, _):
return "Binary"
def type(self, obj):
return obj[0]
func_def = list
params = list
type_list = list
def parse_func_def(text):
parser = lark.Lark(r"""
func_def: "def" func_name "(" params+ ")" "->" type_list ":"
func_name: NAME
params: param ("," param)*
type_list: type
| "(" type ("," type)* ")"
param: param_name ":" type
param_name: NAME
type: u64 | scalar | point | binary
u64: "U64"
scalar: "Scalar"
point: "Point"
binary: "Binary"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="func_def")
tree = parser.parse(text)
tokens = FuncDefTransformer().transform(tree)
assert len(tokens) == 3
return tokens
def compile_func_header(func_def):
func_name, params, retvals = func_def
#print("Function:", func_name)
#print("Params:", params)
#print("Return values:", retvals)
#print()
param_str = ""
for param, type in params:
if param_str:
param_str += ", "
param_str += param + ": Option<"
if type == "U64":
param_str += "u64"
elif type == "Scalar":
param_str += "jubjub::Fr"
elif type == "Point":
param_str += "jubjub::SubgroupPoint"
else:
print("error: unsupported param type", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
param_str += ">"
converted_retvals = []
for type in retvals:
if type == "Binary":
converted_retvals.append("boolean::Boolean")
else:
print("error: unsupported return type", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
retvals = converted_retvals
if len(retvals) == 1:
retstr = retvals[0]
else:
retstr = "(" + ", ".join(retvals) + ")"
header = r"""fn %s<CS>(
mut cs: CS,
%s
) -> Result<%s, SynthesisError>
where
CS: ConstraintSystem<bls12_381::Scalar>,
{
""" % (func_name, param_str, retstr)
return header
def as_expr(line, stack, consts, expr, code):
var_from, type_to = expr.children
if var_from not in stack:
print("error: variable from not in stack frame:", var_from,
file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
type_from = stack[var_from]
if type_from == "U64" and type_to == "Binary":
code += "boolean::u64_into_boolean_vec_le(" + \
"cs.namespace(|| \"" + line.text + "\"), " + var_from + \
")?;"
elif type_from == "Scalar" and type_to == "Binary":
code += "boolean::field_into_boolean_vec_le(" + \
"cs.namespace(|| \"" + line.text + "\"), " + var_from + \
")?;"
else:
print("error: unknown type conversion!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
#print(var_from, type_from, type_to)
return code, type_to
def mul_expr(line, stack, consts, expr, code):
var_a, var_b = expr.children
#print("MUL", var_a, var_b)
if var_b not in consts:
print("error: unknown base!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
base_type = consts[var_b]
if base_type != "Point":
print("error: unknown base type!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
code += "ecc::fixed_base_multiplication(" + \
"cs.namespace(|| \"" + line.text + "\"), &" + var_b + \
", &" + var_a + ")?;"
return code, base_type
def add_expr(line, stack, consts, expr, code):
var_a, var_b = expr.children
if var_a not in stack or var_b not in stack:
print("error: missing stack item!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
result_type = stack[var_a]
if stack[var_b] != result_type:
print("error: non matching items for addition!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
code += var_a + ".add(cs.namespace(|| \"" + line.text \
+ "\"), &" + var_b + ")?;"
return (code, result_type)
def compile_let(line, stack, consts, statement):
is_mutable = False
if statement[0] == "mut":
is_mutable = True
statement = statement[1:]
variable_name, variable_type = statement[0], statement[1]
expr = statement[2]
#print("LET", is_mutable, variable_name, variable_type)
#print(" ", expr)
code = "let " + ("mut " if is_mutable else "") + variable_name + " = "
if expr.data == "as_expr":
ceval = as_expr(line, stack, consts, expr, code)
elif expr.data == "mul_expr":
ceval = mul_expr(line, stack, consts, expr, code)
elif expr.data == "add_expr":
ceval = add_expr(line, stack, consts, expr, code)
if ceval is None:
return None
code, type_to = ceval
if variable_type != type_to:
print("error: sub expr does not evaluate to correct type",
file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
stack[variable_name] = variable_type
return code
def interpret_func(func, consts):
func_def = parse_func_def(func[0].text)
header = compile_func_header(func_def)
if header is None:
return
subroutine = header
indent = " " * 4
stack = dict(func_def[1])
emitted_types = []
for line in func[1:]:
statement_type, statement = interpret_func_line(line.text, stack, consts)
if statement_type == "let":
code = compile_let(line, stack, consts, statement)
if code is None:
return
subroutine += indent + code + "\n"
elif statement_type == "return":
for var in statement:
if var not in stack:
print("error: missing variable in stack!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
if len(statement) == 1:
code = "Ok(" + statement[0] + ")"
else:
code = "Ok(" + ",".join(statement) + ")"
subroutine += indent + code + "\n"
elif statement_type == "emit":
assert len(statement) == 1
variable = statement[0]
if variable not in stack:
print("error: missing variable in stack!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
variable_type = stack[variable]
if variable_type == "Point":
code = variable + ".inputize(cs.namespace(|| \"" + \
line.text + "\"))?;"
else:
print("error: unable to inputize type!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
emitted_types.append(variable_type)
subroutine += indent + code + "\n"
subroutine += "}\n\n"
return subroutine, emitted_types, func_def
class CodeLineTransformer(lark.Transformer):
def variable_name(self, name):
return str(name[0])
def let_statement(self, obj):
return ("let", obj)
def return_statement(self, obj):
return ("return", obj)
def emit_statement(self, obj):
return ("emit", obj)
def point(self, _):
return "Point"
def scalar(self, _):
return "Scalar"
def binary(self, _):
return "Binary"
def u64(self, _):
return "U64"
def type(self, typename):
return str(typename[0])
def mutable(self, _):
return "mut"
statement = list
def interpret_func_line(text, stack, consts):
parser = lark.Lark(r"""
statement: let_statement
| return_statement
| emit_statement
let_statement: "let" [mutable] variable_name ":" type "=" expr
mutable: "mut"
?expr: as_expr
| mul_expr
| add_expr
as_expr: variable_name "as" type
mul_expr: variable_name "*" variable_name
add_expr: variable_name "+" variable_name
return_statement: "return" variable_name
| "return" variable_tuple
variable_tuple: "(" variable_name ("," variable_name)* ")"
emit_statement: "emit" variable_name
variable_name: NAME
type: u64 | scalar | point | binary
u64: "U64"
scalar: "Scalar"
point: "Point"
binary: "Binary"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="statement")
tree = parser.parse(text)
tokens = CodeLineTransformer().transform(tree)[0]
return tokens
class ContractDefTransformer(lark.Transformer):
def contract_name(self, name):
return str(name[0])
def param(self, obj):
return tuple(obj)
def param_name(self, name):
return str(name[0])
def u64(self, _):
return "U64"
def scalar(self, _):
return "Scalar"
def point(self, _):
return "Point"
def binary(self, _):
return "Binary"
def type(self, obj):
return obj[0]
contract_def = list
params = list
type_list = list
def parse_contract_def(text):
parser = lark.Lark(r"""
contract_def: "contract" contract_name "(" params+ ")" "->" type_list ":"
contract_name: NAME
params: param ("," param)*
type_list: type
| "(" type ("," type)* ")"
param: param_name ":" type
param_name: NAME
type: u64 | scalar | point | binary
u64: "U64"
scalar: "Scalar"
point: "Point"
binary: "Binary"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="contract_def")
tree = parser.parse(text)
tokens = ContractDefTransformer().transform(tree)
assert len(tokens) == 3
return tokens
class ContractCodeLineTransformer(lark.Transformer):
def variable_name(self, name):
return str(name[0])
def let_statement(self, obj):
return ("let", obj)
def return_statement(self, obj):
return ("return", obj)
def emit_statement(self, obj):
return ("emit", obj)
def method_statement(self, obj):
return ("method", obj)
def point(self, _):
return "Point"
def scalar(self, _):
return "Scalar"
def binary(self, _):
return "Binary"
def u64(self, _):
return "U64"
def type(self, typename):
return str(typename[0])
def mutable(self, _):
return "mut"
def function_name(self, name):
return str(name[0])
statement = list
variable_assign = list
variable_decl = tuple
def interpret_contract_line(text, stack, consts):
parser = lark.Lark(r"""
statement: let_statement
| return_statement
| emit_statement
| method_statement
let_statement: "let" variable_assign "=" expr
mutable: "mut"
variable_assign: variable_decl
| "(" variable_decl ("," variable_decl)* ")"
variable_decl: [mutable] variable_name ":" type
?expr: as_expr
| mul_expr
| add_expr
| funccall_expr
| empty_list_expr
as_expr: variable_name "as" type
mul_expr: variable_name "*" variable_name
add_expr: variable_name "+" variable_name
funccall_expr: function_name "(" [variable_name ("," variable_name)*] ")"
empty_list_expr: "[]"
return_statement: "return" variable_name
| "return" variable_tuple
variable_tuple: "(" variable_name ("," variable_name)* ")"
emit_statement: "emit" variable_name
method_statement: variable_name "." funccall_expr
variable_name: NAME
function_name: NAME
type: u64 | scalar | point | binary
u64: "U64"
scalar: "Scalar"
point: "Point"
binary: "Binary"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="statement")
tree = parser.parse(text)
tokens = ContractCodeLineTransformer().transform(tree)[0]
return tokens
def to_initial_caps(snake_str):
components = snake_str.split("_")
return "".join(x.title() for x in components)
def create_contract_header(contract_def):
contract_name, params, retvals = contract_def
contract_name = to_initial_caps(contract_name)
header = "pub struct %s {\n" % contract_name
for param_name, param_type in params:
if param_type == "U64":
param_type = "u64"
elif param_type == "Scalar":
param_type = "jubjub::Fr"
elif param_type == "Point":
param_type = "jubjub::SubgroupPoint"
header += " " * 4 + "pub %s: Option<%s>,\n" % (param_name, param_type)
header += "}\n\n"
header += r"""impl Circuit<bls12_381::Scalar> for %s {
fn synthesize<CS: ConstraintSystem<bls12_381::Scalar>>(
self,
cs: &mut CS,
) -> Result<(), SynthesisError> {
""" % contract_name
return header
# Worst code ever
def compile_let2(line, stack, consts, funcs, selfvars, statement):
lhs = []
for variable_decl in statement[0]:
assert len(variable_decl) == 2 or \
(len(variable_decl) == 3 and variable_decl[0] == "mut")
if len(variable_decl) == 2:
mutable = False
elif len(variable_decl) == 3:
assert variable_decl[0] == "mut"
mutable = True
variable_decl = variable_decl[1:]
#else:
# Error!
lhs.append(list(variable_decl) + [mutable])
variable_types = []
code = "let "
if len(lhs) == 1:
name, type, is_mutable = lhs[0]
variable_types.append(type)
code += ("mut " if is_mutable else "") + name
else:
code += "("
start = True
for name, type, is_mutable in lhs:
if not start:
code += ", "
start = False
code += name
variable_types.append(type)
code += ")"
code += " = "
expr = statement[1]
expr_type = expr.data
expr = expr.children
if expr_type == "funccall_expr":
ceval = funccall_expr(line, stack, consts, funcs, selfvars, expr, code)
elif expr_type == "empty_list_expr":
ceval = code + "vec![];", ["Binary"]
#code = "let " + ("mut " if is_mutable else "") + variable_name + " = "
if ceval is None:
return None
code, types_to = ceval
if variable_types != types_to:
print("error: sub expr does not evaluate to correct type",
file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
for name, type, _ in lhs:
stack[name] = type
return code
def funccall_expr(line, stack, consts, funcs, selfvars, expr, code):
func_name, arguments = expr[0], expr[1:]
if func_name not in funcs:
print("error: non-existant function call",
file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
arguments = [("self." + arg if arg in selfvars else arg)
for arg in arguments]
code += "%s(cs.namespace(|| \"%s\"), %s)?;" % (
func_name, line.text, ", ".join(arguments))
return_type = funcs[func_name][-1][-1]
return code, return_type
def compile_method_call(line, stack, consts, funcs, selfvars, statement):
variable = statement[0]
method = statement[1].children[0]
arguments = statement[1].children[1:]
arguments = [("self." + arg if arg in selfvars else arg)
for arg in arguments]
return "%s.%s(%s);" % (variable, method, ", ".join(arguments))
def interpret_contract(contract, consts, funcs):
contract_def = parse_contract_def(contract[0].text)
contract_code = create_contract_header(contract_def)
selfvars = set(varname[0] for varname in contract_def[1])
stack = dict(contract_def[1])
for line in contract[1:10]:
indent = " " * 4 * int(line.level + 1)
statement_type, statement = interpret_contract_line(line.text, stack, consts)
#pprint.pprint(statement_type)
if statement_type == "let":
code = compile_let2(line, stack, consts, funcs, selfvars, statement)
if code is None:
return
elif statement_type == "method":
code = compile_method_call(line, stack, consts, funcs,
selfvars, statement)
if code is None:
return
contract_code += indent + code + "\n"
contract_code += " " * 8 + "Ok(())\n"
contract_code += " " * 4 + "}\n"
contract_code += "}\n\n"
#print("-------------------------------")
#print(contract_code)
return contract_code, contract_def
def main(argv):
if len(argv) == 1:
print("error: missing proof file", file=sys.stderr)
return -1
filename = sys.argv[1]
text = open(filename, "r").read()
if (linedescs := parse(text)) is None:
return -1
sections = section(linedescs)
consts, funcs, contracts = classify(sections)
consts = read_consts(consts)
compiled_funcs = {}
for func in funcs:
if (compiled := interpret_func(func, consts)) is None:
return -1
_, _, func_def = compiled
func_name, _, _ = func_def
compiled_funcs[func_name] = compiled
funcs = compiled_funcs
compiled_contracts = {}
for contract in contracts[1:]:
if (compiled := interpret_contract(contract, consts, funcs)) is None:
return -1
#print(contract)
_, contract_def = compiled
contract_name, _, _ = contract_def
compiled_contracts[contract_name] = compiled
contracts = compiled_contracts
# Concat
output = ""
for _, func in funcs.items():
output += func[0]
for _, contract in contracts.items():
output += contract[0]
#print(output)
if __name__ == "__main__":
main(sys.argv)

View File

@@ -1,528 +0,0 @@
import json
import os
import sys
import codegen
symbol_table = {
"contract": 1,
"param": 2,
"start": 0,
"end": 0,
}
types_map = {
"U64": "u64",
"Fr": "jubjub::Fr",
"Point": "jubjub::SubgroupPoint",
"Scalar": "bls12_381::Scalar",
"Bool": "bool"
}
feature_includes = {"G_SPEND": "use crate::crypto::merkle_node::SAPLING_COMMITMENT_TREE_DEPTH;\n"}
command_desc = {
"witness": (
("EdwardsPoint", True),
("Point", False)
),
"assert_not_small_order": (
("EdwardsPoint", False),
),
"u64_as_binary_le": (
("Vec<Boolean>", True),
("U64", False),
),
"fr_as_binary_le": (
("Vec<Boolean>", True),
("Fr", False)
),
"ec_mul_const": (
("EdwardsPoint", True),
("Vec<Boolean>", False),
("FixedGenerator", False)
),
"ec_mul": (
("EdwardsPoint", True),
("Vec<Boolean>", False),
("EdwardsPoint", False),
),
"ec_add": (
("EdwardsPoint", True),
("EdwardsPoint", False),
("EdwardsPoint", False),
),
"ec_repr": (
("Vec<Boolean>", True),
("EdwardsPoint", False),
),
"ec_get_u": (
("ScalarNum", True),
("EdwardsPoint", False),
),
"emit_ec": (
("EdwardsPoint", False),
),
"alloc_binary": (
("Vec<Boolean>", True),
),
"binary_clone": (
("Vec<Boolean>", True),
("Vec<Boolean>", False),
),
"binary_extend": (
("Vec<Boolean>", False),
("Vec<Boolean>", False),
),
"binary_push": (
("Vec<Boolean>", False),
("Boolean", False),
),
"binary_truncate": (
("Vec<Boolean>", False),
("BinarySize", False),
),
"static_assert_binary_size": (
("Vec<Boolean>", False),
("INTEGER", False),
),
"blake2s": (
("Vec<Boolean>", True),
("Vec<Boolean>", False),
("BlakePersonalization", False),
),
"pedersen_hash": (
("EdwardsPoint", True),
("Vec<Boolean>", False),
("PedersenPersonalization", False),
),
"emit_binary": (
("Vec<Boolean>", False),
),
"alloc_bit": (
("Boolean", True),
("Bool", False),
),
"alloc_const_bit": (
("Boolean", True),
("BOOL_CONST", False),
),
"clone_bit": (
("Boolean", True),
("Boolean", False),
),
"alloc_scalar": (
("ScalarNum", True),
("Scalar", False),
),
"scalar_as_binary": (
("Vec<Boolean>", True),
("ScalarNum", False),
),
"emit_scalar": (
("ScalarNum", False),
),
"scalar_enforce_equal": (
("ScalarNum", False),
("ScalarNum", False),
),
"conditionally_reverse": (
("ScalarNum", True),
("ScalarNum", True),
("ScalarNum", False),
("ScalarNum", False),
("Boolean", False),
),
}
def eprint(*args):
print(*args, file=sys.stderr)
class Line:
def __init__(self, text, line_number):
self.text = text
self.orig = text
self.lineno = line_number
self.clean()
def clean(self):
# Remove the comments
self.text = self.text.split("#", 1)[0]
# Remove whitespace
self.text = self.text.strip()
def is_empty(self):
return bool(self.text)
def __repr__(self):
return "Line %s: %s" % (self.lineno, self.orig.lstrip())
def command(self):
if not self.is_empty():
return None
return self.text.split(" ")[0]
def args(self):
if not self.is_empty():
return None
return self.text.split(" ")[1:]
def clean(contents):
# Split input into lines
contents = contents.split("\n")
contents = [Line(line, i) for i, line in enumerate(contents)]
# Remove empty blank lines
contents = [line for line in contents if line.is_empty()]
return contents
def make_segments(contents):
constants = [line for line in contents if line.command() == "constant"]
segments = []
current_segment = []
for line in contents:
if line.command() == "contract":
current_segment = []
current_segment.append(line)
if line.command() == "end":
segments.append(current_segment)
current_segment = []
return constants, segments
def build_constants_table(constants):
table = {}
for line in constants:
args = line.args()
if len(args) != 2:
eprint("error: wrong number of args")
eprint(line)
return None
name, type = args
table[name] = type
return table
def extract(segment):
assert segment
# Does it have a declaration?
if not segment[0].command() == "contract":
eprint("error: missing contract declaration")
eprint(segment[0])
return None
# Does it have an end?
if not segment[-1].command() == "end":
eprint("error: missing contract end")
eprint(segment[-1])
return None
# Does it have a start?
if not [line for line in segment if line.command() == "start"]:
eprint("error: missing contract start")
eprint(segment[0])
return None
for line in segment:
command, args = line.command(), line.args()
if command in symbol_table:
if symbol_table[command] != len(args):
eprint("error: wrong number of args for command '%s'" % command)
eprint(line)
return None
elif command in command_desc:
if len(command_desc[command]) != len(args):
eprint("error: wrong number of args for command '%s'" % command)
eprint(line)
return None
else:
eprint("error: missing symbol for command '%s'" % command)
eprint(line)
return None
contract_name = segment[0].args()[0]
start_index = [index for index, line in enumerate(segment)
if line.command() == "start"]
if len(start_index) > 1:
eprint("error: multiple start statements in contract '%s'" %
contract_name)
for index in start_index:
eprint(segment[index])
eprint("Aborting.")
return None
assert len(start_index) == 1
start_index = start_index[0]
header = segment[1:start_index]
code = segment[start_index + 1:-1]
params = {}
for param_decl in header:
args = param_decl.args()
assert len(args) == 2
name, type = args
params[name] = type
program = []
for line in code:
command, args = line.command(), line.args()
program.append((command, args, line))
return Contract(contract_name, params, program)
def to_initial_caps(snake_str):
components = snake_str.split("_")
return "".join(x.title() for x in components)
class Contract:
def __init__(self, name, params, program):
self.name = name
self.params = params
self.program = program
def _includes(self):
return \
r"""#![allow(unused_imports)]
#![allow(unused_mut)]
use bellman::{
gadgets::{
boolean,
boolean::{AllocatedBit, Boolean},
multipack,
blake2s,
num,
Assignment,
},
groth16, Circuit, ConstraintSystem, SynthesisError,
};
use bls12_381::Bls12;
use ff::{PrimeField, Field};
use group::Curve;
use zcash_proofs::circuit::{ecc, pedersen_hash};
"""
def _compile_header(self):
code = "pub struct %s {\n" % to_initial_caps(self.name)
for param_name, param_type in self.params.items():
try:
mapped_type = types_map[param_type]
except KeyError:
return None
code += " pub %s: Option<%s>,\n" % (param_name, mapped_type)
code += "}\n"
return code
def _compile_body(self):
self.stack = {}
code = "\n"
#indent = " " * 8
for command, args, line in self.program:
if (code_text := self._compile_line(command, args, line)) is None:
return None
code += "// %s\n" % str(line)
code += code_text + "\n\n"
return code
def _preprocess_args(self, args, line):
nargs = []
for arg in args:
if not arg.startswith("param:"):
nargs.append((arg, False))
continue
_, argname = arg.split(":", 1)
if argname not in self.params:
eprint("error: non-existant param referenced")
eprint(line)
return None
nargs.append((argname, True))
return nargs
def type_checking(self, command, args, line):
assert command in command_desc
type_list = command_desc[command]
if len(type_list) != len(args):
eprint("error: wrong number of arguments!")
eprint(line)
return False
for (expected_type, new_val), (argname, is_param) in \
zip(type_list, args):
# Only type check input arguments, not output values
if new_val:
continue
if expected_type == "INTEGER" or expected_type == "BOOL_CONST":
continue
if is_param:
actual_type = self.params[argname]
elif argname in self.constants:
actual_type = self.constants[argname]
else:
# Check the stack here
if argname not in self.stack:
eprint("error: cannot find value '%s' on the stack!" %
argname)
eprint(line)
return False
actual_type = self.stack[argname]
if expected_type != actual_type:
eprint("error: wrong type for arg '%s'!" % argname)
eprint(line)
return False
return True
def _check_args(self, command, args, line):
assert command in command_desc
type_list = command_desc[command]
assert len(type_list) == len(args)
for (expected_type, is_new_val), (arg, is_param) in zip(type_list, args):
if is_param:
continue
if is_new_val:
continue
if arg in self.stack:
continue
if arg in self.constants:
continue
if expected_type == "INTEGER" or expected_type == "BOOL_CONST":
continue
eprint("error: cannot find '%s' in the stack" % arg)
eprint(line)
return False
return True
def _compile_line(self, command, args, line):
if (args := self._preprocess_args(args, line)) is None:
return None
if not self.type_checking(command, args, line):
return None
if not self._check_args(command, args, line):
return None
self.modify_stack(command, args)
args = [self.carg(arg) for arg in args]
try:
codegen_method = getattr(codegen, command)
except AttributeError:
eprint("error: missing command '%s' does not exist" % command)
eprint(line)
return None
return codegen_method(line, *args)
def carg(self, arg):
argname, is_param = arg
if is_param:
return "self.%s" % argname
if argname in self.rename_consts:
return self.rename_consts[argname]
return argname
def modify_stack(self, command, args):
type_list = command_desc[command]
assert len(type_list) == len(args)
for (expected_type, new_val), (argname, is_param) in \
zip(type_list, args):
if is_param:
assert not new_val
continue
# Now apply the new values to the stack
if new_val:
self.stack[argname] = expected_type
def compile(self, constants, aux):
self.constants = constants
code = ""
code += self._includes()
self.rename_consts = {}
if "constants" in aux:
for const_name, value in aux["constants"].items():
if "maps_to" not in value:
eprint("error: bad aux config '%s', missing maps_to" %
const_name)
return None
if const_name in feature_includes:
code += feature_includes[const_name]
mapped_type = value["maps_to"]
self.rename_consts[const_name] = mapped_type
code += "\n"
if (header := self._compile_header()) is None:
return None
code += header
code += \
r"""impl Circuit<bls12_381::Scalar> for %s {
fn synthesize<CS: ConstraintSystem<bls12_381::Scalar>>(
self,
cs: &mut CS,
) -> Result<(), SynthesisError> {
""" % to_initial_caps(self.name)
if (body := self._compile_body()) is None:
return None
code += body
code += "Ok(())\n"
code += " }\n"
code += "}\n"
return code
def process(contents, aux):
contents = clean(contents)
constants, segments = make_segments(contents)
if (constants := build_constants_table(constants)) is None:
return False
codes = []
for segment in segments:
if (contract := extract(segment)) is None:
return False
if (code := contract.compile(constants, aux)) is None:
return False
codes.append(code)
# Success! Output finished product.
[print(code) for code in codes]
return True
def main(argv):
if len(argv) != 3:
eprint("pism FILENAME AUX_FILENAME")
return -1
aux_filename = argv[2]
aux = json.loads(open(aux_filename).read())
src_filename = argv[1]
contents = open(src_filename).read()
if not process(contents, aux):
return -2
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv))

View File

@@ -1,41 +0,0 @@
"For autoload, add this to your VIM config:
" VIM: .vimrc
" NeoVIM: .config/nvim/init.vim
"
"autocmd BufRead *.pism call SetPismOptions()
"function SetPismOptions()
" set syntax=pism
" source /home/narodnik/src/drk/scripts/pism.vim
"endfunction
if exists('b:current_syntax')
finish
endif
syn keyword drkKeyword constant contract start end constraint
"syn keyword drkAttr
syn keyword drkType FixedGenerator BlakePersonalization PedersenPersonalization ByteSize U64 Fr Point Bool Scalar BinarySize
syn keyword drkFunctionKeyword enforce lc0_add_one lc1_add_one lc2_add_one lc_coeff_reset lc_coeff_double lc0_sub_one lc1_sub_one lc2_sub_one dump_alloc dump_local
syn match drkFunction "^[ ]*[a-z_0-9]* "
syn match drkComment "#.*$"
syn match drkNumber ' \zs\d\+\ze'
syn match drkHexNumber ' \zs0x[a-z0-9]\+\ze'
syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
syn keyword drkBoolVal true false
syn match drkPreproc "{%.*%}"
syn match drkPreproc2 "{{.*}}"
hi def link drkKeyword Statement
"hi def link drkAttr StorageClass
hi def link drkPreproc PreProc
hi def link drkPreproc2 PreProc
hi def link drkType Type
hi def link drkFunction Function
hi def link drkFunctionKeyword Function
hi def link drkComment Comment
hi def link drkNumber Constant
hi def link drkHexNumber Constant
hi def link drkConst Constant
hi def link drkBoolVal Constant
let b:current_syntax = "pism"

View File

@@ -1,20 +0,0 @@
import os.path
import sys
from jinja2 import Environment, FileSystemLoader, Template
def main(argv):
if len(argv) != 2:
print("error: missing arg", file=sys.stderr)
return -1
path = argv[1]
dirname, filename = os.path.dirname(path), os.path.basename(path)
env = Environment(loader = FileSystemLoader([dirname]))
template = env.get_template(filename)
print(template.render())
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv))

View File

@@ -1,5 +0,0 @@
#!/bin/bash
#nvim -c ":TOhtml" $1
sed -i "s/PreProc { color: #5fd7ff; }/PreProc { color: #8f2722; }/" $1
sed -i "s/Comment { color: #00ffff; }/Comment { color: #0055ff; }/" $1