diff --git a/script/to_html.sh b/script/to_html.sh
new file mode 100755
index 000000000..0e377a668
--- /dev/null
+++ b/script/to_html.sh
@@ -0,0 +1,5 @@
+#!/bin/bash
+#nvim -c ":TOhtml" $1
+sed -i "s/PreProc { color: #5fd7ff; }/PreProc { color: #8f2722; }/" $1
+sed -i "s/Comment { color: #00ffff; }/Comment { color: #0055ff; }/" $1
+
diff --git a/script/zk.vim b/script/zk.vim
new file mode 100644
index 000000000..4c5b43d59
--- /dev/null
+++ b/script/zk.vim
@@ -0,0 +1,41 @@
+"For autoload, add this to your VIM config:
+" VIM: .vimrc
+" NeoVIM: .config/nvim/init.vim
+"
+"autocmd BufRead *.pism call SetPismOptions()
+"function SetPismOptions()
+" set syntax=pism
+" source /home/narodnik/src/drk/scripts/pism.vim
+"endfunction
+
+if exists('b:current_syntax')
+ finish
+endif
+
+syn keyword drkKeyword constant contract circuit
+"syn keyword drkAttr
+syn keyword drkType EcFixedPoint Base Scalar
+"syn keyword drkFunctionKeyword enforce lc0_add_one lc1_add_one lc2_add_one lc_coeff_reset lc_coeff_double lc0_sub_one lc1_sub_one lc2_sub_one dump_alloc dump_local
+syn match drkFunction "^[ ]*[a-z_0-9]* "
+syn match drkComment "#.*$"
+syn match drkNumber ' \zs\d\+\ze'
+syn match drkHexNumber ' \zs0x[a-z0-9]\+\ze'
+syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
+syn keyword drkBoolVal true false
+syn match drkPreproc "{%.*%}"
+syn match drkPreproc2 "{{.*}}"
+
+hi def link drkKeyword Statement
+"hi def link drkAttr StorageClass
+hi def link drkPreproc PreProc
+hi def link drkPreproc2 PreProc
+hi def link drkType Type
+hi def link drkFunction Function
+hi def link drkFunctionKeyword Function
+hi def link drkComment Comment
+hi def link drkNumber Constant
+hi def link drkHexNumber Constant
+hi def link drkConst Constant
+hi def link drkBoolVal Constant
+
+let b:current_syntax = "pism"
diff --git a/script/zkas.py b/script/zkas.py
new file mode 100644
index 000000000..a729762ad
--- /dev/null
+++ b/script/zkas.py
@@ -0,0 +1,402 @@
+import argparse
+import sys
+
+from zkas.types import *
+
+class CompileException(Exception):
+
+ def __init__(self, error_message, line):
+ super().__init__(error_message)
+ self.error_message = error_message
+ self.line = line
+
+class Constants:
+
+ def __init__(self):
+ self.table = []
+ self.map = {}
+
+ def add(self, variable, type_id):
+ idx = len(self.table)
+ self.table.append(type_id)
+ self.map[variable] = idx
+
+ def lookup(self, variable):
+ idx = self.map[variable]
+ return self.table[idx]
+
+ def variables(self):
+ return self.map.keys()
+
+class SyntaxStruct:
+
+ def __init__(self):
+ self.contracts = {}
+ self.circuits = {}
+ self.constants = Constants()
+
+ def parse_contract(self, line, it):
+ assert line.tokens[0] == "contract"
+ if len(line.tokens) != 3 or line.tokens[2] != "{":
+ raise CompileException("malformed contract opening", line)
+ name = line.tokens[1]
+ if name in self.contracts:
+ raise CompileException(f"duplicate contract {name}", line)
+ lines = []
+
+ while True:
+ try:
+ line = next(it)
+ except StopIteration:
+ raise CompileException(
+ f"premature end of file while parsing {name} contract", line)
+
+ assert len(line.tokens) > 0
+ if line.tokens[0] == "}":
+ break
+
+ lines.append(line)
+
+ self.contracts[name] = lines
+
+ def parse_circuit(self, line, it):
+ assert line.tokens[0] == "circuit"
+ if len(line.tokens) != 3 or line.tokens[2] != "{":
+ raise CompileException("malformed circuit opening", line)
+ name = line.tokens[1]
+ if name in self.circuits:
+ raise CompileException(f"duplicate contract {name}", line)
+ lines = []
+
+ while True:
+ try:
+ line = next(it)
+ except StopIteration:
+ raise CompileException(
+ f"premature end of file while parsing {name} circuit", line)
+
+ assert len(line.tokens) > 0
+ if line.tokens[0] == "}":
+ break
+
+ lines.append(line)
+
+ self.circuits[name] = lines
+
+ def parse_constant(self, line):
+ assert line.tokens[0] == "constant"
+ if len(line.tokens) != 3:
+ raise CompileException("malformed constant line", line)
+ _, type_name, variable = line.tokens
+ if type_name not in allowed_types:
+ raise CompileException("unknown type '{type}'", line)
+ type_id = allowed_types[type_name]
+ self.constants.add(variable, type_id)
+
+ def verify(self):
+ self.static_checks()
+ schema = self.format_data()
+ self.trace_circuits(schema)
+ return schema
+
+ def static_checks(self):
+ for name, lines in self.contracts.items():
+ for line in lines:
+ if len(line.tokens) != 2:
+ raise CompileException("incorrect number of tokens", line)
+ type, variable = line.tokens
+ if type not in allowed_types:
+ raise CompileException(
+ f"unknown type specifier for variable {variable}", line)
+
+ for name, lines in self.circuits.items():
+ for line in lines:
+ assert len(line.tokens) > 0
+ func_name, args = line.tokens[0], line.tokens[1:]
+ if func_name not in function_formats:
+ raise CompileException(f"unknown function call {func_name}",
+ line)
+ func_format = function_formats[func_name]
+ if len(args) != func_format.total_arguments():
+ raise CompileException(
+ f"incorrect number of arguments for function call {func_name}", line)
+
+ # Finally check there are matching circuits and contracts
+ all_names = set(self.circuits.keys()) | set(self.contracts.keys())
+
+ for name in all_names:
+ if name not in self.contracts:
+ raise CompileException(f"missing contract for {name}", None)
+ if name not in self.circuits:
+ raise CompileException(f"missing circuit for {name}", None)
+
+ def format_data(self):
+ schema = []
+ for name, circuit in self.circuits.items():
+ assert name in self.contracts
+ contract = self.contracts[name]
+
+ witness = []
+ for line in contract:
+ assert len(line.tokens) == 2
+ type_name, variable = line.tokens
+ assert type_name in allowed_types
+ type_id = allowed_types[type_name]
+ witness.append((type_id, variable, line))
+
+ code = []
+ for line in circuit:
+ assert len(line.tokens) > 0
+ func_name, args = line.tokens[0], line.tokens[1:]
+ assert func_name in function_formats
+ func_format = function_formats[func_name]
+ assert len(args) == func_format.total_arguments()
+
+ return_values = []
+ if func_format.return_type_ids:
+ rv_len = len(func_format.return_type_ids)
+ return_values, args = args[:rv_len], args[rv_len:]
+
+ func_id = func_format.func_id
+ code.append((func_format, return_values, args, line))
+
+ schema.append((name, witness, code))
+ return schema
+
+ def trace_circuits(self, schema):
+ for name, witness, code in schema:
+ tracer = DynamicTracer(name, witness, code, self.constants)
+ tracer.execute()
+
+class DynamicTracer:
+
+ def __init__(self, name, contract_witness, circuit_code, constants):
+ self.name = name
+ self.witness = contract_witness
+ self.code = circuit_code
+ self.constants = constants
+
+ def execute(self):
+ stack = {}
+
+ # Load constants
+ for variable in self.constants.variables():
+ stack[variable] = self.constants.lookup(variable)
+
+ # Preload stack with our witness values
+ for type_id, variable, line in self.witness:
+ stack[variable] = type_id
+
+ for i, (func_format, return_values, args, code_line) \
+ in enumerate(self.code):
+
+ assert len(args) == len(func_format.param_types)
+ for variable, type_id in zip(args, func_format.param_types):
+ if variable not in stack:
+ raise CompileException(
+ f"variable '{variable}' is not defined", code_line)
+
+ stack_type_id = stack[variable]
+ if stack_type_id != type_id:
+ type_name = type_id_to_name[type]
+ stack_type_name = type_id_to_name[stack_type]
+ raise CompileException(
+ f"variable '{variable}' has incorrect type. "
+ f"Found {type_name} but expected variable of "
+ f"type {stack_type_name}", code_line)
+
+ assert len(return_values) == len(func_format.return_type_ids)
+
+ for return_variable, return_type_id \
+ in zip(return_values, func_format.return_type_ids):
+
+ # Note that later variables shadow earlier ones.
+ # We accept this.
+
+ stack[return_variable] = return_type_id
+
+class CodeLine:
+
+ def __init__(self, func_format, return_values, args, arg_idxs, code_line):
+ self.func_format = func_format
+ self.return_values = return_values
+ self.args = args
+ self.arg_idxs = arg_idxs
+ self.code_line = code_line
+
+ def func_name(self):
+ return func_id_to_name[self.func_format.func_id]
+
+class CompiledContract:
+
+ def __init__(self, name, witness, code):
+ self.name = name
+ self.witness = witness
+ self.code = code
+
+class Compiler:
+
+ def __init__(self, witness, uncompiled_code, constants):
+ self.witness = witness
+ self.uncompiled_code = uncompiled_code
+ self.constants = constants
+
+ def compile(self):
+ code = []
+
+ # Each unique type_id has its own stack
+ stacks = [[] for i in range(TYPE_ID_LAST)]
+ # Map from variable name to stacks above
+ stack_vars = {}
+
+ def alloc(variable, type_id):
+ assert type_id <= len(stacks)
+ idx = len(stacks[type_id])
+ # Add variable to the stack for its type_id
+ stacks[type_id].append(variable)
+ # Create mapping from variable name
+ stack_vars[variable] = (type_id, idx)
+
+ # Load constants
+ for variable in self.constants.variables():
+ type_id = self.constants.lookup(variable)
+ alloc(variable, type_id)
+
+ # Preload stack with our witness values
+ for type_id, variable, line in self.witness:
+ alloc(variable, type_id)
+
+ for i, (func_format, return_values, args, code_line) \
+ in enumerate(self.uncompiled_code):
+
+ assert len(args) == len(func_format.param_types)
+
+ arg_idxs = []
+
+ # Loop through all arguments
+ for variable, type_id in zip(args, func_format.param_types):
+ assert type_id <= len(stacks)
+ assert variable in stack_vars
+ # Find the index for the M by N matrix of our variable
+ loc_type_id, loc_idx = stack_vars[variable]
+ assert type_id == loc_type_id
+ assert stacks[loc_type_id][loc_idx] == variable
+
+ # This is the info to be serialized, not the variable names
+ arg_idxs.append(loc_idx)
+
+ assert len(return_values) == len(func_format.return_type_ids)
+
+ for return_variable, return_type_id \
+ in zip(return_values, func_format.return_type_ids):
+
+ # Allocate returned values so they can be used by
+ # subsequent function calls.
+ alloc(return_variable, return_type_id)
+
+ code.append(CodeLine(func_format, return_values, args,
+ arg_idxs, code_line))
+
+ return code
+
+class Line:
+
+ def __init__(self, tokens, original_line, number):
+ self.tokens = tokens
+ self.orig = original_line
+ self.number = number
+
+ def __repr__(self):
+ return f"Line({self.number}: {str(self.tokens)})"
+
+def load(src_file):
+ source = []
+ for i, original_line in enumerate(src_file):
+ # Remove whitespace on both sides
+ line = original_line.strip()
+ # Strip out comments
+ line = line.split("#")[0]
+ # Split at whitespace
+ line = line.split()
+ if not line:
+ continue
+ line_number = i + 1
+ source.append(Line(line, original_line, line_number))
+ return source
+
+def parse(source):
+ syntax = SyntaxStruct()
+ it = iter(source)
+ while True:
+ try:
+ line = next(it)
+ except StopIteration:
+ break
+
+ assert len(line.tokens) > 0
+ if line.tokens[0] == "contract":
+ syntax.parse_contract(line, it)
+ elif line.tokens[0] == "circuit":
+ syntax.parse_circuit(line, it)
+ elif line.tokens[0] == "constant":
+ syntax.parse_constant(line)
+ elif line.tokens[0] == "}":
+ raise CompileException("unmatched delimiter '}'", line)
+ return syntax
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("SOURCE", help="ZK script to compile")
+ parser.add_argument("--output", default=None, help="output file")
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument('--display', action='store_true',
+ help="show the compiled code in human readable format")
+ group.add_argument('--bincode', action='store_true',
+ help="output compiled code to zkvm supervisor")
+ args = parser.parse_args()
+
+ with open(args.SOURCE, "r") as src_file:
+ source = load(src_file)
+ try:
+ syntax = parse(source)
+ schema = syntax.verify()
+ contracts = []
+ for name, witness, uncompiled_code in schema:
+ compiler = Compiler(witness, uncompiled_code, syntax.constants)
+ code = compiler.compile()
+ contracts.append(CompiledContract(name, witness, code))
+ constants = syntax.constants
+ if args.display:
+ from zkas.text_output import output
+ if args.output is None:
+ output(sys.stdout, contracts, constants)
+ else:
+ with open(outpath, "w") as file:
+ output(file, contracts, constants)
+ elif args.bincode:
+ from zkas.bincode_output import output
+ outpath = args.output
+ if args.output is None:
+ outpath = args.SOURCE + ".bin"
+ with open(outpath, "wb") as file:
+ output(file, contracts, constants)
+ else:
+ from zkas.text_output import output
+ if args.output is None:
+ output(sys.stdout, contracts, constants)
+ else:
+ with open(outpath, "w") as file:
+ output(file, contracts, constants)
+ except CompileException as ex:
+ print(f"Error: {ex.error_message}", file=sys.stderr)
+ if ex.line is not None:
+ print(f"Line {ex.line.number}: {ex.line.orig}", file=sys.stderr)
+ #return -1
+ raise
+ return 0
+
+if __name__ == "__main__":
+ sys.exit(main())
+
+# todo: think about extendable payment scheme which
+# is like bitcoin soft forks
diff --git a/script/zkas/__init__.py b/script/zkas/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/script/zkas/bincode_output.py b/script/zkas/bincode_output.py
new file mode 100644
index 000000000..4af5873c3
--- /dev/null
+++ b/script/zkas/bincode_output.py
@@ -0,0 +1,50 @@
+import struct
+
+def varuint(value):
+ if value <= 0xfc:
+ return struct.pack("