mirror of
https://github.com/darkrenaissance/darkfi.git
synced 2026-01-09 14:48:08 -05:00
migrate over zkas
This commit is contained in:
5
script/to_html.sh
Executable file
5
script/to_html.sh
Executable file
@@ -0,0 +1,5 @@
|
||||
#!/bin/bash
|
||||
#nvim -c ":TOhtml" $1
|
||||
sed -i "s/PreProc { color: #5fd7ff; }/PreProc { color: #8f2722; }/" $1
|
||||
sed -i "s/Comment { color: #00ffff; }/Comment { color: #0055ff; }/" $1
|
||||
|
||||
41
script/zk.vim
Normal file
41
script/zk.vim
Normal file
@@ -0,0 +1,41 @@
|
||||
"For autoload, add this to your VIM config:
|
||||
" VIM: .vimrc
|
||||
" NeoVIM: .config/nvim/init.vim
|
||||
"
|
||||
"autocmd BufRead *.pism call SetPismOptions()
|
||||
"function SetPismOptions()
|
||||
" set syntax=pism
|
||||
" source /home/narodnik/src/drk/scripts/pism.vim
|
||||
"endfunction
|
||||
|
||||
if exists('b:current_syntax')
|
||||
finish
|
||||
endif
|
||||
|
||||
syn keyword drkKeyword constant contract circuit
|
||||
"syn keyword drkAttr
|
||||
syn keyword drkType EcFixedPoint Base Scalar
|
||||
"syn keyword drkFunctionKeyword enforce lc0_add_one lc1_add_one lc2_add_one lc_coeff_reset lc_coeff_double lc0_sub_one lc1_sub_one lc2_sub_one dump_alloc dump_local
|
||||
syn match drkFunction "^[ ]*[a-z_0-9]* "
|
||||
syn match drkComment "#.*$"
|
||||
syn match drkNumber ' \zs\d\+\ze'
|
||||
syn match drkHexNumber ' \zs0x[a-z0-9]\+\ze'
|
||||
syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
|
||||
syn keyword drkBoolVal true false
|
||||
syn match drkPreproc "{%.*%}"
|
||||
syn match drkPreproc2 "{{.*}}"
|
||||
|
||||
hi def link drkKeyword Statement
|
||||
"hi def link drkAttr StorageClass
|
||||
hi def link drkPreproc PreProc
|
||||
hi def link drkPreproc2 PreProc
|
||||
hi def link drkType Type
|
||||
hi def link drkFunction Function
|
||||
hi def link drkFunctionKeyword Function
|
||||
hi def link drkComment Comment
|
||||
hi def link drkNumber Constant
|
||||
hi def link drkHexNumber Constant
|
||||
hi def link drkConst Constant
|
||||
hi def link drkBoolVal Constant
|
||||
|
||||
let b:current_syntax = "pism"
|
||||
402
script/zkas.py
Normal file
402
script/zkas.py
Normal file
@@ -0,0 +1,402 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from zkas.types import *
|
||||
|
||||
class CompileException(Exception):
|
||||
|
||||
def __init__(self, error_message, line):
|
||||
super().__init__(error_message)
|
||||
self.error_message = error_message
|
||||
self.line = line
|
||||
|
||||
class Constants:
|
||||
|
||||
def __init__(self):
|
||||
self.table = []
|
||||
self.map = {}
|
||||
|
||||
def add(self, variable, type_id):
|
||||
idx = len(self.table)
|
||||
self.table.append(type_id)
|
||||
self.map[variable] = idx
|
||||
|
||||
def lookup(self, variable):
|
||||
idx = self.map[variable]
|
||||
return self.table[idx]
|
||||
|
||||
def variables(self):
|
||||
return self.map.keys()
|
||||
|
||||
class SyntaxStruct:
|
||||
|
||||
def __init__(self):
|
||||
self.contracts = {}
|
||||
self.circuits = {}
|
||||
self.constants = Constants()
|
||||
|
||||
def parse_contract(self, line, it):
|
||||
assert line.tokens[0] == "contract"
|
||||
if len(line.tokens) != 3 or line.tokens[2] != "{":
|
||||
raise CompileException("malformed contract opening", line)
|
||||
name = line.tokens[1]
|
||||
if name in self.contracts:
|
||||
raise CompileException(f"duplicate contract {name}", line)
|
||||
lines = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
line = next(it)
|
||||
except StopIteration:
|
||||
raise CompileException(
|
||||
f"premature end of file while parsing {name} contract", line)
|
||||
|
||||
assert len(line.tokens) > 0
|
||||
if line.tokens[0] == "}":
|
||||
break
|
||||
|
||||
lines.append(line)
|
||||
|
||||
self.contracts[name] = lines
|
||||
|
||||
def parse_circuit(self, line, it):
|
||||
assert line.tokens[0] == "circuit"
|
||||
if len(line.tokens) != 3 or line.tokens[2] != "{":
|
||||
raise CompileException("malformed circuit opening", line)
|
||||
name = line.tokens[1]
|
||||
if name in self.circuits:
|
||||
raise CompileException(f"duplicate contract {name}", line)
|
||||
lines = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
line = next(it)
|
||||
except StopIteration:
|
||||
raise CompileException(
|
||||
f"premature end of file while parsing {name} circuit", line)
|
||||
|
||||
assert len(line.tokens) > 0
|
||||
if line.tokens[0] == "}":
|
||||
break
|
||||
|
||||
lines.append(line)
|
||||
|
||||
self.circuits[name] = lines
|
||||
|
||||
def parse_constant(self, line):
|
||||
assert line.tokens[0] == "constant"
|
||||
if len(line.tokens) != 3:
|
||||
raise CompileException("malformed constant line", line)
|
||||
_, type_name, variable = line.tokens
|
||||
if type_name not in allowed_types:
|
||||
raise CompileException("unknown type '{type}'", line)
|
||||
type_id = allowed_types[type_name]
|
||||
self.constants.add(variable, type_id)
|
||||
|
||||
def verify(self):
|
||||
self.static_checks()
|
||||
schema = self.format_data()
|
||||
self.trace_circuits(schema)
|
||||
return schema
|
||||
|
||||
def static_checks(self):
|
||||
for name, lines in self.contracts.items():
|
||||
for line in lines:
|
||||
if len(line.tokens) != 2:
|
||||
raise CompileException("incorrect number of tokens", line)
|
||||
type, variable = line.tokens
|
||||
if type not in allowed_types:
|
||||
raise CompileException(
|
||||
f"unknown type specifier for variable {variable}", line)
|
||||
|
||||
for name, lines in self.circuits.items():
|
||||
for line in lines:
|
||||
assert len(line.tokens) > 0
|
||||
func_name, args = line.tokens[0], line.tokens[1:]
|
||||
if func_name not in function_formats:
|
||||
raise CompileException(f"unknown function call {func_name}",
|
||||
line)
|
||||
func_format = function_formats[func_name]
|
||||
if len(args) != func_format.total_arguments():
|
||||
raise CompileException(
|
||||
f"incorrect number of arguments for function call {func_name}", line)
|
||||
|
||||
# Finally check there are matching circuits and contracts
|
||||
all_names = set(self.circuits.keys()) | set(self.contracts.keys())
|
||||
|
||||
for name in all_names:
|
||||
if name not in self.contracts:
|
||||
raise CompileException(f"missing contract for {name}", None)
|
||||
if name not in self.circuits:
|
||||
raise CompileException(f"missing circuit for {name}", None)
|
||||
|
||||
def format_data(self):
|
||||
schema = []
|
||||
for name, circuit in self.circuits.items():
|
||||
assert name in self.contracts
|
||||
contract = self.contracts[name]
|
||||
|
||||
witness = []
|
||||
for line in contract:
|
||||
assert len(line.tokens) == 2
|
||||
type_name, variable = line.tokens
|
||||
assert type_name in allowed_types
|
||||
type_id = allowed_types[type_name]
|
||||
witness.append((type_id, variable, line))
|
||||
|
||||
code = []
|
||||
for line in circuit:
|
||||
assert len(line.tokens) > 0
|
||||
func_name, args = line.tokens[0], line.tokens[1:]
|
||||
assert func_name in function_formats
|
||||
func_format = function_formats[func_name]
|
||||
assert len(args) == func_format.total_arguments()
|
||||
|
||||
return_values = []
|
||||
if func_format.return_type_ids:
|
||||
rv_len = len(func_format.return_type_ids)
|
||||
return_values, args = args[:rv_len], args[rv_len:]
|
||||
|
||||
func_id = func_format.func_id
|
||||
code.append((func_format, return_values, args, line))
|
||||
|
||||
schema.append((name, witness, code))
|
||||
return schema
|
||||
|
||||
def trace_circuits(self, schema):
|
||||
for name, witness, code in schema:
|
||||
tracer = DynamicTracer(name, witness, code, self.constants)
|
||||
tracer.execute()
|
||||
|
||||
class DynamicTracer:
|
||||
|
||||
def __init__(self, name, contract_witness, circuit_code, constants):
|
||||
self.name = name
|
||||
self.witness = contract_witness
|
||||
self.code = circuit_code
|
||||
self.constants = constants
|
||||
|
||||
def execute(self):
|
||||
stack = {}
|
||||
|
||||
# Load constants
|
||||
for variable in self.constants.variables():
|
||||
stack[variable] = self.constants.lookup(variable)
|
||||
|
||||
# Preload stack with our witness values
|
||||
for type_id, variable, line in self.witness:
|
||||
stack[variable] = type_id
|
||||
|
||||
for i, (func_format, return_values, args, code_line) \
|
||||
in enumerate(self.code):
|
||||
|
||||
assert len(args) == len(func_format.param_types)
|
||||
for variable, type_id in zip(args, func_format.param_types):
|
||||
if variable not in stack:
|
||||
raise CompileException(
|
||||
f"variable '{variable}' is not defined", code_line)
|
||||
|
||||
stack_type_id = stack[variable]
|
||||
if stack_type_id != type_id:
|
||||
type_name = type_id_to_name[type]
|
||||
stack_type_name = type_id_to_name[stack_type]
|
||||
raise CompileException(
|
||||
f"variable '{variable}' has incorrect type. "
|
||||
f"Found {type_name} but expected variable of "
|
||||
f"type {stack_type_name}", code_line)
|
||||
|
||||
assert len(return_values) == len(func_format.return_type_ids)
|
||||
|
||||
for return_variable, return_type_id \
|
||||
in zip(return_values, func_format.return_type_ids):
|
||||
|
||||
# Note that later variables shadow earlier ones.
|
||||
# We accept this.
|
||||
|
||||
stack[return_variable] = return_type_id
|
||||
|
||||
class CodeLine:
|
||||
|
||||
def __init__(self, func_format, return_values, args, arg_idxs, code_line):
|
||||
self.func_format = func_format
|
||||
self.return_values = return_values
|
||||
self.args = args
|
||||
self.arg_idxs = arg_idxs
|
||||
self.code_line = code_line
|
||||
|
||||
def func_name(self):
|
||||
return func_id_to_name[self.func_format.func_id]
|
||||
|
||||
class CompiledContract:
|
||||
|
||||
def __init__(self, name, witness, code):
|
||||
self.name = name
|
||||
self.witness = witness
|
||||
self.code = code
|
||||
|
||||
class Compiler:
|
||||
|
||||
def __init__(self, witness, uncompiled_code, constants):
|
||||
self.witness = witness
|
||||
self.uncompiled_code = uncompiled_code
|
||||
self.constants = constants
|
||||
|
||||
def compile(self):
|
||||
code = []
|
||||
|
||||
# Each unique type_id has its own stack
|
||||
stacks = [[] for i in range(TYPE_ID_LAST)]
|
||||
# Map from variable name to stacks above
|
||||
stack_vars = {}
|
||||
|
||||
def alloc(variable, type_id):
|
||||
assert type_id <= len(stacks)
|
||||
idx = len(stacks[type_id])
|
||||
# Add variable to the stack for its type_id
|
||||
stacks[type_id].append(variable)
|
||||
# Create mapping from variable name
|
||||
stack_vars[variable] = (type_id, idx)
|
||||
|
||||
# Load constants
|
||||
for variable in self.constants.variables():
|
||||
type_id = self.constants.lookup(variable)
|
||||
alloc(variable, type_id)
|
||||
|
||||
# Preload stack with our witness values
|
||||
for type_id, variable, line in self.witness:
|
||||
alloc(variable, type_id)
|
||||
|
||||
for i, (func_format, return_values, args, code_line) \
|
||||
in enumerate(self.uncompiled_code):
|
||||
|
||||
assert len(args) == len(func_format.param_types)
|
||||
|
||||
arg_idxs = []
|
||||
|
||||
# Loop through all arguments
|
||||
for variable, type_id in zip(args, func_format.param_types):
|
||||
assert type_id <= len(stacks)
|
||||
assert variable in stack_vars
|
||||
# Find the index for the M by N matrix of our variable
|
||||
loc_type_id, loc_idx = stack_vars[variable]
|
||||
assert type_id == loc_type_id
|
||||
assert stacks[loc_type_id][loc_idx] == variable
|
||||
|
||||
# This is the info to be serialized, not the variable names
|
||||
arg_idxs.append(loc_idx)
|
||||
|
||||
assert len(return_values) == len(func_format.return_type_ids)
|
||||
|
||||
for return_variable, return_type_id \
|
||||
in zip(return_values, func_format.return_type_ids):
|
||||
|
||||
# Allocate returned values so they can be used by
|
||||
# subsequent function calls.
|
||||
alloc(return_variable, return_type_id)
|
||||
|
||||
code.append(CodeLine(func_format, return_values, args,
|
||||
arg_idxs, code_line))
|
||||
|
||||
return code
|
||||
|
||||
class Line:
|
||||
|
||||
def __init__(self, tokens, original_line, number):
|
||||
self.tokens = tokens
|
||||
self.orig = original_line
|
||||
self.number = number
|
||||
|
||||
def __repr__(self):
|
||||
return f"Line({self.number}: {str(self.tokens)})"
|
||||
|
||||
def load(src_file):
|
||||
source = []
|
||||
for i, original_line in enumerate(src_file):
|
||||
# Remove whitespace on both sides
|
||||
line = original_line.strip()
|
||||
# Strip out comments
|
||||
line = line.split("#")[0]
|
||||
# Split at whitespace
|
||||
line = line.split()
|
||||
if not line:
|
||||
continue
|
||||
line_number = i + 1
|
||||
source.append(Line(line, original_line, line_number))
|
||||
return source
|
||||
|
||||
def parse(source):
|
||||
syntax = SyntaxStruct()
|
||||
it = iter(source)
|
||||
while True:
|
||||
try:
|
||||
line = next(it)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
assert len(line.tokens) > 0
|
||||
if line.tokens[0] == "contract":
|
||||
syntax.parse_contract(line, it)
|
||||
elif line.tokens[0] == "circuit":
|
||||
syntax.parse_circuit(line, it)
|
||||
elif line.tokens[0] == "constant":
|
||||
syntax.parse_constant(line)
|
||||
elif line.tokens[0] == "}":
|
||||
raise CompileException("unmatched delimiter '}'", line)
|
||||
return syntax
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("SOURCE", help="ZK script to compile")
|
||||
parser.add_argument("--output", default=None, help="output file")
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument('--display', action='store_true',
|
||||
help="show the compiled code in human readable format")
|
||||
group.add_argument('--bincode', action='store_true',
|
||||
help="output compiled code to zkvm supervisor")
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.SOURCE, "r") as src_file:
|
||||
source = load(src_file)
|
||||
try:
|
||||
syntax = parse(source)
|
||||
schema = syntax.verify()
|
||||
contracts = []
|
||||
for name, witness, uncompiled_code in schema:
|
||||
compiler = Compiler(witness, uncompiled_code, syntax.constants)
|
||||
code = compiler.compile()
|
||||
contracts.append(CompiledContract(name, witness, code))
|
||||
constants = syntax.constants
|
||||
if args.display:
|
||||
from zkas.text_output import output
|
||||
if args.output is None:
|
||||
output(sys.stdout, contracts, constants)
|
||||
else:
|
||||
with open(outpath, "w") as file:
|
||||
output(file, contracts, constants)
|
||||
elif args.bincode:
|
||||
from zkas.bincode_output import output
|
||||
outpath = args.output
|
||||
if args.output is None:
|
||||
outpath = args.SOURCE + ".bin"
|
||||
with open(outpath, "wb") as file:
|
||||
output(file, contracts, constants)
|
||||
else:
|
||||
from zkas.text_output import output
|
||||
if args.output is None:
|
||||
output(sys.stdout, contracts, constants)
|
||||
else:
|
||||
with open(outpath, "w") as file:
|
||||
output(file, contracts, constants)
|
||||
except CompileException as ex:
|
||||
print(f"Error: {ex.error_message}", file=sys.stderr)
|
||||
if ex.line is not None:
|
||||
print(f"Line {ex.line.number}: {ex.line.orig}", file=sys.stderr)
|
||||
#return -1
|
||||
raise
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
# todo: think about extendable payment scheme which
|
||||
# is like bitcoin soft forks
|
||||
0
script/zkas/__init__.py
Normal file
0
script/zkas/__init__.py
Normal file
50
script/zkas/bincode_output.py
Normal file
50
script/zkas/bincode_output.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import struct
|
||||
|
||||
def varuint(value):
|
||||
if value <= 0xfc:
|
||||
return struct.pack("<B", value)
|
||||
elif value <= 0xffff:
|
||||
return struct.pack("<BH", 0xfd, value)
|
||||
elif value <= 0xffffffff:
|
||||
return struct.pack("<BI", 0xfe, value)
|
||||
else:
|
||||
return struct.pack("<BQ", 0xff, value)
|
||||
|
||||
def write_len(output, objects):
|
||||
output.write(varuint(len(objects)))
|
||||
|
||||
def write_value(fmt, output, value):
|
||||
value_bytes = struct.pack("<" + fmt, value)
|
||||
output.write(value_bytes)
|
||||
|
||||
def write_u8(output, value):
|
||||
write_value("B", output, value)
|
||||
def write_u32(output, value):
|
||||
write_value("I", output, value)
|
||||
|
||||
def output_contract(output, contract):
|
||||
write_len(output, contract.name)
|
||||
output.write(contract.name.encode())
|
||||
write_len(output, contract.witness)
|
||||
for type_id, variable, _ in contract.witness:
|
||||
write_len(output, variable)
|
||||
output.write(variable.encode())
|
||||
write_u8(output, type_id)
|
||||
write_len(output, contract.code)
|
||||
for code in contract.code:
|
||||
func_id = code.func_format.func_id
|
||||
write_u8(output, func_id)
|
||||
for arg_idx in code.arg_idxs:
|
||||
write_u32(output, arg_idx)
|
||||
|
||||
def output(output, contracts, constants):
|
||||
write_len(output, constants.variables())
|
||||
for variable in constants.variables():
|
||||
write_len(output, variable)
|
||||
output.write(variable.encode())
|
||||
type_id = constants.lookup(variable)
|
||||
write_u8(output, type_id)
|
||||
write_len(output, contracts)
|
||||
for contract in contracts:
|
||||
output_contract(output, contract)
|
||||
|
||||
21
script/zkas/text_output.py
Normal file
21
script/zkas/text_output.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .types import type_id_to_name, func_id_to_name
|
||||
|
||||
def output(output, contracts, constants):
|
||||
output.write("Constants:\n")
|
||||
for variable in constants.variables():
|
||||
type_id = constants.lookup(variable)
|
||||
output.write(f" {type_id} {variable}\n")
|
||||
for contract in contracts:
|
||||
output.write(f"{contract.name}:\n")
|
||||
|
||||
output.write(f" Witness:\n")
|
||||
for type_id, variable, _ in contract.witness:
|
||||
type_name = type_id_to_name[type_id]
|
||||
output.write(f" {type_name} {variable}\n")
|
||||
|
||||
output.write(f" Code:\n")
|
||||
for code in contract.code:
|
||||
output.write(f" # args = {code.args}\n")
|
||||
output.write(f" {code.func_name()} {code.return_values} "
|
||||
f"{code.arg_idxs}\n")
|
||||
|
||||
69
script/zkas/types.py
Normal file
69
script/zkas/types.py
Normal file
@@ -0,0 +1,69 @@
|
||||
TYPE_ID_BASE = 0
|
||||
TYPE_ID_SCALAR = 1
|
||||
TYPE_ID_EC_POINT = 2
|
||||
TYPE_ID_EC_FIXED_POINT = 3
|
||||
# This is so we know the number of TYPE_ID stacks
|
||||
TYPE_ID_LAST = 4
|
||||
|
||||
allowed_types = {
|
||||
"Base": TYPE_ID_BASE,
|
||||
"Scalar": TYPE_ID_SCALAR,
|
||||
"EcFixedPoint": TYPE_ID_EC_FIXED_POINT
|
||||
}
|
||||
# Used for debug and error messages
|
||||
type_id_to_name = dict((value, key) for key, value in allowed_types.items())
|
||||
|
||||
FUNC_ID_POSEIDON_HASH = 0
|
||||
FUNC_ID_ADD = 1
|
||||
FUNC_ID_CONSTRAIN_INSTANCE = 2
|
||||
FUNC_ID_EC_MUL_SHORT = 3
|
||||
FUNC_ID_EC_MUL = 4
|
||||
FUNC_ID_EC_ADD = 5
|
||||
FUNC_ID_EC_GET_X = 6
|
||||
FUNC_ID_EC_GET_Y = 7
|
||||
|
||||
class FuncFormat:
|
||||
|
||||
def __init__(self, func_id, return_type_ids, param_types):
|
||||
self.func_id = func_id
|
||||
self.return_type_ids = return_type_ids
|
||||
self.param_types = param_types
|
||||
|
||||
def total_arguments(self):
|
||||
return len(self.return_type_ids) + len(self.param_types)
|
||||
|
||||
function_formats = {
|
||||
"poseidon_hash": FuncFormat(
|
||||
# Funcion ID Type ID Parameter types
|
||||
FUNC_ID_POSEIDON_HASH, [TYPE_ID_BASE], [TYPE_ID_BASE,
|
||||
TYPE_ID_BASE]
|
||||
),
|
||||
"add": FuncFormat(
|
||||
FUNC_ID_ADD, [TYPE_ID_BASE], [TYPE_ID_BASE,
|
||||
TYPE_ID_BASE]
|
||||
),
|
||||
"constrain_instance": FuncFormat(
|
||||
FUNC_ID_CONSTRAIN_INSTANCE, [], [TYPE_ID_BASE]
|
||||
),
|
||||
"ec_mul_short": FuncFormat(
|
||||
FUNC_ID_EC_MUL_SHORT, [TYPE_ID_EC_POINT], [TYPE_ID_BASE,
|
||||
TYPE_ID_EC_FIXED_POINT]
|
||||
),
|
||||
"ec_mul": FuncFormat(
|
||||
FUNC_ID_EC_MUL, [TYPE_ID_EC_POINT], [TYPE_ID_SCALAR,
|
||||
TYPE_ID_EC_FIXED_POINT]
|
||||
),
|
||||
"ec_add": FuncFormat(
|
||||
FUNC_ID_EC_ADD, [TYPE_ID_EC_POINT], [TYPE_ID_EC_POINT,
|
||||
TYPE_ID_EC_POINT]
|
||||
),
|
||||
"ec_get_x": FuncFormat(
|
||||
FUNC_ID_EC_GET_X, [TYPE_ID_BASE], [TYPE_ID_EC_POINT]
|
||||
),
|
||||
"ec_get_y": FuncFormat(
|
||||
FUNC_ID_EC_GET_Y, [TYPE_ID_BASE], [TYPE_ID_EC_POINT]
|
||||
),
|
||||
}
|
||||
|
||||
func_id_to_name = dict((fmt.func_id, key) for key, fmt
|
||||
in function_formats.items())
|
||||
Reference in New Issue
Block a user