migrate over zkas

This commit is contained in:
narodnik
2021-10-30 17:59:31 +02:00
parent 78c325fc2c
commit a19c438cd6
7 changed files with 588 additions and 0 deletions

5
script/to_html.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash
#nvim -c ":TOhtml" $1
sed -i "s/PreProc { color: #5fd7ff; }/PreProc { color: #8f2722; }/" $1
sed -i "s/Comment { color: #00ffff; }/Comment { color: #0055ff; }/" $1

41
script/zk.vim Normal file
View File

@@ -0,0 +1,41 @@
"For autoload, add this to your VIM config:
" VIM: .vimrc
" NeoVIM: .config/nvim/init.vim
"
"autocmd BufRead *.pism call SetPismOptions()
"function SetPismOptions()
" set syntax=pism
" source /home/narodnik/src/drk/scripts/pism.vim
"endfunction
if exists('b:current_syntax')
finish
endif
syn keyword drkKeyword constant contract circuit
"syn keyword drkAttr
syn keyword drkType EcFixedPoint Base Scalar
"syn keyword drkFunctionKeyword enforce lc0_add_one lc1_add_one lc2_add_one lc_coeff_reset lc_coeff_double lc0_sub_one lc1_sub_one lc2_sub_one dump_alloc dump_local
syn match drkFunction "^[ ]*[a-z_0-9]* "
syn match drkComment "#.*$"
syn match drkNumber ' \zs\d\+\ze'
syn match drkHexNumber ' \zs0x[a-z0-9]\+\ze'
syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
syn keyword drkBoolVal true false
syn match drkPreproc "{%.*%}"
syn match drkPreproc2 "{{.*}}"
hi def link drkKeyword Statement
"hi def link drkAttr StorageClass
hi def link drkPreproc PreProc
hi def link drkPreproc2 PreProc
hi def link drkType Type
hi def link drkFunction Function
hi def link drkFunctionKeyword Function
hi def link drkComment Comment
hi def link drkNumber Constant
hi def link drkHexNumber Constant
hi def link drkConst Constant
hi def link drkBoolVal Constant
let b:current_syntax = "pism"

402
script/zkas.py Normal file
View File

@@ -0,0 +1,402 @@
import argparse
import sys
from zkas.types import *
class CompileException(Exception):
def __init__(self, error_message, line):
super().__init__(error_message)
self.error_message = error_message
self.line = line
class Constants:
def __init__(self):
self.table = []
self.map = {}
def add(self, variable, type_id):
idx = len(self.table)
self.table.append(type_id)
self.map[variable] = idx
def lookup(self, variable):
idx = self.map[variable]
return self.table[idx]
def variables(self):
return self.map.keys()
class SyntaxStruct:
def __init__(self):
self.contracts = {}
self.circuits = {}
self.constants = Constants()
def parse_contract(self, line, it):
assert line.tokens[0] == "contract"
if len(line.tokens) != 3 or line.tokens[2] != "{":
raise CompileException("malformed contract opening", line)
name = line.tokens[1]
if name in self.contracts:
raise CompileException(f"duplicate contract {name}", line)
lines = []
while True:
try:
line = next(it)
except StopIteration:
raise CompileException(
f"premature end of file while parsing {name} contract", line)
assert len(line.tokens) > 0
if line.tokens[0] == "}":
break
lines.append(line)
self.contracts[name] = lines
def parse_circuit(self, line, it):
assert line.tokens[0] == "circuit"
if len(line.tokens) != 3 or line.tokens[2] != "{":
raise CompileException("malformed circuit opening", line)
name = line.tokens[1]
if name in self.circuits:
raise CompileException(f"duplicate contract {name}", line)
lines = []
while True:
try:
line = next(it)
except StopIteration:
raise CompileException(
f"premature end of file while parsing {name} circuit", line)
assert len(line.tokens) > 0
if line.tokens[0] == "}":
break
lines.append(line)
self.circuits[name] = lines
def parse_constant(self, line):
assert line.tokens[0] == "constant"
if len(line.tokens) != 3:
raise CompileException("malformed constant line", line)
_, type_name, variable = line.tokens
if type_name not in allowed_types:
raise CompileException("unknown type '{type}'", line)
type_id = allowed_types[type_name]
self.constants.add(variable, type_id)
def verify(self):
self.static_checks()
schema = self.format_data()
self.trace_circuits(schema)
return schema
def static_checks(self):
for name, lines in self.contracts.items():
for line in lines:
if len(line.tokens) != 2:
raise CompileException("incorrect number of tokens", line)
type, variable = line.tokens
if type not in allowed_types:
raise CompileException(
f"unknown type specifier for variable {variable}", line)
for name, lines in self.circuits.items():
for line in lines:
assert len(line.tokens) > 0
func_name, args = line.tokens[0], line.tokens[1:]
if func_name not in function_formats:
raise CompileException(f"unknown function call {func_name}",
line)
func_format = function_formats[func_name]
if len(args) != func_format.total_arguments():
raise CompileException(
f"incorrect number of arguments for function call {func_name}", line)
# Finally check there are matching circuits and contracts
all_names = set(self.circuits.keys()) | set(self.contracts.keys())
for name in all_names:
if name not in self.contracts:
raise CompileException(f"missing contract for {name}", None)
if name not in self.circuits:
raise CompileException(f"missing circuit for {name}", None)
def format_data(self):
schema = []
for name, circuit in self.circuits.items():
assert name in self.contracts
contract = self.contracts[name]
witness = []
for line in contract:
assert len(line.tokens) == 2
type_name, variable = line.tokens
assert type_name in allowed_types
type_id = allowed_types[type_name]
witness.append((type_id, variable, line))
code = []
for line in circuit:
assert len(line.tokens) > 0
func_name, args = line.tokens[0], line.tokens[1:]
assert func_name in function_formats
func_format = function_formats[func_name]
assert len(args) == func_format.total_arguments()
return_values = []
if func_format.return_type_ids:
rv_len = len(func_format.return_type_ids)
return_values, args = args[:rv_len], args[rv_len:]
func_id = func_format.func_id
code.append((func_format, return_values, args, line))
schema.append((name, witness, code))
return schema
def trace_circuits(self, schema):
for name, witness, code in schema:
tracer = DynamicTracer(name, witness, code, self.constants)
tracer.execute()
class DynamicTracer:
def __init__(self, name, contract_witness, circuit_code, constants):
self.name = name
self.witness = contract_witness
self.code = circuit_code
self.constants = constants
def execute(self):
stack = {}
# Load constants
for variable in self.constants.variables():
stack[variable] = self.constants.lookup(variable)
# Preload stack with our witness values
for type_id, variable, line in self.witness:
stack[variable] = type_id
for i, (func_format, return_values, args, code_line) \
in enumerate(self.code):
assert len(args) == len(func_format.param_types)
for variable, type_id in zip(args, func_format.param_types):
if variable not in stack:
raise CompileException(
f"variable '{variable}' is not defined", code_line)
stack_type_id = stack[variable]
if stack_type_id != type_id:
type_name = type_id_to_name[type]
stack_type_name = type_id_to_name[stack_type]
raise CompileException(
f"variable '{variable}' has incorrect type. "
f"Found {type_name} but expected variable of "
f"type {stack_type_name}", code_line)
assert len(return_values) == len(func_format.return_type_ids)
for return_variable, return_type_id \
in zip(return_values, func_format.return_type_ids):
# Note that later variables shadow earlier ones.
# We accept this.
stack[return_variable] = return_type_id
class CodeLine:
def __init__(self, func_format, return_values, args, arg_idxs, code_line):
self.func_format = func_format
self.return_values = return_values
self.args = args
self.arg_idxs = arg_idxs
self.code_line = code_line
def func_name(self):
return func_id_to_name[self.func_format.func_id]
class CompiledContract:
def __init__(self, name, witness, code):
self.name = name
self.witness = witness
self.code = code
class Compiler:
def __init__(self, witness, uncompiled_code, constants):
self.witness = witness
self.uncompiled_code = uncompiled_code
self.constants = constants
def compile(self):
code = []
# Each unique type_id has its own stack
stacks = [[] for i in range(TYPE_ID_LAST)]
# Map from variable name to stacks above
stack_vars = {}
def alloc(variable, type_id):
assert type_id <= len(stacks)
idx = len(stacks[type_id])
# Add variable to the stack for its type_id
stacks[type_id].append(variable)
# Create mapping from variable name
stack_vars[variable] = (type_id, idx)
# Load constants
for variable in self.constants.variables():
type_id = self.constants.lookup(variable)
alloc(variable, type_id)
# Preload stack with our witness values
for type_id, variable, line in self.witness:
alloc(variable, type_id)
for i, (func_format, return_values, args, code_line) \
in enumerate(self.uncompiled_code):
assert len(args) == len(func_format.param_types)
arg_idxs = []
# Loop through all arguments
for variable, type_id in zip(args, func_format.param_types):
assert type_id <= len(stacks)
assert variable in stack_vars
# Find the index for the M by N matrix of our variable
loc_type_id, loc_idx = stack_vars[variable]
assert type_id == loc_type_id
assert stacks[loc_type_id][loc_idx] == variable
# This is the info to be serialized, not the variable names
arg_idxs.append(loc_idx)
assert len(return_values) == len(func_format.return_type_ids)
for return_variable, return_type_id \
in zip(return_values, func_format.return_type_ids):
# Allocate returned values so they can be used by
# subsequent function calls.
alloc(return_variable, return_type_id)
code.append(CodeLine(func_format, return_values, args,
arg_idxs, code_line))
return code
class Line:
def __init__(self, tokens, original_line, number):
self.tokens = tokens
self.orig = original_line
self.number = number
def __repr__(self):
return f"Line({self.number}: {str(self.tokens)})"
def load(src_file):
source = []
for i, original_line in enumerate(src_file):
# Remove whitespace on both sides
line = original_line.strip()
# Strip out comments
line = line.split("#")[0]
# Split at whitespace
line = line.split()
if not line:
continue
line_number = i + 1
source.append(Line(line, original_line, line_number))
return source
def parse(source):
syntax = SyntaxStruct()
it = iter(source)
while True:
try:
line = next(it)
except StopIteration:
break
assert len(line.tokens) > 0
if line.tokens[0] == "contract":
syntax.parse_contract(line, it)
elif line.tokens[0] == "circuit":
syntax.parse_circuit(line, it)
elif line.tokens[0] == "constant":
syntax.parse_constant(line)
elif line.tokens[0] == "}":
raise CompileException("unmatched delimiter '}'", line)
return syntax
def main():
parser = argparse.ArgumentParser()
parser.add_argument("SOURCE", help="ZK script to compile")
parser.add_argument("--output", default=None, help="output file")
group = parser.add_mutually_exclusive_group()
group.add_argument('--display', action='store_true',
help="show the compiled code in human readable format")
group.add_argument('--bincode', action='store_true',
help="output compiled code to zkvm supervisor")
args = parser.parse_args()
with open(args.SOURCE, "r") as src_file:
source = load(src_file)
try:
syntax = parse(source)
schema = syntax.verify()
contracts = []
for name, witness, uncompiled_code in schema:
compiler = Compiler(witness, uncompiled_code, syntax.constants)
code = compiler.compile()
contracts.append(CompiledContract(name, witness, code))
constants = syntax.constants
if args.display:
from zkas.text_output import output
if args.output is None:
output(sys.stdout, contracts, constants)
else:
with open(outpath, "w") as file:
output(file, contracts, constants)
elif args.bincode:
from zkas.bincode_output import output
outpath = args.output
if args.output is None:
outpath = args.SOURCE + ".bin"
with open(outpath, "wb") as file:
output(file, contracts, constants)
else:
from zkas.text_output import output
if args.output is None:
output(sys.stdout, contracts, constants)
else:
with open(outpath, "w") as file:
output(file, contracts, constants)
except CompileException as ex:
print(f"Error: {ex.error_message}", file=sys.stderr)
if ex.line is not None:
print(f"Line {ex.line.number}: {ex.line.orig}", file=sys.stderr)
#return -1
raise
return 0
if __name__ == "__main__":
sys.exit(main())
# todo: think about extendable payment scheme which
# is like bitcoin soft forks

0
script/zkas/__init__.py Normal file
View File

View File

@@ -0,0 +1,50 @@
import struct
def varuint(value):
if value <= 0xfc:
return struct.pack("<B", value)
elif value <= 0xffff:
return struct.pack("<BH", 0xfd, value)
elif value <= 0xffffffff:
return struct.pack("<BI", 0xfe, value)
else:
return struct.pack("<BQ", 0xff, value)
def write_len(output, objects):
output.write(varuint(len(objects)))
def write_value(fmt, output, value):
value_bytes = struct.pack("<" + fmt, value)
output.write(value_bytes)
def write_u8(output, value):
write_value("B", output, value)
def write_u32(output, value):
write_value("I", output, value)
def output_contract(output, contract):
write_len(output, contract.name)
output.write(contract.name.encode())
write_len(output, contract.witness)
for type_id, variable, _ in contract.witness:
write_len(output, variable)
output.write(variable.encode())
write_u8(output, type_id)
write_len(output, contract.code)
for code in contract.code:
func_id = code.func_format.func_id
write_u8(output, func_id)
for arg_idx in code.arg_idxs:
write_u32(output, arg_idx)
def output(output, contracts, constants):
write_len(output, constants.variables())
for variable in constants.variables():
write_len(output, variable)
output.write(variable.encode())
type_id = constants.lookup(variable)
write_u8(output, type_id)
write_len(output, contracts)
for contract in contracts:
output_contract(output, contract)

View File

@@ -0,0 +1,21 @@
from .types import type_id_to_name, func_id_to_name
def output(output, contracts, constants):
output.write("Constants:\n")
for variable in constants.variables():
type_id = constants.lookup(variable)
output.write(f" {type_id} {variable}\n")
for contract in contracts:
output.write(f"{contract.name}:\n")
output.write(f" Witness:\n")
for type_id, variable, _ in contract.witness:
type_name = type_id_to_name[type_id]
output.write(f" {type_name} {variable}\n")
output.write(f" Code:\n")
for code in contract.code:
output.write(f" # args = {code.args}\n")
output.write(f" {code.func_name()} {code.return_values} "
f"{code.arg_idxs}\n")

69
script/zkas/types.py Normal file
View File

@@ -0,0 +1,69 @@
TYPE_ID_BASE = 0
TYPE_ID_SCALAR = 1
TYPE_ID_EC_POINT = 2
TYPE_ID_EC_FIXED_POINT = 3
# This is so we know the number of TYPE_ID stacks
TYPE_ID_LAST = 4
allowed_types = {
"Base": TYPE_ID_BASE,
"Scalar": TYPE_ID_SCALAR,
"EcFixedPoint": TYPE_ID_EC_FIXED_POINT
}
# Used for debug and error messages
type_id_to_name = dict((value, key) for key, value in allowed_types.items())
FUNC_ID_POSEIDON_HASH = 0
FUNC_ID_ADD = 1
FUNC_ID_CONSTRAIN_INSTANCE = 2
FUNC_ID_EC_MUL_SHORT = 3
FUNC_ID_EC_MUL = 4
FUNC_ID_EC_ADD = 5
FUNC_ID_EC_GET_X = 6
FUNC_ID_EC_GET_Y = 7
class FuncFormat:
def __init__(self, func_id, return_type_ids, param_types):
self.func_id = func_id
self.return_type_ids = return_type_ids
self.param_types = param_types
def total_arguments(self):
return len(self.return_type_ids) + len(self.param_types)
function_formats = {
"poseidon_hash": FuncFormat(
# Funcion ID Type ID Parameter types
FUNC_ID_POSEIDON_HASH, [TYPE_ID_BASE], [TYPE_ID_BASE,
TYPE_ID_BASE]
),
"add": FuncFormat(
FUNC_ID_ADD, [TYPE_ID_BASE], [TYPE_ID_BASE,
TYPE_ID_BASE]
),
"constrain_instance": FuncFormat(
FUNC_ID_CONSTRAIN_INSTANCE, [], [TYPE_ID_BASE]
),
"ec_mul_short": FuncFormat(
FUNC_ID_EC_MUL_SHORT, [TYPE_ID_EC_POINT], [TYPE_ID_BASE,
TYPE_ID_EC_FIXED_POINT]
),
"ec_mul": FuncFormat(
FUNC_ID_EC_MUL, [TYPE_ID_EC_POINT], [TYPE_ID_SCALAR,
TYPE_ID_EC_FIXED_POINT]
),
"ec_add": FuncFormat(
FUNC_ID_EC_ADD, [TYPE_ID_EC_POINT], [TYPE_ID_EC_POINT,
TYPE_ID_EC_POINT]
),
"ec_get_x": FuncFormat(
FUNC_ID_EC_GET_X, [TYPE_ID_BASE], [TYPE_ID_EC_POINT]
),
"ec_get_y": FuncFormat(
FUNC_ID_EC_GET_Y, [TYPE_ID_BASE], [TYPE_ID_EC_POINT]
),
}
func_id_to_name = dict((fmt.func_id, key) for key, fmt
in function_formats.items())