readded scripts directory

This commit is contained in:
lunar-mining
2021-09-16 11:59:41 +02:00
parent 19c74bfc4f
commit e5a7ad8d1c
71 changed files with 8107 additions and 0 deletions

131
scripts/codegen.py Normal file
View File

@@ -0,0 +1,131 @@
# Functions here are called from pism.py using getattr()
# and the function name as a string.
def witness(line, out, point):
return \
r"""let %s = ecc::EdwardsPoint::witness(
cs.namespace(|| "%s"),
%s.map(jubjub::ExtendedPoint::from))?;""" % (out, line, point)
def assert_not_small_order(line, point):
return '%s.assert_not_small_order(cs.namespace(|| "%s"))?;' % (point, line)
def u64_as_binary_le(line, out, val):
return \
r"""let %s = boolean::u64_into_boolean_vec_le(
cs.namespace(|| "%s"),
%s,
)?;""" % (out, line, val)
def fr_as_binary_le(line, out, fr):
return \
r"""let %s = boolean::field_into_boolean_vec_le(
cs.namespace(|| "%s"), %s)?;""" % (out, line, fr)
def ec_mul_const(line, out, fr, base):
return \
r"""let %s = ecc::fixed_base_multiplication(
cs.namespace(|| "%s"),
&%s,
&%s,
)?;""" % (out, line, base, fr)
def ec_mul(line, out, fr, base):
return 'let %s = %s.mul(cs.namespace(|| "%s"), &%s)?;' % (
out, base, line, fr)
def ec_add(line, out, a, b):
return 'let %s = %s.add(cs.namespace(|| "%s"), &%s)?;' % (out, a, line, b)
def ec_repr(line, out, point):
return 'let %s = %s.repr(cs.namespace(|| "%s"))?;' % (out, point, line)
def ec_get_u(line, out, point):
return "let mut %s = %s.get_u().clone();" % (out, point)
def emit_ec(line, point):
return '%s.inputize(cs.namespace(|| "%s"))?;' % (point, line)
def alloc_binary(line, out):
return "let mut %s = vec![];" % out
def binary_clone(line, out, binary):
return "let mut %s: Vec<_> = %s.iter().cloned().collect();" % (out, binary)
def binary_extend(line, binary, value):
return "%s.extend(%s);" % (binary, value)
def binary_push(line, binary, bit):
return "%s.push(%s);" % (binary, bit)
def binary_truncate(line, binary, size):
return "%s.truncate(%s);" % (binary, size)
def static_assert_binary_size(line, binary, size):
return "assert_eq!(%s.len(), %s);" % (binary, size)
def blake2s(line, out, input, personalization):
return \
r"""let mut %s = blake2s::blake2s(
cs.namespace(|| "%s"),
&%s,
%s,
)?;""" % (out, line, input, personalization)
def pedersen_hash(line, out, input, personalization):
return \
r"""let mut %s = pedersen_hash::pedersen_hash(
cs.namespace(|| "%s"),
%s,
&%s,
)?;""" % (out, line, personalization, input)
def emit_binary(line, binary):
return 'multipack::pack_into_inputs(cs.namespace(|| "%s"), &%s)?;' % (
line, binary)
def alloc_bit(line, out, value):
return \
r"""let %s = boolean::Boolean::from(boolean::AllocatedBit::alloc(
cs.namespace(|| "%s"),
%s
)?);""" % (out, line, value)
def alloc_const_bit(line, out, value):
return "let %s = Boolean::constant(%s);" % (out, value)
def clone_bit(line, out, value):
return "let %s = %s.clone();" % (out, value)
def alloc_scalar(line, out, scalar):
return \
r"""let %s =
num::AllocatedNum::alloc(cs.namespace(|| "%s"), || Ok(*%s.get()?))?;""" % (
out, line, scalar)
def scalar_as_binary(line, out, scalar):
return 'let %s = %s.to_bits_le(cs.namespace(|| "%s"))?;' % (out, scalar,
line)
def emit_scalar(line, scalar):
return '%s.inputize(cs.namespace(|| "%s"))?;' % (scalar, line)
def scalar_enforce_equal(line, scalar_left, scalar_right):
return \
r"""cs.enforce(
|| "%s",
|lc| lc + %s.get_variable(),
|lc| lc + CS::one(),
|lc| lc + %s.get_variable(),
);""" % (line, scalar_left, scalar_right)
def conditionally_reverse(line, out_left, out_right, in_left, in_right,
condition):
return \
r"""let (%s, %s) = num::AllocatedNum::conditionally_reverse(
cs.namespace(|| "%s"),
&%s,
&%s,
&%s,
)?;""" % (out_left, out_right, line, in_left, in_right, condition)

460
scripts/compile.py Normal file
View File

@@ -0,0 +1,460 @@
import argparse
import sys
from enum import Enum
alloc_commands = {
"param": 1,
"private": 1,
"public": 1,
}
op_commands = {
"set": 2,
"mul": 2,
"add": 2,
"sub": 2,
"divide": 2,
"double": 1,
"square": 1,
"invert": 1,
"unpack_bits": 3,
"load": 2,
"local": 1,
"debug": 1,
"dump_alloc": 0,
"dump_local": 0,
}
constraint_commands = {
"lc0_add": 1,
"lc1_add": 1,
"lc2_add": 1,
"lc0_sub": 1,
"lc1_sub": 1,
"lc2_sub": 1,
"lc0_add_one": 0,
"lc1_add_one": 0,
"lc2_add_one": 0,
"lc0_sub_one": 0,
"lc1_sub_one": 0,
"lc2_sub_one": 0,
"lc0_add_coeff": 2,
"lc1_add_coeff": 2,
"lc2_add_coeff": 2,
"lc0_add_constant": 1,
"lc1_add_constant": 1,
"lc2_add_constant": 1,
"enforce": 0,
"lc_coeff_reset": 0,
"lc_coeff_double": 0,
}
def eprint(*args):
print(*args, file=sys.stderr)
class Line:
def __init__(self, text, line_number):
self.text = text
self.orig = text
self.lineno = line_number
self.clean()
def clean(self):
# Remove the comments
self.text = self.text.split("#", 1)[0]
# Remove whitespace
self.text = self.text.strip()
def is_empty(self):
return bool(self.text)
def __repr__(self):
return "Line %s: %s" % (self.lineno, self.orig.lstrip())
def command(self):
if not self.is_empty():
return None
return self.text.split(" ")[0]
def args(self):
if not self.is_empty():
return None
return self.text.split()[1:]
def clean(contents):
# Split input into lines
contents = contents.split("\n")
contents = [Line(line, i + 1) for i, line in enumerate(contents)]
# Remove empty blank lines
contents = [line for line in contents if line.is_empty()]
return contents
def divide_sections(contents):
state = "NOSCOPE"
segments = {}
current_segment = []
contract_name = None
for line in contents:
if line.command() == "contract":
if len(line.args()) != 1:
eprint("error: missing contract name")
eprint(line)
return None
contract_name = line.args()[0]
if state == "NOSCOPE":
assert not current_segment
state = "INSCOPE"
continue
else:
assert state == "INSCOPE"
eprint("error: double contract entry violation")
eprint(line)
return None
elif line.command() == "end":
if len(line.args()) != 0:
eprint("error: end takes no args")
eprint(line)
return None
if state == "NOSCOPE":
eprint("error: missing contract start for end")
eprint(line)
return None
else:
assert state == "INSCOPE"
state = "NOSCOPE"
segments[contract_name] = current_segment
current_segment = []
continue
elif state == "NOSCOPE":
# Ignore lines outside any contract
continue
current_segment.append(line)
if state != "NOSCOPE":
eprint("error: reached end of file with unclosed scope")
return None
return segments
def extract_relevant_lines(contract, commands_table):
relevant_lines = []
for line in contract:
command = line.command()
if command not in commands_table.keys():
continue
define = commands_table[command]
if len(line.args()) != define:
eprint("error: wrong number of args")
return None
relevant_lines.append(line)
return relevant_lines
class VariableType(Enum):
PUBLIC = 1
PRIVATE = 2
class Variable:
def __init__(self, symbol, index, type, is_param):
self.symbol = symbol
self.index = index
self.type = type
self.is_param = is_param
def __repr__(self):
return "<Variable %s:%s>" % (self.symbol, self.index)
def generate_alloc_table(contract):
relevant_lines = extract_relevant_lines(contract, alloc_commands)
alloc_table = {}
for i, line in enumerate(relevant_lines):
assert len(line.args()) == 1
symbol = line.args()[0]
command = line.command()
if command == "param":
type = VariableType.PRIVATE
is_param = True
elif command == "private":
type = VariableType.PRIVATE
is_param = False
elif command == "public":
type = VariableType.PUBLIC
is_param = False
else:
assert False
if symbol in alloc_table:
eprint("error: duplicate symbol '%s'" % symbol)
eprint(line)
return None
alloc_table[symbol] = Variable(symbol, i, type, is_param)
return alloc_table
class Operation:
def __init__(self, line, indexes):
self.command = line.command()
self.args = indexes
self.line = line
class VariableRefType(Enum):
AUX = 1
LOCAL = 2
CONST = 3
class VariableRef:
def __init__(self, type, index):
self.type = type
self.index = index
def __repr__(self):
return "%s(%s)" % (self.type.name, self.index)
def symbols_list_to_refs(line, alloc, local_vars, constants):
indexes = []
for symbol in line.args():
if symbol in alloc:
# Lookup variable index
index = alloc[symbol].index
index = VariableRef(VariableRefType.AUX, index)
elif symbol in local_vars:
index = local_vars[symbol]
index = VariableRef(VariableRefType.LOCAL, index)
elif symbol in constants:
index = constants[symbol][0]
index = VariableRef(VariableRefType.CONST, index)
else:
eprint("error: missing unallocated symbol '%s'" % symbol)
eprint(line)
return None
indexes.append(index)
return indexes
def generate_ops_table(contract, alloc, constants):
relevant_lines = extract_relevant_lines(contract, op_commands)
ops = []
local_vars = {}
for line in relevant_lines:
# This is a special case which creates a new local stack value
if line.command() == "local":
assert len(line.args()) == 1
symbol = line.args()[0]
local_vars[symbol] = len(local_vars)
indexes = []
else:
if (indexes := symbols_list_to_refs(line, alloc,
local_vars, constants)) is None:
return None
# Handle this here directly since only the
# load command deals with constants
if line.command() == "load":
assert len(indexes) == 2
# This is the only command which uses consts
if indexes[1].type != VariableRefType.CONST:
eprint("error: load command takes a const argument")
eprint(line)
return None
elif any(index.type == VariableRefType.CONST for index in indexes):
eprint("error: invalid const arg")
eprint(line)
return None
ops.append(Operation(line, indexes))
return ops
class Constraint:
def __init__(self, line, lcargs):
self.command = line.command()
self.args = lcargs
self.line = line
def args_comment(self):
return ", ".join("%s" % symbol for symbol in self.line.args())
def symbols_list_to_lcargs(line, alloc, constants):
lcargs = []
for symbol in line.args():
if symbol in alloc:
# Lookup variable index
index = alloc[symbol].index
lcargs.append(index)
elif symbol in constants:
value = constants[symbol]
lcargs.append(value)
else:
eprint("error: missing unallocated symbol '%s'" % symbol)
eprint(line)
return None
return lcargs
def generate_constraints_table(contract, alloc, constants):
relevant_lines = extract_relevant_lines(contract, constraint_commands)
constraints = []
for line in relevant_lines:
if (lcargs := symbols_list_to_lcargs(line, alloc, constants)) is None:
return None
constraints.append(Constraint(line, lcargs))
return constraints
class Contract:
def __init__(self, constants, alloc, ops, constraints):
self.constants = constants
self.alloc = alloc
self.ops = ops
self.constraints = constraints
def __repr__(self):
repr_str = ""
repr_str += "Constants:\n"
for symbol, value in self.constants.items():
repr_str += " // %s\n" % symbol
repr_str += " %s: %s\n" % value
repr_str += "Alloc table:\n"
for symbol, variable in self.alloc.items():
repr_str += " // %s\n" % symbol
repr_str += " %s %s\n" % (variable.type, variable.index)
repr_str += "Operations:\n"
for op in self.ops:
repr_str += " // %s\n" % op.line
repr_str += " %s %s\n" % (op.command, op.args)
repr_str += "Constraints:\n"
for constraint in self.constraints:
if constraint.args:
repr_str += " // %s\n" % constraint.args_comment()
repr_str += " %s %s\n" % (constraint.command, constraint.args)
repr_str += "Stats:\n"
repr_str += " Constants: %s\n" % len(self.constants)
repr_str += " Alloc: %s\n" % len(self.alloc)
repr_str += " Operations: %s\n" % len(self.ops)
repr_str += " Constraint Instructions: %s\n" % len(self.constraints)
return repr_str
def compile(contract, constants):
# Allocation table
# symbol: Private/Public, is_param, index
if (alloc := generate_alloc_table(contract)) is None:
return None
# Operations lines list
if (ops := generate_ops_table(contract, alloc, constants)) is None:
return None
# Constraint commands
if (constraints := generate_constraints_table(
contract, alloc, constants)) is None:
return None
return Contract(constants, alloc, ops, constraints)
def parse_constants(contents):
relevant_lines = [line for line in contents if line.command() == "constant"]
constants = {}
for line in relevant_lines:
assert line.command() == "constant"
if len(line.args()) != 2:
eprint("error: wrong number of args for constant")
eprint(line)
return None
symbol, value = line.args()
try:
int(value, 16)
except ValueError:
eprint("error: invalid constant value for '%s'" % symbol)
eprint(line)
return None
if len(value) != 32*2 + 2 or value[:2] != "0x":
eprint("error: invalid hex value for constant")
eprint(line)
return None
# Remove 0x prefix
value = value[2:]
constants[symbol] = (len(constants), value)
return constants
def process(contents):
# Remove left whitespace
contents = clean(contents)
# Parse all constants
if (constants := parse_constants(contents)) is None:
return None
# Divide into contract sections
if (pre_contracts := divide_sections(contents)) is None:
return None
# Process each contract
contracts = {}
for contract_name, pre_contract in pre_contracts.items():
if (contract := compile(pre_contract, constants)) is None:
return None
contracts[contract_name] = contract
return contracts
def main(argv):
parser = argparse.ArgumentParser()
parser.add_argument("filename", help="VM PISM file: proofs/vm.pism")
parser.add_argument("--output", type=argparse.FileType('wb', 0),
default=sys.stdout.buffer, help="Output file")
group = parser.add_mutually_exclusive_group()
group.add_argument('--display', action='store_true',
help="show the compiled code in human readable format")
group.add_argument('--rust', action='store_true',
help="output compiled code to rust for testing")
group.add_argument('--supervisor', action='store_true',
help="output compiled code to zkvm supervisor")
args = parser.parse_args()
src_filename = args.filename
contents = open(src_filename).read()
if (contracts := process(contents)) is None:
return -2
def default_display():
for contract_name, contract in contracts.items():
print("Contract:", contract_name)
print(contract)
if args.display:
default_display()
elif args.rust:
import compile_export_rust
for contract_name, contract in contracts.items():
compile_export_rust.display(contract)
elif args.supervisor:
import compile_export_supervisor
for contract_name, contract in contracts.items():
compile_export_supervisor.export(args.output, contract_name,
contract)
else:
default_display()
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv))

View File

@@ -0,0 +1,122 @@
from compile import VariableType, VariableRefType
def to_initial_caps(snake_str):
components = snake_str.split("_")
return "".join(x.title() for x in components)
def display(contract):
indent = " " * 4
print(r"""use super::vm::{ZkVirtualMachine, CryptoOperation, AllocType, ConstraintInstruction, VariableIndex, VariableRef};
use bls12_381::Scalar;
pub fn load_params(params: Vec<Scalar>) -> Vec<(VariableIndex, Scalar)> {""")
params = [(symbol, var) for symbol, var in contract.alloc.items() if var.is_param]
print("%sassert_eq!(params.len(), %s);" % (indent, len(params)))
print("%slet mut result = vec![(0, Scalar::zero()); %s];" % (
indent, len(params)))
for i, (symbol, variable) in enumerate(params):
assert variable.is_param
print("%s// %s" % (indent, symbol))
print("%sresult[%s] = (%s, params[%s]);" % (
indent, i, variable.index, i))
print("%sresult" % indent)
print("}\n")
print(r"""pub fn load_zkvm() -> ZkVirtualMachine {
ZkVirtualMachine {
constants: vec![""")
constants = list(contract.constants.items())
constants.sort(key=lambda obj: obj[1][0])
constants = [(obj[0], obj[1][1]) for obj in constants]
for symbol, value in constants:
print("%s// %s" % (indent * 3, symbol))
assert len(value) == 32*2
chunk_str = lambda line, n: \
[line[i:i + n] for i in range(0, len(line), n)]
chunks = chunk_str(value, 2)
# Reverse the endianness
# We allow literal numbers but rust wants little endian
chunks = chunks[::-1]
print("%sScalar::from_bytes(&[" % (indent * 3))
for i in range(0, 32, 4):
print("%s0x%s, 0x%s, 0x%s, 0x%s," % (indent * 4,
chunks[i], chunks[i + 1], chunks[i + 2], chunks[i + 3]))
print("%s]).unwrap()," % (indent * 3))
print("%s]," % (indent * 2))
print("%salloc: vec![" % (indent * 2))
for symbol, variable in contract.alloc.items():
print("%s// %s" % (indent * 3, symbol))
if variable.type.name == VariableType.PRIVATE.name:
typestring = "Private"
elif variable.type.name == VariableType.PUBLIC.name:
typestring = "Public"
else:
assert False
print("%s(AllocType::%s, %s)," % (indent * 3, typestring,
variable.index))
print("%s]," % (indent * 2))
print("%sops: vec![" % (indent * 2))
def var_ref_str(var_ref):
if var_ref.type.name == VariableRefType.AUX.name:
return "VariableRef::Aux(%s)" % var_ref.index
elif var_ref.type.name == VariableRefType.LOCAL.name:
return "VariableRef::Local(%s)" % var_ref.index
else:
assert False
for op in contract.ops:
print("%s// %s" % (indent * 3, op.line))
args_part = ""
if op.command == "load":
assert len(op.args) == 2
args_part = "(%s, %s)" % (var_ref_str(op.args[0]), op.args[1].index)
elif op.command == "debug":
assert len(op.args) == 1
args_part = '(String::from("%s"), %s)' % (
op.line, var_ref_str(op.args[0]))
elif op.args:
args_part = ", ".join(var_ref_str(var_ref) for var_ref in op.args)
args_part = "(%s)" % args_part
print("%sCryptoOperation::%s%s," % (
indent * 3,
to_initial_caps(op.command),
args_part
))
print("%s]," % (indent * 2))
print("%sconstraints: vec![" % (indent * 2))
for constraint in contract.constraints:
args_part = ""
if constraint.args:
print("%s// %s" % (indent *3, constraint.args_comment()))
args = constraint.args[:]
if (constraint.command == "lc0_add_coeff" or
constraint.command == "lc1_add_coeff" or
constraint.command == "lc2_add_coeff" or
constraint.command == "lc0_add_one_coeff" or
constraint.command == "lc1_add_one_coeff" or
constraint.command == "lc2_add_one_coeff"):
args[0] = args[0][0]
args_part = ", ".join(str(index) for index in args)
args_part = "(%s)" % args_part
print("%sConstraintInstruction::%s%s," % (
indent * 3,
to_initial_caps(constraint.command),
args_part
))
print(r""" ],
aux: vec![],
params: None,
verifying_key: None,
}
}""")

View File

@@ -0,0 +1,193 @@
import struct
from compile import VariableType, VariableRefType
class Operation:
def __init__(self, ident, args):
self.ident = ident
self.args = args
class ArgVarRef:
def __init__(self, type, index):
self.type = type
self.index = index
def bytes(self):
return struct.pack("<BI", self.type, self.index)
class ArgVarIndex:
def __init__(self, _, index):
self.index = index
def bytes(self):
return struct.pack("<I", self.index)
class ArgString:
def __init__(self, description, index):
self.description = description
self.index = index
ops_table = {
"set": Operation(0, [ArgVarRef, ArgVarRef]),
"mul": Operation(1, [ArgVarRef, ArgVarRef]),
"add": Operation(2, [ArgVarRef, ArgVarRef]),
"sub": Operation(3, [ArgVarRef, ArgVarRef]),
"divide": Operation(4, [ArgVarRef, ArgVarRef]),
"double": Operation(5, [ArgVarRef]),
"square": Operation(6, [ArgVarRef]),
"invert": Operation(7, [ArgVarRef]),
"unpack_bits": Operation(8, [ArgVarRef, ArgVarRef, ArgVarRef]),
"local": Operation(9, []),
"load": Operation(10, [ArgVarRef, ArgVarIndex]),
"debug": Operation(11, [ArgString, ArgVarRef]),
"dump_alloc": Operation(12, []),
"dump_local": Operation(13, []),
}
constraint_ident_map = {
"lc0_add": 0,
"lc1_add": 1,
"lc2_add": 2,
"lc0_sub": 3,
"lc1_sub": 4,
"lc2_sub": 5,
"lc0_add_one": 6,
"lc1_add_one": 7,
"lc2_add_one": 8,
"lc0_sub_one": 9,
"lc1_sub_one": 10,
"lc2_sub_one": 11,
"lc0_add_coeff": 12,
"lc1_add_coeff": 13,
"lc2_add_coeff": 14,
"lc0_add_constant": 15,
"lc1_add_constant": 16,
"lc2_add_constant": 17,
"enforce": 18,
"lc_coeff_reset": 19,
"lc_coeff_double": 20,
}
def varuint(value):
if value <= 0xfc:
return struct.pack("<B", value)
elif value <= 0xffff:
return struct.pack("<BH", 0xfd, value)
elif value <= 0xffffffff:
return struct.pack("<BI", 0xfe, value)
else:
return struct.pack("<BQ", 0xff, value)
def export(output, contract_name, contract):
output.write(varuint(len(contract_name)))
output.write(contract_name.encode())
constants = list(contract.constants.items())
constants.sort(key=lambda obj: obj[1][0])
constants = [(obj[0], obj[1][1]) for obj in constants]
# Constants
output.write(varuint(len(constants)))
for symbol, value in constants:
print("Constant '%s' = %s" % (symbol, value))
# Bellman uses little endian for Scalars from_bytes function
const_bytes = bytearray.fromhex(value)[::-1]
assert len(const_bytes) == 32
output.write(const_bytes)
# Alloc
output.write(varuint(len(contract.alloc)))
for symbol, variable in contract.alloc.items():
print("Alloc '%s' = (%s, %s)" % (symbol,
variable.type.name, variable.index))
if variable.type.name == VariableType.PRIVATE.name:
typeval = 0
elif variable.type.name == VariableType.PUBLIC.name:
typeval = 1
else:
assert False
alloc_bytes = struct.pack("<BI", typeval, variable.index)
assert len(alloc_bytes) == 5
output.write(alloc_bytes)
# Ops
output.write(varuint(len(contract.ops)))
for op in contract.ops:
op_form = ops_table[op.command]
output.write(struct.pack("B", op_form.ident))
if op.command == "debug":
# Special case
assert len(op.args) == 1
line_str = str(op.line).encode()
output.write(varuint(len(line_str)))
output.write(line_str)
op_arg = op.args[0]
if op_arg.type.name == VariableRefType.AUX.name:
arg_type = 0
elif op_arg.type.name == VariableRefType.LOCAL.name:
arg_type = 1
arg = ArgVarRef(arg_type, op_arg.index)
output.write(arg.bytes())
continue
assert len(op_form.args) == len(op.args)
for arg_form, op_arg in zip(op_form.args, op.args):
if op_arg.type.name == VariableRefType.AUX.name:
arg_type = 0
elif op_arg.type.name == VariableRefType.LOCAL.name:
arg_type = 1
arg = arg_form(arg_type, op_arg.index)
output.write(arg.bytes())
print("#", op.line)
print("Operation", op.command,
[(arg.type.name, arg.index) for arg in op.args])
# Constraints
output.write(varuint(len(contract.constraints)))
for constraint in contract.constraints:
args = constraint.args[:]
if (constraint.command == "lc0_add_coeff" or
constraint.command == "lc1_add_coeff" or
constraint.command == "lc2_add_coeff" or
constraint.command == "lc0_add_constant" or
constraint.command == "lc1_add_constant" or
constraint.command == "lc2_add_constant"):
args[0] = args[0][0]
print("#", constraint.line)
print("Constraint", constraint.command, args if args else "")
enum_ident = constraint_ident_map[constraint.command]
output.write(struct.pack("B", enum_ident))
for arg in args:
output.write(struct.pack("<I", arg))
# Params Map
param_alloc = [(symbol, variable) for (symbol, variable)
in contract.alloc.items() if variable.is_param]
output.write(varuint(len(param_alloc)))
for symbol, variable in param_alloc:
assert variable.is_param
print("Parameter '%s' = %s" % (symbol, variable.index))
symbol = symbol.encode()
output.write(varuint(len(symbol)))
output.write(symbol)
output.write(struct.pack("<I", variable.index))
# Public Map
public_alloc = [(symbol, variable) for (symbol, variable)
in contract.alloc.items()
if variable.type.name == VariableType.PUBLIC.name]
output.write(varuint(len(public_alloc)))
for symbol, variable in public_alloc:
assert not variable.is_param
assert variable.type.name == VariableType.PUBLIC.name
print("Public '%s' = %s" % (symbol, variable.index))
symbol = symbol.encode()
output.write(varuint(len(symbol)))
output.write(symbol)
output.write(struct.pack("<I", variable.index))

21
scripts/drk.vim Normal file
View File

@@ -0,0 +1,21 @@
if exists('b:current_syntax')
finish
endif
syn keyword drkKeyword assert enforce for in def return const as let emit contract private proof
syn keyword drkAttr mut
syn keyword drkType BinaryNumber Point Fr SubgroupPoint EdwardsPoint Scalar EncryptedNum list Bool U64 Num Binary
syn match drkFunction "\zs[a-zA-Z0-9_]*\ze("
syn match drkComment "#.*$"
syn match drkNumber '\d\+'
syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
hi def link drkKeyword Statement
hi def link drkAttr StorageClass
hi def link drkType Type
hi def link drkFunction Function
hi def link drkComment Comment
hi def link drkNumber Constant
hi def link drkConst Constant
let b:current_syntax = "drk"

View File

@@ -0,0 +1,40 @@
from finite_fields import finitefield
def add(x_1, y_1, x_2, y_2):
if (x_1, y_1) == (x_2, y_2):
if y_1 == 0:
return None
# slope of the tangent line
m = (3 * x_1 * x_1 + a) / (2 * y_1)
return None
else:
if x_1 == x_2:
return None
# slope of the secant line
m = (y_2 - y_1) / (x_2 - x_1)
x_3 = m*m - x_1 - x_2
y_3 = m*(x_1 - x_3) - y_1
return (x_3, y_3)
if __name__ == "__main__":
# Vesta
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
fq = finitefield.IntegersModP(q)
a, b = fq(0x00), fq(0x05)
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
C = (fq(0x1ca18c7c3fcb110f9e92c694ce552238f95e9f9b911599cedaff6018cfc5ed52), fq(0x3ad6133a791e41f3e062d370b40e97e77d20effc00b7ee88c4bb097d245cb438))
D = (fq(0x3e544e611bb895166afe1a46c6e551c47968daf962d824f79f795cb53585b098), fq(0x2fd03c4da47baf2dfd251e85d18864d4885ddd0e8df648550565b850b79349e3))
C_plus_D = (fq(0x06f822cbde350215558c46aac9e60eee31afd942ca6da568845ca4f8fe911e17), fq(0x3e294e73970abc197dfff1a14e74cb20c11b81422d9f920c7b0b0c63affdf67b))
result = add(C[0], C[1], D[0], D[1])
print(result)
print(list("%x" % x.n for x in result))
assert result[0] == C_plus_D[0]
assert result[1] == C_plus_D[1]

View File

@@ -0,0 +1 @@
../finite_fields/

View File

@@ -0,0 +1,4 @@
finite-fields
=============
Python code and tests for the post ["Programming with Finite Fields"](http://jeremykun.com/2014/03/13/programming-with-finite-fields/)

View File

@@ -0,0 +1,45 @@
from test import test
from euclidean import *
test(1, gcd(7, 9))
test(2, gcd(8, 18))
test(-12, gcd(-12, 24))
test(12, gcd(12, -24)) # gcd is only unique up to multiplication by a unit, and so sometimes we'll get negatives.
test(38, gcd(4864, 3458))
test((32, -45, 38), extendedEuclideanAlgorithm(4864, 3458))
test((-45, 32, 38), extendedEuclideanAlgorithm(3458, 4864))
from modp import *
Mod2 = IntegersModP(2)
test(Mod2(1), gcd(Mod2(1), Mod2(0)))
test(Mod2(1), gcd(Mod2(1), Mod2(1)))
test(Mod2(0), gcd(Mod2(2), Mod2(2)))
Mod7 = IntegersModP(7)
test(Mod7(6), gcd(Mod7(6), Mod7(14)))
test(Mod7(2), gcd(Mod7(6), Mod7(9)))
ModHuge = IntegersModP(9923)
test(ModHuge(38), gcd(ModHuge(4864), ModHuge(3458)))
test((ModHuge(32), ModHuge(-45), ModHuge(38)),
extendedEuclideanAlgorithm(ModHuge(4864), ModHuge(3458)))
from polynomial import *
p = polynomialsOver(Mod7).factory
test(p([-1, 1]), gcd(p([-1,0,1]), p([-1,0,0,1])))
f = p([-1,0,1])
g = p([-1,0,0,1])
test((p([0,-1]), p([1]), p([-1, 1])), extendedEuclideanAlgorithm(f, g))
test(p([-1,1]), f * p([0,-1]) + g * p([1]))
p = polynomialsOver(Mod2).factory
f = p([1,0,0,0,1,1,1,0,1,1,1]) # x^10 + x^9 + x^8 + x^6 + x^5 + x^4 + 1
g = p([1,0,1,1,0,1,1,0,0,1]) # x^9 + x^6 + x^5 + x^3 + x^1 + 1
theGcd = p([1,1,0,1]) # x^3 + x + 1
x = p([0,0,0,0,1]) # x^4
y = p([1,1,1,1,1,1]) # x^5 + x^4 + x^3 + x^2 + x + 1
test((x, y, theGcd), extendedEuclideanAlgorithm(f, g))

View File

@@ -0,0 +1,34 @@
# a general Euclidean algorithm for any number type with
# a divmod and a valuation abs() whose minimum value is zero
def gcd(a, b):
if abs(a) < abs(b):
return gcd(b, a)
while abs(b) > 0:
_,r = divmod(a,b)
a,b = b,r
return a
# extendedEuclideanAlgorithm: int, int -> int, int, int
# input (a,b) and output three numbers x,y,d such that ax + by = d = gcd(a,b).
# Works for any number type with a divmod and a valuation abs()
# whose minimum value is zero
def extendedEuclideanAlgorithm(a, b):
if abs(b) > abs(a):
(x,y,d) = extendedEuclideanAlgorithm(b, a)
return (y,x,d)
if abs(b) == 0:
return (1, 0, a)
x1, x2, y1, y2 = 0, 1, 1, 0
while abs(b) > 0:
q, r = divmod(a,b)
x = x2 - q*x1
y = y2 - q*y1
a, b, x2, x1, y2, y1 = b, r, x1, x, y1, y
return (x2, y2, a)

View File

@@ -0,0 +1,28 @@
from test import test
from finitefield import *
from polynomial import *
from modp import *
def p(L, q):
f = IntegersModP(q)
Polynomial = polynomialsOver(f).factory
return Polynomial(L)
test(True, isIrreducible(p([0,1], 2), 2))
test(False, isIrreducible(p([1,0,1], 2), 2))
test(True, isIrreducible(p([1,0,1], 3), 3))
test(False, isIrreducible(p([1,0,0,1], 5), 5))
test(False, isIrreducible(p([1,0,0,1], 7), 7))
test(False, isIrreducible(p([1,0,0,1], 11), 11))
test(True, isIrreducible(p([-2, 0, 1], 13), 13))
Z5 = IntegersModP(5)
Poly = polynomialsOver(Z5).factory
f = Poly([3,0,1])
F25 = FiniteField(5, 2, polynomialModulus=f)
x = F25([2,1])
test(Poly([1,2]), x.inverse())

View File

@@ -0,0 +1,128 @@
import random
from .polynomial import polynomialsOver
from .modp import *
# isIrreducible: Polynomial, int -> bool
# determine if the given monic polynomial with coefficients in Z/p is
# irreducible over Z/p where p is the given integer
# Algorithm 4.69 in the Handbook of Applied Cryptography
def isIrreducible(polynomial, p):
ZmodP = IntegersModP(p)
if polynomial.field is not ZmodP:
raise TypeError("Given a polynomial that's not over %s, but instead %r" %
(ZmodP.__name__, polynomial.field.__name__))
poly = polynomialsOver(ZmodP).factory
x = poly([0,1])
powerTerm = x
isUnit = lambda p: p.degree() == 0
for _ in range(int(polynomial.degree() / 2)):
powerTerm = powerTerm.powmod(p, polynomial)
gcdOverZmodp = gcd(polynomial, powerTerm - x)
if not isUnit(gcdOverZmodp):
return False
return True
# generateIrreduciblePolynomial: int, int -> Polynomial
# generate a random irreducible polynomial of a given degree over Z/p, where p
# is given by the integer 'modulus'. This algorithm is expected to terminate
# after 'degree' many irreducilibity tests. By Chernoff bounds the probability
# it deviates from this by very much is exponentially small.
def generateIrreduciblePolynomial(modulus, degree):
Zp = IntegersModP(modulus)
Polynomial = polynomialsOver(Zp)
while True:
coefficients = [Zp(random.randint(0, modulus-1)) for _ in range(degree)]
randomMonicPolynomial = Polynomial(coefficients + [Zp(1)])
print(randomMonicPolynomial)
if isIrreducible(randomMonicPolynomial, modulus):
return randomMonicPolynomial
# create a type constructor for the finite field of order p^m for p prime, m >= 1
@memoize
def FiniteField(p, m, polynomialModulus=None):
Zp = IntegersModP(p)
if m == 1:
return Zp
Polynomial = polynomialsOver(Zp)
if polynomialModulus is None:
polynomialModulus = generateIrreduciblePolynomial(modulus=p, degree=m)
class Fq(FieldElement):
fieldSize = int(p ** m)
primeSubfield = Zp
idealGenerator = polynomialModulus
operatorPrecedence = 3
def __init__(self, poly):
if type(poly) is Fq:
self.poly = poly.poly
elif type(poly) is int or type(poly) is Zp:
self.poly = Polynomial([Zp(poly)])
elif isinstance(poly, Polynomial):
self.poly = poly % polynomialModulus
else:
self.poly = Polynomial([Zp(x) for x in poly]) % polynomialModulus
self.field = Fq
@typecheck
def __add__(self, other): return Fq(self.poly + other.poly)
@typecheck
def __sub__(self, other): return Fq(self.poly - other.poly)
@typecheck
def __mul__(self, other): return Fq(self.poly * other.poly)
@typecheck
def __eq__(self, other): return isinstance(other, Fq) and self.poly == other.poly
@typecheck
def __ne__(self, other): return not self == other
def __pow__(self, n):
if n==0: return Fq([1])
if n==1: return self
if n%2==0:
sqrut = self**(n//2)
return sqrut*sqrut
if n%2==1: return (self**(n-1))*self
#def __pow__(self, n): return Fq(pow(self.poly, n))
def __neg__(self): return Fq(-self.poly)
def __abs__(self): return abs(self.poly)
def __repr__(self): return repr(self.poly) + ' \u2208 ' + self.__class__.__name__
@typecheck
def __divmod__(self, divisor):
q,r = divmod(self.poly, divisor.poly)
return (Fq(q), Fq(r))
def inverse(self):
if self == Fq(0):
raise ZeroDivisionError
x,y,d = extendedEuclideanAlgorithm(self.poly, self.idealGenerator)
if d.degree() != 0:
raise Exception('Somehow, this element has no inverse! Maybe intialized with a non-prime?')
return Fq(x) * Fq(d.coefficients[0].inverse())
Fq.__name__ = 'F_{%d^%d}' % (p,m)
return Fq
if __name__ == "__main__":
F23 = FiniteField(2,3)
x = F23([1,1])
F35 = FiniteField(3,5)
y = F35([1,1,2])

View File

@@ -0,0 +1,12 @@
from modp import *
from test import test
mod7 = IntegersModP(7)
test(mod7(5), mod7(5)) # Sanity check
test(mod7(5), 1 / mod7(3))
test(mod7(1), mod7(3) * mod7(5))
test(mod7(3), mod7(3) * 1)
test(mod7(2), mod7(5) + mod7(4))
test(True, mod7(0) == mod7(3) + mod7(4))

View File

@@ -0,0 +1,84 @@
from .euclidean import *
from .numbertype import *
# so all IntegersModP are instances of the same base class
class _Modular(FieldElement):
pass
@memoize
def IntegersModP(p):
# assume p is prime
class IntegerModP(_Modular):
def __init__(self, n):
try:
self.n = int(n) % IntegerModP.p
except:
raise TypeError("Can't cast type %s to %s in __init__" %
(type(n).__name__, type(self).__name__))
self.field = IntegerModP
@typecheck
def __add__(self, other):
return IntegerModP(self.n + other.n)
@typecheck
def __sub__(self, other):
return IntegerModP(self.n - other.n)
@typecheck
def __mul__(self, other):
return IntegerModP(self.n * other.n)
def __neg__(self):
return IntegerModP(-self.n)
@typecheck
def __eq__(self, other):
return isinstance(other, IntegerModP) and self.n == other.n
@typecheck
def __ne__(self, other):
return isinstance(other, IntegerModP) is False or self.n != other.n
@typecheck
def __divmod__(self, divisor):
q,r = divmod(self.n, divisor.n)
return (IntegerModP(q), IntegerModP(r))
def inverse(self):
# need to use the division algorithm *as integers* because we're
# doing it on the modulus itself (which would otherwise be zero)
x,y,d = extendedEuclideanAlgorithm(self.n, self.p)
if d != 1:
raise Exception("Error: p is not prime in %s!" % (self.__name__))
return IntegerModP(x)
def __abs__(self):
return abs(self.n)
def __str__(self):
return str(self.n)
def __repr__(self):
return '%d (mod %d)' % (self.n, self.p)
def __int__(self):
return self.n
def __hash__(self):
return hash((self.n, self.p))
IntegerModP.p = p
IntegerModP.__name__ = 'Z/%d' % (p)
IntegerModP.englishName = 'IntegersMod%d' % (p)
return IntegerModP
if __name__ == "__main__":
mod7 = IntegersModP(7)

View File

@@ -0,0 +1,97 @@
# memoize calls to the class constructors for fields
# this helps typechecking by never creating two separate
# instances of a number class.
def memoize(f):
cache = {}
def memoizedFunction(*args, **kwargs):
argTuple = args + tuple(kwargs)
if argTuple not in cache:
cache[argTuple] = f(*args, **kwargs)
return cache[argTuple]
memoizedFunction.cache = cache
return memoizedFunction
# type check a binary operation, and silently typecast 0 or 1
def typecheck(f):
def newF(self, other):
if (hasattr(other.__class__, 'operatorPrecedence') and
other.__class__.operatorPrecedence > self.__class__.operatorPrecedence):
return NotImplemented
if type(self) is not type(other):
try:
other = self.__class__(other)
except TypeError:
message = 'Not able to typecast %s of type %s to type %s in function %s'
raise TypeError(message % (other, type(other).__name__, type(self).__name__, f.__name__))
except Exception as e:
message = 'Type error on arguments %r, %r for functon %s. Reason:%s'
raise TypeError(message % (self, other, f.__name__, type(other).__name__, type(self).__name__, e))
return f(self, other)
return newF
# require a subclass to implement +-* neg and to perform typechecks on all of
# the binary operations finally, the __init__ must operate when given a single
# argument, provided that argument is the int zero or one
class DomainElement(object):
operatorPrecedence = 1
# the 'r'-operators are only used when typecasting ints
def __radd__(self, other): return self + other
def __rsub__(self, other): return -self + other
def __rmul__(self, other): return self * other
# square-and-multiply algorithm for fast exponentiation
def __pow__(self, n):
if type(n) is not int:
raise TypeError
Q = self
R = self if n & 1 else self.__class__(1)
i = 2
while i <= n:
Q = (Q * Q)
if n & i == i:
R = (Q * R)
i = i << 1
return R
# requires the additional % operator (i.e. a Euclidean Domain)
def powmod(self, n, modulus):
if type(n) is not int:
raise TypeError
Q = self
R = self if n & 1 else self.__class__(1)
i = 2
while i <= n:
Q = (Q * Q) % modulus
if n & i == i:
R = (Q * R) % modulus
i = i << 1
return R
# additionally require inverse() on subclasses
class FieldElement(DomainElement):
def __truediv__(self, other): return self * other.inverse()
def __rtruediv__(self, other): return self.inverse() * other
def __div__(self, other): return self.__truediv__(other)
def __rdiv__(self, other): return self.__rtruediv__(other)

View File

@@ -0,0 +1,51 @@
from __future__ import division
from test import test
from fractions import Fraction
from polynomial import *
from modp import *
Mod5 = IntegersModP(5)
Mod11 = IntegersModP(11)
polysOverQ = polynomialsOver(Fraction).factory
polysMod5 = polynomialsOver(Mod5).factory
polysMod11 = polynomialsOver(Mod11).factory
for p in [polysOverQ, polysMod5, polysMod11]:
# equality
test(True, p([]) == p([]))
test(True, p([1,2]) == p([1,2]))
test(True, p([1,2,0]) == p([1,2,0,0]))
# addition
test(p([1,2,3]), p([1,0,3]) + p([0,2]))
test(p([1,2,3]), p([1,2,3]) + p([]))
test(p([5,2,3]), p([4]) + p([1,2,3]))
test(p([1,2]), p([1,2,3]) + p([0,0,-3]))
# subtraction
test(p([1,-2,3]), p([1,0,3]) - p([0,2]))
test(p([1,2,3]), p([1,2,3]) - p([]))
test(p([-1,-2,-3]), p([]) - p([1,2,3]))
# multiplication
test(p([1,2,1]), p([1,1]) * p([1,1]))
test(p([2,5,5,3]), p([2,3]) * p([1,1,1]))
test(p([0,7,49]), p([0,1,7]) * p([7]))
# division
test(p([1,1,1,1,1,1]), p([-1,0,0,0,0,0,1]) / p([-1,1]))
test(p([-1,1,-1,1,-1,1]), p([1,0,0,0,0,0,1]) / p([1,1]))
test(p([]), p([]) / p([1,1]))
test(p([1,1]), p([1,1]) / p([1]))
test(p([1,1]), p([2,2]) / p([2]))
# modulus
test(p([]), p([1,7,49]) % p([7]))
test(p([-7]), p([-3,10,-5,3]) % p([1,3]))
test(polysOverQ([Fraction(1,7), 1, 7]), polysOverQ([1,7,49]) / polysOverQ([7]))
test(polysMod5([1 / Mod5(7), 1, 7]), polysMod5([1,7,49]) / polysMod5([7]))
test(polysMod11([1 / Mod11(7), 1, 7]), polysMod11([1,7,49]) / polysMod11([7]))

View File

@@ -0,0 +1,143 @@
try:
from itertools import zip_longest
except ImportError:
from itertools import izip_longest as zip_longest
import fractions
from .numbertype import *
# strip all copies of elt from the end of the list
def strip(L, elt):
if len(L) == 0: return L
i = len(L) - 1
while i >= 0 and L[i] == elt:
i -= 1
return L[:i+1]
# create a polynomial with coefficients in a field; coefficients are in
# increasing order of monomial degree so that, for example, [1,2,3]
# corresponds to 1 + 2x + 3x^2
@memoize
def polynomialsOver(field=fractions.Fraction):
class Polynomial(DomainElement):
operatorPrecedence = 2
@classmethod
def factory(cls, L):
return Polynomial([cls.field(x) for x in L])
def __init__(self, c):
if type(c) is Polynomial:
self.coefficients = c.coefficients
elif isinstance(c, field):
self.coefficients = [c]
elif not hasattr(c, '__iter__') and not hasattr(c, 'iter'):
self.coefficients = [field(c)]
else:
self.coefficients = c
self.coefficients = strip(self.coefficients, field(0))
def isZero(self): return self.coefficients == []
def __repr__(self):
if self.isZero():
return '0'
return ' + '.join(['%s x^%d' % (a,i) if i > 0 else '%s'%a
for i,a in enumerate(self.coefficients)])
def __abs__(self): return len(self.coefficients) # the valuation only gives 0 to the zero polynomial, i.e. 1+degree
def __len__(self): return len(self.coefficients)
def __sub__(self, other): return self + (-other)
def __iter__(self): return iter(self.coefficients)
def __neg__(self): return Polynomial([-a for a in self])
def iter(self): return self.__iter__()
def leadingCoefficient(self): return self.coefficients[-1]
def degree(self): return abs(self) - 1
@typecheck
def __eq__(self, other):
return self.degree() == other.degree() and all([x==y for (x,y) in zip(self, other)])
@typecheck
def __ne__(self, other):
return self.degree() != other.degree() or any([x!=y for (x,y) in zip(self, other)])
@typecheck
def __add__(self, other):
newCoefficients = [sum(x) for x in zip_longest(self, other, fillvalue=self.field(0))]
return Polynomial(newCoefficients)
@typecheck
def __mul__(self, other):
if self.isZero() or other.isZero():
return Zero()
newCoeffs = [self.field(0) for _ in range(len(self) + len(other) - 1)]
for i,a in enumerate(self):
for j,b in enumerate(other):
newCoeffs[i+j] += a*b
return Polynomial(newCoeffs)
@typecheck
def __divmod__(self, divisor):
quotient, remainder = Zero(), self
divisorDeg = divisor.degree()
divisorLC = divisor.leadingCoefficient()
while remainder.degree() >= divisorDeg:
monomialExponent = remainder.degree() - divisorDeg
monomialZeros = [self.field(0) for _ in range(monomialExponent)]
monomialDivisor = Polynomial(monomialZeros + [remainder.leadingCoefficient() / divisorLC])
quotient += monomialDivisor
remainder -= monomialDivisor * divisor
return quotient, remainder
@typecheck
def __truediv__(self, divisor):
if divisor.isZero():
raise ZeroDivisionError
return divmod(self, divisor)[0]
@typecheck
def __mod__(self, divisor):
if divisor.isZero():
raise ZeroDivisionError
return divmod(self, divisor)[1]
def __call__(self, x):
if type(x) is int:
x = self.field(x)
assert type(x) is self.field
if self.isZero():
return self.field(0)
y = self.leadingCoefficient()
for coeff in self.coefficients[-2::-1]:
y = y * x + coeff
return y
def Zero():
return Polynomial([])
Polynomial.field = field
Polynomial.__name__ = '(%s)[x]' % field.__name__
Polynomial.englishName = 'Polynomials in one variable over %s' % field.__name__
return Polynomial

View File

@@ -0,0 +1,8 @@
def test(expected, actual):
if expected != actual:
import sys, traceback
(filename, lineno, container, code) = traceback.extract_stack()[-2]
print("Test: %r failed on line %d in file %r.\nExpected %r but got %r\n" %
(code, lineno, filename, expected, actual))
sys.exit(1)

View File

@@ -0,0 +1,11 @@
from modp import *
from polynomial import *
mod3 = IntegersModP(3)
Polynomial = polynomialsOver(mod3)
x = mod3(1)
p = Polynomial([1,2])
x+p
p+x

88
scripts/halo/bootle16.py Normal file
View File

@@ -0,0 +1,88 @@
# Notes from paper:
# "Efficient Zero-Knowledge Arguments for Arithmetic Circuits in the
# Discrete Log Setting" by Bootle and others (EUROCRYPT 2016)
from finite_fields import finitefield
import numpy as np
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
fp = finitefield.IntegersModP(p)
# Number of variables
m = 16
# Number of rows for multiplication statements
n = 3
N = n * m
# Initialize zeroed table
aux = np.full(m, fp(0))
# From the zk-explainer document, we will represent the function:
#
# def foo(w, a, b):
# if w:
# return a * b
# else:
# return a + b
#
# Which can be translated mathematically to the statements:
#
# ab = m
# w(m - a - b) = v - a - b
# w^2 = w
#
# Where m is an intermediate value.
var_one = 0
aux[var_one] = fp(1)
var_a = 1
var_b = 2
var_w = 3
aux[var_a] = fp(110)
aux[var_b] = fp(4)
aux[var_w] = fp(1)
# Calculate intermediate advice values
var_m = 4
aux[var_m] = aux[var_a] * aux[var_b]
# Calculate public input values
var_v = 5
aux[var_v] = aux[var_w] * (aux[var_a] * aux[var_b]) + \
(aux[var_one] - aux[var_w]) * (aux[var_a] + aux[var_b])
# Just a quick enforcement check:
assert aux[var_a] * aux[var_b] == aux[var_m]
assert aux[var_w] * (aux[var_m] - aux[var_a] - aux[var_b]) == \
aux[var_v] - aux[var_a] - aux[var_b]
assert aux[var_w] * aux[var_w] == aux[var_w]
# Setup the gates. For each row of a, b and c, the statement a b = c holds
# R1CS, more info here:
# http://www.zeroknowledgeblog.com/index.php/the-pinocchio-protocol/r1cs
left = np.full((n, m), fp(0))
right = np.full((n, m), fp(0))
output = np.full((n, m), fp(0))
# ab = m
left[0][var_a] = fp(1)
right[0][var_b] = fp(1)
output[0][var_m] = fp(1)
assert aux.dot(left[0]) * aux.dot(right[0]) == aux.dot(output[0])
# w(m - a - b) = v - a - b
left[1][var_w] = fp(1)
right[1][var_m] = fp(1)
right[1][var_a] = fp(-1)
right[1][var_b] = fp(-1)
output[1][var_v] = fp(1)
output[1][var_a] = fp(-1)
output[1][var_b] = fp(-1)
assert aux.dot(left[1]) * aux.dot(right[1]) == aux.dot(output[1])
# w^2 = w
left[2][var_w] = fp(1)
right[2][var_w] = fp(1)
output[2][var_w] = fp(1)
assert aux.dot(left[2]) * aux.dot(right[2]) == aux.dot(output[2])

64
scripts/halo/fft.sage Normal file
View File

@@ -0,0 +1,64 @@
# See also https://cp-algorithms.com/algebra/fft.html
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
K = GF(q)
P.<X> = K[]
def get_omega():
generator = K(5)
assert (q - 1) % 2^32 == 0
# Root of unity
t = (q - 1) / 2^32
omega = generator**t
assert omega != 1
assert omega^(2^16) != 1
assert omega^(2^31) != 1
assert omega^(2^32) == 1
return omega
# Order of this element is 2^32
omega = get_omega()
k = 3
n = 2^k
omega = omega^(2^32 / n)
assert omega^n == 1
f = 6*X^7 + 7*X^5 + 3*X^2 + X
def fft(F):
print(f"fft({F})")
# On the first invocation:
#assert len(F) == n
N = len(F)
if N == 1:
print(" returning 1")
return F
omega_prime = omega^(n/N)
assert omega_prime^(n - 1) != 1
assert omega_prime^N == 1
# Split into even and odd powers of X
F_e = [a for a in F[::2]]
print(" Evens:", F_e)
F_o = [a for a in F[1::2]]
print(" Odds:", F_o)
y_e, y_o = fft(F_e), fft(F_o)
print(f"y_e = {y_e}, y_o = {y_o}")
y = [0] * N
for j in range(N / 2):
y[j] = y_e[j] + omega_prime^j * y_o[j]
y[j + N / 2] = y_e[j] - omega_prime^j * y_o[j]
print(f" returning y = {y}")
return y
print("f =", f)
evals = fft(list(f))
print("evals =", evals)
print("{omega^i : i in {0, 1, ..., n - 1}} =", [omega^i for i in range(n)])
evals2 = [f(omega^i) for i in range(n)]
print("{f(omega^i) for all omega^i} =", evals2)
assert evals == evals2

1
scripts/halo/finite_fields Symbolic link
View File

@@ -0,0 +1 @@
../finite_fields/

View File

@@ -0,0 +1,89 @@
import numpy as np
# Implementation of Groth09 inner product proof
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
K = GF(q)
a = K(0x00)
b = K(0x05)
E = EllipticCurve(K, (a, b))
G = E(0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000, 0x02)
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
assert E.order() == p
Scalar = GF(p)
x = np.array([
Scalar(110), Scalar(56), Scalar(89), Scalar(6543), Scalar(2)
])
y = np.array([
Scalar(4), Scalar(88), Scalar(14), Scalar(33), Scalar(6)
])
z = x.dot(y)
assert len(x) == len(y)
# Create some generator points. Normally we would use hash to curve.
# All these points will be generators since the curve is a cyclic group
H = E.random_element()
G_vec = [E.random_element() for _ in range(len(x))]
# We will now construct a proof
# Commitments
def dot_product(x, y):
result = None
for x_i, y_i in zip(x, y):
if result is None:
result = int(x_i) * y_i
else:
result += int(x_i) * y_i
return result
t = Scalar.random_element()
r = Scalar.random_element()
s = Scalar.random_element()
C_z = int(t) * H + int(z) * G
C_x = int(r) * H + dot_product(x, G_vec)
C_y = int(s) * H + dot_product(y, G_vec)
d_x = np.array([Scalar.random_element() for _ in range(len(x))])
d_y = np.array([Scalar.random_element() for _ in range(len(x))])
r_d = Scalar.random_element()
s_d = Scalar.random_element()
A_d = int(r_d) * H + dot_product(d_x, G_vec)
B_d = int(s_d) * H + dot_product(d_y, G_vec)
# (cx + d_x)(cy + d_y) = d_x d_y + c(x d_y + y d_x) + c^2 xy
t_0 = Scalar.random_element()
t_1 = Scalar.random_element()
C_0 = int(t_0) * H + int(d_x.dot(d_y)) * G
C_1 = int(t_1) * H + int(x.dot(d_y) + y.dot(d_x)) * G
# Challenge
# Using the Fiat-Shamir transform, we would hash the transcript
c = Scalar.random_element()
# Responses
f_x = c * x + d_x
f_y = c * y + d_y
r_x = c * r + r_d
s_y = c * s + s_d
t_z = c**2 * t + c * t_1 + t_0
# Verify
assert int(c) * C_x + A_d == int(r_x) * H + dot_product(f_x, G_vec)
assert int(c) * C_y + B_d == int(s_y) * H + dot_product(f_y, G_vec)
# Actual inner product check
# Comm(f_x f_y) == e^2 C_z + c Comm(x d_y + y d_x) + Comm(d_x d_y)
assert int(t_z) * H + int(f_x.dot(f_y)) * G == int(c**2) * C_z + int(c) * C_1 + C_0

View File

@@ -0,0 +1,177 @@
# This file was *autogenerated* from the file groth_poly_commit.sage
from sage.all_cmdline import * # import sage library
_sage_const_0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001 = Integer(0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001); _sage_const_0x00 = Integer(0x00); _sage_const_0x05 = Integer(0x05); _sage_const_0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000 = Integer(0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000); _sage_const_0x02 = Integer(0x02); _sage_const_0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001 = Integer(0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001); _sage_const_1000 = Integer(1000); _sage_const_1 = Integer(1); _sage_const_110 = Integer(110); _sage_const_2 = Integer(2); _sage_const_56 = Integer(56); _sage_const_89 = Integer(89); _sage_const_6543 = Integer(6543); _sage_const_0 = Integer(0); _sage_const_77 = Integer(77)
import numpy as np
from collections import namedtuple
PolyProof = namedtuple("PolyProof", [
"poly_commit",
"poly_blind_commit",
"poly_response",
"poly_blind_respond",
"x_blind_factors",
"evaluation_commits",
"evaluation_response",
"value"
])
# Implementation of Groth09 inner product proof
q = _sage_const_0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
K = GF(q)
a = K(_sage_const_0x00 )
b = K(_sage_const_0x05 )
E = EllipticCurve(K, (a, b))
G = E(_sage_const_0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000 , _sage_const_0x02 )
p = _sage_const_0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
assert E.order() == p
Scalar = GF(p)
# Create some generator points. Normally we would use hash to curve.
# All these points will be generators since the curve is a cyclic group
H = E.random_element()
G_vec = [E.random_element() for _ in range(_sage_const_1000 )]
def dot_product(x, y):
result = None
for x_i, y_i in zip(x, y):
if result is None:
result = int(x_i) * y_i
else:
result += int(x_i) * y_i
return result
def poly_commit(p):
# Sage randomly orders terms. No guarantee about ordering.
#a = np.array(p.coefficients())
a = np.array([p[i] for i in range(p.degree() + _sage_const_1 )])
r = Scalar.random_element()
C_x = int(r) * H + dot_product(a, G_vec)
return (r, C_x)
def create_proof(p, r, x):
a = np.array([p[i] for i in range(p.degree() + _sage_const_1 )])
#a = np.array(p.coefficients())
x = np.array([x**i for i in range(p.degree() + _sage_const_1 )])
# Evaluate the polynomial
z = a.dot(x)
assert len(a) == len(x)
# We will now construct a proof
# Commitments
t = Scalar.random_element()
#r = Scalar.random_element()
s = Scalar.random_element()
C_z = int(t) * H + int(z) * G
C_x = int(r) * H + dot_product(a, G_vec)
C_y = int(s) * H + dot_product(x, G_vec)
d_x = np.array([Scalar.random_element() for _ in range(len(x))])
d_y = np.array([Scalar.random_element() for _ in range(len(x))])
r_d = Scalar.random_element()
s_d = Scalar.random_element()
A_d = int(r_d) * H + dot_product(d_x, G_vec)
B_d = int(s_d) * H + dot_product(d_y, G_vec)
# (cx + d_x)(cy + d_y) = d_x d_y + c(x d_y + y d_x) + c^2 xy
t_0 = Scalar.random_element()
t_1 = Scalar.random_element()
C_0 = int(t_0) * H + int(d_x.dot(d_y)) * G
C_1 = int(t_1) * H + int(a.dot(d_y) + x.dot(d_x)) * G
# Challenge
# Using the Fiat-Shamir transform, we would hash the transcript
#c = Scalar.random_element()
c = _sage_const_110
# Responses
f_x = c * a + d_x
f_y = c * x + d_y
r_x = c * r + r_d
s_y = c * s + s_d
t_z = c**_sage_const_2 * t + c * t_1 + t_0
# Verify
#B_d = int(s_d) * H + dot_product(d_y, G_vec)
#C_y = int(s) * H + dot_product(x, G_vec)
assert int(c) * C_x + A_d == int(r_x) * H + dot_product(f_x, G_vec)
assert int(c) * C_y + B_d == int(s_y) * H + dot_product(f_y, G_vec)
# Actual inner product check
# Comm(f_x f_y) == e^2 C_z + c Comm(x d_y + y d_x) + Comm(d_x d_y)
assert int(t_z) * H + int(f_x.dot(f_y)) * G == int(c**_sage_const_2 ) * C_z + int(c) * C_1 + C_0
return PolyProof(
poly_commit=C_x,
poly_blind_commit=A_d,
poly_response=f_x,
poly_blind_respond=r_x,
x_blind_factors=(s_d, d_y, s),
evaluation_commits=(C_0, C_1, C_z),
evaluation_response=t_z,
value=z
)
def verify_proof(proof, x):
C_x = proof.poly_commit
A_d = proof.poly_blind_commit
f_x = proof.poly_response
r_x = proof.poly_blind_respond
(s_d, d_y, s) = proof.x_blind_factors
(C_0, C_1, C_z) = proof.evaluation_commits
t_z = proof.evaluation_response
z = proof.value
x = np.array([x**i for i in range(len(a))])
c = _sage_const_110
f_y = c * x + d_y
s_y = c * s + s_d
B_d = int(s_d) * H + dot_product(d_y, G_vec)
C_y = int(s) * H + dot_product(x, G_vec)
if int(c) * C_x + A_d != int(r_x) * H + dot_product(f_x, G_vec):
return False
if int(c) * C_y + B_d != int(s_y) * H + dot_product(f_y, G_vec):
return False
# Actual inner product check
# Comm(f_x f_y) == e^2 C_z + c Comm(x d_y + y d_x) + Comm(d_x d_y)
if int(t_z) * H + int(f_x.dot(f_y)) * G != int(c**_sage_const_2 ) * C_z + int(c) * C_1 + C_0:
return False
return True
#R = LaurentPolynomialRing(Scalar, names=('x',)); (x,) = R._first_ngens(1)
#a = np.array([
# Scalar(_sage_const_110 ), Scalar(_sage_const_56 ), Scalar(_sage_const_89 ), Scalar(_sage_const_6543 ), Scalar(_sage_const_2 )
#])
#p = _sage_const_0
#for i, a_i in enumerate(a):
# p += a_i * x**i
#print(p)
#xx = Scalar(_sage_const_77 )
#r, commit = poly_commit(p)
#proof = create_proof(p, r, xx)
#assert verify_proof(proof, xx)
#assert proof.poly_commit == commit
#assert proof.value == p(xx)

View File

@@ -0,0 +1,170 @@
import numpy as np
from collections import namedtuple
PolyProof = namedtuple("PolyProof", [
"poly_commit",
"poly_blind_commit",
"poly_response",
"poly_blind_respond",
"x_blind_factors",
"evaluation_commits",
"evaluation_response",
"value"
])
# Implementation of Groth09 inner product proof
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
K = GF(q)
a = K(0x00)
b = K(0x05)
E = EllipticCurve(K, (a, b))
G = E(0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000, 0x02)
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
assert E.order() == p
Scalar = GF(p)
# Create some generator points. Normally we would use hash to curve.
# All these points will be generators since the curve is a cyclic group
H = E.random_element()
G_vec = [E.random_element() for _ in range(1000)]
def dot_product(x, y):
result = None
for x_i, y_i in zip(x, y):
if result is None:
result = int(x_i) * y_i
else:
result += int(x_i) * y_i
return result
def poly_commit(p):
# Sage randomly orders terms. No guarantee about ordering.
#a = np.array(p.coefficients())
a = np.array([p[i] for i in range(p.degree() + 1)])
r = Scalar.random_element()
C_x = int(r) * H + dot_product(a, G_vec)
return (r, C_x)
def create_proof(p, r, x):
a = np.array([p[i] for i in range(p.degree() + 1)])
#a = np.array(p.coefficients())
x = np.array([x**i for i in range(p.degree() + 1)])
# Evaluate the polynomial
z = a.dot(x)
assert len(a) == len(x)
# We will now construct a proof
# Commitments
t = Scalar.random_element()
#r = Scalar.random_element()
s = Scalar.random_element()
C_z = int(t) * H + int(z) * G
C_x = int(r) * H + dot_product(a, G_vec)
C_y = int(s) * H + dot_product(x, G_vec)
d_x = np.array([Scalar.random_element() for _ in range(len(x))])
d_y = np.array([Scalar.random_element() for _ in range(len(x))])
r_d = Scalar.random_element()
s_d = Scalar.random_element()
A_d = int(r_d) * H + dot_product(d_x, G_vec)
B_d = int(s_d) * H + dot_product(d_y, G_vec)
# (cx + d_x)(cy + d_y) = d_x d_y + c(x d_y + y d_x) + c^2 xy
t_0 = Scalar.random_element()
t_1 = Scalar.random_element()
C_0 = int(t_0) * H + int(d_x.dot(d_y)) * G
C_1 = int(t_1) * H + int(a.dot(d_y) + x.dot(d_x)) * G
# Challenge
# Using the Fiat-Shamir transform, we would hash the transcript
#c = Scalar.random_element()
c = 110
# Responses
f_x = c * a + d_x
f_y = c * x + d_y
r_x = c * r + r_d
s_y = c * s + s_d
t_z = c**2 * t + c * t_1 + t_0
# Verify
#B_d = int(s_d) * H + dot_product(d_y, G_vec)
#C_y = int(s) * H + dot_product(x, G_vec)
assert int(c) * C_x + A_d == int(r_x) * H + dot_product(f_x, G_vec)
assert int(c) * C_y + B_d == int(s_y) * H + dot_product(f_y, G_vec)
# Actual inner product check
# Comm(f_x f_y) == e^2 C_z + c Comm(x d_y + y d_x) + Comm(d_x d_y)
assert int(t_z) * H + int(f_x.dot(f_y)) * G == int(c**2) * C_z + int(c) * C_1 + C_0
return PolyProof(
poly_commit=C_x,
poly_blind_commit=A_d,
poly_response=f_x,
poly_blind_respond=r_x,
x_blind_factors=(s_d, d_y, s),
evaluation_commits=(C_0, C_1, C_z),
evaluation_response=t_z,
value=z
)
def verify_proof(proof, x):
C_x = proof.poly_commit
A_d = proof.poly_blind_commit
f_x = proof.poly_response
r_x = proof.poly_blind_respond
(s_d, d_y, s) = proof.x_blind_factors
(C_0, C_1, C_z) = proof.evaluation_commits
t_z = proof.evaluation_response
z = proof.value
x = np.array([x**i for i in range(len(a))])
c = 110
f_y = c * x + d_y
s_y = c * s + s_d
B_d = int(s_d) * H + dot_product(d_y, G_vec)
C_y = int(s) * H + dot_product(x, G_vec)
if int(c) * C_x + A_d != int(r_x) * H + dot_product(f_x, G_vec):
return False
if int(c) * C_y + B_d != int(s_y) * H + dot_product(f_y, G_vec):
return False
# Actual inner product check
# Comm(f_x f_y) == e^2 C_z + c Comm(x d_y + y d_x) + Comm(d_x d_y)
if int(t_z) * H + int(f_x.dot(f_y)) * G != int(c**2) * C_z + int(c) * C_1 + C_0:
return False
return True
R.<x> = LaurentPolynomialRing(Scalar)
a = np.array([
Scalar(110), Scalar(56), Scalar(89), Scalar(6543), Scalar(2)
])
p = 0
for i, a_i in enumerate(a):
p += a_i * x**i
print(p)
xx = Scalar(77)
r, commit = poly_commit(p)
proof = create_proof(p, r, xx)
assert verify_proof(proof, xx)
assert proof.poly_commit == commit
assert proof.value == p(xx)

282
scripts/halo/halo1.sage Normal file
View File

@@ -0,0 +1,282 @@
import numpy as np
from groth_poly_commit import Scalar, poly_commit, create_proof, verify_proof
K = Scalar
# Just use the same finite field we put in the polynomial commitment scheme file
#p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
#K = FiniteField(p)
R.<x, y> = LaurentPolynomialRing(K)
var_one = K(1)
var_x = K(4)
var_y = K(6)
var_s = K(1)
var_xy = var_x * var_y
var_sxy = var_s * var_xy
var_1_neg_s = var_one - var_s
var_x_y = var_x + var_y
var_1_neg_s_x_y = var_1_neg_s * var_x_y
var_s_neg_1 = -var_1_neg_s
var_zero = K(0)
public_v = var_s * (var_x * var_y) + (1 - var_s) * (var_x + var_y)
a = np.array([
var_one, var_x, var_xy, var_1_neg_s, var_s
])
b = np.array([
var_one, var_y, var_s, var_x_y, var_s_neg_1
])
c = np.array([
var_one, var_xy, var_sxy, var_1_neg_s_x_y, var_zero
])
assert len(a) == len(b)
assert len(b) == len(c)
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
try:
assert a_i * b_i == c_i
except AssertionError:
print("Error for %i" % i)
raise
# 1 - s = -(s - 1)
u1 = np.array([0, 0, 0, 1, 0])
v1 = np.array([0, 0, 0, 0, 1])
w1 = np.array([0, 0, 0, 0, 0])
k1 = 0
assert a.dot(u1) + b.dot(v1) + c.dot(w1) == k1
# xy = xy
u2 = np.array([0, 0, 1, 0, 0])
v2 = np.array([0, 0, 0, 0, 0])
w2 = np.array([0, -1, 0, 0, 0])
k2 = 0
assert a.dot(u2) + b.dot(v2) + c.dot(w2) == k2
# s = s
u3 = np.array([0, 0, 0, 0, -1])
v3 = np.array([0, 0, 1, 0, 0])
w3 = np.array([0, 0, 0, 0, 0])
k3 = 0
assert a.dot(u3) + b.dot(v3) + c.dot(w3) == k3
# zero = 0
u4 = np.array([0, 0, 0, 0, 0])
v4 = np.array([0, 0, 0, 0, 0])
w4 = np.array([0, 0, 0, 0, 1])
k4 = 0
assert a.dot(u4) + b.dot(v4) + c.dot(w4) == k4
# 1 - s
u5 = np.array([1, 0, 0, -1, 0])
v5 = np.array([0, 0, -1, 0, 0])
w5 = np.array([0, 0, 0, 0, 0])
k5 = 0
assert a.dot(u5) + b.dot(v5) + c.dot(w5) == k5
# x + y
u6 = np.array([0, 1, 0, 0, 0])
v6 = np.array([0, 1, 0, -1, 0])
w6 = np.array([0, 0, 0, 0, 0])
k6 = 0
assert a.dot(u6) + b.dot(v6) + c.dot(w6) == k6
# Final check:
# v = s(xy) + (1 - s)(x + y)
u7 = np.array([0, 0, 0, 0, 0])
v7 = np.array([0, 0, 0, 0, 0])
w7 = np.array([0, 0, 1, 1, 0])
k7 = public_v
assert a.dot(u7) + b.dot(v7) + c.dot(w7) == k7
u = np.vstack((u1, u2, u3, u4, u5, u6, u7))
v = np.vstack((v1, v2, v3, v4, v5, v6, v7))
w = np.vstack((w1, w2, w3, w4, w5, w6, w7))
assert u.shape == v.shape
assert u.shape == w.shape
k = np.array((k1, k2, k3, k4, k5, k6, k7))
p = K(0)
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
#print(a_i, "\t", b_i, "\t", c_i)
p += y**i * (a_i * b_i - c_i)
print(p)
p = K(0)
for q, (u_q, v_q, w_q, k_q) in enumerate(zip(u, v, w, k)):
p += y**q * (a.dot(u_q) + b.dot(v_q) + c.dot(w_q) - k_q)
print(p)
n = len(a)
assert len(b) == n
assert len(c) == n
assert u.shape == (7, n)
assert v.shape == u.shape
assert w.shape == u.shape
assert k.shape == (7,)
r_x_y = 0
s_x_y = 0
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
assert 1 <= i <= n
r_x_y += x**i * y**i * a_i
r_x_y += x**-i * y**-i * b_i
r_x_y += x**(-i - n) * y**(-i - n) * c_i
u_i = u.T[i - 1]
v_i = v.T[i - 1]
w_i = w.T[i - 1]
u_i_Y = 0
v_i_Y = 0
w_i_Y = 0
for q, (u_q_i, v_q_i, w_q_i) in enumerate(zip(u_i, v_i, w_i), 1):
assert 1 <= q <= 7
u_i_Y += y**q * u_q_i
v_i_Y += y**q * v_q_i
w_i_Y += y**q * w_q_i
s_x_y += u_i_Y * x**-i + v_i_Y * x**i + w_i_Y * x**(i + n)
k_y = 0
for q, k_q in enumerate(k, 1):
assert 1 <= q <= 7
k_y += y**q * k_q
# Section 6, Figure 2
#
# zkP1
# 4 blinding factors since we evaluate r(X, Y) 3 times
# Blind r(X, Y)
for i in range(1, 4 + 1):
blind_c_i = K.random_element()
r_x_y += x**(-2*n - i) * y**(-2*n - i) * blind_c_i
# Commit to r(X, Y)
s_prime_x_y = y**n * s_x_y
for i in range(1, n):
s_prime_x_y -= (y**i + y**-i) * x**(i + n)
r_x_1 = r_x_y(y=K(1))
t_x_y = r_x_1 * (r_x_y + s_prime_x_y) - y**n * k_y
# This can be opened to r(X, Y) since r(X, Y) = r(XY, 1)
r_x_1_scaled = (r_x_1 * x**(3*n - 1)).univariate_polynomial()
rx1_commit_blind, rx1_commit = poly_commit(r_x_1_scaled)
print("===================")
print(" t(X, Y)")
print("===================")
power_dict = ["", "¹", "²", "³", "", "", "", "", "", ""]
def superscript(number):
sign = ""
if number < 0:
sign = ""
number = -number
return sign + "".join([power_dict[int(digit)] for digit in list(str(number))])
decorated = []
for (x_power, y_power), coeff in t_x_y.dict().items():
if coeff == 1:
coeff = ""
display = "%s X%s Y%s" % (coeff, superscript(x_power), superscript(y_power))
decorated.append([x_power, y_power, display])
decorated.sort(key=lambda x: (x[0], -x[1]))
for _, _, display in decorated:
print(display)
print()
print("Constant coefficient:", t_x_y.constant_coefficient())
print()
# zkV1
# Send a random y
challenge_y = K.random_element()
# zkP2
# Commit to t(X, y)
t_x = t_x_y(y=challenge_y)
t_x = t_x.univariate_polynomial()
print("===================")
print(" t(X, y)")
print("===================")
print(t_x.dict())
print()
print("Constant coefficient:", t_x.constant_coefficient())
# Split the polynomial into low and hi versions
t_lo_x = 0
t_hi_x = 0
smallest_power = -min(t_x.dict().keys())
for power, coeff in t_x.dict().items():
assert power != 0
if power < 0:
t_lo_x += x**(smallest_power + power) * coeff
else:
t_hi_x += x**(power - 1) * coeff
d = t_lo_x.degree() + 1
t_lo_x = t_lo_x.univariate_polynomial()
t_hi_x = t_hi_x.univariate_polynomial()
assert (t_lo_x * x**-d + t_hi_x * x).univariate_polynomial() == t_x
T_lo_commit_blind, T_lo = poly_commit(t_lo_x)
T_hi_commit_blind, T_hi = poly_commit(t_hi_x)
# zkV2
# Send a random z
challenge_z = K.random_element()
# zkP3
# Evaluate a = r(z, 1)
a = r_x_y(x=challenge_z, y=K(1))
# Evaluate b = r(z, y)
b = r_x_y(x=challenge_z, y=challenge_y)
# Evaluate t = t(z, y)
t = t_x_y(x=challenge_z, y=challenge_y)
# Evaluate s = s(z, y)
s = s_prime_x_y(x=challenge_z, y=challenge_y)
# Calculate equivalent openings
# s'(X, Y) is known by both prover and verifier
a_proof = create_proof(r_x_1_scaled, rx1_commit_blind, challenge_z)
assert a_proof.poly_commit == rx1_commit
b_proof = create_proof(r_x_1_scaled, rx1_commit_blind, challenge_y * challenge_z)
assert b_proof.poly_commit == rx1_commit
t_proof_lo = create_proof(t_lo_x, T_lo_commit_blind, challenge_z)
assert t_proof_lo.poly_commit == T_lo
t_proof_hi = create_proof(t_hi_x, T_hi_commit_blind, challenge_z)
assert t_proof_hi.poly_commit == T_hi
# Signature of correct computation not yet implemented
# So just use s for now as is
# Scaling factor
verifier_rescale = challenge_z**(-3*n + 1)
assert a_proof.value * verifier_rescale == a
verifier_rescale = (challenge_y * challenge_z)**(-3*n + 1)
assert b_proof.value * verifier_rescale == b
# zkV3
# Recalculate t from a, b and s
t_new = t_proof_lo.value * challenge_z**-d + t_proof_hi.value * challenge_z
assert t_new == t
t = t_new
k = (y**n * k_y)(y=challenge_y)
t_new = a * (b + s) - k
assert t_new == t
# Verify polynomial commitments

347
scripts/halo/halo2.sage Normal file
View File

@@ -0,0 +1,347 @@
import numpy as np
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
K = GF(q)
P.<X> = K[]
# GENERATOR^{2^s} where t * 2^s + 1 = q with t odd.
# In other words, this is a t root of unity.
generator = K(5)
# There is a large 2^32 order subgroup in this curve because it is 2-adic
t = (K(q) - 1) / 2^32
assert int(t) % 2 != 0
delta = generator^(2^32)
assert delta^t == 1
# The size of the multiplicative group is phi(q) = q - 1
# And inside this group are 2 distinct subgroups of size t and 2^s.
# delta is the generator for the size t subgroup, and omega for the 2^s one.
# Taking powers of these generators and multiplying them will produce
# unique cosets that divide the entire group for q.
def get_omega():
generator = K(5)
assert (q - 1) % 2^32 == 0
# Root of unity
t = (q - 1) / 2^32
omega = generator**t
assert omega != 1
assert omega^(2^16) != 1
assert omega^(2^31) != 1
assert omega^(2^32) == 1
return omega
# Order of this element is 2^32
omega = get_omega()
k = 4
n = 2^k
omega = omega^(2^32 / n)
assert omega^n == 1
# Arithmetization for:
# sxy + (s - 1)(x + y) - z = 0
# s(s - 1) = 0
# F1(A1 - 1) + F2 (A1 - I) + F3((1 - A1)(A2 + A3) - A4) + F4(A1 A2 A3 - A4) = 0
A = []
F = []
var_zero = K(0)
var_x = K(4)
var_y = K(6)
var_s = K(1)
var_sxy = var_s * var_x * var_y
var_1s_xy = (1 - var_s) * (var_x + var_y)
# Public
var_z = var_sxy + var_1s_xy
# 4 advice columns
# 4 fixed columns
# 1 instance column
# Row 1
# z = public z
A_1_1, A_2_1, A_3_1, A_4_1 = var_z, 0, 0, 0
F_1_1, F_2_1, F_3_1, F_4_1 = 1, 0, 0, 0
I_1 = var_z
# Row 2
# ~0 == 0
A_1_2, A_2_2, A_3_2, A_4_2 = var_zero, 0, 0, 0
F_1_2, F_2_2, F_3_2, F_4_2 = 0, 1, 0, 0
I_2 = 0
# Row 3
# Boolean check
# (1 - s)(s + 0) == 0
A_1_3, A_2_3, A_3_3, A_4_3 = var_s, var_s, var_zero, var_zero
F_1_3, F_2_3, F_3_3, F_4_3 = 0, 0, 1, 0
I_3 = 0
# Row 4
# s x y == sxy
A_1_4, A_2_4, A_3_4, A_4_4 = var_s, var_x, var_y, var_sxy
F_1_4, F_2_4, F_3_4, F_4_4 = 0, 0, 0, 1
I_4 = 0
# Row 5
# (1 - s)(x + y) = (1-s)(x+y)
A_1_5, A_2_5, A_3_5, A_4_5 = var_s, var_x, var_y, var_1s_xy
F_1_5, F_2_5, F_3_5, F_4_5 = 0, 0, 1, 0
I_5 = 0
# Row 6
# (1 - 0)(sxy + (1-s)(x+y)) = z
A_1_6, A_2_6, A_3_6, A_4_6 = var_zero, var_sxy, var_1s_xy, var_z
F_1_6, F_2_6, F_3_6, F_4_6 = 0, 0, 1, 0
I_6 = 0
A1 = [A_1_1, A_1_2, A_1_3, A_1_4, A_1_5, A_1_6]
A2 = [A_2_1, A_2_2, A_2_3, A_2_4, A_2_5, A_2_6]
A3 = [A_3_1, A_3_2, A_3_3, A_3_4, A_3_5, A_3_6]
A4 = [A_4_1, A_4_2, A_4_3, A_4_4, A_4_5, A_4_6]
F1 = [F_1_1, F_1_2, F_1_3, F_1_4, F_1_5, F_1_6]
F2 = [F_2_1, F_2_2, F_2_3, F_2_4, F_2_5, F_2_6]
F3 = [F_3_1, F_3_2, F_3_3, F_3_4, F_3_5, F_3_6]
F4 = [F_4_1, F_4_2, F_4_3, F_4_4, F_4_5, F_4_6]
I = [I_1, I_2, I_3, I_4, I_5, I_6]
# There should be 5 unused blinding rows.
# see src/plonk/circuit.rs: fn blinding_factors(&self) -> usize;
# We have 9 so we are perfectly fine.
# Add 9 empty rows
assert n - len(A1) == 10
for i in range(10):
A1.append(K.random_element())
A2.append(K.random_element())
A3.append(K.random_element())
A4.append(K.random_element())
F1.append(0)
F2.append(0)
F3.append(0)
F4.append(0)
I.append(K.random_element())
assert (len(A1) == len(A2) == len(A3) == len(A4) == len(F1) == len(F2)
== len(F3) == len(F4) == len(I) == n)
for A_1_i, A_2_i, A_3_i, A_4_i, F_1_i, F_2_i, F_3_i, F_4_i, I_i in zip(
A1, A2, A3, A4, F1, F2, F3, F4, I):
assert (F_1_i * (A_1_i - I_i)
+ F_2_i * A_1_i
+ F_3_i * ((1 - A_1_i) * (A_2_i + A_3_i) - A_4_i)
+ F_4_i * (A_1_i * A_2_i * A_3_i - A_4_i)) == 0
a_1_X = P.lagrange_polynomial((omega^i, A_1_i) for i, A_1_i in enumerate(A1))
a_2_X = P.lagrange_polynomial((omega^i, A_2_i) for i, A_2_i in enumerate(A2))
a_3_X = P.lagrange_polynomial((omega^i, A_3_i) for i, A_3_i in enumerate(A3))
a_4_X = P.lagrange_polynomial((omega^i, A_4_i) for i, A_4_i in enumerate(A4))
f_1_X = P.lagrange_polynomial((omega^i, F_1_i) for i, F_1_i in enumerate(F1))
f_2_X = P.lagrange_polynomial((omega^i, F_2_i) for i, F_2_i in enumerate(F2))
f_3_X = P.lagrange_polynomial((omega^i, F_3_i) for i, F_3_i in enumerate(F3))
f_4_X = P.lagrange_polynomial((omega^i, F_4_i) for i, F_4_i in enumerate(F4))
# Treat the instance wire as a 5th advice wire
a_5_X = P.lagrange_polynomial((omega^i, A_5_i) for i, A_5_i in enumerate(I))
for i, (A_1_i, A_2_i, A_3_i, A_4_i, F_1_i, F_2_i, F_3_i, F_4_i, I_i) in \
enumerate(zip(A1, A2, A3, A4, F1, F2, F3, F4, I)):
assert a_1_X(omega^i) == A_1_i
assert a_2_X(omega^i) == A_2_i
assert a_3_X(omega^i) == A_3_i
assert a_4_X(omega^i) == A_4_i
assert a_5_X(omega^i) == I_i
assert f_1_X(omega^i) == F_1_i
assert f_2_X(omega^i) == F_2_i
assert f_3_X(omega^i) == F_3_i
assert f_4_X(omega^i) == F_4_i
# beta, gamma
beta = K.random_element()
gamma = K.random_element()
# 0 1 2 3 4 5 ... 15
# A1: z, 0, s, s, s, 0,
#
# 16 17 18 19 20 21 ... 31
# A2: -, -, s, x, x, sxy,
#
# 32 33 34 35 36 37 ... 47
# A3: -, -, 0, y, y, (1-s)(x+y),
#
# 48 49 50 51 52 53 ... 63
# A4: -, -, 0, sxy, (1-s)(x + y), z,
#
# 64 65 66 67 68 69 ... 79
# A5: z, -, -, -, -, -,
# z = (0 53 64)
# 0 = (1 5 34 50)
# s = (2 3 4 18)
# x = (19 20)
# sxy = (21 51)
# y = (35 36)
# (1-s)(x+y) = (37 52)
permuted_indices = list(range(n * 5))
assert len(permuted_indices) == 80
# Apply the actual permutation cycles
# z
permuted_indices[0] = 53
permuted_indices[53] = 64
permuted_indices[64] = 0
# ~0
permuted_indices[1] = 5
permuted_indices[5] = 34
permuted_indices[34] = 50
permuted_indices[50] = 1
# s
permuted_indices[2] = 3
permuted_indices[3] = 4
permuted_indices[4] = 18
permuted_indices[18] = 2
# x
permuted_indices[19] = 20
permuted_indices[20] = 19
# sxy
permuted_indices[21] = 51
permuted_indices[51] = 21
# y
permuted_indices[35] = 36
permuted_indices[36] = 35
# (1-s)(x+y)
permuted_indices[37] = 52
permuted_indices[52] = 37
witness = A1 + A2 + A3 + A4 + I
for i, val in enumerate(witness):
assert val == witness[permuted_indices[i]]
# How to join lists together?
indices = ([omega^i for i in range(n)]
+ [delta * omega^i for i in range(n)]
+ [delta^2 * omega^i for i in range(n)]
+ [delta^3 * omega^i for i in range(n)]
+ [delta^4 * omega^i for i in range(n)])
assert len(indices) == 80
# Permuted indices
sigma_star = [indices[i] for i in permuted_indices]
s = [sigma_star[:n], sigma_star[n:2 * n], sigma_star[2 * n:3 * n],
sigma_star[3 * n:4 * n], sigma_star[4 * n:]]
assert s[0] + s[1] + s[2] + s[3] + s[4] == sigma_star
v = [A1, A2, A3, A4, I]
# We split the columns into sets of size m.
# Here we will use m = 1 for illustration purposes
# We have 6 usable rows
# n = 16 rows total
# row u (q_last) will be the 7th row
# So we have 9 unusable rows
q_blind = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]
q_last = [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# Turn both of these into polynomial form
q_blind = P.lagrange_polynomial((omega^i, q_i) for i, q_i in enumerate(q_blind))
assert q_blind(omega^5) == 0
assert q_blind(omega^6) == 0
assert q_blind(omega^7) == 1
assert q_blind(omega^11) == 1
q_last = P.lagrange_polynomial((omega^i, q_i) for i, q_i in enumerate(q_last))
assert q_last(omega^5) == 0
assert q_last(omega^6) == 1
assert q_last(omega^7) == 0
assert q_last(omega^11) == 0
m = 5
assert n == 16
# 6 usable rows
u = 6
# There are 5 columns
# We will split the columns partitions into 5 partitions to make things easy
# So b = 5, and each partition contains only a single column
# We still iterate over the column to make it more obvious
m = 1
permutation_points = [(1, 1)]
last_y_value = 1
ZP = []
# a is the current column partition we are aggregating
for a in range(5):
# j iterates over the rows
for j in range(u):
current = last_y_value
# i iterates over the columns in our partition
for i in range(a * m, (a + 1)):
current *= v[i][j] + beta * delta^i * omega^j + gamma
current /= v[i][j] + beta * s[i][j] + gamma
last_y_value = current
permutation_points.append((omega^(j + 1), current))
ZP_a = P.lagrange_polynomial(permutation_points)
ZP.append(ZP_a)
permutation_points = [(1, last_y_value)]
# l_0(X) (1 - ZP,0(X)) = 0
# => ZP,0(1) = 1
assert ZP[0](1) == 1
# Checks for l_0(X) (ZP,a(X) - ZP,a-1(omega^u X)) = 1
# => ZP,a(Z) = ZP,a-1(omega^u X)
# This copies the end value from one partition to the next one
assert ZP[1](omega^0) == ZP[0](omega^u)
assert ZP[2](omega^0) == ZP[1](omega^u)
assert ZP[3](omega^0) == ZP[2](omega^u)
assert ZP[4](omega^0) == ZP[3](omega^u)
# Allow the last value to be either 0 or 1 for full ZK
assert ZP[4](omega^u) in (0, 1)
y = K.random_element()
gate_0 = f_1_X * (a_1_X - a_5_X)
gate_1 = f_2_X * a_1_X
gate_2 = f_3_X * ((1 - a_1_X) * (a_2_X + a_3_X) - a_4_X)
gate_3 = f_4_X * (a_1_X * a_2_X * a_3_X - a_4_X)
c = gate_0 + y * gate_1 + y^2 * gate_2 + y^3 * gate_3
t = X^n - 1
for i in range(n):
assert h(omega^i) == 0
# Normally we do:
#h = c / t
# But for some reason sage is producing fractional coefficients
h, rem = c.quo_rem(t)
assert rem == 0
# We send commitments to the terms of h(X)
# h_0(x), ..., h_{d - 1}(x)
# Commitments:
# H = [H_0, ..., H_{d - 1}]
x = K.random_element()
# Send evaluations at x of everything we committed to so far
# A_0(x), ..., A_{m - 1}(x)
# ZP,0(x), ..., ZP,b-1(x)
# H_0(x), ..., H_{d-1}(x)
a_evals = [a_1_X(x), a_2_X(x), a_3_X(x), a_4_X(x), a_5_X(x)]
h_evals = []
# Iterate starting from lowest powers first
h_test = 0
for i, h_i in enumerate(h):
h_evals.append(h_i * x^i)
h_test += h_i * X^i
assert h_test == h
assert sum(h_evals) == h(x)
assert sum(h_evals) * t(x) == (
f_1_X(x) * (a_evals[0] - a_evals[4])
+ y * f_2_X(x) * a_evals[0]
+ y^2 * f_3_X(x) * ((1 - a_evals[0]) * (a_evals[1] + a_evals[2])
- a_evals[3])
+ y^3 * f_4_X(x) * (a_evals[0] * a_evals[1] * a_evals[2] - a_evals[3]))

33
scripts/halo/misc.py Normal file
View File

@@ -0,0 +1,33 @@
import random
def sample_random(fp, seed=None):
rnd = random.Random(seed)
# Range of the field is 0 ... p - 1
return fp(rnd.randint(0, fp.p - 1))
def is_power_of_two(n):
# Power of two number is represented by a single digit
# followed by zeroes.
return n & (n - 1) == 0
#| ## Choosing roots of unity
def get_omega(fp, n, seed=None):
"""
Given a field, this method returns an n^th root of unity.
If the seed is not None then this method will return the
same n'th root of unity for every run with the same seed
This only makes sense if n is a power of 2.
"""
assert is_power_of_two(n)
# https://crypto.stackexchange.com/questions/63614/finding-the-n-th-root-of-unity-in-a-finite-field
while True:
# Sample random x != 0
x = sample_random(fp, seed)
# Compute g = x^{(q - 1)/n}
y = pow(x, (fp.p - 1) // n)
# If g^{n/2} != 1 then g is a primitive root
if y != 1 and pow(y, n // 2) != 1:
assert pow(y, n) == 1, "omega must be 2nd root of unity"
return y

288
scripts/halo/multipoly.py Normal file
View File

@@ -0,0 +1,288 @@
import numpy as np
from finite_fields import finitefield
class Variable:
def __init__(self, name, fp):
self.name = name
self.fp = fp
def __pow__(self, n):
expr = MultiplyExpression(self.fp)
expr.set_symbol(self.name, n)
return expr
def __eq__(self, other):
return self.name == other.name
def __hash__(self):
return hash(self.name)
def termify(self):
expr = MultiplyExpression(self.fp)
expr.set_symbol(self.name, 1)
return expr
class MultiplyExpression:
def __init__(self, fp):
self.coeff = fp(1)
self.symbols = {}
self.fp = fp
def copy(self):
result = MultiplyExpression(self.fp)
result.coeff = self.coeff
result.symbols = self.symbols.copy()
return result
def clean(self):
for symbol in list(self.symbols.keys()):
if self.symbols[symbol] == 0:
del self.symbols[symbol]
def matches(self, other):
return self.symbols == other.symbols
def set_symbol(self, var_name, power):
self.symbols[var_name] = power
def __eq__(self, other):
return (self.coeff == other.coeff and
self.symbols == other.symbols)
def __neg__(self):
result = self.copy()
result.coeff *= -1
return result
def __mul__(self, expr):
result = MultiplyExpression(self.fp)
result.coeff = self.coeff
result.symbols = self.symbols.copy()
if isinstance(expr, np.int64) or isinstance(expr, int):
expr = self.fp(int(expr))
if hasattr(expr, "field"):
result.coeff *= expr
return result
if isinstance(expr, Variable):
expr = expr.termify()
for var_name, power in expr.symbols.items():
if var_name in result.symbols:
result.symbols[var_name] += power
else:
result.symbols[var_name] = power
# Remember to multiply the coefficients
result.coeff *= expr.coeff
return result
def __add__(self, expr):
if isinstance(expr, Variable):
expr = expr.termify()
if self.matches(expr):
result = self.copy()
result.coeff += expr.coeff
return result
return MultivariatePolynomial([self, expr])
def __sub__(self, expr):
expr = -expr
return self + expr
def evaluate(self, symbol_map):
result = MultiplyExpression(self.fp)
for symbol, power in self.symbols.items():
if symbol in symbol_map:
value = symbol_map[symbol]
result *= value**power
else:
result *= Variable(symbol, self.fp)**power
return result
def __str__(self):
repr = ""
first = True
if self.coeff != 1:
repr += str(self.coeff)
first = False
for var_name, power in self.symbols.items():
if first:
first = False
else:
repr += " "
if power == 1:
repr += var_name
else:
repr += var_name + "^" + str(power)
return repr
class MultivariatePolynomial:
def __init__(self, terms=[]):
self.terms = terms
def copy(self):
terms = [term.copy() for term in self.terms]
return MultivariatePolynomial(terms)
# Operations can accept Variables and constants
# so we make sure to convert them to MultiplyExpression types
def _convert_term(self, term):
if isinstance(term, Variable):
term = term.termify()
if hasattr(term, "field"):
expr = MultiplyExpression(term.field)
expr.coeff = term
term = expr
return term
def __bool__(self):
return bool(self.terms)
def __eq__(self, other):
return self.terms == other.terms
def __neg__(self):
terms = [-term for term in self.terms]
return MultivariatePolynomial(terms)
def __add__(self, term):
term = self._convert_term(term)
if isinstance(term, MultivariatePolynomial):
# Recursively apply addition operation
result = self.copy()
for other_term in term.terms:
result += other_term
return result
assert isinstance(term, MultiplyExpression)
# Delete ^0 variables
term.clean()
# Skip terms where the coeff is 0
if term.coeff == 0:
return self
result = self.copy()
result_term = result._find(term)
if result_term is None:
result.terms.append(term)
else:
result_term.coeff += term.coeff
return result
def __sub__(self, term):
term = -term
return self + term
def __mul__(self, term):
term = self._convert_term(term)
if isinstance(term, MultivariatePolynomial):
# Recursively apply addition operation
result = MultivariatePolynomial()
for other_term in term.terms:
result += self * other_term
return result
assert isinstance(term, MultiplyExpression)
# Delete ^0 variables
term.clean()
# Skip terms where the coeff is 0
if term.coeff == 0:
return self
terms = [self_term * term for self_term in self.terms]
result = MultivariatePolynomial(terms)
return result
def divmod(self, poly):
assert isinstance(poly, MultivariatePolynomial)
# https://www.win.tue.nl/~aeb/2WF02/groebner.pdf
def _find(self, other):
for term in self.terms:
if term.matches(other):
return term
return None
def evaluate(self, variable_map):
p = MultivariatePolynomial()
for term in self.terms:
assert isinstance(term, MultiplyExpression)
p += term.evaluate(variable_map)
return p
def _assert_unique_terms(self):
for i, term1 in enumerate(self.terms):
for q, term2 in enumerate(self.terms):
if i == q:
continue
assert not term1.matches(term2)
def filter(self, variables):
p = MultivariatePolynomial()
for term in self.terms:
assert isinstance(term, MultiplyExpression)
skip = False
for variable in variables:
symbol = variable.name
if symbol in term.symbols:
skip = True
if not skip:
p += term
return p
def __str__(self):
if not self.terms:
return "0"
repr = ""
first = True
for term in self.terms:
if first:
first = False
else:
repr += " + "
repr += str(term)
return repr
if __name__ == "__main__":
from finite_fields import finitefield
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
fp = finitefield.IntegersModP(p)
x = Variable("X")
y = Variable("Y")
z = Variable("Z")
print(y**2 + y**2)
p = x**3 * y**2 * x**2 * fp(5) * fp(2) + x**3 * y + z + fp(6)
q = x**3 * y * fp(3) + y
print(p)
print(q)
print(p + q)
print(p * q)
print(-q)
print(p - q)

5
scripts/halo/pasta.py Normal file
View File

@@ -0,0 +1,5 @@
from finite_fields import finitefield
q = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
fq = finitefield.IntegersModP(q)

View File

@@ -0,0 +1,187 @@
#import numpy as np
from groth_poly_commit import Scalar, poly_commit, create_proof, verify_proof
K = Scalar
#R.<x> = LaurentPolynomialRing(K)
R.<x> = PolynomialRing(K)
var_one = K(1)
var_x = K(4)
var_y = K(6)
var_s = K(1)
var_xy = var_x * var_y
var_x_y = var_x + var_y
var_1_neg_s = var_one - var_s
var_sxy = var_s * var_xy
var_1_neg_s_x_y = var_1_neg_s * var_x_y
#var_s_neg_1 = -var_1_neg_s
var_zero = K(0)
public_value = var_s * (var_x * var_y) + (1 - var_s) * (var_x + var_y)
# x * y = xy
a1 = var_x
b1 = var_y
c1 = var_xy
Ql1 = 0
Qr1 = 0
Qm1 = 1
Qo1 = -1
Qc1 = 0
assert Ql1 * a1 + Qr1 * b1 + Qm1 * a1 * b1 + Qo1 * c1 + Qc1 == 0
# x + y = (x + y)
a2 = var_x
b2 = var_y
c2 = var_x_y
Ql2 = 1
Qr2 = 1
Qm2 = 0
Qo2 = -1
Qc2 = 0
assert Ql2 * a2 + Qr2 * b2 + Qm2 * a2 * b2 + Qo2 * c2 + Qc2 == 0
# 1 - s = (1 - s)
a3 = var_one
b3 = var_s
c3 = var_1_neg_s
Ql3 = 1
Qr3 = -1
Qm3 = 0
Qo3 = -1
Qc3 = 0
assert Ql3 * a3 + Qr3 * b3 + Qm3 * a3 * b3 + Qo3 * c3 + Qc3 == 0
# s * (xy) = sxy
a4 = var_s
b4 = var_xy
c4 = var_sxy
Ql4 = 0
Qr4 = 0
Qm4 = 1
Qo4 = -1
Qc4 = 0
assert Ql4 * a4 + Qr4 * b4 + Qm4 * a4 * b4 + Qo4 * c4 + Qc4 == 0
# (1 - s) * (x + y) = [(1 - s)(x + y)]
a5 = var_1_neg_s
b5 = var_x_y
c5 = var_1_neg_s_x_y
Ql5 = 0
Qr5 = 0
Qm5 = 1
Qo5 = -1
Qc5 = 0
assert Ql5 * a5 + Qr5 * b5 + Qm5 * a5 * b5 + Qo5 * c5 + Qc5 == 0
# (sxy) + [(1 - s)(x + y)] = public_value
a6 = var_sxy
b6 = var_1_neg_s_x_y
# Unused
c6 = var_zero
Ql6 = 1
Qr6 = 1
Qm6 = 0
Qo6 = 0
Qc6 = -public_value
assert Ql6 * a6 + Qr6 * b6 + Qm6 * a6 * b6 + Qo6 * c6 + Qc6 == 0
# one == 1
a7 = var_one
# Unused
b7 = var_zero
# Unused
c7 = var_zero
Ql7 = 1
Qr7 = 0
Qm7 = 0
Qo7 = 0
Qc7 = -1
assert Ql7 * a7 + Qr7 * b7 + Qm7 * a7 * b7 + Qo7 * c7 + Qc7 == 0
a = [a1, a2, a3, a4, a5, a6, a7]
b = [b1, b2, b3, b4, b5, b6, b7]
c = [c1, c2, c3, c4, c5, c6, c7]
Ql = [Ql1, Ql2, Ql3, Ql4, Ql5, Ql6]
Qr = [Qr1, Qr2, Qr3, Qr4, Qr5, Qr6]
Qm = [Qm1, Qm2, Qm3, Qm4, Qm5, Qm6]
Qo = [Qo1, Qo2, Qo3, Qo4, Qo5, Qo6]
Qc = [Qc1, Qc2, Qc3, Qc4, Qc5, Qc6]
# 0 1 2 3 4 5 6
# a: x, x, 1, s, 1 - s, sxy, 1
#
# 7 8 9 10 11 12 13
# b: y, y, s, xy, x + y, (1 - s)(x + y), -
#
# 14 15 16 17 18 19 20
# c: xy, x + y, 1 - s, sxy, (1 - s)(x + y), -, -
permuted_indices = [
1, 0, 6, 9, 16, 17, 2,
8, 7, 3, 14, 15, 18, 13,
10, 11, 4, 5, 12, 19, 20
]
eval_domain = range(0, len(permuted_indices))
witness = a + b + c
for i, val in enumerate(a + b + c):
assert val == witness[permuted_indices[i]]
#def lagrange(domain, codomain):
# S.<x> = PolynomialRing(K)
# p = S.lagrange_polynomial(zip(eval_domain, permuted_indices))
# # Convert to a Laurent polynomial
# return R(p)
# This is what the prover passes to the verifier
witness_y = R.lagrange_polynomial(enumerate(witness))
assert witness_y(12) == witness[12]
witness_x_a = R.lagrange_polynomial(
zip(eval_domain[0:7], eval_domain[0:7]))
witness_x_b = R.lagrange_polynomial(
zip(eval_domain[7:14], eval_domain[7:14]))
witness_x_c = R.lagrange_polynomial(
zip(eval_domain[14:], eval_domain[14:]))
assert witness_x_a(2) == eval_domain[2]
assert witness_x_b(8) == eval_domain[8]
assert witness_x_c(16) == eval_domain[16]
witness_x_a_prime = R.lagrange_polynomial(
zip(eval_domain[0:7], permuted_indices[0:7]))
witness_x_b_prime = R.lagrange_polynomial(
zip(eval_domain[7:14], permuted_indices[7:14]))
witness_x_c_prime = R.lagrange_polynomial(
zip(eval_domain[14:], permuted_indices[14:]))
assert witness_x_a_prime(2) == permuted_indices[2]
assert witness_x_b_prime(8) == permuted_indices[8]
assert witness_x_c_prime(16) == permuted_indices[16]
v1 = K(2)
v2 = K(3)
px = 1
for i in range(0, len(a)):
px *= v1 + witness_x_a(i) + v2 * witness_y(i)
for i in range(len(a), 2 * len(a)):
px *= v1 + witness_x_b(i) + v2 * witness_y(i)
for i in range(2 * len(a), 3 * len(a)):
px *= v1 + witness_x_c(i) + v2 * witness_y(i)
px_prime = 1
for i in range(0, len(a)):
px_prime *= v1 + witness_x_a_prime(i) + v2 * witness_y(i)
for i in range(len(a), 2 * len(a)):
px_prime *= v1 + witness_x_b_prime(i) + v2 * witness_y(i)
for i in range(2 * len(a), 3 * len(a)):
px_prime *= v1 + witness_x_c_prime(i) + v2 * witness_y(i)
assert px == px_prime

View File

@@ -0,0 +1,360 @@
F101 = Integers(101)
E = EllipticCurve(F101, [0, 3])
R.<X> = PolynomialRing(F101)
# Extension field of points of 101, and solutions of x^2 + 2
K.<X> = GF(101**2, modulus=X^2 + 2)
# y^2 = x^3 + 3 in this curve defined over the extension field.
# Needed for pairing.
E2 = EllipticCurve(K, [0, 3])
# Generator point because 1^3 + 3 = 4 which is sqrt of 2
G = E([1, 2])
G2 = E2([36, 31*X])
assert G.order() == 17
F17 = Integers(17)
assert F17.square_roots_of_one() == (1, 16)
# 16 == -1
# so now we have the 4th roots of 1
w = vector(F17, [1, 4, -1, -4])
# omega = 4, omega^0 = 1, omega^1 = 4, omega^3 = -1 = 13, omega^3 = 13
# so now we are still defining our reference string
# we have 4 x labels for our permutation vector but we need 12
# so we generate 2 cosets using quadratic non-residues in F
# all 3 cosets should not share any value in common
k1 = 2
k2 = 3
assert w == vector(F17, [1, 4, 16, 13])
assert k1 * w == vector(F17, [2, 8, 15, 9])
assert k2 * w == vector(F17, [3, 12, 14, 5])
A = matrix(F17, [
[1, 1, 1, 1],
[4^0, 4^1, 4^2, 4^3],
[16^0, 16^1, 16^2, 16^3],
[13^0, 13^1, 13^2, 13^3]
])
Ai = A.inverse()
P.<x> = F17[]
x = P.0
# We only have 3 gates in this example for x^3 + x = 30
# x x^2 x^3
# x x x
# x^2 x^3 0
# public input: 30
# we have 4 w values so the last column is empty (all set to 0 in this case)
fa = P.lagrange_polynomial(zip([1, 4, 16, 13], [3, 9, 27, 0]))
#fa = P(list(Ai * vector(F17, [3, 9, 27, 0])))
assert fa(1) == 3
assert fa(4) == 9
assert fa(16) == 27
assert fa(13) == 0
fb = P(list(Ai * vector(F17, [3, 3, 3, 0])))
assert fb(1) == 3
assert fb(4) == 3
assert fb(16) == 3
assert fb(13) == 0
fc = P(list(Ai * vector(F17, [9, 27, 0, 0])))
assert fc(1) == 9
assert fc(4) == 27
assert fc(16) == 0
assert fc(13) == 0
# List of operations
#
# mul, mul, add/cons, null
ql = P(list(Ai * vector(F17, [0, 0, 1, 0])))
qr = P(list(Ai * vector(F17, [0, 0, 1, 0])))
qm = P(list(Ai * vector(F17, [1, 1, 0, 0])))
qo = P(list(Ai * vector(F17, [-1, -1, 0, 0])))
qc = P(list(Ai * vector(F17, [0, 0, -30, 0])))
# permutation/copy constraints
# We are using the coset values here for a, b, c
# 1 4 16 13
# 2 8 15 9
# 3 12 14 5
# Applying the permutation for:
# x x^2 x^3
# x x x
# x^2 x^3 0
# then we get:
# 2 3 12 13
# 1 15 8 9
# 4 16 14 5
# We swap indices whenever there is an equality between wires:
# a1 = b1
# a2 = c1
# ...
sa = P(list(Ai * vector(F17, [2, 3, 12, 13])))
sb = P(list(Ai * vector(F17, [1, 15, 8, 9])))
sc = P(list(Ai * vector(F17, [4, 16, 14, 5])))
# Setup phase complete
# Prove phase
# Round 1
# Create vanishing polynomial which is zero for every root of unity.
# That is Z(w_1) = Z(w_2) = ... = 0
Z = x^4 - 1
assert Z(1) == 0
assert Z(4) == 0
assert Z(16) == 0
assert Z(13) == 0
# 9 random blinding values. We will use:
# 7, 4, 11, 12, 16, 2
# 14, 11, 7 (used in round 2)
# Blind our witness polynomials
# The blinding factors will disappear at the evaluation points.
a = (7*x + 4) * Z + fa
b = (11*x + 12) * Z + fb
c = (16*x + 2) * Z + fc
# During the SRS phase we created a random s point and its powers
s = 2
# So now we evaluate a, b, c with these powers of G
a_s = ZZ(a(s)) * G
b_s = ZZ(b(s)) * G
c_s = ZZ(c(s)) * G
# Round 2
# Random transcript challenges
beta = 12
gamma = 13
# Build accumulation
acc = 1
accs = []
for i in range(4):
# w_{n + j} corresponds to b(w[i])
# and w_{2n + j} is c(w[i])
accs.append(acc)
acc = acc * (
(a(w[i]) + beta * w[i] + gamma)
* (b(w[i]) + beta * k1 * w[i] + gamma)
* (c(w[i]) + beta * k2 * w[i] + gamma) /
(
(a(w[i]) + beta * sa(w[i]) + gamma)
* (b(w[i]) + beta * sb(w[i]) + gamma)
* (c(w[i]) + beta * sc(w[i]) + gamma)
))
assert accs == [1, 12, 10, 1]
del accs
acc = P(list(Ai * vector(F17, [1, 12, 10, 1])))
Zx = (14*x^2 + 11*x + 7) * Z + acc
# Evaluate z(x) at our secret point
Z_s = ZZ(Zx(s)) * G
# Round 3
alpha = 15
t1Z = a * b * qm + a * ql + b * qr + c * qo + qc
t2Z = ((a + beta * x + gamma)
* (b + beta * k1 * x + gamma)
* (c + beta * k2 * x + gamma)) * Zx * alpha
# w[1] is our first root of unity
Zw = Zx(w[1] * x)
t3Z = -((a + beta * sa + gamma)
* (b + beta * sb + gamma)
* (c + beta * sc + gamma)) * Zw * alpha
# Lagrangian polynomial which evaluates to 1 at 1
# L_1(w_1) = 1 and 0 on the other evaluation points
L = P(list(Ai * vector(F17, [1, 0, 0, 0])))
assert L(1) == 1
# w_2 = 4
assert L(4) == 0
t4Z = (Zx - 1) * L * alpha^2
tZ = t1Z + t2Z + t3Z + t4Z
# and cancel out the factor Z now
t = P(tZ / Z)
# Split t into 3 parts
# t(X) = t_lo(X) + X^n t_mid(X) + X^{2n} t_hi(X)
t_list = t.list()
t_lo = t_list[0:6]
t_mid = t_list[6:12]
t_hi = t_list[12:18]
# and create the evaluations
t_lo_s = ZZ(P(t_lo)(s)) * G
t_mid_s = ZZ(P(t_mid)(s)) * G
t_hi_s = ZZ(P(t_hi)(s)) * G
# Round 4
zeta = 5
a_ = a(zeta)
b_ = b(zeta)
c_ = c(zeta)
sa_ = sa(zeta)
sb_ = sb(zeta)
t_ = t(zeta)
zw_ = Zx(zeta * w[1])
l_ = L(zeta)
assert a_ == 8
assert b_ == 12
assert c_ == 10
assert sa_ == 0
assert sb_ == 16
assert t_ == 3
assert zw_ == 14
r1 = a_ * b_ * qm + a_ * ql + b_ * qr + c_ * qo + qc
r2 = ((a_ + beta * zeta + gamma)
* (b_ + beta * k1 * zeta + gamma)
* (c_ + beta * k2 * zeta + gamma)) * Zx * alpha
r3 = -((a_ + beta * sa_ + gamma)
* (b_ + beta * sb_ + gamma)
* beta * zw_ * sc * alpha)
r4 = Zx * l_ * alpha^2
r = r1 + r2 + r3 + r4
r_ = r(zeta)
assert r_ == 7
# Round 5
vega = 12
v1 = P(t_lo)
# Polynomial was in parts consisting of 6 powers
v2 = zeta^6 * P(t_mid)
v3 = zeta^12 * P(t_hi)
v4 = -t_
assert v4 == 14
v5 = (
vega * (r - r_)
+ vega^2 * (a - a_) + vega^3 * (b - b_) + vega^4 * (c - c_)
+ vega^5 * (sa - sa_) + vega^6 * (sb - sb_)
)
W = v1 + v2 + v3 + v4 + v5
Wz = W / (x - zeta)
# Calculate the opening proof
Wzw = (Zx - zw_) / (x - zeta * w[1])
# Compute evaluations of Wz and Wzw
Wz_s = ZZ(Wz(s)) * G
Wzw_s = ZZ(Wzw(s)) * G
# Finished the proving algo
proof = (a_s, b_s, c_s, Z_s, t_lo_s, t_mid_s, t_hi_s, Wz_s, Wzw_s,
a_, b_, c_, sa_, sb_, r_, zw_)
# Verification
qm_s = ZZ(qm(s)) * G
ql_s = ZZ(ql(s)) * G
qr_s = ZZ(qr(s)) * G
qo_s = ZZ(qo(s)) * G
qc_s = ZZ(qc(s)) * G
sa_s = ZZ(sa(s)) * G
sb_s = ZZ(sb(s)) * G
sc_s = ZZ(sc(s)) * G
# Check all the points are on the curve.
# y^2 = x^3 + 3
# ...
# Also check the scalar values are in the group for F17
# ...
# step 4: random upsilon
upsilon = 4
# step 5
Z_z = F17(zeta^4 - 1)
assert Z_z == 12
# step 6
# Calculate evaluation of L1 at zeta
L1_z = F17((zeta^4 - 1) / (4 * (zeta - 1)))
assert L1_z == 5
# step 7
# no public inputs in this example
# step 8
t_ = (r_ - (a_ + beta * sa_ + gamma)
* (b_ + beta * sb_ + gamma)
* (c_ + gamma) * zw_ * alpha
- L1_z * alpha^2) / Z_z
assert t_ == 3
# step 9
# qx_s are points, and we are multiplying them by scalars
# so convert the values to integers first
d1 = (ZZ(a_ * b_ * vega) * qm_s
+ ZZ(a_ * vega) * ql_s
+ ZZ(b_ * vega) * qr_s
+ ZZ(c_ * vega) * qo_s
+ vega * qc_s)
d2 = ZZ((a_ + beta * zeta + gamma)
* (b_ + beta * k1 * zeta + gamma)
* (c_ + beta * k2 * zeta + gamma)
* alpha * vega
+ L1_z * alpha^2 * vega
+ F17(upsilon)) * Z_s
d3 = -ZZ((a_ + beta * sa_ + gamma)
* (b_ + beta * sb_ + gamma)
* alpha * vega * beta * zw_) * sc_s
d = d1 + d2 + d3
# step 10
f = (t_lo_s + zeta^6 * t_mid_s + zeta^12 * t_hi_s
+ d
+ vega^2 * a_s + vega^3 * b_s + vega^4 * c_s
+ vega^5 * sa_s + vega^6 * sb_s)
# step 11
e = ZZ(t_ + vega * r_
+ vega^2 * a_ + vega^3 * b_ + vega^4 * c_
+ vega^5 * sa_ + vega^6 * sb_
+ upsilon * zw_) * G
# step 12
# construct points for the pairing check
x1 = Wz_s + upsilon * Wzw_s
x2 = s * G2
y1 = zeta * Wz_s + ZZ(upsilon * zeta * w[1]) * Wzw_s + f - e
y2 = G2
# do the pairing check
x1_ = E2(x1)
x2_ = E2(x2)
y1_ = E2(y1)
y2_ = E2(y2)
assert x1_.weil_pairing(x2_, 17) == y1_.weil_pairing(y2_, 17)

325
scripts/halo/plonk.sage Normal file
View File

@@ -0,0 +1,325 @@
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
K = GF(q)
P.<X> = K[]
# The pallas and vesta curves are 2-adic. This means there is a large
# power of 2 subgroup within both of their fields.
# This function finds a generator for this subgroup within the field.
def get_omega():
# Slower alternative:
# generator = K.multiplicative_generator()
# Just hardcode the value here instead
generator = K(5)
assert (q - 1) % 2^32 == 0
# Root of unity
t = (q - 1) / 2^32
omega = generator**t
assert omega != 1
assert omega^(2^16) != 1
assert omega^(2^31) != 1
assert omega^(2^32) == 1
return omega
# Order of this element is 2^32
omega = get_omega()
# f(s, x, y) = sxy + (1 - s)(x + y)
var_one = K(1)
var_x = K(4)
var_y = K(6)
var_s = K(1)
var_xy = var_x * var_y
var_x_y = var_x + var_y
var_1_neg_s = var_one - var_s
var_sxy = var_s * var_xy
var_1_neg_s_x_y = var_1_neg_s * var_x_y
#var_s_neg_1 = -var_1_neg_s
var_zero = K(0)
public_value = -(var_s * (var_x * var_y) + (1 - var_s) * (var_x + var_y))
# Ql a + Qr b + Qm a b + Qo c + Qc + P == 0
# See also the file plonk-naive.sage
# x * y = xy
a1, b1, c1 = var_x, var_y, var_xy
Ql1, Qr1, Qm1, Qo1, Qc1 = 0, 0, 1, -1, 0
assert Ql1 * a1 + Qr1 * b1 + Qm1 * a1 * b1 + Qo1 * c1 + Qc1 == 0
# x + y = (x + y)
a2, b2, c2 = var_x, var_y, var_x_y
Ql2, Qr2, Qm2, Qo2, Qc2 = 1, 1, 0, -1, 0
assert Ql2 * a2 + Qr2 * b2 + Qm2 * a2 * b2 + Qo2 * c2 + Qc2 == 0
# 1 - s = (1 - s)
a3, b3, c3 = var_one, var_s, var_1_neg_s
Ql3, Qr3, Qm3, Qo3, Qc3 = 1, -1, 0, -1, 0
assert Ql3 * a3 + Qr3 * b3 + Qm3 * a3 * b3 + Qo3 * c3 + Qc3 == 0
# s * (xy) = sxy
a4, b4, c4 = var_s, var_xy, var_sxy
Ql4, Qr4, Qm4, Qo4, Qc4 = 0, 0, 1, -1, 0
assert Ql4 * a4 + Qr4 * b4 + Qm4 * a4 * b4 + Qo4 * c4 + Qc4 == 0
# (1 - s) * (x + y) = [(1 - s)(x + y)]
a5, b5, c5 = var_1_neg_s, var_x_y, var_1_neg_s_x_y
Ql5, Qr5, Qm5, Qo5, Qc5 = 0, 0, 1, -1, 0
assert Ql5 * a5 + Qr5 * b5 + Qm5 * a5 * b5 + Qo5 * c5 + Qc5 == 0
# (sxy) + [(1 - s)(x + y)] = public_value
# c6 is unused
a6, b6, c6 = var_sxy, var_1_neg_s_x_y, var_zero
Ql6, Qr6, Qm6, Qo6, Qc6 = 1, 1, 0, 0, 0
assert Ql6 * a6 + Qr6 * b6 + Qm6 * a6 * b6 + Qo6 * c6 + Qc6 + public_value == 0
# one == 1, b7 and c7 unused
a7, b7, c7 = var_one, var_zero, var_zero
Ql7, Qr7, Qm7, Qo7, Qc7 = 1, 0, 0, 0, -1
assert Ql7 * a7 + Qr7 * b7 + Qm7 * a7 * b7 + Qo7 * c7 + Qc7 == 0
# Add a last fake constraint so n is a power of 2
# This is needed since we are working with omega whose size is 2^32
# and we will create a generator from it whose order is 2^3
a8, b8, c8 = var_zero, var_zero, var_zero
Ql8, Qr8, Qm8, Qo8, Qc8 = 0, 0, 0, 0, 0
assert Ql8 * a8 + Qr8 * b8 + Qm8 * a8 * b8 + Qo8 * c8 + Qc8 == 0
a = [a1, a2, a3, a4, a5, a6, a7, a8]
b = [b1, b2, b3, b4, b5, b6, b7, b8]
c = [c1, c2, c3, c4, c5, c6, c7, c8]
Ql = [Ql1, Ql2, Ql3, Ql4, Ql5, Ql6, Ql7, Ql8]
Qr = [Qr1, Qr2, Qr3, Qr4, Qr5, Qr6, Qr7, Qr8]
Qm = [Qm1, Qm2, Qm3, Qm4, Qm5, Qm6, Qm7, Qm8]
Qo = [Qo1, Qo2, Qo3, Qo4, Qo5, Qo6, Qo7, Qo8]
Qc = [Qc1, Qc2, Qc3, Qc4, Qc5, Qc6, Qc7, Qc8]
public_values = [0, 0, 0, 0, 0, public_value, 0, 0]
n = 8
for a_i, b_i, c_i, Ql_i, Qr_i, Qm_i, Qo_i, Qc_i, public_i in \
zip(a, b, c, Ql, Qr, Qm, Qo, Qc, public_values):
assert (Ql_i * a_i + Qr_i * b_i + Qm_i * a_i * b_i + Qo_i * c_i
+ Qc_i + public_i) == 0
# 0 1 2 3 4 5 6 7
# a: x, x, 1, s, 1 - s, sxy, 1 -
#
# 8 9 10 11 12 13 14 15
# b: y, y, s, xy, x + y, (1 - s)(x + y), - -
#
# 16 17 18 19 20 21 22 23
# c: xy, x + y, 1 - s, sxy, (1 - s)(x + y), -, - -
permuted_indices_a = [1, 0, 6, 10, 18, 19, 2, 7]
permuted_indices_b = [8, 9, 3, 16, 17, 20, 14, 15]
permuted_indices_c = [11, 12, 4, 5, 13, 21, 22, 23]
eval_domain = range(0, n * 3)
witness = a + b + c
permuted_indices = permuted_indices_a + permuted_indices_b + permuted_indices_c
for i, val in enumerate(a + b + c):
assert val == witness[permuted_indices[i]]
omega = omega^(2^32 / n)
assert omega^n == 1
# Calculate the vanishing polynomial
# This is the same as (X - omega^0)(X - omega^1)...(X - omega^{n - 1})
Z_H = X^n - 1
assert Z_H(1) == 0
assert Z_H(omega^4) == 0
qL_X = P.lagrange_polynomial((omega^i, Ql_i) for i, Ql_i in enumerate(Ql))
qR_X = P.lagrange_polynomial((omega^i, Qr_i) for i, Qr_i in enumerate(Qr))
qM_X = P.lagrange_polynomial((omega^i, Qm_i) for i, Qm_i in enumerate(Qm))
qO_X = P.lagrange_polynomial((omega^i, Qo_i) for i, Qo_i in enumerate(Qo))
qC_X = P.lagrange_polynomial((omega^i, Qc_i) for i, Qc_i in enumerate(Qc))
PI_X = P.lagrange_polynomial((omega^i, public_i) for i, public_i
in enumerate(public_values))
b_1 = K.random_element()
b_2 = K.random_element()
b_3 = K.random_element()
b_4 = K.random_element()
b_5 = K.random_element()
b_6 = K.random_element()
b_7 = K.random_element()
b_8 = K.random_element()
b_9 = K.random_element()
# Round 1
# Calculate wire witness polynomials
a_X = (b_1 * X + b_2) * Z_H + \
P.lagrange_polynomial((omega^i, a_i) for i, a_i in enumerate(a))
assert a_X(omega^2) == a[2]
b_X = (b_3 * X + b_4) * Z_H + \
P.lagrange_polynomial((omega^i, b_i) for i, b_i in enumerate(b))
assert b_X(omega^5) == b[5]
c_X = (b_5 * X + b_6) * Z_H + \
P.lagrange_polynomial((omega^i, c_i) for i, c_i in enumerate(c))
assert c_X(omega^0) == c[0]
# Commit to a(X), b(X), c(X)
# ...
# Round 2
beta = K.random_element()
gamma = K.random_element()
def find_quadratic_non_residue():
k = K.random_element()
while kronecker(k, q) != -1:
k = K.random_element()
return k
# These values do not have a square root
k1 = find_quadratic_non_residue()
k2 = find_quadratic_non_residue()
assert k1 != k2
indices = ([omega^i for i in range(n)]
+ [k1 * omega^i for i in range(n)]
+ [k2 * omega^i for i in range(n)])
# Permuted indices
sigma_star = [indices[i] for i in permuted_indices]
permutation_points = [(1, 1)]
for i in range(n - 1):
x = omega^(i + 1)
y = 1
for j in range(i + 1):
y *= witness[j] + beta * omega^j + gamma
y *= witness[n + j] + beta * k1 * omega^j + gamma
y *= witness[2 * n + j] + beta * k2 * omega^j + gamma
y /= witness[j] + sigma_star[j] * beta + gamma
y /= witness[n + j] + sigma_star[n + j] * beta + gamma
y /= witness[2 * n + j] + sigma_star[2 * n + j] * beta + gamma
permutation_points.append((x, y))
z_X = (b_7 * X^2 + b_8 * X + b_9) * Z_H + \
P.lagrange_polynomial(permutation_points)
assert witness[0] == 4
assert witness[n] == 6
assert witness[2 * n] == var_xy == 24
assert sigma_star[0] == omega
assert sigma_star[n] == k1 * omega^8
assert sigma_star[2 * n] == k1 * omega^11
assert z_X(omega^0) == 1
assert ((4 + beta + gamma) * (6 + beta * k1 + gamma) * (24 + beta * k2 + gamma)
) == (z_X(omega)
* (4 + omega * beta + gamma)
* (6 + k1 * omega^8 * beta + gamma)
* (24 + k1 * omega^11 * beta + gamma))
assert witness[2] == var_one == 1
assert witness[n + 2] == var_s == 1
assert witness[2 * n + 2] == var_1_neg_s == 0
assert sigma_star[2] == omega^6
assert sigma_star[n + 2] == omega^3
assert sigma_star[2 * n + 2] == omega^4
assert (z_X(omega^2) * (1 + beta * omega^2 + gamma)
* (1 + beta * k1 * omega^2 + gamma)
* (0 + beta * k2 * omega^2 + gamma)
) == (z_X(omega^3) * (1 + omega^6 * beta + gamma)
* (1 + omega^3 * beta + gamma)
* (0 + omega^4 * beta + gamma))
# Round 3
alpha = K.random_element()
Ssigma_1 = P.lagrange_polynomial((omega^i, sigma_star[i]) for i in range(8))
Ssigma_2 = P.lagrange_polynomial((omega^i, sigma_star[n + i]) for i in range(8))
Ssigma_3 = P.lagrange_polynomial((omega^i, sigma_star[2 * n + i])
for i in range(8))
assert Ssigma_1(omega^0) == omega^1
assert Ssigma_1(omega^3) == k1 * omega^10
assert Ssigma_2(omega^2) == omega^3
assert Ssigma_3(omega^7) == k2 * omega^7 == k2 * omega^23
t_X_constraints = ((a_X * b_X * qM_X) + (a_X * qL_X) + (b_X * qR_X)
+ (c_X * qO_X) + qC_X + PI_X)
for i in range(8):
assert t_X_constraints(omega^i) == 0
t_X_permutations = ((a_X + beta * X + gamma)
* (b_X + beta * k1 * X + gamma)
* (c_X + beta * k2 * X + gamma) * z_X
# Permutated accumulator
- (a_X + beta * Ssigma_1 + gamma)
* (b_X + beta * Ssigma_2 + gamma)
* (c_X + beta * Ssigma_3 + gamma) * z_X(X * omega))
for i in range(8):
assert t_X_permutations(omega^i) == 0
L1_X = P.lagrange_polynomial([(1, 1)] + [(omega^i, 0) for i in range(1, n)])
assert L1_X(omega^0) == 1
assert L1_X(omega^2) == 0
t_X_zloops = (z_X - 1) * L1_X
assert t_X_zloops(omega^0) == 0
assert t_X_zloops(omega^2) == 0
assert t_X_zloops(omega^8) == 0
t = (t_X_constraints + t_X_permutations * alpha + t_X_zloops * alpha^2) / Z_H
# Commit to t
# ...
# Round 4
zeta = K.random_element()
a_bar = a_X(zeta)
b_bar = b_X(zeta)
c_bar = c_X(zeta)
s_bar_1 = Ssigma_1(zeta)
s_bar_2 = Ssigma_2(zeta)
z_bar_omega = z_X(zeta * omega)
# Now we provide proofs that all the above values are correct openings
# of the committed polynomials.
# And we prove that a reconstructed version of t(X) from the polynomial
# commitments of the witness and permutation polynomials equals the
# t(X) commitment.
# t(X) - r(X) = 0 where r(X) is the reconstructed polynomial.
# In order to avoid sending Ssigma_1(zeta) and z(zeta), plonk does an
# optimization using the Maller trick documented in section 4 under
# the title "Reducing the number of field elements"
# Round 5
# To reduce the proof by two elements, we construct a linearization polynomial
# which only contains 1 interminate per multiplication expression which is
# enough to prove the polynomial correctly evaluates.
r = (
# This is proving the constraint polynomial has roots at H
(a_bar * b_bar * qM_X) + (a_bar * qL_X) + (b_bar * qR_X)
+ (c_bar * qO_X) + PI_X + qC_X
+ alpha * ((a_bar + beta * zeta + gamma)
* (b_bar + beta * k1 * zeta + gamma)
* (c_bar + beta * k2 * zeta + gamma) * z_X
-
(a_bar + beta * s_bar_1 + gamma)
* (b_bar + beta * s_bar_2 + gamma)
* (c_bar + beta * Ssigma_3 + gamma) * z_bar_omega)
+ alpha^2 * (z_X - 1) * L1_X(zeta)
# t = (t_X_constraints + t_X_permutations * alpha + t_X_zloops * alpha^2)
# -------------------------------------------------------------------
# Z_H
- Z_H(zeta) * t
)
assert r(zeta) == 0
# That is basically the plonk prover. The remaining stuff are details such as
# which polynomial commitment scheme you use (kate, bulletproofs, ...)

View File

@@ -0,0 +1,301 @@
#| # Evaluation Representation of Polynomials and FFT optimizations
#| In addition to the coefficient-based representation of polynomials used
#| in babysnark.py, for performance we will also use an alternative
#| representation where the polynomial is evaluated at a fixed set of points.
#| Some operations, like multiplication and division, are significantly more
#| efficient in this form.
#| We can use FFT-based tools for efficiently converting
#| between coefficient and evaluation representation.
#|
#| This library provides:
#| - Fast fourier transform for finite fields
#| - Interpolation and evaluation using FFT
from finite_fields.finitefield import FiniteField
from finite_fields.polynomial import polynomialsOver
from finite_fields.euclidean import extendedEuclideanAlgorithm
import random
from finite_fields.numbertype import typecheck, memoize, DomainElement
from functools import reduce
import numpy as np
#| ## Fast Fourier Transform on Finite Fields
def fft_helper(a, omega, field):
"""
Given coefficients A of polynomial this method does FFT and returns
the evaluation of the polynomial at [omega^0, omega^(n-1)]
If the polynomial is a0*x^0 + a1*x^1 + ... + an*x^n then the coefficients
list is of the form [a0, a1, ... , an].
"""
n = len(a)
assert not (n & (n - 1)), "n must be a power of 2"
if n == 1:
return a
b, c = a[0::2], a[1::2]
b_bar = fft_helper(b, pow(omega, 2), field)
c_bar = fft_helper(c, pow(omega, 2), field)
a_bar = [field(1)] * (n)
for j in range(n):
k = j % (n // 2)
a_bar[j] = b_bar[k] + pow(omega, j) * c_bar[k]
return a_bar
#| ## Representing a polynomial by evaluation at fixed points
@memoize
def make_polynomial_evalrep(field, omega, n):
assert n & n - 1 == 0, "n must be a power of 2"
# Check that omega is an n'th primitive root of unity
assert type(omega) is field
omega = field(omega)
assert omega**(n) == 1
_powers = [omega**i for i in range(n)]
assert len(set(_powers)) == n
_poly_coeff = polynomialsOver(field)
class PolynomialEvalRep(object):
def __init__(self, xs, ys):
# Each element of xs must be a power of omega.
# There must be a corresponding y for every x.
if type(xs) is not tuple:
xs = tuple(xs)
if type(ys) is not tuple:
ys = tuple(ys)
assert len(xs) <= n+1
assert len(xs) == len(ys)
for x in xs:
assert x in _powers
for y in ys:
assert type(y) is field
self.evalmap = dict(zip(xs, ys))
@classmethod
def from_coeffs(cls, poly):
assert type(poly) is _poly_coeff
assert poly.degree() <= n
padded_coeffs = poly.coefficients + [field(0)] * (n - len(poly.coefficients))
ys = fft_helper(padded_coeffs, omega, field)
xs = [omega**i for i in range(n) if ys[i] != 0]
ys = [y for y in ys if y != 0]
return cls(xs, ys)
def to_coeffs(self):
# To convert back to the coefficient form, we use polynomial interpolation.
# The non-zero elements stored in self.evalmap, so we fill in the zero values
# here.
ys = [self.evalmap[x] if x in self.evalmap else field(0) for x in _powers]
coeffs = [b / field(n) for b in fft_helper(ys, 1 / omega, field)]
return _poly_coeff(coeffs)
_lagrange_cache = {}
def __call__(self, x):
if type(x) is int:
x = field(x)
assert type(x) is field
xs = _powers
def lagrange(x, xi):
# Let's cache lagrange values
if (x,xi) in PolynomialEvalRep._lagrange_cache:
return PolynomialEvalRep._lagrange_cache[(x,xi)]
mul = lambda a,b: a*b
num = reduce(mul, [x - xj for xj in xs if xj != xi], field(1))
den = reduce(mul, [xi - xj for xj in xs if xj != xi], field(1))
PolynomialEvalRep._lagrange_cache[(x,xi)] = num / den
return PolynomialEvalRep._lagrange_cache[(x,xi)]
y = field(0)
for xi, yi in self.evalmap.items():
y += yi * lagrange(x, xi)
return y
def __mul__(self, other):
# Scale by integer
if type(other) is int:
other = field(other)
if type(other) is field:
return PolynomialEvalRep(self.evalmap.keys(),
[other * y for y in self.evalmap.values()])
# Multiply another polynomial in the same representation
if type(other) is type(self):
xs = []
ys = []
for x, y in self.evalmap.items():
if x in other.evalmap:
xs.append(x)
ys.append(y * other.evalmap[x])
return PolynomialEvalRep(xs, ys)
@typecheck
def __iadd__(self, other):
# Add another polynomial to this one in place.
# This is especially efficient when the other polynomial is sparse,
# since we only need to add the non-zero elements.
for x, y in other.evalmap.items():
if x not in self.evalmap:
self.evalmap[x] = y
else:
self.evalmap[x] += y
return self
@typecheck
def __add__(self, other):
res = PolynomialEvalRep(self.evalmap.keys(), self.evalmap.values())
res += other
return res
def __sub__(self, other): return self + (-other)
def __neg__(self): return PolynomialEvalRep(self.evalmap.keys(),
[-y for y in self.evalmap.values()])
def __truediv__(self, divisor):
# Scale by integer
if type(divisor) is int:
other = field(divisor)
if type(divisor) is field:
return self * (1/divisor)
if type(divisor) is type(self):
res = PolynomialEvalRep((),())
for x, y in self.evalmap.items():
assert x in divisor.evalmap
res.evalmap[x] = y / divisor.evalmap[x]
return res
return NotImplemented
def __copy__(self):
return PolynomialEvalRep(self.evalmap.keys(), self.evalmap.values())
def __repr__(self):
return f'PolyEvalRep[{hex(omega.n)[:15]}...,{n}]({len(self.evalmap)} elements)'
@classmethod
def divideWithCoset(cls, p, t, c=field(3)):
"""
This assumes that p and t are polynomials in coefficient representation,
and that p is divisible by t.
This function is useful when t has roots at some or all of the powers of omega,
in which case we cannot just convert to evalrep and use division above
(since it would cause a divide by zero.
Instead, we evaluate p(X) at powers of (c*omega) for some constant cofactor c.
To do this efficiently, we create new polynomials, pc(X) = p(cX), tc(X) = t(cX),
and evaluate these at powers of omega. This conversion can be done efficiently
on the coefficient representation.
See also: cosetFFT in libsnark / libfqfft.
https://github.com/scipr-lab/libfqfft/blob/master/libfqfft/evaluation_domain/domains/extended_radix2_domain.tcc
"""
assert type(p) is _poly_coeff
assert type(t) is _poly_coeff
# Compute p(cX), t(cX) by multiplying coefficients
c_acc = field(1)
pc = _poly_coeff(list(p.coefficients)) # make a copy
for i in range(p.degree() + 1):
pc.coefficients[-i-1] *= c_acc
c_acc *= c
c_acc = field(1)
tc = _poly_coeff(list(t.coefficients)) # make a copy
for i in range(t.degree() + 1):
tc.coefficients[-i-1] *= c_acc
c_acc *= c
# Divide using evalrep
pc_rep = cls.from_coeffs(pc)
tc_rep = cls.from_coeffs(tc)
hc_rep = pc_rep / tc_rep
hc = hc_rep.to_coeffs()
# Compute h(X) from h(cX) by dividing coefficients
c_acc = field(1)
h = _poly_coeff(list(hc.coefficients)) # make a copy
for i in range(hc.degree() + 1):
h.coefficients[-i-1] /= c_acc
c_acc *= c
# Correctness checks
# assert pc == tc * hc
# assert p == t * h
return h
return PolynomialEvalRep
#| ## Sparse Matrix
#| In our setting, we have O(m*m) elements in the matrix, and expect the number of
#| elements to be O(m).
#| In this setting, it's appropriate to use a rowdict representation - a dense
#| array of dictionaries, one for each row, where the keys of each dictionary
#| are column indices.
class RowDictSparseMatrix():
# Only a few necessary methods are included here.
# This could be replaced with a generic sparse matrix class, such as scipy.sparse,
# but this does not work as well with custom value types like Fp
def __init__(self, m, n, zero=None):
self.m = m
self.n = n
self.shape = (m,n)
self.zero = zero
self.rowdicts = [dict() for _ in range(m)]
def __setitem__(self, key, v):
i, j = key
self.rowdicts[i][j] = v
def __getitem__(self, key):
i, j = key
return self.rowdicts[i][j] if j in self.rowdicts[i] else self.zero
def items(self):
for i in range(self.m):
for j, v in self.rowdicts[i].items():
yield (i,j), v
def dot(self, other):
if isinstance(other, np.ndarray):
assert other.dtype == 'O'
assert other.shape in ((self.n,),(self.n,1))
ret = np.empty((self.m,), dtype='O')
ret.fill(self.zero)
for i in range(self.m):
for j, v in self.rowdicts[i].items():
ret[i] += other[j] * v
return ret
def to_dense(self):
mat = np.empty((self.m, self.n), dtype='O')
mat.fill(self.zero)
for (i,j), val in self.items():
mat[i,j] = val
return mat
def __repr__(self): return repr(self.rowdicts)
#-
# Examples
if __name__ == '__main__':
import misc
Fp = FiniteField(52435875175126190479447740508185965837690552500527637822603658699938581184513,1) # (# noqa: E501)
Poly = polynomialsOver(Fp)
n = 8
omega = misc.get_omega(Fp, n)
PolyEvalRep = make_polynomial_evalrep(Fp, omega, n)
f = Poly([1,2,3,4,5])
xs = tuple([omega**i for i in range(n)])
ys = tuple(map(f, xs))
# print('xs:', xs)
# print('ys:', ys)
assert f == PolyEvalRep(xs, ys).to_coeffs()

204
scripts/halo/sonic.py Normal file
View File

@@ -0,0 +1,204 @@
# From the Sonic paper
from finite_fields import finitefield
import numpy as np
import misc
from multipoly import Variable, MultivariatePolynomial
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
fp = finitefield.IntegersModP(p)
var_one = fp(1)
var_x = fp(4)
var_y = fp(6)
var_s = fp(1)
var_xy = var_x * var_y
var_sxy = var_s * var_xy
var_1_neg_s = var_one - var_s
var_x_y = var_x + var_y
var_1_neg_s_x_y = var_1_neg_s * var_x_y
var_s_neg_1 = -var_1_neg_s
var_zero = fp(0)
public_v = var_s * (var_x * var_y) + (1 - var_s) * (var_x + var_y)
a = np.array([
var_one, var_x, var_xy, var_1_neg_s, var_s
])
b = np.array([
var_one, var_y, var_s, var_x_y, var_s_neg_1
])
c = np.array([
var_one, var_xy, var_sxy, var_1_neg_s_x_y, var_zero
])
assert len(a) == len(b)
assert len(b) == len(c)
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
try:
assert a_i * b_i == c_i
except AssertionError:
print("Error for %i" % i)
raise
# 1 - s = -(s - 1)
u1 = np.array([0, 0, 0, 1, 0])
v1 = np.array([0, 0, 0, 0, 1])
w1 = np.array([0, 0, 0, 0, 0])
k1 = 0
assert a.dot(u1) + b.dot(v1) + c.dot(w1) == k1
# xy = xy
u2 = np.array([0, 0, 1, 0, 0])
v2 = np.array([0, 0, 0, 0, 0])
w2 = np.array([0, -1, 0, 0, 0])
k2 = 0
assert a.dot(u2) + b.dot(v2) + c.dot(w2) == k2
# s = s
u3 = np.array([0, 0, 0, 0, -1])
v3 = np.array([0, 0, 1, 0, 0])
w3 = np.array([0, 0, 0, 0, 0])
k3 = 0
assert a.dot(u3) + b.dot(v3) + c.dot(w3) == k3
# zero = 0
u4 = np.array([0, 0, 0, 0, 0])
v4 = np.array([0, 0, 0, 0, 0])
w4 = np.array([0, 0, 0, 0, 1])
k4 = 0
assert a.dot(u4) + b.dot(v4) + c.dot(w4) == k4
# 1 - s
u5 = np.array([1, 0, 0, -1, 0])
v5 = np.array([0, 0, -1, 0, 0])
w5 = np.array([0, 0, 0, 0, 0])
k5 = 0
assert a.dot(u5) + b.dot(v5) + c.dot(w5) == k5
# x + y
u6 = np.array([0, 1, 0, 0, 0])
v6 = np.array([0, 1, 0, -1, 0])
w6 = np.array([0, 0, 0, 0, 0])
k6 = 0
assert a.dot(u6) + b.dot(v6) + c.dot(w6) == k6
# Final check:
# v = s(xy) + (1 - s)(x + y)
u7 = np.array([0, 0, 0, 0, 0])
v7 = np.array([0, 0, 0, 0, 0])
w7 = np.array([0, 0, 1, 1, 0])
k7 = public_v
assert a.dot(u7) + b.dot(v7) + c.dot(w7) == k7
u = np.vstack((u1, u2, u3, u4, u5, u6, u7))
v = np.vstack((v1, v2, v3, v4, v5, v6, v7))
w = np.vstack((w1, w2, w3, w4, w5, w6, w7))
assert u.shape == v.shape
assert u.shape == w.shape
k = np.array((k1, k2, k3, k4, k5, k6, k7))
x = Variable("X", fp)
y = Variable("Y", fp)
p = MultivariatePolynomial()
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
#print(a_i, "\t", b_i, "\t", c_i)
p += y**i * (a_i * b_i - c_i)
assert not p
p = MultivariatePolynomial()
for q, (u_q, v_q, w_q, k_q) in enumerate(zip(u, v, w, k)):
p += y**q * (a.dot(u_q) + b.dot(v_q) + c.dot(w_q) - k_q)
assert not p
n = len(a)
assert len(b) == n
assert len(c) == n
assert u.shape == (7, n)
assert v.shape == u.shape
assert w.shape == u.shape
assert k.shape == (7,)
r_x_y = MultivariatePolynomial()
s_x_y = MultivariatePolynomial()
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
assert 1 <= i <= n
r_x_y += x**i * y**i * a_i
r_x_y += x**-i * y**-i * b_i
r_x_y += x**(-i - n) * y**(-i - n) * c_i
u_i = u.T[i - 1]
v_i = v.T[i - 1]
w_i = w.T[i - 1]
u_i_Y = MultivariatePolynomial()
v_i_Y = MultivariatePolynomial()
w_i_Y = MultivariatePolynomial()
for q, (u_q_i, v_q_i, w_q_i) in enumerate(zip(u_i, v_i, w_i), 1):
assert 1 <= q <= 7
u_i_Y += y**(q + n) * u_q_i
v_i_Y += y**(q + n) * v_q_i
w_i_Y += -y**i - y**(-i) + y**(q + n) * v_q_i
s_x_y += u_i_Y * x**-i + v_i_Y * x**i + w_i_Y * x**(i + n)
k_y = MultivariatePolynomial()
for q, k_q in enumerate(k, 1):
assert 1 <= q <= 7
k_y += y**(q + n) * k_q
r_prime_x_y = r_x_y + s_x_y
r_x_1 = r_x_y.evaluate({y.name: fp(1)})
t_x_y = r_x_1 * r_prime_x_y - k_y
t_x_y._assert_unique_terms()
const_t = t_x_y.filter([x])
print(const_t)
# Section 6, Figure 2
#
# zkP1
# 4 blinding factors since we evaluate r(X, Y) 3 times
# Blind r(X, Y)
for i in range(1, 4):
blind_c_i = misc.sample_random(fp)
r_x_y += x**(-2*n - i) * y**(-2*n - i) * blind_c_i
# Commit to r(X, Y)
# zkV1
# Send a random y
challenge_y = misc.sample_random(fp)
# zkP2
# Commit to t(X, y)
# zkV2
# Send a random z
challenge_z = misc.sample_random(fp)
# zkP3
# Evaluate a = r(z, 1)
a = r_x_y.evaluate({x.name: challenge_z, y.name: fp(1)})
# Evaluate b = r(z, y)
b = r_x_y.evaluate({x.name: challenge_z, y.name: challenge_y})
# Evaluate t = t(z, y)
t = t_x_y.evaluate({x.name: challenge_z, y.name: challenge_y})
# Evaluate s = s(z, y)
s = s_x_y.evaluate({x.name: challenge_z, y.name: challenge_y})
# zkV3
# Recalculate t from a, b and s
k = k_y.evaluate({y.name: challenge_y})
t = a * (b + s) - k
# Verify polynomial commitments

201
scripts/halo/sonic.sage Normal file
View File

@@ -0,0 +1,201 @@
import numpy as np
from groth_poly_commit import K, create_proof, verify_proof
# Just use the same finite field we put in the polynomial commitment scheme file
#p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
#K = FiniteField(p)
R.<x, y> = LaurentPolynomialRing(K)
var_one = K(1)
var_x = K(4)
var_y = K(6)
var_s = K(1)
var_xy = var_x * var_y
var_sxy = var_s * var_xy
var_1_neg_s = var_one - var_s
var_x_y = var_x + var_y
var_1_neg_s_x_y = var_1_neg_s * var_x_y
var_s_neg_1 = -var_1_neg_s
var_zero = K(0)
public_v = var_s * (var_x * var_y) + (1 - var_s) * (var_x + var_y)
a = np.array([
var_one, var_x, var_xy, var_1_neg_s, var_s
])
b = np.array([
var_one, var_y, var_s, var_x_y, var_s_neg_1
])
c = np.array([
var_one, var_xy, var_sxy, var_1_neg_s_x_y, var_zero
])
assert len(a) == len(b)
assert len(b) == len(c)
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
try:
assert a_i * b_i == c_i
except AssertionError:
print("Error for %i" % i)
raise
# 1 - s = -(s - 1)
u1 = np.array([0, 0, 0, 1, 0])
v1 = np.array([0, 0, 0, 0, 1])
w1 = np.array([0, 0, 0, 0, 0])
k1 = 0
assert a.dot(u1) + b.dot(v1) + c.dot(w1) == k1
# xy = xy
u2 = np.array([0, 0, 1, 0, 0])
v2 = np.array([0, 0, 0, 0, 0])
w2 = np.array([0, -1, 0, 0, 0])
k2 = 0
assert a.dot(u2) + b.dot(v2) + c.dot(w2) == k2
# s = s
u3 = np.array([0, 0, 0, 0, -1])
v3 = np.array([0, 0, 1, 0, 0])
w3 = np.array([0, 0, 0, 0, 0])
k3 = 0
assert a.dot(u3) + b.dot(v3) + c.dot(w3) == k3
# zero = 0
u4 = np.array([0, 0, 0, 0, 0])
v4 = np.array([0, 0, 0, 0, 0])
w4 = np.array([0, 0, 0, 0, 1])
k4 = 0
assert a.dot(u4) + b.dot(v4) + c.dot(w4) == k4
# 1 - s
u5 = np.array([1, 0, 0, -1, 0])
v5 = np.array([0, 0, -1, 0, 0])
w5 = np.array([0, 0, 0, 0, 0])
k5 = 0
assert a.dot(u5) + b.dot(v5) + c.dot(w5) == k5
# x + y
u6 = np.array([0, 1, 0, 0, 0])
v6 = np.array([0, 1, 0, -1, 0])
w6 = np.array([0, 0, 0, 0, 0])
k6 = 0
assert a.dot(u6) + b.dot(v6) + c.dot(w6) == k6
# Final check:
# v = s(xy) + (1 - s)(x + y)
u7 = np.array([0, 0, 0, 0, 0])
v7 = np.array([0, 0, 0, 0, 0])
w7 = np.array([0, 0, 1, 1, 0])
k7 = public_v
assert a.dot(u7) + b.dot(v7) + c.dot(w7) == k7
u = np.vstack((u1, u2, u3, u4, u5, u6, u7))
v = np.vstack((v1, v2, v3, v4, v5, v6, v7))
w = np.vstack((w1, w2, w3, w4, w5, w6, w7))
assert u.shape == v.shape
assert u.shape == w.shape
k = np.array((k1, k2, k3, k4, k5, k6, k7))
p = K(0)
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
#print(a_i, "\t", b_i, "\t", c_i)
p += y**i * (a_i * b_i - c_i)
print(p)
p = K(0)
for q, (u_q, v_q, w_q, k_q) in enumerate(zip(u, v, w, k)):
p += y**q * (a.dot(u_q) + b.dot(v_q) + c.dot(w_q) - k_q)
print(p)
n = len(a)
assert len(b) == n
assert len(c) == n
assert u.shape == (7, n)
assert v.shape == u.shape
assert w.shape == u.shape
assert k.shape == (7,)
r_x_y = 0
s_x_y = 0
for i, (a_i, b_i, c_i) in enumerate(zip(a, b, c), 1):
assert 1 <= i <= n
r_x_y += x**i * y**i * a_i
r_x_y += x**-i * y**-i * b_i
r_x_y += x**(-i - n) * y**(-i - n) * c_i
u_i = u.T[i - 1]
v_i = v.T[i - 1]
w_i = w.T[i - 1]
u_i_Y = 0
v_i_Y = 0
w_i_Y = 0
for q, (u_q_i, v_q_i, w_q_i) in enumerate(zip(u_i, v_i, w_i), 1):
assert 1 <= q <= 7
u_i_Y += y**(q + n) * u_q_i
v_i_Y += y**(q + n) * v_q_i
w_i_Y += -y**i - y**(-i) + y**(q + n) * w_q_i
s_x_y += u_i_Y * x**-i + v_i_Y * x**i + w_i_Y * x**(i + n)
k_y = 0
for q, k_q in enumerate(k, 1):
assert 1 <= q <= 7
k_y += y**(q + n) * k_q
# Section 6, Figure 2
#
# zkP1
# 4 blinding factors since we evaluate r(X, Y) 3 times
# Blind r(X, Y)
#for i in range(1, 4 + 1):
# blind_c_i = K.random_element()
# r_x_y += x**(-2*n - i) * y**(-2*n - i) * blind_c_i
# Commit to r(X, Y)
r_prime_x_y = r_x_y + s_x_y
r_x_1 = r_x_y(y=K(1))
t_x_y = r_x_1 * r_prime_x_y - k_y
print(t_x_y.constant_coefficient())
# zkV1
# Send a random y
challenge_y = K.random_element()
# zkP2
# Commit to t(X, y)
t_x = t_x_y(y=challenge_y)
t_x = t_x.univariate_polynomial()
print(t_x.constant_coefficient())
# zkV2
# Send a random z
challenge_z = K.random_element()
# zkP3
# Evaluate a = r(z, 1)
a = r_x_y(x=challenge_z, y=K(1))
# Evaluate b = r(z, y)
b = r_x_y(x=challenge_z, y=challenge_y)
# Evaluate t = t(z, y)
t = t_x_y(x=challenge_z, y=challenge_y)
# Evaluate s = s(z, y)
s = s_x_y(x=challenge_z, y=challenge_y)
# zkV3
# Recalculate t from a, b and s
k = k_y(y=challenge_y)
t_new = a * (b + s) - k
assert t_new == t
# Verify polynomial commitments

116
scripts/halo/test.py Normal file
View File

@@ -0,0 +1,116 @@
import random
import misc
import pasta
from polynomial_evalrep import make_polynomial_evalrep
n = 8
omega_base = misc.get_omega(pasta.fp, 2**32, seed=0)
assert misc.is_power_of_two(8)
omega = omega_base ** (2 ** 32 // n)
# Order of omega is n
assert omega ** n == 1
# Compute complete roots of this group
ROOTS = [omega ** i for i in range(n)]
PolyEvalRep = make_polynomial_evalrep(pasta.fp, omega, n)
import numpy as np
from tabulate import tabulate
# Wires
a = ["x", "v1", "v2", "1", "1", "v3", "e1", "e2"]
b = ["x", "x", "x", "5", "35", "5", "e3", "e4"]
c = ["v1", "v2", "v3", "5", "35", "35", "e5", "e6"]
wires = a + b + c
# Gates
# La + Rb + Oc + Mab + C = 0
add = np.array([1, 1, 0, -1, 0])
mul = np.array([0, 0, 1, -1, 0])
const5 = np.array([0, 1, 0, 0, -5])
public_input = np.array([0, 1, 0, 0, 0])
empty = np.array([0, 0, 0, 0, 0])
gates_matrix = np.array(
[mul, mul, add, const5, public_input, add, empty, empty])
print("Wires:")
print(tabulate([["a ="] + a, ["b ="] + b, ["c ="] + c]))
print()
print("Gates:")
print(gates_matrix)
print()
# The index of the public input in the gates_matrix
# We specify its position and its value
public_input_values = [(4, 35)]
def permute_indices(wires):
size = len(wires)
permutation = [i + 1 for i in range(size)]
for i in range(size):
for j in range(i + 1, size):
if wires[i] == wires[j]:
permutation[i], permutation[j] = permutation[j], permutation[i]
break
return permutation
permutation = permute_indices(wires)
table = [
["Wires"] + wires,
["Indices"] + list(i + 1 for i in range(len(wires))),
["Permutations"] + permutation
]
print(tabulate(table))
print()
import misc
from pasta import fp
def setup(wires, gates_matrix):
# Section 8.1
# The selector polynomials that define the circuit's arithmetisation
gates_matrix = gates_matrix.transpose()
ql = PolyEvalRep(ROOTS, [fp(i) for i in gates_matrix[0]])
qr = PolyEvalRep(ROOTS, [fp(i) for i in gates_matrix[1]])
qm = PolyEvalRep(ROOTS, [fp(i) for i in gates_matrix[2]])
qo = PolyEvalRep(ROOTS, [fp(i) for i in gates_matrix[3]])
qc = PolyEvalRep(ROOTS, [fp(i) for i in gates_matrix[4]])
selector_polys = [ql, qr, qm, qo, qc]
public_input = [fp(0) for i in range(len(ROOTS))]
for (index, value) in public_input_values:
# This is negative because the value is added to
# the output of the const selector poly:
# La + Rb + Oc + Mab + (C + PI) = 0
public_input[index] = fp(-value)
public_input_poly = PolyEvalRep(ROOTS, public_input)
# Identity permutations applied to a, b, c
# Ideally H, k_1 H, k_2 H are distinct cosets of H
# Here we just sample k and assume it's high-order
# Random high order k to form distinct cosets
k = misc.sample_random(fp)
id_domain_a = ROOTS
id_domain_b = [k * root for root in ROOTS]
id_domain_c = [k**2 * root for root in ROOTS]
id_domain = id_domain_a + id_domain_b + id_domain_c
# Intermediate step where we permute the positions of the domain
# generated above
permuted_domain = [id_domain[i - 1] for i in permutation]
permuted_domain_a = permuted_domain[:n]
permuted_domain_b = permuted_domain[n:2 * n]
permuted_domain_c = permuted_domain[2*n:3 * n]
# The copy permuation applied to a, b, c
# Returns the permuted index value (corresponding root of unity coset)
# when evaluated on the domain.
ssigma_1 = PolyEvalRep(ROOTS, permuted_domain_a)
ssigma_2 = PolyEvalRep(ROOTS, permuted_domain_b)
ssigma_3 = PolyEvalRep(ROOTS, permuted_domain_c)
copy_permutes = [ssigma_1, ssigma_2, ssigma_3]
setup(wires, gates_matrix)

80
scripts/jubjub.py Normal file
View File

@@ -0,0 +1,80 @@
from finite_fields.modp import IntegersModP
q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
modq = IntegersModP(q)
a = modq(-1)
print("a:", hex(a.n))
d = -(modq(10240)/modq(10241))
params = (a, d)
def is_jubjub(params, x, y):
a, d = params
return a * x**2 + y**2 == 1 + d * x**2 * y**2
def add(params, point_1, point_2):
# From here: https://z.cash/technology/jubjub/
a, d = params
x1, y1 = point_1
x2, y2 = point_2
x3 = (x1 * y2 + y1 * x2) / (1 + d * x1 * x2 * y1 * y2)
y3 = (y1 * y2 + x1 * x2) / (1 - d * x1 * x2 * y1 * y2)
return (x3, y3)
def fake_zk_add(params, point_1, point_2):
# From here: https://z.cash/technology/jubjub/
a, d = params
x1, y1 = point_1
x2, y2 = point_2
# Compute U = (u1 + v1) * (v2 - EDWARDS_A*u2)
# = (u1 + v1) * (u2 + v2)
U = (x1 + y1) * (x2 + y2)
assert (x1 + y1) * (x2 + y2) == U
# Compute A = v2 * u1
A = y2 * x1
# Compute B = u2 * v1
B = x2 * y1
# Compute C = d*A*B
C = d * A * B
assert (d * A) * (B) == C
# Compute u3 = (A + B) / (1 + C)
# NOTE: make sure we check for (1 + C) has an inverse
u3 = (A + B) / (1 + C)
assert (1 + C) * (u3) == (A + B)
# Compute v3 = (U - A - B) / (1 - C)
# We will also need to check inverse here as well.
v3 = (U - A - B) / (1 - C)
assert (1 - C) * (v3) == (U - A - B)
return u3, v3
x = 0x15a36d1f0f390d8852a35a8c1908dd87a361ee3fd48fdf77b9819dc82d90607e
y = 0x015d8c7f5b43fe33f7891142c001d9251f3abeeb98fad3e87b0dc53c4ebf1891
x3, y3 = add(params, (x, y), (x, y))
print(hex(x3.n), hex(y3.n))
u3, v3 = fake_zk_add(params, (x, y), (x, y))
print(hex(u3.n), hex(v3.n))
print(is_jubjub(params, x, y))
print(is_jubjub(params, x3, y3))
print()
print("Identity (0, 1) is jubjub?", is_jubjub(params, 0, 1))
print("Torsion (0, -1) is jubjub?", is_jubjub(params, 0, -1))
double_torsion = add(params, (0, -1), (0, -1))
print("Double torsion is:", hex(double_torsion[0].n), hex(double_torsion[1].n))
dbl_ident = add(params, (0, 1), (0, 1))
print("Double identity is:", hex(dbl_ident[0].n), hex(dbl_ident[1].n))

30
scripts/modp.py Normal file
View File

@@ -0,0 +1,30 @@
from finite_fields.modp import IntegersModP
q = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
modq = IntegersModP(q)
a = modq(-1)
print("0x%x" % a.n)
print("\n")
two = modq(2)
inv2 = modq(2).inverse()
print("Inverse of 2 = 0x%x" % inv2.n)
print((two * inv2))
# This is from bellman
inv2_bellman = 0x39f6d3a994cebea4199cec0404d0ec02a9ded2017fff2dff7fffffff80000001
assert inv2.n == inv2_bellman
assert (2 * inv2.n) % q == 1
# Futures contract calculation
multiplier = modq(1)
quantity = modq(100)
entry_price = modq(10000)
exit_price = modq(15000)
initial_margin = multiplier * quantity
print("initial margin =", initial_margin)
price_return = exit_price * entry_price.inverse()
print("R =", price_return)
pnl = initial_margin - (initial_margin * exit_price) * entry_price.inverse()
print("PNL =", pnl)

159
scripts/monitor-p2p.py Normal file
View File

@@ -0,0 +1,159 @@
import asyncio
from tabulate import tabulate
from copy import deepcopy
import re
import os
import sys
import time
lock = asyncio.Lock()
logs_path = "/tmp/darkfi/"
node_info = {
}
ping_times = {
}
def debug(line):
#print(line)
pass
def process(info, line):
regex_listen = re.compile(
".* Listening on (\d+[.]\d+[.]\d+[.]\d+:\d+)")
regex_inbound_connect = re.compile(
".* Connected inbound \[(\d+[.]\d+[.]\d+[.]\d+:\d+)\]")
regex_outbound_slots = re.compile(
".* Starting (\d+) outbound connection slots.")
regex_outbound_connect = re.compile(
".* #(\d+) connected to outbound \[(\d+[.]\d+[.]\d+[.]\d+:\d+)\]")
regex_channel_disconnected = re.compile(
".* Channel (\d+[.]\d+[.]\d+[.]\d+:\d+) disconnected")
regex_pong_recv = re.compile(
".* Received Pong message (\d+)ms from \[(\d+[.]\d+[.]\d+[.]\d+:\d+)\]")
if "net: P2p::start() [BEGIN]" in line:
info["status"] = "p2p-start"
elif "net: SeedSession::start() [START]" in line:
info["status"] = "seed-start"
elif "net: SeedSession::start() [END]" in line:
info["status"] = "seed-done"
elif "net: P2p::start() [END]" in line:
info["status"] = "p2p-done"
elif "net: P2p::run() [BEGIN]" in line:
info["status"] = "p2p-run"
elif "Not configured for accepting incoming connections." in line:
info["inbounds"] = ["Disabled"]
elif (match := regex_listen.match(line)) is not None:
address = match.group(1)
info["listen"] = address
elif (match := regex_inbound_connect.match(line)) is not None:
address = match.group(1)
info["inbounds"].append(address)
elif (match := regex_outbound_slots.match(line)) is not None:
slots = match.group(1)
info["outbounds"] = ["None" for _ in range(int(slots))]
elif (match := regex_outbound_connect.match(line)) is not None:
slot = match.group(1)
address = match.group(2)
info["outbounds"][int(slot)] = address
elif (match := regex_channel_disconnected.match(line)) is not None:
address = match.group(1)
try:
info["inbounds"].remove(address)
except ValueError:
pass
try:
idx = info["outbounds"].index(address)
info["outbounds"][idx] = "None"
except ValueError:
pass
elif (match := regex_pong_recv.match(line)) is not None:
ping_time = match.group(1)
address = match.group(2)
ping_times[address] = ping_time
async def scanner(filename):
global table_data
async with lock:
node_info[filename] = {
"status": "none",
"inbounds": [],
"outbounds": [],
}
info = node_info[filename]
with open(logs_path + filename) as fileh:
while True:
line = fileh.readline()
if line:
debug("R: " + filename + ": " + line[:-1])
async with lock:
process(info, line)
else:
await asyncio.sleep(0.5)
def clear_lines(n):
for i in range(n):
sys.stdout.write('\033[F')
def get_ping(addr):
ping_time = "none"
if addr in ping_times:
ping_time = str(ping_times[addr]) + " ms"
return ping_time
def table_format(ninfo):
table_data = []
for filename, info in ninfo.items():
table_data.append([filename, "", ""])
table_data.append(["", "status", info["status"]])
if "listen" in info:
table_data.append(["", "listen", info["listen"]])
inbounds = info["inbounds"]
if inbounds:
table_data.append(["", "inbounds", inbounds[0],
get_ping(inbounds[0])])
for inbound in inbounds[1:]:
table_data.append(["", "", inbound, get_ping(inbound)])
outbounds = info["outbounds"]
if outbounds:
table_data.append(["", "outbounds", outbounds[0],
get_ping(outbounds[0])])
for outbound in outbounds[1:]:
table_data.append(["", "", outbound, get_ping(outbound)])
headers = ["Name", "Attribute", "Value", "Ping Times"]
return headers, table_data
async def refresh_table(tick=1):
for filename in os.listdir(logs_path):
asyncio.create_task(scanner(filename))
previous_lines = 0
while True:
clear_lines(previous_lines)
async with lock:
ninfo = deepcopy(node_info)
headers, table_data = table_format(ninfo)
lines = tabulate(table_data, headers=headers).split("\n")
debug("-------------------")
for line in lines:
print('\x1b[2K\r', end="")
print(line)
previous_lines = len(lines)
await asyncio.sleep(1)
asyncio.run(refresh_table())

2
scripts/old/compile.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/bash
python scripts/parser.py proofs/sapling3.prf | rustfmt > proofs/sapling3.rs

4
scripts/old/compile_and_run.sh Executable file
View File

@@ -0,0 +1,4 @@
#!/bin/bash -x
python scripts/pism.py proofs/simple.pism | rustfmt > src/simple_circuit.rs
cargo run --release --bin simple

View File

@@ -0,0 +1,5 @@
# Dark Client
$ python3 -m venv env
$ source env/bin/activate
$ pip install -r requirements.txt

89
scripts/old/dark_client/drk.py Executable file
View File

@@ -0,0 +1,89 @@
#!/usr/bin/env python
from util import arg_parser
import aiohttp
import asyncio
class DarkClient:
# TODO: generate random ID (4 byte unsigned int) (rand range 0 - max size
# uint32
def __init__(self, client_session):
self.url = "http://localhost:8000/"
self.client_session = client_session
self.payload = {
"method": [],
"params": [],
"jsonrpc": [],
"id": [],
}
async def key_gen(self, payload):
payload['method'] = "key_gen"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
key = await self.__request(payload)
print(key)
async def get_info(self, payload):
payload['method'] = "get_info"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
info = await self.__request(payload)
print(info)
async def stop(self, payload):
payload['method'] = "stop"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
stop = await self.__request(payload)
print(stop)
async def say_hello(self, payload):
payload['method'] = "say_hello"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
hello = await self.__request(payload)
print(hello)
async def create_wallet(self, payload):
payload['method'] = "create_wallet"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
wallet = await self.__request(payload)
print(wallet)
async def create_cashier_wallet(self, payload):
payload['method'] = "create_cashier_wallet"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
wallet = await self.__request(payload)
print(wallet)
async def test_wallet(self, payload):
payload['method'] = "test_wallet"
payload['jsonrpc'] = "2.0"
payload['id'] = "0"
test = await self.__request(payload)
print(test)
async def __request(self, payload):
async with self.client_session.post(self.url, json=payload) as response:
resp = await response.text()
print(resp)
async def main():
try:
async with aiohttp.ClientSession() as session:
client = DarkClient(session)
await arg_parser(client)
except aiohttp.ClientConnectorError as err:
print('CONNECTION ERROR:', str(err))
except Exception as err:
print("ERROR: ", str(err))
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(main())

View File

@@ -0,0 +1,14 @@
aiodns==3.0.0
aiohttp==3.7.4.post0
async-timeout==3.0.1
attrs==21.2.0
brotlipy==0.7.0
cchardet==2.1.7
cffi==1.14.5
chardet==4.0.0
idna==3.2
multidict==5.1.0
pycares==4.0.0
pycparser==2.20
typing-extensions==3.10.0.0
yarl==1.6.3

View File

@@ -0,0 +1,53 @@
import argparse
async def arg_parser(client):
parser = argparse.ArgumentParser(
prog='drk',
usage='%(prog)s [commands]',
description="""DarkFi wallet command-line tool"""
)
parser.add_argument('-c', '--cashier', action='store_true', help='Create a cashier wallet')
parser.add_argument('-w', '--wallet', action='store_true', help='Create a new wallet')
parser.add_argument('-k', '--key', action='store_true', help='Test key')
parser.add_argument('-i', '--info', action='store_true', help='Request info from daemon')
parser.add_argument('-hi', '--hello', action='store_true', help='Test hello')
parser.add_argument("-s", "--stop", action='store_true', help="Send a stop signal to the daemon")
parser.add_argument("-t", "--test", action='store_true', help="Test writing to the wallet")
try:
args = parser.parse_args()
if args.key:
print("Attemping to generate a create key pair...")
await client.key_gen(client.payload)
if args.wallet:
print("Attemping to create a wallet...")
await client.create_wallet(client.payload)
if args.info:
print("Info was entered")
await client.get_info(client.payload)
print("Requesting daemon info...")
if args.stop:
print("Stop was entered")
await client.stop(client.payload)
print("Sending a stop signal...")
if args.hello:
print("Hello was entered")
await client.say_hello(client.payload)
if args.cashier:
print("Attempting to generate a cashier wallet...")
await client.create_cashier_wallet(client.payload)
if args.test:
print("Testing wallet write")
await client.test_wallet(client.payload)
except Exception:
raise

5
scripts/old/run_bits.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/bits.psm > /tmp/bits.psm
python scripts/vm.py --rust /tmp/bits.psm > src/bits_contract.rs
cargo run --release --bin bits

5
scripts/old/run_mimc.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/mimc.psm > /tmp/mimc.psm
python scripts/vm.py --rust /tmp/mimc.psm > src/zkmimc_contract.rs
cargo run --release --bin zkmimc

5
scripts/old/run_mint2.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/mint2.psm > /tmp/mint2.psm || exit $?
python scripts/vm.py --rust /tmp/mint2.psm > src/mint2_contract.rs || exit $?
cargo run --release --bin mint2

View File

@@ -0,0 +1,4 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/mint.pism > /tmp/mint.pism
python scripts/pism.py /tmp/mint.pism proofs/mint.aux | rustfmt > src/mint_contract.rs
cargo run --release --bin mint

View File

@@ -0,0 +1,4 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/spend.pism > /tmp/spend.pism
python scripts/pism.py /tmp/spend.pism proofs/mint.aux | rustfmt > src/spend_contract.rs
cargo run --release --bin spend

5
scripts/old/run_vmtest.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash -x
python scripts/preprocess.py proofs/jubjub.pism > /tmp/jubjub.pism
python scripts/vm.py --rust /tmp/jubjub.pism > src/vm_load.rs
cargo run --release --bin vmtest

835
scripts/parser.py Normal file
View File

@@ -0,0 +1,835 @@
import lark
import pprint
import re
import sys
class LineDesc:
def __init__(self, level, text, lineno):
self.level = level
self.text = text
self.lineno = lineno
assert self.text[0] != ' '
def __repr__(self):
return "<%s:'%s'>" % (self.level, self.text)
def clean_line(line):
lead_spaces = len(line) - len(line.lstrip(" "))
level = lead_spaces / 4
# Remove leading spaces
if line.strip(" ") == "":
return None
line = line.lstrip(" ")
# Remove all comments
line = re.sub('#.*$', '', line).strip()
if not line:
return None
return level, line
def parse(text):
lines = text.split("\n")
linedescs = []
# These are to join open parenthesis
current_line = ""
paren_level = 0
for lineno, line in enumerate(lines):
if (lineinfo := clean_line(line)) is None:
continue
level, line = lineinfo
for c in line:
if c == "(":
paren_level += 1
elif c == ")":
paren_level -= 1
#print(level, paren_level, current_line)
if paren_level < 0:
print("error: too many closing paren )", file=sys.stderr)
print("line:", lineno)
return
if current_line:
current_line += " " + line
else:
current_line = line
if paren_level > 0:
continue
#print(level, current_line)
ldesc = LineDesc(level, current_line, lineno)
linedescs.append(ldesc)
current_line = ""
if paren_level > 0:
print("error: missing closing paren )", file=sys.stderr)
return None
return linedescs
def section(linedescs):
sections = []
current_section = None
for desc in linedescs:
if desc.level == 0:
if current_section:
sections.append(current_section)
current_section = [desc]
continue
current_section.append(desc)
sections.append(current_section)
return sections
def classify(sections):
consts = []
funcs = []
contracts = []
for section in sections:
assert len(section)
if section[0].text == "const:":
consts.append(section)
elif section[0].text.startswith("def"):
funcs.append(section)
elif section[0].text.startswith("contract"):
contracts.append(section)
return consts, funcs, contracts
def tokenize_const(text):
parser = lark.Lark(r"""
value_map: name ":" type_def
name: NAME
?type_def: point
| blake2s_personalization
| pedersen_personalization
| list
point: "Point"
blake2s_personalization: "Blake2sPersonalization"
pedersen_personalization: "PedersenPersonalization"
list: "list<" type_def ">"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="value_map")
return parser.parse(text)
class ConstTransformer(lark.Transformer):
def name(self, name):
return str(name[0])
def point(self, _):
return "Point"
def blake2s_personalization(self, _):
return "Blake2sPersonalization"
def pedersen_personalization(self, _):
return "PedersenPersonalization"
value_map = tuple
list = list
def read_consts(consts):
consts_map = {}
for subsection in consts:
assert subsection[0].text == "const:"
for ldesc in subsection[1:]:
tree = tokenize_const(ldesc.text)
tokens = ConstTransformer().transform(tree)
#print(tokens)
name, typedesc = tokens
consts_map[name] = typedesc
#pprint.pprint(consts_map)
return consts_map
class FuncDefTransformer(lark.Transformer):
def func_name(self, name):
return str(name[0])
def param(self, obj):
return tuple(obj)
def param_name(self, name):
return str(name[0])
def u64(self, _):
return "U64"
def scalar(self, _):
return "Scalar"
def point(self, _):
return "Point"
def binary(self, _):
return "Binary"
def type(self, obj):
return obj[0]
func_def = list
params = list
type_list = list
def parse_func_def(text):
parser = lark.Lark(r"""
func_def: "def" func_name "(" params+ ")" "->" type_list ":"
func_name: NAME
params: param ("," param)*
type_list: type
| "(" type ("," type)* ")"
param: param_name ":" type
param_name: NAME
type: u64 | scalar | point | binary
u64: "U64"
scalar: "Scalar"
point: "Point"
binary: "Binary"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="func_def")
tree = parser.parse(text)
tokens = FuncDefTransformer().transform(tree)
assert len(tokens) == 3
return tokens
def compile_func_header(func_def):
func_name, params, retvals = func_def
#print("Function:", func_name)
#print("Params:", params)
#print("Return values:", retvals)
#print()
param_str = ""
for param, type in params:
if param_str:
param_str += ", "
param_str += param + ": Option<"
if type == "U64":
param_str += "u64"
elif type == "Scalar":
param_str += "jubjub::Fr"
elif type == "Point":
param_str += "jubjub::SubgroupPoint"
else:
print("error: unsupported param type", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
param_str += ">"
converted_retvals = []
for type in retvals:
if type == "Binary":
converted_retvals.append("boolean::Boolean")
else:
print("error: unsupported return type", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
retvals = converted_retvals
if len(retvals) == 1:
retstr = retvals[0]
else:
retstr = "(" + ", ".join(retvals) + ")"
header = r"""fn %s<CS>(
mut cs: CS,
%s
) -> Result<%s, SynthesisError>
where
CS: ConstraintSystem<bls12_381::Scalar>,
{
""" % (func_name, param_str, retstr)
return header
def as_expr(line, stack, consts, expr, code):
var_from, type_to = expr.children
if var_from not in stack:
print("error: variable from not in stack frame:", var_from,
file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
type_from = stack[var_from]
if type_from == "U64" and type_to == "Binary":
code += "boolean::u64_into_boolean_vec_le(" + \
"cs.namespace(|| \"" + line.text + "\"), " + var_from + \
")?;"
elif type_from == "Scalar" and type_to == "Binary":
code += "boolean::field_into_boolean_vec_le(" + \
"cs.namespace(|| \"" + line.text + "\"), " + var_from + \
")?;"
else:
print("error: unknown type conversion!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
#print(var_from, type_from, type_to)
return code, type_to
def mul_expr(line, stack, consts, expr, code):
var_a, var_b = expr.children
#print("MUL", var_a, var_b)
if var_b not in consts:
print("error: unknown base!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
base_type = consts[var_b]
if base_type != "Point":
print("error: unknown base type!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
code += "ecc::fixed_base_multiplication(" + \
"cs.namespace(|| \"" + line.text + "\"), &" + var_b + \
", &" + var_a + ")?;"
return code, base_type
def add_expr(line, stack, consts, expr, code):
var_a, var_b = expr.children
if var_a not in stack or var_b not in stack:
print("error: missing stack item!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
result_type = stack[var_a]
if stack[var_b] != result_type:
print("error: non matching items for addition!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
code += var_a + ".add(cs.namespace(|| \"" + line.text \
+ "\"), &" + var_b + ")?;"
return (code, result_type)
def compile_let(line, stack, consts, statement):
is_mutable = False
if statement[0] == "mut":
is_mutable = True
statement = statement[1:]
variable_name, variable_type = statement[0], statement[1]
expr = statement[2]
#print("LET", is_mutable, variable_name, variable_type)
#print(" ", expr)
code = "let " + ("mut " if is_mutable else "") + variable_name + " = "
if expr.data == "as_expr":
ceval = as_expr(line, stack, consts, expr, code)
elif expr.data == "mul_expr":
ceval = mul_expr(line, stack, consts, expr, code)
elif expr.data == "add_expr":
ceval = add_expr(line, stack, consts, expr, code)
if ceval is None:
return None
code, type_to = ceval
if variable_type != type_to:
print("error: sub expr does not evaluate to correct type",
file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
stack[variable_name] = variable_type
return code
def interpret_func(func, consts):
func_def = parse_func_def(func[0].text)
header = compile_func_header(func_def)
if header is None:
return
subroutine = header
indent = " " * 4
stack = dict(func_def[1])
emitted_types = []
for line in func[1:]:
statement_type, statement = interpret_func_line(line.text, stack, consts)
if statement_type == "let":
code = compile_let(line, stack, consts, statement)
if code is None:
return
subroutine += indent + code + "\n"
elif statement_type == "return":
for var in statement:
if var not in stack:
print("error: missing variable in stack!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
if len(statement) == 1:
code = "Ok(" + statement[0] + ")"
else:
code = "Ok(" + ",".join(statement) + ")"
subroutine += indent + code + "\n"
elif statement_type == "emit":
assert len(statement) == 1
variable = statement[0]
if variable not in stack:
print("error: missing variable in stack!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
variable_type = stack[variable]
if variable_type == "Point":
code = variable + ".inputize(cs.namespace(|| \"" + \
line.text + "\"))?;"
else:
print("error: unable to inputize type!", file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
emitted_types.append(variable_type)
subroutine += indent + code + "\n"
subroutine += "}\n\n"
return subroutine, emitted_types, func_def
class CodeLineTransformer(lark.Transformer):
def variable_name(self, name):
return str(name[0])
def let_statement(self, obj):
return ("let", obj)
def return_statement(self, obj):
return ("return", obj)
def emit_statement(self, obj):
return ("emit", obj)
def point(self, _):
return "Point"
def scalar(self, _):
return "Scalar"
def binary(self, _):
return "Binary"
def u64(self, _):
return "U64"
def type(self, typename):
return str(typename[0])
def mutable(self, _):
return "mut"
statement = list
def interpret_func_line(text, stack, consts):
parser = lark.Lark(r"""
statement: let_statement
| return_statement
| emit_statement
let_statement: "let" [mutable] variable_name ":" type "=" expr
mutable: "mut"
?expr: as_expr
| mul_expr
| add_expr
as_expr: variable_name "as" type
mul_expr: variable_name "*" variable_name
add_expr: variable_name "+" variable_name
return_statement: "return" variable_name
| "return" variable_tuple
variable_tuple: "(" variable_name ("," variable_name)* ")"
emit_statement: "emit" variable_name
variable_name: NAME
type: u64 | scalar | point | binary
u64: "U64"
scalar: "Scalar"
point: "Point"
binary: "Binary"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="statement")
tree = parser.parse(text)
tokens = CodeLineTransformer().transform(tree)[0]
return tokens
class ContractDefTransformer(lark.Transformer):
def contract_name(self, name):
return str(name[0])
def param(self, obj):
return tuple(obj)
def param_name(self, name):
return str(name[0])
def u64(self, _):
return "U64"
def scalar(self, _):
return "Scalar"
def point(self, _):
return "Point"
def binary(self, _):
return "Binary"
def type(self, obj):
return obj[0]
contract_def = list
params = list
type_list = list
def parse_contract_def(text):
parser = lark.Lark(r"""
contract_def: "contract" contract_name "(" params+ ")" "->" type_list ":"
contract_name: NAME
params: param ("," param)*
type_list: type
| "(" type ("," type)* ")"
param: param_name ":" type
param_name: NAME
type: u64 | scalar | point | binary
u64: "U64"
scalar: "Scalar"
point: "Point"
binary: "Binary"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="contract_def")
tree = parser.parse(text)
tokens = ContractDefTransformer().transform(tree)
assert len(tokens) == 3
return tokens
class ContractCodeLineTransformer(lark.Transformer):
def variable_name(self, name):
return str(name[0])
def let_statement(self, obj):
return ("let", obj)
def return_statement(self, obj):
return ("return", obj)
def emit_statement(self, obj):
return ("emit", obj)
def method_statement(self, obj):
return ("method", obj)
def point(self, _):
return "Point"
def scalar(self, _):
return "Scalar"
def binary(self, _):
return "Binary"
def u64(self, _):
return "U64"
def type(self, typename):
return str(typename[0])
def mutable(self, _):
return "mut"
def function_name(self, name):
return str(name[0])
statement = list
variable_assign = list
variable_decl = tuple
def interpret_contract_line(text, stack, consts):
parser = lark.Lark(r"""
statement: let_statement
| return_statement
| emit_statement
| method_statement
let_statement: "let" variable_assign "=" expr
mutable: "mut"
variable_assign: variable_decl
| "(" variable_decl ("," variable_decl)* ")"
variable_decl: [mutable] variable_name ":" type
?expr: as_expr
| mul_expr
| add_expr
| funccall_expr
| empty_list_expr
as_expr: variable_name "as" type
mul_expr: variable_name "*" variable_name
add_expr: variable_name "+" variable_name
funccall_expr: function_name "(" [variable_name ("," variable_name)*] ")"
empty_list_expr: "[]"
return_statement: "return" variable_name
| "return" variable_tuple
variable_tuple: "(" variable_name ("," variable_name)* ")"
emit_statement: "emit" variable_name
method_statement: variable_name "." funccall_expr
variable_name: NAME
function_name: NAME
type: u64 | scalar | point | binary
u64: "U64"
scalar: "Scalar"
point: "Point"
binary: "Binary"
%import common.CNAME -> NAME
%import common.WS
%ignore WS
""", start="statement")
tree = parser.parse(text)
tokens = ContractCodeLineTransformer().transform(tree)[0]
return tokens
def to_initial_caps(snake_str):
components = snake_str.split("_")
return "".join(x.title() for x in components)
def create_contract_header(contract_def):
contract_name, params, retvals = contract_def
contract_name = to_initial_caps(contract_name)
header = "pub struct %s {\n" % contract_name
for param_name, param_type in params:
if param_type == "U64":
param_type = "u64"
elif param_type == "Scalar":
param_type = "jubjub::Fr"
elif param_type == "Point":
param_type = "jubjub::SubgroupPoint"
header += " " * 4 + "pub %s: Option<%s>,\n" % (param_name, param_type)
header += "}\n\n"
header += r"""impl Circuit<bls12_381::Scalar> for %s {
fn synthesize<CS: ConstraintSystem<bls12_381::Scalar>>(
self,
cs: &mut CS,
) -> Result<(), SynthesisError> {
""" % contract_name
return header
# Worst code ever
def compile_let2(line, stack, consts, funcs, selfvars, statement):
lhs = []
for variable_decl in statement[0]:
assert len(variable_decl) == 2 or \
(len(variable_decl) == 3 and variable_decl[0] == "mut")
if len(variable_decl) == 2:
mutable = False
elif len(variable_decl) == 3:
assert variable_decl[0] == "mut"
mutable = True
variable_decl = variable_decl[1:]
#else:
# Error!
lhs.append(list(variable_decl) + [mutable])
variable_types = []
code = "let "
if len(lhs) == 1:
name, type, is_mutable = lhs[0]
variable_types.append(type)
code += ("mut " if is_mutable else "") + name
else:
code += "("
start = True
for name, type, is_mutable in lhs:
if not start:
code += ", "
start = False
code += name
variable_types.append(type)
code += ")"
code += " = "
expr = statement[1]
expr_type = expr.data
expr = expr.children
if expr_type == "funccall_expr":
ceval = funccall_expr(line, stack, consts, funcs, selfvars, expr, code)
elif expr_type == "empty_list_expr":
ceval = code + "vec![];", ["Binary"]
#code = "let " + ("mut " if is_mutable else "") + variable_name + " = "
if ceval is None:
return None
code, types_to = ceval
if variable_types != types_to:
print("error: sub expr does not evaluate to correct type",
file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
for name, type, _ in lhs:
stack[name] = type
return code
def funccall_expr(line, stack, consts, funcs, selfvars, expr, code):
func_name, arguments = expr[0], expr[1:]
if func_name not in funcs:
print("error: non-existant function call",
file=sys.stderr)
print("line:", line.text, "line:", line.lineno)
return None
arguments = [("self." + arg if arg in selfvars else arg)
for arg in arguments]
code += "%s(cs.namespace(|| \"%s\"), %s)?;" % (
func_name, line.text, ", ".join(arguments))
return_type = funcs[func_name][-1][-1]
return code, return_type
def compile_method_call(line, stack, consts, funcs, selfvars, statement):
variable = statement[0]
method = statement[1].children[0]
arguments = statement[1].children[1:]
arguments = [("self." + arg if arg in selfvars else arg)
for arg in arguments]
return "%s.%s(%s);" % (variable, method, ", ".join(arguments))
def interpret_contract(contract, consts, funcs):
contract_def = parse_contract_def(contract[0].text)
contract_code = create_contract_header(contract_def)
selfvars = set(varname[0] for varname in contract_def[1])
stack = dict(contract_def[1])
for line in contract[1:10]:
indent = " " * 4 * int(line.level + 1)
statement_type, statement = interpret_contract_line(line.text, stack, consts)
#pprint.pprint(statement_type)
if statement_type == "let":
code = compile_let2(line, stack, consts, funcs, selfvars, statement)
if code is None:
return
elif statement_type == "method":
code = compile_method_call(line, stack, consts, funcs,
selfvars, statement)
if code is None:
return
contract_code += indent + code + "\n"
contract_code += " " * 8 + "Ok(())\n"
contract_code += " " * 4 + "}\n"
contract_code += "}\n\n"
#print("-------------------------------")
#print(contract_code)
return contract_code, contract_def
def main(argv):
if len(argv) == 1:
print("error: missing proof file", file=sys.stderr)
return -1
filename = sys.argv[1]
text = open(filename, "r").read()
if (linedescs := parse(text)) is None:
return -1
sections = section(linedescs)
consts, funcs, contracts = classify(sections)
consts = read_consts(consts)
compiled_funcs = {}
for func in funcs:
if (compiled := interpret_func(func, consts)) is None:
return -1
_, _, func_def = compiled
func_name, _, _ = func_def
compiled_funcs[func_name] = compiled
funcs = compiled_funcs
compiled_contracts = {}
for contract in contracts[1:]:
if (compiled := interpret_contract(contract, consts, funcs)) is None:
return -1
#print(contract)
_, contract_def = compiled
contract_name, _, _ = contract_def
compiled_contracts[contract_name] = compiled
contracts = compiled_contracts
# Concat
output = ""
for _, func in funcs.items():
output += func[0]
for _, contract in contracts.items():
output += contract[0]
#print(output)
if __name__ == "__main__":
main(sys.argv)

16
scripts/pasta/Cargo.toml Normal file
View File

@@ -0,0 +1,16 @@
[package]
name = "pasta"
version = "0.1.0"
authors = ["narodnik <x@x.org>"]
edition = "2018"
[dependencies]
pasta_curves = { path = "/home/narodnik/src/sw/zectings/pasta_curves" }
ff = "0.9"
group = "0.9"
rand = "0.8.4"
[[bin]]
name = "pasta"
path = "main.rs"

21
scripts/pasta/main.rs Normal file
View File

@@ -0,0 +1,21 @@
use pasta_curves as pasta;
use group::{Group, Curve};
use rand::rngs::OsRng;
fn main() {
let g = pasta::vesta::Point::generator();
println!("G = {:?}", g.to_affine());
let x = pasta::vesta::Scalar::from(87u64);
println!("x = 87 = {:?}", x);
let b = g * x;
println!("B = xG = {:?}", b.to_affine());
let y = x - pasta::vesta::Scalar::from(90u64);
println!("y = x - 90 = {:?}", y);
let c = pasta::vesta::Point::random(&mut OsRng);
let d = pasta::vesta::Point::random(&mut OsRng);
println!("C = {:?}", c.to_affine());
println!("D = {:?}", d.to_affine());
println!("C + D = {:?}", (c + d).to_affine());
}

22
scripts/pasta/vesta.sage Normal file
View File

@@ -0,0 +1,22 @@
q = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
K = GF(q)
a = K(0x00)
b = K(0x05)
E = EllipticCurve(K, (a, b))
G = E(0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000, 0x02)
E.set_order(0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001 * 0x01)
p = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
Scalar = GF(p)
A = E(0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000, 0x0000000000000000000000000000000000000000000000000000000000000002)
x = Scalar(0x0000000000000000000000000000000000000000000000000000000000000057)
print(int(x))
#print([hex(x) for x in (x * A).xy()])
B = E(0x04cbd122b054187ff98f0726651096de7f55a213c233902764fe87d376eeb99c, 0x2d5c8c632b1f33eb42726cda0af9f95eaf3dfa8511e6a7b657713dbf62bb13a5)
assert int(x) * A == B
yy = Scalar(0x40000000000000000000000000000000224698fc094cf91b992d30ecfffffffe)
y = x - Scalar(90)
print(hex(int(y)))
assert yy == y

528
scripts/pism.py Normal file
View File

@@ -0,0 +1,528 @@
import json
import os
import sys
import codegen
symbol_table = {
"contract": 1,
"param": 2,
"start": 0,
"end": 0,
}
types_map = {
"U64": "u64",
"Fr": "jubjub::Fr",
"Point": "jubjub::SubgroupPoint",
"Scalar": "bls12_381::Scalar",
"Bool": "bool"
}
feature_includes = {"G_SPEND": "use crate::crypto::merkle_node::SAPLING_COMMITMENT_TREE_DEPTH;\n"}
command_desc = {
"witness": (
("EdwardsPoint", True),
("Point", False)
),
"assert_not_small_order": (
("EdwardsPoint", False),
),
"u64_as_binary_le": (
("Vec<Boolean>", True),
("U64", False),
),
"fr_as_binary_le": (
("Vec<Boolean>", True),
("Fr", False)
),
"ec_mul_const": (
("EdwardsPoint", True),
("Vec<Boolean>", False),
("FixedGenerator", False)
),
"ec_mul": (
("EdwardsPoint", True),
("Vec<Boolean>", False),
("EdwardsPoint", False),
),
"ec_add": (
("EdwardsPoint", True),
("EdwardsPoint", False),
("EdwardsPoint", False),
),
"ec_repr": (
("Vec<Boolean>", True),
("EdwardsPoint", False),
),
"ec_get_u": (
("ScalarNum", True),
("EdwardsPoint", False),
),
"emit_ec": (
("EdwardsPoint", False),
),
"alloc_binary": (
("Vec<Boolean>", True),
),
"binary_clone": (
("Vec<Boolean>", True),
("Vec<Boolean>", False),
),
"binary_extend": (
("Vec<Boolean>", False),
("Vec<Boolean>", False),
),
"binary_push": (
("Vec<Boolean>", False),
("Boolean", False),
),
"binary_truncate": (
("Vec<Boolean>", False),
("BinarySize", False),
),
"static_assert_binary_size": (
("Vec<Boolean>", False),
("INTEGER", False),
),
"blake2s": (
("Vec<Boolean>", True),
("Vec<Boolean>", False),
("BlakePersonalization", False),
),
"pedersen_hash": (
("EdwardsPoint", True),
("Vec<Boolean>", False),
("PedersenPersonalization", False),
),
"emit_binary": (
("Vec<Boolean>", False),
),
"alloc_bit": (
("Boolean", True),
("Bool", False),
),
"alloc_const_bit": (
("Boolean", True),
("BOOL_CONST", False),
),
"clone_bit": (
("Boolean", True),
("Boolean", False),
),
"alloc_scalar": (
("ScalarNum", True),
("Scalar", False),
),
"scalar_as_binary": (
("Vec<Boolean>", True),
("ScalarNum", False),
),
"emit_scalar": (
("ScalarNum", False),
),
"scalar_enforce_equal": (
("ScalarNum", False),
("ScalarNum", False),
),
"conditionally_reverse": (
("ScalarNum", True),
("ScalarNum", True),
("ScalarNum", False),
("ScalarNum", False),
("Boolean", False),
),
}
def eprint(*args):
print(*args, file=sys.stderr)
class Line:
def __init__(self, text, line_number):
self.text = text
self.orig = text
self.lineno = line_number
self.clean()
def clean(self):
# Remove the comments
self.text = self.text.split("#", 1)[0]
# Remove whitespace
self.text = self.text.strip()
def is_empty(self):
return bool(self.text)
def __repr__(self):
return "Line %s: %s" % (self.lineno, self.orig.lstrip())
def command(self):
if not self.is_empty():
return None
return self.text.split(" ")[0]
def args(self):
if not self.is_empty():
return None
return self.text.split(" ")[1:]
def clean(contents):
# Split input into lines
contents = contents.split("\n")
contents = [Line(line, i) for i, line in enumerate(contents)]
# Remove empty blank lines
contents = [line for line in contents if line.is_empty()]
return contents
def make_segments(contents):
constants = [line for line in contents if line.command() == "constant"]
segments = []
current_segment = []
for line in contents:
if line.command() == "contract":
current_segment = []
current_segment.append(line)
if line.command() == "end":
segments.append(current_segment)
current_segment = []
return constants, segments
def build_constants_table(constants):
table = {}
for line in constants:
args = line.args()
if len(args) != 2:
eprint("error: wrong number of args")
eprint(line)
return None
name, type = args
table[name] = type
return table
def extract(segment):
assert segment
# Does it have a declaration?
if not segment[0].command() == "contract":
eprint("error: missing contract declaration")
eprint(segment[0])
return None
# Does it have an end?
if not segment[-1].command() == "end":
eprint("error: missing contract end")
eprint(segment[-1])
return None
# Does it have a start?
if not [line for line in segment if line.command() == "start"]:
eprint("error: missing contract start")
eprint(segment[0])
return None
for line in segment:
command, args = line.command(), line.args()
if command in symbol_table:
if symbol_table[command] != len(args):
eprint("error: wrong number of args for command '%s'" % command)
eprint(line)
return None
elif command in command_desc:
if len(command_desc[command]) != len(args):
eprint("error: wrong number of args for command '%s'" % command)
eprint(line)
return None
else:
eprint("error: missing symbol for command '%s'" % command)
eprint(line)
return None
contract_name = segment[0].args()[0]
start_index = [index for index, line in enumerate(segment)
if line.command() == "start"]
if len(start_index) > 1:
eprint("error: multiple start statements in contract '%s'" %
contract_name)
for index in start_index:
eprint(segment[index])
eprint("Aborting.")
return None
assert len(start_index) == 1
start_index = start_index[0]
header = segment[1:start_index]
code = segment[start_index + 1:-1]
params = {}
for param_decl in header:
args = param_decl.args()
assert len(args) == 2
name, type = args
params[name] = type
program = []
for line in code:
command, args = line.command(), line.args()
program.append((command, args, line))
return Contract(contract_name, params, program)
def to_initial_caps(snake_str):
components = snake_str.split("_")
return "".join(x.title() for x in components)
class Contract:
def __init__(self, name, params, program):
self.name = name
self.params = params
self.program = program
def _includes(self):
return \
r"""#![allow(unused_imports)]
#![allow(unused_mut)]
use bellman::{
gadgets::{
boolean,
boolean::{AllocatedBit, Boolean},
multipack,
blake2s,
num,
Assignment,
},
groth16, Circuit, ConstraintSystem, SynthesisError,
};
use bls12_381::Bls12;
use ff::{PrimeField, Field};
use group::Curve;
use zcash_proofs::circuit::{ecc, pedersen_hash};
"""
def _compile_header(self):
code = "pub struct %s {\n" % to_initial_caps(self.name)
for param_name, param_type in self.params.items():
try:
mapped_type = types_map[param_type]
except KeyError:
return None
code += " pub %s: Option<%s>,\n" % (param_name, mapped_type)
code += "}\n"
return code
def _compile_body(self):
self.stack = {}
code = "\n"
#indent = " " * 8
for command, args, line in self.program:
if (code_text := self._compile_line(command, args, line)) is None:
return None
code += "// %s\n" % str(line)
code += code_text + "\n\n"
return code
def _preprocess_args(self, args, line):
nargs = []
for arg in args:
if not arg.startswith("param:"):
nargs.append((arg, False))
continue
_, argname = arg.split(":", 1)
if argname not in self.params:
eprint("error: non-existant param referenced")
eprint(line)
return None
nargs.append((argname, True))
return nargs
def type_checking(self, command, args, line):
assert command in command_desc
type_list = command_desc[command]
if len(type_list) != len(args):
eprint("error: wrong number of arguments!")
eprint(line)
return False
for (expected_type, new_val), (argname, is_param) in \
zip(type_list, args):
# Only type check input arguments, not output values
if new_val:
continue
if expected_type == "INTEGER" or expected_type == "BOOL_CONST":
continue
if is_param:
actual_type = self.params[argname]
elif argname in self.constants:
actual_type = self.constants[argname]
else:
# Check the stack here
if argname not in self.stack:
eprint("error: cannot find value '%s' on the stack!" %
argname)
eprint(line)
return False
actual_type = self.stack[argname]
if expected_type != actual_type:
eprint("error: wrong type for arg '%s'!" % argname)
eprint(line)
return False
return True
def _check_args(self, command, args, line):
assert command in command_desc
type_list = command_desc[command]
assert len(type_list) == len(args)
for (expected_type, is_new_val), (arg, is_param) in zip(type_list, args):
if is_param:
continue
if is_new_val:
continue
if arg in self.stack:
continue
if arg in self.constants:
continue
if expected_type == "INTEGER" or expected_type == "BOOL_CONST":
continue
eprint("error: cannot find '%s' in the stack" % arg)
eprint(line)
return False
return True
def _compile_line(self, command, args, line):
if (args := self._preprocess_args(args, line)) is None:
return None
if not self.type_checking(command, args, line):
return None
if not self._check_args(command, args, line):
return None
self.modify_stack(command, args)
args = [self.carg(arg) for arg in args]
try:
codegen_method = getattr(codegen, command)
except AttributeError:
eprint("error: missing command '%s' does not exist" % command)
eprint(line)
return None
return codegen_method(line, *args)
def carg(self, arg):
argname, is_param = arg
if is_param:
return "self.%s" % argname
if argname in self.rename_consts:
return self.rename_consts[argname]
return argname
def modify_stack(self, command, args):
type_list = command_desc[command]
assert len(type_list) == len(args)
for (expected_type, new_val), (argname, is_param) in \
zip(type_list, args):
if is_param:
assert not new_val
continue
# Now apply the new values to the stack
if new_val:
self.stack[argname] = expected_type
def compile(self, constants, aux):
self.constants = constants
code = ""
code += self._includes()
self.rename_consts = {}
if "constants" in aux:
for const_name, value in aux["constants"].items():
if "maps_to" not in value:
eprint("error: bad aux config '%s', missing maps_to" %
const_name)
return None
if const_name in feature_includes:
code += feature_includes[const_name]
mapped_type = value["maps_to"]
self.rename_consts[const_name] = mapped_type
code += "\n"
if (header := self._compile_header()) is None:
return None
code += header
code += \
r"""impl Circuit<bls12_381::Scalar> for %s {
fn synthesize<CS: ConstraintSystem<bls12_381::Scalar>>(
self,
cs: &mut CS,
) -> Result<(), SynthesisError> {
""" % to_initial_caps(self.name)
if (body := self._compile_body()) is None:
return None
code += body
code += "Ok(())\n"
code += " }\n"
code += "}\n"
return code
def process(contents, aux):
contents = clean(contents)
constants, segments = make_segments(contents)
if (constants := build_constants_table(constants)) is None:
return False
codes = []
for segment in segments:
if (contract := extract(segment)) is None:
return False
if (code := contract.compile(constants, aux)) is None:
return False
codes.append(code)
# Success! Output finished product.
[print(code) for code in codes]
return True
def main(argv):
if len(argv) != 3:
eprint("pism FILENAME AUX_FILENAME")
return -1
aux_filename = argv[2]
aux = json.loads(open(aux_filename).read())
src_filename = argv[1]
contents = open(src_filename).read()
if not process(contents, aux):
return -2
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv))

41
scripts/pism.vim Normal file
View File

@@ -0,0 +1,41 @@
"For autoload, add this to your VIM config:
" VIM: .vimrc
" NeoVIM: .config/nvim/init.vim
"
"autocmd BufRead *.pism call SetPismOptions()
"function SetPismOptions()
" set syntax=pism
" source /home/narodnik/src/drk/scripts/pism.vim
"endfunction
if exists('b:current_syntax')
finish
endif
syn keyword drkKeyword constant contract start end constraint
"syn keyword drkAttr
syn keyword drkType FixedGenerator BlakePersonalization PedersenPersonalization ByteSize U64 Fr Point Bool Scalar BinarySize
syn keyword drkFunctionKeyword enforce lc0_add_one lc1_add_one lc2_add_one lc_coeff_reset lc_coeff_double lc0_sub_one lc1_sub_one lc2_sub_one dump_alloc dump_local
syn match drkFunction "^[ ]*[a-z_0-9]* "
syn match drkComment "#.*$"
syn match drkNumber ' \zs\d\+\ze'
syn match drkHexNumber ' \zs0x[a-z0-9]\+\ze'
syn match drkConst '[A-Z_]\{2,}[A-Z0-9_]*'
syn keyword drkBoolVal true false
syn match drkPreproc "{%.*%}"
syn match drkPreproc2 "{{.*}}"
hi def link drkKeyword Statement
"hi def link drkAttr StorageClass
hi def link drkPreproc PreProc
hi def link drkPreproc2 PreProc
hi def link drkType Type
hi def link drkFunction Function
hi def link drkFunctionKeyword Function
hi def link drkComment Comment
hi def link drkNumber Constant
hi def link drkHexNumber Constant
hi def link drkConst Constant
hi def link drkBoolVal Constant
let b:current_syntax = "pism"

20
scripts/preprocess.py Normal file
View File

@@ -0,0 +1,20 @@
import os.path
import sys
from jinja2 import Environment, FileSystemLoader, Template
def main(argv):
if len(argv) != 2:
print("error: missing arg", file=sys.stderr)
return -1
path = argv[1]
dirname, filename = os.path.dirname(path), os.path.basename(path)
env = Environment(loader = FileSystemLoader([dirname]))
template = env.get_template(filename)
print(template.render())
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv))

29
scripts/reorder-logs.py Normal file
View File

@@ -0,0 +1,29 @@
import datetime as dt
def isotime_to_ms(isotime):
ms = dt.timedelta(microseconds=1)
time = dt.time.fromisoformat(isotime)
ms_time = (dt.datetime.combine(dt.date.min, time) - dt.datetime.min) / ms
return ms_time
def line_time(line):
return isotime_to_ms(line.split()[1])
lines = []
filenames = {"Client": "/tmp/a.txt", "Server": "/tmp/b.txt"}
for label, filename in filenames.items():
with open(filename) as file:
file_lines = file.read().split("\n")
# Cleanup a bit
file_lines = [line for line in file_lines if line and line[0].isdigit()]
# Attach the label to each line
file_lines = ["%s: %s" % (label, line) for line in file_lines]
lines.extend(file_lines)
lines.sort(key=line_time)
for line in lines:
# Now remove timestamps and other info we don't need
line = line.split()
line = line[0] + " " + " ".join(line[4:])
print(line)

5
scripts/to_html.sh Executable file
View File

@@ -0,0 +1,5 @@
#!/bin/bash
#nvim -c ":TOhtml" $1
sed -i "s/PreProc { color: #5fd7ff; }/PreProc { color: #8f2722; }/" $1
sed -i "s/Comment { color: #00ffff; }/Comment { color: #0055ff; }/" $1

View File

@@ -0,0 +1,103 @@
from bls_py import bls12381
from bls_py import pairing
from bls_py import ec
from bls_py.fields import Fq, Fq2, Fq6, Fq12, bls12381_q as Q
import random
import numpy as np
# Section 3.3.4 from "Why and How zk-SNARK Works"
def rand_scalar():
return random.randrange(1, bls12381.q)
#x = rand_scalar()
#y = ec.y_for_x(x)
g1 = ec.generator_Fq(bls12381)
g2 = ec.generator_Fq2(bls12381)
null = ec.AffinePoint(Fq(Q, 0), Fq(Q, 1), True, bls12381)
assert g1 + null == g1
#################################
# Verifier (trusted setup)
#################################
# samples a random value (a secret)
s = rand_scalar()
# calculates encryptions of s for all powers i in 0 to d
# E(s^i) = g^s^i
d = 10
encrypted_powers = [
g1 * (s**i) for i in range(d)
]
# evaluates unencrypted target polynomial with s: t(s)
target = (s - 1) * (s - 2)
# encrypted values of s provided to the prover
# Actual values of s are toxic waste and discarded
#################################
# Prover
#################################
# E(p(s)) = p(s)G
# = c_d s^d G + ... + c_1 s^1 G + c_0 s^0 G
# = s^3 G - 3 s^2 G + 2 s G
# E(h(s)) = sG
# t(s) = s^2 - 3s + 2
# E(h(s)) t(s) = s^3 G - 3 s^2 G + 2 s G
# Lets test these manually:
e_s = encrypted_powers
e_p_s = e_s[3] - 3 * e_s[2] + 2 * e_s[1]
e_h_s = e_s[1]
t_s = s**2 - 3*s + 2
assert t_s == target
assert e_p_s == e_h_s * t_s
#############################
# x^3 - 3x^2 + 2x
main_poly = np.poly1d([1, -3, 2, 0])
# (x - 1)(x - 2)
target_poly = np.poly1d([1, -1]) * np.poly1d([1, -2])
# Calculates polynomial h(x) = p(x) / t(x)
cofactor, remainder = main_poly / target_poly
assert remainder == np.poly1d([0])
# Using encrypted powers and coefficients, evaluates
# E(p(s)) and E(h(s))
def evaluate(poly, encrypted_powers):
coeffs = list(poly.coef)[::-1]
result = null
for power, coeff in zip(encrypted_powers, coeffs):
#print(coeff, power)
coeff = int(coeff)
# I have to do this for some strange reason
# Because if coeff is negative and I do += power * coeff
# then it gives me a different result than what I expect
if coeff < 0:
result -= power * (-coeff)
else:
result += power * coeff
return result
encrypted_poly = evaluate(main_poly, encrypted_powers)
assert encrypted_poly == e_p_s
encrypted_cofactor = evaluate(cofactor, encrypted_powers)
# resulting g^p and g^h are provided to the verifier
#################################
# Verifier
#################################
# Last check that p = t(s) h
assert encrypted_poly == encrypted_cofactor * target

View File

@@ -0,0 +1,119 @@
from bls_py import bls12381
from bls_py import pairing
from bls_py import ec
from bls_py.fields import Fq, Fq2, Fq6, Fq12, bls12381_q as Q
import random
import numpy as np
# Section 3.4 from "Why and How zk-SNARK Works"
def rand_scalar():
return random.randrange(1, bls12381.q)
#x = rand_scalar()
#y = ec.y_for_x(x)
g1 = ec.generator_Fq(bls12381)
g2 = ec.generator_Fq2(bls12381)
null = ec.AffinePoint(Fq(Q, 0), Fq(Q, 1), True, bls12381)
assert g1 + null == g1
#################################
# Verifier (trusted setup)
#################################
# samples a random value (a secret)
s = rand_scalar()
# calculate the shift
a = rand_scalar()
# calculates encryptions of s for all powers i in 0 to d
# E(s^i) = g^s^i
d = 10
encrypted_powers = [
g1 * (s**i) for i in range(d)
]
encrypted_shifted_powers = [
g1 * (a * s**i) for i in range(d)
]
# evaluates unencrypted target polynomial with s: t(s)
target = (s - 1) * (s - 2)
# encrypted values of s provided to the prover
# Actual values of s are toxic waste and discarded
#################################
# Prover
#################################
# E(p(s)) = p(s)G
# = c_d s^d G + ... + c_1 s^1 G + c_0 s^0 G
# = s^3 G - 3 s^2 G + 2 s G
# E(h(s)) = sG
# t(s) = s^2 - 3s + 2
# E(h(s)) t(s) = s^3 G - 3 s^2 G + 2 s G
# Lets test these manually:
e_s = encrypted_powers
e_p_s = e_s[3] - 3 * e_s[2] + 2 * e_s[1]
e_h_s = e_s[1]
t_s = s**2 - 3*s + 2
assert t_s == target
assert e_p_s == e_h_s * t_s
e_as = encrypted_shifted_powers
e_p_as = e_as[3] - 3 * e_as[2] + 2 * e_as[1]
assert e_p_s * a == e_p_as
#############################
# x^3 - 3x^2 + 2x
main_poly = np.poly1d([1, -3, 2, 0])
# (x - 1)(x - 2)
target_poly = np.poly1d([1, -1]) * np.poly1d([1, -2])
# Calculates polynomial h(x) = p(x) / t(x)
cofactor, remainder = main_poly / target_poly
assert remainder == np.poly1d([0])
# Using encrypted powers and coefficients, evaluates
# E(p(s)) and E(h(s))
def evaluate(poly, encrypted_powers):
coeffs = list(poly.coef)[::-1]
result = null
for power, coeff in zip(encrypted_powers, coeffs):
#print(coeff, power)
coeff = int(coeff)
# I have to do this for some strange reason
# Because if coeff is negative and I do += power * coeff
# then it gives me a different result than what I expect
if coeff < 0:
result -= power * (-coeff)
else:
result += power * coeff
return result
encrypted_poly = evaluate(main_poly, encrypted_powers)
assert encrypted_poly == e_p_s
encrypted_cofactor = evaluate(cofactor, encrypted_powers)
# Alpha shifted powers
encrypted_shift_poly = evaluate(main_poly, encrypted_shifted_powers)
# resulting g^p and g^h are provided to the verifier
#################################
# Verifier
#################################
# Last check that p = t(s) h
assert encrypted_poly == encrypted_cofactor * target
# Verify (g^p)^a == g^p'
assert encrypted_poly * a == encrypted_shift_poly

View File

@@ -0,0 +1,129 @@
from bls_py import bls12381
from bls_py import pairing
from bls_py import ec
from bls_py.fields import Fq, Fq2, Fq6, Fq12, bls12381_q as Q
import random
import numpy as np
# Section 3.5 from "Why and How zk-SNARK Works"
def rand_scalar():
return random.randrange(1, bls12381.q)
#x = rand_scalar()
#y = ec.y_for_x(x)
g1 = ec.generator_Fq(bls12381)
g2 = ec.generator_Fq2(bls12381)
null = ec.AffinePoint(Fq(Q, 0), Fq(Q, 1), True, bls12381)
assert g1 + null == g1
#################################
# Verifier (trusted setup)
#################################
# samples a random value (a secret)
s = rand_scalar()
# calculate the shift
a = rand_scalar()
# calculates encryptions of s for all powers i in 0 to d
# E(s^i) = g^s^i
d = 10
encrypted_powers = [
g1 * (s**i) for i in range(d)
]
encrypted_shifted_powers = [
g1 * (a * s**i) for i in range(d)
]
# evaluates unencrypted target polynomial with s: t(s)
target = (s - 1) * (s - 2)
# encrypted values of s provided to the prover
# Actual values of s are toxic waste and discarded
#################################
# Prover
#################################
# delta shift
delta = rand_scalar()
# E(p(s)) = p(s)G
# = c_d s^d G + ... + c_1 s^1 G + c_0 s^0 G
# = s^3 G - 3 s^2 G + 2 s G
# E(h(s)) = sG
# t(s) = s^2 - 3s + 2
# E(h(s)) t(s) = s^3 G - 3 s^2 G + 2 s G
# Lets test these manually:
e_s = encrypted_powers
e_p_s = e_s[3] - 3 * e_s[2] + 2 * e_s[1]
e_h_s = e_s[1]
t_s = s**2 - 3*s + 2
# exponentiate with delta
e_p_s *= delta
e_h_s *= delta
assert t_s == target
assert e_p_s == e_h_s * t_s
e_as = encrypted_shifted_powers
e_p_as = e_as[3] - 3 * e_as[2] + 2 * e_as[1]
# exponentiate with delta
e_p_as *= delta
assert e_p_s * a == e_p_as
#############################
# x^3 - 3x^2 + 2x
main_poly = np.poly1d([1, -3, 2, 0])
# (x - 1)(x - 2)
target_poly = np.poly1d([1, -1]) * np.poly1d([1, -2])
# Calculates polynomial h(x) = p(x) / t(x)
cofactor, remainder = main_poly / target_poly
assert remainder == np.poly1d([0])
# Using encrypted powers and coefficients, evaluates
# E(p(s)) and E(h(s))
def evaluate(poly, encrypted_powers):
coeffs = list(poly.coef)[::-1]
result = null
for power, coeff in zip(encrypted_powers, coeffs):
#print(coeff, power)
coeff = int(coeff)
# I have to do this for some strange reason
# Because if coeff is negative and I do += power * coeff
# then it gives me a different result than what I expect
if coeff < 0:
result -= power * (-coeff)
else:
result += power * coeff
# Add delta to the result
# Free extra obfuscation to the polynomial
return result * delta
encrypted_poly = evaluate(main_poly, encrypted_powers)
assert encrypted_poly == e_p_s
encrypted_cofactor = evaluate(cofactor, encrypted_powers)
# Alpha shifted powers
encrypted_shift_poly = evaluate(main_poly, encrypted_shifted_powers)
# resulting g^p and g^h are provided to the verifier
#################################
# Verifier
#################################
# Last check that p = t(s) h
assert encrypted_poly == encrypted_cofactor * target
# Verify (g^p)^a == g^p'
assert encrypted_poly * a == encrypted_shift_poly

View File

@@ -0,0 +1,152 @@
from bls_py import bls12381
from bls_py import pairing
from bls_py import ec
from bls_py.fields import Fq, Fq2, Fq6, Fq12, bls12381_q as Q
import random
import numpy as np
# Section 3.6 from "Why and How zk-SNARK Works"
def rand_scalar():
return random.randrange(1, bls12381.q)
#x = rand_scalar()
#y = ec.y_for_x(x)
g1 = ec.generator_Fq(bls12381)
g2 = ec.generator_Fq2(bls12381)
null = ec.AffinePoint(Fq(Q, 0), Fq(Q, 1), True, bls12381)
assert g1 + null == g1
null2 = ec.AffinePoint(Fq2.zero(Q), Fq2.zero(Q), True, bls12381)
assert null2 + g2 == g2
#################################
# Verifier (trusted setup)
#################################
# samples a random value (a secret)
s = rand_scalar()
# calculate the shift
a = rand_scalar()
# calculates encryptions of s for all powers i in 0 to d
# E(s^i) = g^s^i
d = 10
encrypted_powers = [
g1 * (s**i) for i in range(d)
]
encrypted_powers_g2 = [
g2 * (s**i) for i in range(d)
]
encrypted_shifted_powers = [
g1 * (a * s**i) for i in range(d)
]
# evaluates unencrypted target polynomial with s: t(s)
target = (s - 1) * (s - 2)
# CRS = common reference string = trusted setup parameters
target_crs = g1 * target
alpha_crs = g2 * a
# Proving key = (encrypted_powers, encrypted_shifted_powers)
# Verify key = (target_crs, alpha_crs)
# encrypted values of s provided to the prover
# Actual values of s are toxic waste and discarded
#################################
# Prover
#################################
# delta shift
delta = rand_scalar()
# E(p(s)) = p(s)G
# = c_d s^d G + ... + c_1 s^1 G + c_0 s^0 G
# = s^3 G - 3 s^2 G + 2 s G
# E(h(s)) = sG
# t(s) = s^2 - 3s + 2
# E(h(s)) t(s) = s^3 G - 3 s^2 G + 2 s G
# Lets test these manually:
e_s = encrypted_powers
e_p_s = e_s[3] - 3 * e_s[2] + 2 * e_s[1]
e_h_s = e_s[1]
t_s = s**2 - 3*s + 2
# exponentiate with delta
e_p_s *= delta
e_h_s *= delta
assert t_s == target
assert e_p_s == e_h_s * t_s
e_as = encrypted_shifted_powers
e_p_as = e_as[3] - 3 * e_as[2] + 2 * e_as[1]
# exponentiate with delta
e_p_as *= delta
assert e_p_s * a == e_p_as
#############################
# x^3 - 3x^2 + 2x
main_poly = np.poly1d([1, -3, 2, 0])
# (x - 1)(x - 2)
target_poly = np.poly1d([1, -1]) * np.poly1d([1, -2])
# Calculates polynomial h(x) = p(x) / t(x)
cofactor, remainder = main_poly / target_poly
assert remainder == np.poly1d([0])
# Using encrypted powers and coefficients, evaluates
# E(p(s)) and E(h(s))
def evaluate(poly, encrypted_powers, identity):
coeffs = list(poly.coef)[::-1]
result = identity
for power, coeff in zip(encrypted_powers, coeffs):
#print(coeff, power)
coeff = int(coeff)
# I have to do this for some strange reason
# Because if coeff is negative and I do += power * coeff
# then it gives me a different result than what I expect
if coeff < 0:
result -= power * (-coeff)
else:
result += power * coeff
# Add delta to the result
# Free extra obfuscation to the polynomial
return result * delta
encrypted_poly = evaluate(main_poly, encrypted_powers, null)
assert encrypted_poly == e_p_s
encrypted_cofactor = evaluate(cofactor, encrypted_powers_g2, null2)
# Alpha shifted powers
encrypted_shift_poly = evaluate(main_poly, encrypted_shifted_powers, null)
# resulting g^p and g^h are provided to the verifier
# proof = (encrypted_poly, encrypted_cofactor, encrypted_shift_poly)
#################################
# Verifier
#################################
# Last check that p = t(s) h
# Check polynomial cofactors:
#assert encrypted_poly == encrypted_cofactor * target
# e(g^p, g) == e(g^t, g^h)
res1 = pairing.ate_pairing(encrypted_poly, g2)
res2 = pairing.ate_pairing(target_crs, encrypted_cofactor)
assert res1 == res2
# Verify (g^p)^a == g^p'
# Check polynomial restriction:
res1 = pairing.ate_pairing(encrypted_shift_poly, g2)
res2 = pairing.ate_pairing(encrypted_poly, alpha_crs)
assert res1 == res2
#assert encrypted_poly * a == encrypted_shift_poly

View File

@@ -0,0 +1,146 @@
from bls_py import bls12381
from bls_py import pairing
from bls_py import ec
from bls_py.fields import Fq, Fq2, Fq6, Fq12, bls12381_q as Q
import random
import numpy as np
# Section 3.6 from "Why and How zk-SNARK Works"
def rand_scalar():
return random.randrange(1, bls12381.q)
#x = rand_scalar()
#y = ec.y_for_x(x)
g1 = ec.generator_Fq(bls12381)
g2 = ec.generator_Fq2(bls12381)
null = ec.AffinePoint(Fq(Q, 0), Fq(Q, 1), True, bls12381)
assert g1 + null == g1
null2 = ec.AffinePoint(Fq2.zero(Q), Fq2.zero(Q), True, bls12381)
assert null2 + g2 == g2
#################################
# Verifier (trusted setup)
#################################
# samples a random value (a secret)
s = rand_scalar()
# calculate the shift
a = rand_scalar()
# calculates encryptions of s for all powers i in 0 to d
# E(s^i) = g^s^i
d = 10
encrypted_powers = [
g1 * (s**i) for i in range(d)
]
encrypted_powers_g2 = [
g2 * (s**i) for i in range(d)
]
encrypted_shifted_powers = [
g1 * (a * s**i) for i in range(d)
]
encrypted_shifted_powers_g2 = [
g2 * (a * s**i) for i in range(d)
]
# evaluates unencrypted target polynomial with s: t(s)
target = (s - 1)
# CRS = common reference string = trusted setup parameters
target_crs = g1 * target
alpha_crs = g2 * a
alpha_crs_g1 = g1 * a
# Proving key = (encrypted_powers, encrypted_shifted_powers)
# Verify key = (target_crs, alpha_crs)
# encrypted values of s provided to the prover
# Actual values of s are toxic waste and discarded
#################################
# Prover
#################################
left_poly = np.poly1d([3])
right_poly = np.poly1d([2])
out_poly = np.poly1d([6])
# x^3 - 3x^2 + 2x
main_poly = left_poly * right_poly - out_poly
# (x - 1)
target_poly = np.poly1d([1, -1])
# Calculates polynomial h(x) = p(x) / t(x)
cofactor, remainder = main_poly / target_poly
assert remainder == np.poly1d([0])
# Using encrypted powers and coefficients, evaluates
# E(p(s)) and E(h(s))
def evaluate(poly, encrypted_powers, identity):
coeffs = list(poly.coef)[::-1]
result = identity
for power, coeff in zip(encrypted_powers, coeffs):
#print(coeff, power)
coeff = int(coeff)
# I have to do this for some strange reason
# Because if coeff is negative and I do += power * coeff
# then it gives me a different result than what I expect
if coeff < 0:
result -= power * (-coeff)
else:
result += power * coeff
return result
assert left_poly * right_poly == out_poly
encrypted_left_poly = evaluate(left_poly, encrypted_powers, null)
encrypted_right_poly = evaluate(right_poly, encrypted_powers_g2, null2)
encrypted_out_poly = evaluate(out_poly, encrypted_powers, null)
#assert encrypted_poly == e_p_s
encrypted_cofactor = evaluate(cofactor, encrypted_powers_g2, null2)
# Alpha shifted powers
encrypted_shift_left_poly = evaluate(left_poly, encrypted_shifted_powers, null)
encrypted_shift_right_poly = evaluate(right_poly, encrypted_shifted_powers_g2, null2)
encrypted_shift_out_poly = evaluate(out_poly, encrypted_shifted_powers, null)
# resulting g^p and g^h are provided to the verifier
# proof = (encrypted_poly, encrypted_cofactor, encrypted_shift_poly)
#################################
# Verifier
#################################
# Last check that p = t(s) h
assert pairing.ate_pairing(2 * g1, g2) == pairing.ate_pairing(g1, g2) * pairing.ate_pairing(g1, g2)
# Verify (g^p)^a == g^p'
# Check polynomial restriction:
def check_polynomial_restriction(encrypted_shift_poly, encrypted_poly):
res1 = pairing.ate_pairing(encrypted_shift_poly, g2)
res2 = pairing.ate_pairing(encrypted_poly, alpha_crs)
assert res1 == res2
def check_polynomial_restriction_swapped(encrypted_shift_poly, encrypted_poly):
res1 = pairing.ate_pairing(g1, encrypted_shift_poly)
res2 = pairing.ate_pairing(alpha_crs_g1, encrypted_poly)
assert res1 == res2
check_polynomial_restriction(encrypted_shift_left_poly, encrypted_left_poly)
check_polynomial_restriction_swapped(encrypted_shift_right_poly, encrypted_right_poly)
check_polynomial_restriction(encrypted_shift_out_poly, encrypted_out_poly)
# Valid operation check
# e(g^l, g^r) == e(g^t, g^h) * e(g^o, g)
res1 = pairing.ate_pairing(encrypted_left_poly, encrypted_right_poly)
res2 = pairing.ate_pairing(target_crs, encrypted_cofactor) * \
pairing.ate_pairing(encrypted_out_poly, g2)
assert res1 == res2

View File

@@ -0,0 +1,30 @@
import numpy as np
def lagrange(points):
result = np.poly1d([0])
for i, (x_i, y_i) in enumerate(points):
poly = np.poly1d([y_i])
for j, (x_j, y_j) in enumerate(points):
if i == j:
continue
poly *= np.poly1d([1, -x_j]) / (x_i - x_j)
#print(poly)
#print(poly(1), poly(2), poly(3))
result += poly
return result
left = lagrange([
(1, 2), (2, 2), (3, 6)
])
print(left)
right = lagrange([
(1, 1), (2, 3), (3, 2)
])
print(right)
out = lagrange([
(1, 2), (2, 6), (3, 12)
])
print(out)

View File

@@ -0,0 +1,167 @@
from bls_py import bls12381
from bls_py import pairing
from bls_py import ec
from bls_py.fields import Fq, Fq2, Fq6, Fq12, bls12381_q as Q
from finite_fields.modp import IntegersModP
from finite_fields.polynomial import polynomialsOver
import random
n = bls12381.n
g1 = ec.generator_Fq(bls12381)
g2 = ec.generator_Fq2(bls12381)
null = ec.AffinePoint(Fq(n, 0), Fq(n, 1), True, bls12381)
assert null + g1 == g1
null2 = ec.AffinePoint(Fq2.zero(n), Fq2.zero(n), True, bls12381)
assert null2 + g2 == g2
mod_field = IntegersModP(n)
poly = polynomialsOver(mod_field).factory
def lagrange(points):
result = poly([0])
for i, (x_i, y_i) in enumerate(points):
p = poly([y_i])
for j, (x_j, y_j) in enumerate(points):
if i == j:
continue
p *= poly([-x_j, 1]) / (x_i - x_j)
#print(poly)
#print(poly(1), poly(2), poly(3))
result += p
return result
def poly_call(poly, x):
result = mod_field(0)
for degree, coeff in enumerate(poly):
result += coeff * (x**degree)
return result.n
left_points = [
(1, 2), (2, 2), (3, 6)
]
left_poly = lagrange(left_points)
#l = poly([2]) * poly([1, -1])
print("Left:")
print(left_poly)
for x, y in left_points:
assert poly_call(left_poly, x) == y
right_points = [
(1, 1), (2, 3), (3, 2)
]
right_poly = lagrange(right_points)
print("Right:")
print(right_poly)
for x, y in right_points:
assert poly_call(right_poly, x) == y
out_points = [
(1, 2), (2, 6), (3, 12)
]
out_poly = lagrange(out_points)
print("Out:")
print(out_poly)
for x, y in out_points:
assert poly_call(out_poly, x) == y
target_poly = poly([-1, 1]) * poly([-2, 1]) * poly([-3, 1])
assert poly_call(target_poly, 1) == 0
assert poly_call(target_poly, 2) == 0
assert poly_call(target_poly, 3) == 0
main_poly = left_poly * right_poly - out_poly
cofactor_poly = main_poly / target_poly
assert left_poly * right_poly - out_poly == target_poly * cofactor_poly
def rand_scalar():
return random.randrange(1, bls12381.q)
#################################
# Verifier (trusted setup)
#################################
# samples a random value (a secret)
toxic_scalar = rand_scalar()
# calculate the shift
alpha_shift = rand_scalar()
# calculates encryptions of s for all powers i in 0 to d
# E(s^i) = g^s^i
degree = 10
enc_s1 = [
g1 * (toxic_scalar**i) for i in range(degree)
]
enc_s2 = [
g2 * (toxic_scalar**i) for i in range(degree)
]
enc_s1_shift = [
g1 * (alpha_shift * toxic_scalar**i) for i in range(degree)
]
enc_s2_shift = [
g2 * (alpha_shift * toxic_scalar**i) for i in range(degree)
]
# evaluates unencrypted target polynomial with s: t(s)
toxic_target = (toxic_scalar - 1) * (toxic_scalar - 2) * (toxic_scalar - 3)
# CRS = common reference string = trusted setup parameters
target_crs = g1 * toxic_target
alpha_crs = g2 * alpha_shift
alpha_crs_g1 = g1 * alpha_shift
# Proving key = (encrypted_powers, encrypted_shifted_powers)
# Verify key = (target_crs, alpha_crs)
# encrypted values of s provided to the prover
# Actual values of s are toxic waste and discarded
#################################
# Prover
#################################
# Using encrypted powers and coefficients, evaluates
# E(p(s)) and E(h(s))
def evaluate(poly, encrypted_powers, identity):
result = identity
for power, coeff in zip(encrypted_powers, poly):
result += power * coeff.n
return result
enc_left = evaluate(left_poly, enc_s1, null)
enc_right = evaluate(right_poly, enc_s2, null2)
enc_out = evaluate(out_poly, enc_s1, null)
enc_cofactor = evaluate(cofactor_poly, enc_s2, null2)
# Alpha shifted powers
enc_left_shift = evaluate(left_poly, enc_s1_shift, null)
enc_right_shift = evaluate(right_poly, enc_s2_shift, null2)
enc_out_shift = evaluate(out_poly, enc_s1_shift, null)
#################################
# Verifier
#################################
def restrict_polynomial_g1(encrypted_shift_poly, encrypted_poly):
res1 = pairing.ate_pairing(encrypted_shift_poly, g2)
res2 = pairing.ate_pairing(encrypted_poly, alpha_crs)
assert res1 == res2
def restrict_polynomial_g2(encrypted_shift_poly, encrypted_poly):
res1 = pairing.ate_pairing(g1, encrypted_shift_poly)
res2 = pairing.ate_pairing(alpha_crs_g1, encrypted_poly)
assert res1 == res2
restrict_polynomial_g1(enc_left_shift, enc_left)
restrict_polynomial_g2(enc_right_shift, enc_right)
restrict_polynomial_g1(enc_out_shift, enc_out)
# Valid operation check
# e(g^l, g^r) == e(g^t, g^h) * e(g^o, g)
res1 = pairing.ate_pairing(enc_left, enc_right)
res2 = pairing.ate_pairing(target_crs, enc_cofactor) * \
pairing.ate_pairing(enc_out, g2)
assert res1 == res2

View File

@@ -0,0 +1,111 @@
from bls_py import bls12381
from bls_py import pairing
from bls_py import ec
from bls_py.fields import Fq, Fq2, Fq6, Fq12, bls12381_q as Q
from finite_fields.modp import IntegersModP
from finite_fields.polynomial import polynomialsOver
import random
n = bls12381.n
g1 = ec.generator_Fq(bls12381)
g2 = ec.generator_Fq2(bls12381)
mod_field = IntegersModP(n)
poly = polynomialsOver(mod_field).factory
def lagrange(points):
result = poly([0])
for i, (x_i, y_i) in enumerate(points):
p = poly([y_i])
for j, (x_j, y_j) in enumerate(points):
if i == j:
continue
p *= poly([-x_j, 1]) / (x_i - x_j)
#print(poly)
#print(poly(1), poly(2), poly(3))
result += p
return result
l_a_points = [
(1, 1), (2, 1), (3, 0)
]
l_a = lagrange(l_a_points)
#print(l_a)
l_d_points = [
(1, 0), (2, 0), (3, 1)
]
l_d = lagrange(l_d_points)
#print(l_d)
# a x b = r_1
# a x c = r_2
# d x c = r_3
# a = 3
# d = 2
L = 3*l_a + 2*l_d
#print(L)
def poly_call(poly, x):
result = mod_field(0)
for degree, coeff in enumerate(poly):
result += coeff * (x**degree)
return result.n
assert poly_call(L, 1) == 3
assert poly_call(L, 2) == 3
assert poly_call(L, 3) == 2
def rand_scalar():
return random.randrange(1, bls12381.q)
#################################
# Verifier (trusted setup)
#################################
# samples a random value (a secret)
toxic_scalar = rand_scalar()
# calculate the shift
alpha_shift = rand_scalar()
l_a_s = poly_call(l_a, toxic_scalar)
l_d_s = poly_call(l_d, toxic_scalar)
enc_a_s = g1 * l_a_s
enc_a_s_alpha = enc_a_s * alpha_shift
enc_d_s = g1 * l_d_s
enc_d_s_alpha = enc_d_s * alpha_shift
# Proving key is enc_* values above
# Actual values of s are toxic waste and discarded
verify_key = g2 * alpha_shift
#################################
# Prover
#################################
a = 3
d = 2
assigned_a = enc_a_s * a
assigned_d = enc_d_s * d
assigned_a_shift = enc_a_s_alpha * a
assigned_d_shift = enc_d_s_alpha * d
operand = assigned_a + assigned_d
operand_shift = assigned_a_shift + assigned_d_shift
# proof = operand, operand_shift
#################################
# Verifier
#################################
e = pairing.ate_pairing
assert e(operand_shift, g2) == e(operand, verify_key)

View File

@@ -0,0 +1,139 @@
# Algorithm:
# if w { a * b } else { a + b }
# Equation:
# f(w, a, b) = w(ab) + (1 - w)(a + b) = v
# w(ab) + a + b - w(ab) = v
# w(ab - a - b) = v - a - b
# Constraints:
# 1: [1 a] [1 b] [1 m]
# 2: [1 w] [1 m, -1 a, -1 b] = [1 v, -1 a, -1 b]
# 3: [1 w] [1 w] [1 w]
# f(1, 4, 2) = 8
from bls_py import bls12381
from bls_py import pairing
from bls_py import ec
from bls_py.fields import Fq, Fq2, Fq6, Fq12, bls12381_q as Q
from finite_fields.modp import IntegersModP
from finite_fields.polynomial import polynomialsOver
import random
n = bls12381.n
g1 = ec.generator_Fq(bls12381)
g2 = ec.generator_Fq2(bls12381)
mod_field = IntegersModP(n)
poly = polynomialsOver(mod_field).factory
def lagrange(points):
result = poly([0])
for i, (x_i, y_i) in enumerate(points):
p = poly([y_i])
for j, (x_j, y_j) in enumerate(points):
if i == j:
continue
p *= poly([-x_j, 1]) / (x_i - x_j)
#print(poly)
#print(poly(1), poly(2), poly(3))
result += p
return result
left_variables = {
"a": lagrange([
(1, 1), (2, 0), (3, 0)
]),
"w": lagrange([
(1, 0), (2, 1), (3, 1)
])
}
right_variables = {
"m": lagrange([
(1, 0), (2, 1), (3, 0)
]),
"a": lagrange([
(1, 0), (2, -1), (3, 0)
]),
"b": lagrange([
(1, 1), (2, -1), (3, 0)
]),
"w": lagrange([
(1, 0), (2, 0), (3, 1)
]),
}
out_variables = {
"m": lagrange([
(1, 1), (2, 0), (3, 0)
]),
"v": lagrange([
(1, 0), (2, 1), (3, 0)
]),
"a": lagrange([
(1, 0), (2, -1), (3, 0)
]),
"b": lagrange([
(1, 0), (2, -1), (3, 0)
]),
"w": lagrange([
(1, 0), (2, 0), (3, 1)
]),
}
private_inputs = {
"w": 1,
"a": 3,
"b": 2
}
private_inputs["m"] = private_inputs["a"] * private_inputs["b"]
private_inputs["v"] = \
private_inputs["w"] * (
private_inputs["m"] - private_inputs["a"] - private_inputs["b"]) \
+ private_inputs["a"] + private_inputs["b"]
assert private_inputs["v"] == 6
left_variable_poly = (
private_inputs["a"] * left_variables["a"]
+ private_inputs["w"] * left_variables["w"]
)
right_variable_poly = (
private_inputs["m"] * right_variables["m"]
+ private_inputs["a"] * right_variables["a"]
+ private_inputs["b"] * right_variables["b"]
+ private_inputs["w"] * right_variables["w"]
)
out_variable_poly = (
private_inputs["m"] * out_variables["m"]
+ private_inputs["v"] * out_variables["v"]
+ private_inputs["a"] * out_variables["a"]
+ private_inputs["b"] * out_variables["b"]
+ private_inputs["w"] * out_variables["w"]
)
# (x - 1)(x - 2)(x - 3)
target_poly = poly([-1, 1]) * poly([-2, 1]) * poly([-3, 1])
def poly_call(poly, x):
result = mod_field(0)
for degree, coeff in enumerate(poly):
result += coeff * (x**degree)
return result.n
assert poly_call(target_poly, 1) == 0
assert poly_call(target_poly, 2) == 0
assert poly_call(target_poly, 3) == 0
main_poly = left_variable_poly * right_variable_poly - out_variable_poly
cofactor_poly = main_poly / target_poly
assert (
left_variable_poly * right_variable_poly == \
cofactor_poly * target_poly + out_variable_poly
)

179
scripts/zk/qap.py Normal file
View File

@@ -0,0 +1,179 @@
import numpy as np
# Lets prove we know the answer to x**3 + x + 5 == 35 (x = 5)
# We break it down into these statements:
# L1: s1 = x * x
# L2: y = s1 * x
# L3: s2 = y + x
# L4: out = s2 + 5
# Statements are of the form:
# a * b = c
# s1 = x * x
# OR a * b = c, where a = x, b = x and c = s1
L1 = np.array([
# a b c
[0, 0, 0], # 1
[1, 1, 0], # x
[0, 0, 0], # out
[0, 0, 1], # s1
[0, 0, 0], # y
[0, 0, 0] # s2
])
# y = s1 * x
L2 = np.array([
# a b c
[0, 0, 0], # 1
[0, 1, 0], # x
[0, 0, 0], # out
[1, 0, 0], # s1
[0, 0, 1], # y
[0, 0, 0] # s2
])
# s2 = y + x
L3 = np.array([
# a b c
[0, 1, 0], # 1
[1, 0, 0], # x
[0, 0, 0], # out
[0, 0, 0], # s1
[1, 0, 0], # y
[0, 0, 1] # s2
])
# out = s2 + 5
L4 = np.array([
# a b c
[5, 1, 0], # 1
[0, 0, 0], # x
[0, 0, 1], # out
[0, 0, 0], # s1
[0, 0, 0], # y
[1, 0, 0] # s2
])
a = np.array([L.transpose()[0] for L in (L1, L2, L3, L4)])
b = np.array([L.transpose()[1] for L in (L1, L2, L3, L4)])
c = np.array([L.transpose()[2] for L in (L1, L2, L3, L4)])
print("A")
print(a)
print("B")
print(b)
print("C")
print(c)
# The witness
s = np.array([
1,
3,
35,
9,
27,
30
])
print()
#print(s * a * s * b - s * c)
for a_i, b_i, c_i in zip(a, b, c):
assert sum(s * a_i) * sum(s * b_i) - sum(s * c_i) == 0
print("R1CS done.")
print()
def factorial(x):
r = 1
for x_i in range(2, x + 1):
r *= x_i
return r
def combinations(n, r):
return factorial(n) / (factorial(n - r) * factorial(r))
def lagrange(points):
result = np.poly1d([0])
for i, (x_i, y_i) in enumerate(points):
poly = np.poly1d([y_i])
for j, (x_j, y_j) in enumerate(points):
if i == j:
continue
poly *= np.poly1d([1, -x_j]) / (x_i - x_j)
#print(poly)
#print(poly(1), poly(2), poly(3))
result += poly
return result
# 1.5, -5.5, 7
#poly = lagrange([(1, 3), (2, 2), (3, 4)])
#print(poly)
def make_qap(a):
a_qap = []
a_polys = []
for a_i in a.transpose():
poly = lagrange(list(enumerate(a_i, start=1)))
coeffs = poly.c.tolist()
if len(coeffs) < 4:
coeffs = [0] * (4 - len(coeffs)) + coeffs
a_qap.append(coeffs)
a_polys.append(poly)
a_qap = np.array(a_qap)
print(a_qap)
return a_polys
print("A")
a_polys = make_qap(a)
print("B")
b_polys = make_qap(b)
print("C")
c_polys = make_qap(c)
def check(polys, x):
results = []
for poly in polys:
results.append(int(poly(x)))
return results
print()
print("A results at x", check(a_polys, 1))
print()
print("B results at x", check(b_polys, 1))
print()
print("C results at x", check(c_polys, 1))
def combine_polys(polys):
r = np.poly1d([0])
for s_i, p_i in zip(s, polys):
r += s_i * p_i
return r
print()
print()
A = combine_polys(a_polys)
print("A =")
print(A)
B = combine_polys(b_polys)
print("B =")
print(B)
C = combine_polys(c_polys)
print("C =")
print(C)
print()
t = A * B - C
print("t =")
print(t)
# 4 statements in our R1CS: L1, L2, L3, L4
divisor_poly = np.poly1d([1])
for x in range(1, 4 + 1):
divisor_poly *= np.poly1d([1, -x])
quot, remainder = np.polydiv(t, divisor_poly)
assert len(remainder.c) == 1
print()
print("Result of QAP:")
print(int(remainder.c[0]))