This commit is contained in:
narodnik
2020-09-20 18:38:26 +02:00
parent 5828bee338
commit 20424ce011
3 changed files with 481 additions and 1 deletions

View File

@@ -125,7 +125,9 @@ clone_bit cur_is_right2 cur_is_right
binary_push position cur_is_right2
alloc_num path_element param:auth_path_0_0
# let (left: Scalar, right: Scalar) = swap_if(is_right, cur, node)
conditionally_reverse ul ur cur path_element is_right
conditionally_reverse ulur cur path_element is_right
get_0 ul ulur
get_1 ur ulur
# let mut preimage: BinaryNumber = []
alloc_binary preimage
# preimage.put(left)

163
proofs/working.pism Normal file
View File

@@ -0,0 +1,163 @@
# :set syntax=pism
# :source ../scripts/pism.vim
constant G_VCV FixedGenerator
constant G_VCR FixedGenerator
constant G_SPEND FixedGenerator
constant G_PROOF FixedGenerator
constant G_NOTE_COMMIT_R FixedGenerator
constant G_NULL FixedGenerator
constant CRH_IVK BlakePersonalization
constant NOTE_COMMIT PedersenPersonalization
constant MERKLE_0 PedersenPersonalization
constant MERKLE_1 PedersenPersonalization
constant MERKLE_2 PedersenPersonalization
constant MERKLE_3 PedersenPersonalization
# ...
constant PRF_NF BlakePersonalization
constant JUBJUB_FR_CAPACITY ByteSize
contract input_spend
param value U64
param randomness Fr
param ak Point
param ar Fr
param nsk Fr
param g_d Point
param commitment_randomness Fr
param auth_path_0_0 Scalar
param auth_path_0_1 Bool
param auth_path_1_0 Scalar
param auth_path_1_1 Bool
# ...
param anchor Scalar
start
# let rk: Point = ak + ar * G_SPEND
witness ak param:ak
#assert_not_small_order ak
#fr_as_binary_le ar param:ar
#ec_mul_const ar ar G_SPEND
#ec_add rk ak ar
## emit rk
#emit_ec rk
#
## let nk: Point = nsk * G_PROOF
#fr_as_binary_le nsk param:nsk
#ec_mul_const nk nsk G_PROOF
#
## let mut ivk_preimage: BinaryNumber = []
#alloc_binary ivk_preimage
## ivk_preimage.put(ak)
#ec_repr repr_ak ak
#binary_extend ivk_preimage repr_ak
#
## let mut nf_preimage: BinaryNumber = []
#alloc_binary nf_preimage
#ec_repr repr_nk nk
#binary_clone repr_nk repr_nk2
## ivk_preimage.put(nk)
#binary_extend ivk_preimage repr_nk
## nf_preimage.put(nk)
#binary_extend ivk_preimage repr_nk2
#
## assert ivk_preimage.len() == 512
#static_assert_binary_size ivk_preimage 512
## assert nf_preimage.len() == 256
#static_assert_binary_size nf_preimage 256
#
## let mut ivk = blake2s(ivk_preimage, CRH_IVK)
#blake2s ivk ivk_preimage CRH_IVK
## ivk.truncate(JUBJUB_FR_CAPACITY)
#binary_truncate ivk JUBJUB_FR_CAPACITY
#
## let pk_d: Point = ivk * g_d
#witness g_d param:g_d
#assert_not_small_order g_d
#ec_mul pk_d ivk g_d
#
## let cv: Point = value * G_VCV + rcv * G_VCR
#u64_as_binary_le value_bits param:value
#ec_mul_const value value_bits G_VCV
#fr_as_binary_le rcv param:randomness
#ec_mul_const rcv rcv G_VCR
#ec_add cv value rcv
## emit cv
#emit_ec cv
#
## let mut note_contents: BinaryNumber = []
#alloc_binary note_contents
#
## note_contents.put(value)
#binary_extend note_contents value
## note_contents.put(g_d)
#ec_repr repr_g_d g_d
#binary_extend note_contents repr_g_d
## note_contents.put(p_k)
#ec_repr repr_p_k p_k
#binary_extend note_contents repr_p_k
## assert note_contents.len() == 64 + 256 + 256
#static_assert_binary_size ivk_preimage 576
#
## let mut cm = pedersen_hash(note_contents, NOTE_COMMIT)
#pedersen_hash cm note_contents NOTE_COMMIT
## cm += commitment_randomness * G_NOTE_COMMIT_R
#fr_as_binary_le rcm param:commitment_randomness
#ec_mul_const cm1 rcm G_NOTE_COMMIT_R
#ec_add cm cm cm1
#
## let mut position = []
#alloc_binary position
## let mut cur: Scalar = cm.u
#ec_get_u cur cm
#
## There are no loops in this language.
## ZK proofs must have a fixed size.
## So in this assembly we UNROLL all loops.
## for i in range(auth_path.size()):
##
## Here we give the example of loop 0.
## Replace the indexes with the value i
## Below line is auth_path[0].1
#
## let (node: Scalar, is_right: Bool) = auth_path[i]
## position.push(is_right)
#alloc_bit cur_is_right param:auth_path_0_1
#clone_bit cur_is_right2 cur_is_right
#binary_push position cur_is_right2
#alloc_num path_element param:auth_path_0_0
## let (left: Scalar, right: Scalar) = swap_if(is_right, cur, node)
#conditionally_reverse ulur cur path_element is_right
#get_0 ul ulur
#get_1 ur ulur
## let mut preimage: BinaryNumber = []
#alloc_binary preimage
## preimage.put(left)
#num_to_binary ul_bits ul
#binary_extend preimage ul_bits
## preimage.put(right)
#num_to_binary ur_bits ur
#binary_extend preimage ur_bits
## cur = pedersen_hash(MERKLE_TREE[i], preimage).u
#pedersen_hash curhash preimage MERKLE_0
#ec_get_u cur curhash
## ... repeat the above N times
#
## enforce cur == rt
#alloc_num rt param:anchor
#num_enforce_equal cur rt
## emit rt
#emit_num rt
#
## let rho: Point = rho + position * G_NULL
#ec_mul_const position position_bits G_NULL
#ec_add rho rho position
## nf_preimage.put(rho)
#ec_repr repr_rho rho
#binary_extend nf_preimage repr_rho
## assert nf_preimage.len() == 512
#static_assert_binary_size nf_preimage 512
#
## let nf: BinaryNumber = blake2s(nf_preimage, PRF_NF)
#blake2s nf nf_preimage PRF_NF
#emit_binary nf
end

315
scripts/pism.py Normal file
View File

@@ -0,0 +1,315 @@
import sys
def eprint(*args):
print(*args, file=sys.stderr)
class Line:
def __init__(self, text, line_number):
self.text = text
self.orig = text
self.lineno = line_number
self.clean()
def clean(self):
# Remove the comments
self.text = self.text.split("#", 1)[0]
# Remove whitespace
self.text = self.text.strip()
def is_empty(self):
return bool(self.text)
def __repr__(self):
return "Line %s: %s" % (self.lineno, self.orig)
def command(self):
if not self.is_empty():
return None
return self.text.split(" ")[0]
def args(self):
if not self.is_empty():
return None
return self.text.split(" ")[1:]
def clean(contents):
# Split input into lines
contents = contents.split("\n")
contents = [Line(line, i) for i, line in enumerate(contents)]
# Remove empty blank lines
contents = [line for line in contents if line.is_empty()]
return contents
def make_segments(contents):
constants = [line for line in contents if line.command() == "constant"]
segments = []
current_segment = []
for line in contents:
if line.command() == "contract":
current_segment = []
current_segment.append(line)
if line.command() == "end":
segments.append(current_segment)
current_segment = []
return constants, segments
def build_constants_table(constants):
table = {}
for line in constants:
args = line.args()
if len(args) != 2:
eprint("error: wrong number of args")
eprint(line)
return None
name, type = args
table[name] = type
return table
symbol_table = {
"contract": 1,
"param": 2,
"start": 0,
"end": 0,
"witness": 2,
}
def extract(segment):
assert segment
# Does it have a declaration?
if not segment[0].command() == "contract":
eprint("error: missing contract declaration")
eprint(segment[0])
return None
# Does it have an end?
if not segment[-1].command() == "end":
eprint("error: missing contract end")
eprint(segment[-1])
return None
# Does it have a start?
if not [line for line in segment if line.command() == "start"]:
eprint("error: missing contract start")
eprint(segment[0])
return None
for line in segment:
command, args = line.command(), line.args()
if symbol_table[command] != len(args):
eprint("error: wrong number of args for command '%s'" % command)
eprint(line)
return None
contract_name = segment[0].args()[0]
start_index = [index for index, line in enumerate(segment)
if line.command() == "start"]
if len(start_index) > 1:
eprint("error: multiple start statements in contract '%s'" %
contract_name)
for index in start_index:
eprint(segment[index])
eprint("Aborting.")
return None
assert len(start_index) == 1
start_index = start_index[0]
header = segment[1:start_index]
code = segment[start_index + 1:-1]
params = {}
for param_decl in header:
args = param_decl.args()
assert len(args) == 2
name, type = args
params[name] = type
program = []
for line in code:
command, args = line.command(), line.args()
program.append((command, args, line))
return Contract(contract_name, params, program)
def to_initial_caps(snake_str):
components = snake_str.split("_")
return "".join(x.title() for x in components)
types_map = {
"U64": "u64",
"Fr": "jubjub::Fr",
"Point": "jubjub::SubgroupPoint",
"Scalar": "bls12_381::Scalar",
"Bool": "bool"
}
command_desc = {
"witness": (("EdwardsPoint", True), ("Point", False))
}
class Contract:
def __init__(self, name, params, program):
self.name = name
self.params = params
self.program = program
def _compile_header(self):
code = "pub struct %s {\n" % to_initial_caps(self.name)
for param_name, param_type in self.params.items():
try:
mapped_type = types_map[param_type]
except KeyError:
return None
code += " pub %s: Option<%s>,\n" % (param_name, mapped_type)
code += "}\n"
return code
def _compile_body(self):
self.stack = {}
code = "\n"
#indent = " " * 8
for command, args, line in self.program:
if (code_text := self._compile_line(command, args, line)) is None:
return None
code += code_text + "\n"
return code
def _preprocess_args(self, args, line):
nargs = []
for arg in args:
if not arg.startswith("param:"):
nargs.append((arg, False))
continue
_, argname = arg.split(":", 1)
if argname not in self.params:
eprint("error: non-existant param referenced")
eprint(line)
return None
nargs.append((argname, True))
return nargs
def type_checking(self, command, args, line):
assert command in command_desc
type_list = command_desc[command]
if len(type_list) != len(args):
eprint("error: wrong number of arguments!")
eprint(line)
return False
for (expected_type, new_val), (argname, is_param) in \
zip(type_list, args):
# Only type check input arguments, not output values
if new_val:
continue
if is_param:
actual_type = self.params[argname]
else:
# Check the stack here
if argname not in self.stack:
eprint("error: cannot find value '%s' on the stack!" %
argname)
eprint(line)
return False
actual_type = self.stack[argname]
return True
def _compile_line(self, command, args, line):
if (args := self._preprocess_args(args, line)) is None:
return None
if not self.type_checking(command, args, line):
return None
self.modify_stack(command, args)
args = [self.carg(arg) for arg in args]
if command == "witness":
out, point = args
return \
r"""let %s = ecc::EdwardsPoint::witness(
cs.namespace(|| "%s"),
%s.map(jubjub::ExtendedPoint::from))?;""" % (out, line, point)
def carg(self, arg):
argname, is_param = arg
if is_param:
return "self.%s" % argname
return argname
def modify_stack(self, command, args):
type_list = command_desc[command]
assert len(type_list) == len(args)
for (expected_type, new_val), (argname, is_param) in \
zip(type_list, args):
if is_param:
assert not new_val
continue
# Now apply the new values to the stack
if new_val:
self.stack[argname] = expected_type
def compile(self):
code = ""
if (header := self._compile_header()) is None:
return None
code += header
code += \
r"""impl Circuit<bls12_381::Scalar> for %s {
fn synthesize<CS: ConstraintSystem<bls12_381::Scalar>>(
self,
cs: &mut CS,
) -> Result<(), SynthesisError> {
""" % to_initial_caps(self.name)
if (body := self._compile_body()) is None:
return None
code += body
code += "}\n"
return code
def process(contents):
contents = clean(contents)
constants, segments = make_segments(contents)
if (constants := build_constants_table(constants)) is None:
return False
codes = []
for segment in segments:
contract = extract(segment)
if (code := contract.compile()) is None:
return False
codes.append(code)
# Success! Output finished product.
[print(code) for code in codes]
return True
def main(argv):
if len(argv) != 2:
eprint("pism FILENAME")
return -1
contents = open(argv[1]).read()
if not process(contents):
return -2
return 0
if __name__ == "__main__":
sys.exit(main(sys.argv))