impl of curve trees

This commit is contained in:
x
2023-01-06 19:04:43 +01:00
parent 9558338545
commit e549746c87
2 changed files with 640 additions and 0 deletions

View File

@@ -0,0 +1,232 @@
import hashlib
from collections import namedtuple
# Your Funds Are Safu
p = [
0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001,
0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
]
# Pallas, Vesta
K = [GF(p_i) for p_i in p]
E = [EllipticCurve(K_i, (0, 5)) for K_i in K]
# Scalar fields
Scalar = [K[1], K[0]]
base_G = [E_i.gens()[0] for E_i in E]
assert all(base_G_i.order() == p_i for p_i, base_G_i in zip(reversed(p), base_G))
E1, E2 = 0, 1
gens = [
[E[E1].random_point() for _ in range(5)],
[E[E2].random_point() for _ in range(5)],
]
def hash_nodes(Ei, P1, P2, r):
G1, G2, G3, G4, H = gens[Ei]
(P1_x, P1_y), (P2_x, P2_y) = P1.xy(), P2.xy()
v1G1 = int(P1_x) * G1
v2G2 = int(P1_y) * G2
v3G3 = int(P2_x) * G3
v4G4 = int(P2_y) * G4
rH = int(r) * H
return v1G1 + v2G2 + v3G3 + v4G4 + rH
def hash_point(Ei, P, b):
G1, G2, G3, G4, H = gens[Ei]
x, y = P.xy()
return int(x)*G1 + int(y)*G2 + int(b)*H
# You can ignore this particular impl.
# Just some rough code to illustrate the main concept.
# The proofs enforce these relations:
#
# σ ∈ {0, 1}
# C = x1 G1 + y1 G2 + x2 G3 + y2 G4 + rH
# Ĉ = x_i G1 + y_i G2 + bH
#
# where
#
# x_i = { x0 if σ = 0
# { x1 if σ = 1
#
# y_i = { y0 if σ = 0
# { y1 if σ = 1
#
# It is just a quick hackjob proof of concept and horribly inefficient
load("curve_tree_proofs.sage")
test_proof()
# Our tree is a height of D=3
def create_tree(C3):
assert len(C3) == 2**3
# j = 2
C2 = []
for i in range(4):
C2_i = hash_nodes(E2, C3[2*i], C3[2*i + 1], 0)
C2.append(C2_i)
# j = 1
C1 = []
for i in range(2):
C1_i = hash_nodes(E1, C2[2*i], C2[2*i + 1], 0)
C1.append(C1_i)
# j = 0 (root)
C0 = hash_nodes(E2, C1[0], C1[1], 0)
return C0
def create_path(C3):
# To make things easier, we assume that our coin is
# always on the left hand side of the tree.
X3 = C3[1]
X2 = hash_nodes(E2, C3[2], C3[3], 0)
X1 = hash_nodes(
E1,
hash_nodes(E2, C3[4], C3[5], 0),
hash_nodes(E2, C3[6], C3[7], 0),
0
)
return (X3, X2, X1)
def main():
coins = [E[E1].random_point() for _ in range(2**3)]
root = create_tree(coins)
path = create_path(coins)
# Test the path works
X3, X2, X1 = path
C3 = coins[0]
C2 = hash_nodes(
E2,
C3,
X3,
0
)
C1 = hash_nodes(
E1,
C2,
X2,
0
)
C0 = hash_nodes(
E2,
C1,
X1,
0
)
assert C0 == root
# E1 point
C3 = coins[0]
Ĉ0 = root
r0 = 0
# Same as this:
# Ĉ0 = hash_nodes(E1, C2, X2, 0)
# j = 1
b1 = int(Scalar[E2].random_element())
Ĉ1 = hash_point(E2, C1, b1)
C1_x, C1_y = C1.xy()
X1_x, X1_y = X1.xy()
proof1, public1 = make_proof(
E2,
ProofWitness(
C1_x,
C1_y,
X1_x,
X1_y,
r0,
b1,
0
)
)
public1.C = Ĉ0
public1.D = Ĉ1
assert verify_proof(E2, proof1, public1)
# j = 2
# Now we know that Ĉ1 is the root of a new subtree
# But Ĉ1 ∈ E2, whereas we need to produce a blinded
# Ĉ1 ∈ E1.
# The reason this system uses curve cycles is because
# EC arithmetic is efficient to represent.
# We skip this part so assume these next to lines are
# part of the previous proof.
r1 = int(Scalar[E1].random_element())
Ĉ1 = hash_nodes(E1, C2, X2, r1)
################################
b2 = int(Scalar[E1].random_element())
Ĉ2 = hash_point(E1, C2, b2)
C2_x, C2_y = C2.xy()
X2_x, X2_y = X2.xy()
proof2, public2 = make_proof(
E1,
ProofWitness(
C2_x,
C2_y,
X2_x,
X2_y,
r1,
b2,
0
)
)
public2.C = Ĉ1
public2.D = Ĉ2
assert verify_proof(E1, proof2, public2)
# j = 3
# Same as before. We now have a randomized C2
r2 = int(Scalar[E2].random_element())
Ĉ2 = hash_nodes(E2, C3, X3, r2)
#################################
b3 = int(Scalar[E2].random_element())
Ĉ3 = hash_point(E2, C3, b3)
C3_x, C3_y = C3.xy()
X3_x, X3_y = X3.xy()
proof3, public3 = make_proof(
E2,
ProofWitness(
C3_x,
C3_y,
X3_x,
X3_y,
r2,
b3,
0
)
)
public3.C = Ĉ2
public3.D = Ĉ3
assert verify_proof(E2, proof3, public3)
# Now just unblind Ĉ3
main()

View File

@@ -0,0 +1,408 @@
ProofWitness = namedtuple("ProofWitness", [
"v1", "v2", "v3", "v4", "r", "b", "σ"
])
class ProofPublic:
def __init__(self):
self.C = None
self.D = None
self.X = None
self.Y = None
self.Z = None
class ProofCommits:
def __init__(self):
self.v1 = None
self.v2 = None
self.v3 = None
self.v4 = None
self.r = None
self.b = None
self.σ = None
self.σ_G1 = None
self.σ_G2 = None
self.v3_G1 = None
self.v4_G2 = None
self.blind_x = None
self.blind_y = None
self.blind_z = None
# Used by inner product
self.C0 = None
self.C1 = None
def transcript(self):
points = [
self.v1, self.v2, self.v3, self.v4, self.r, self.b, #self.σ,
self.σ_G1, self.σ_G2, self.v3_G1, self.v4_G2,
self.blind_x, self.blind_y, self.blind_z, self.C0, self.C1
]
assert all(P is not None for P in points)
points = [P.xy() for P in points]
return list(zip(*points))
class ProofResponses:
def __init__(self):
self.v1 = None
self.v2 = None
self.v3 = None
self.v4 = None
self.r = None
self.b = None
self.σ = None
self.blind_x = None
self.blind_y = None
self.blind_z = None
self.txy = None
Proof = namedtuple("Proof", [
"R", "s", "boolean_check"
])
RingProof = namedtuple("RingProof", [
"c0", "s0", "s1"
])
def make_proof(Ei, witness):
G1, G2, G3, G4, H = gens[Ei]
S = Scalar[Ei]
blind_x = int(S.random_element())
blind_y = int(S.random_element())
blind_xy = int(S.random_element())
# We want that blind_xy + blind_z == witness.b
blind_z = int(S(witness.b - blind_xy))
k_v1 = int(S.random_element())
k_v2 = int(S.random_element())
k_v3 = int(S.random_element())
k_v4 = int(S.random_element())
k_r = int(S.random_element())
k_b = int(S.random_element())
k_σ = int(S.random_element())
k_blind_x = int(S.random_element())
k_blind_y = int(S.random_element())
k_blind_xy = int(S.random_element())
k_blind_z = int(S.random_element())
# Used for inner product
k_t0 = int(S.random_element())
k_t1 = int(S.random_element())
R = ProofCommits()
R.v1 = k_v1 * G1
R.v2 = k_v2 * G2
R.v3 = k_v3 * G3
R.v4 = k_v4 * G4
R.r = k_r * H
R.b = k_b * H
# Used for 2nd proof
R.σ_G1 = k_σ * G1
R.σ_G2 = k_σ * G2
R.v3_G1 = k_v3 * G1
R.v4_G2 = k_v4 * G2
R.blind_x = k_blind_x * H
R.blind_y = k_blind_y * H
R.blind_xy = k_blind_xy * H
R.blind_z = k_blind_z * H
# σ (v1 - v3)
# σ (v2 - v4)
# (k_σ + c σ)(k_v1 - k_v3 + c*(v1 - v3))
#
# sage: var("k_σ c σ k_v1 k_v3 v1 v3")
# sage: ((k_σ + c*σ)*(k_v1 - k_v3 + c*(v1 - v3))).expand().collect(c)
# (v1*σ - v3*σ)*c^2 + (k_σ*v1 - k_σ*v3 + k_v1*σ - k_v3*σ)*c + k_v1*k_σ - k_v3*k_σ
R.C0 = (
(k_v3*k_σ - k_v1*k_σ) * G1 +
(k_v4*k_σ - k_v2*k_σ) * G2 +
k_t0 * H
)
R.C1 = (
(k_σ*witness.v3 - k_σ*witness.v1 + k_v3*witness.σ - k_v1*witness.σ) * G1 +
(k_σ*witness.v4 - k_σ*witness.v2 + k_v4*witness.σ - k_v2*witness.σ) * G2 +
k_t1 * H
)
c = hash_scalar(Ei, R.transcript())
s = ProofResponses()
s.v1 = int( k_v1 + c*witness.v1 )
s.v2 = int( k_v2 + c*witness.v2 )
s.v3 = int( k_v3 + c*witness.v3 )
s.v4 = int( k_v4 + c*witness.v4 )
s.r = int( k_r + c*witness.r )
s.b = int( k_b + c*witness.b )
s.σ = int( k_σ + c*witness.σ )
s.blind_x = int(k_blind_x + c*blind_x)
s.blind_y = int(k_blind_y + c*blind_y)
s.blind_xy = int(k_blind_xy + c*blind_xy)
s.blind_z = int(k_blind_z + c*blind_z)
s.txy = c**2 * blind_xy + c * k_t1 + k_t0
public = ProofPublic()
public.X = ((witness.v3 - witness.v1) * G1 +
(witness.v4 - witness.v2) * G2 +
blind_x * H)
public.Y = witness.σ * G1 + witness.σ * G2 + blind_y * H
public.XY = (
witness.σ * (witness.v3 - witness.v1) * G1 +
witness.σ * (witness.v4 - witness.v2) * G2 +
blind_xy * H
)
public.Z = witness.v1 * G1 + witness.v2 * G2 + blind_z * H
assert witness.σ in (0, 1)
if witness.σ == 0:
assert public.XY == blind_xy * H
assert (
public.XY + public.Z
==
witness.v1*G1 + witness.v2*G2 + (blind_xy + blind_z)*H
)
else:
assert witness.σ == 1
assert (
public.XY
==
(witness.v3 - witness.v1) * G1 +
(witness.v4 - witness.v2) * G2 +
blind_xy * H
)
assert (
public.XY + public.Z
==
witness.v3*G1 + witness.v4*G2 + (blind_xy + blind_z)*H
)
assert blind_xy + blind_z == witness.b
P1 = public.Y
P2 = public.Y - G1 - G2
assert blind_y*H == [P1, P2][witness.σ]
if witness.σ == 0:
assert blind_y*H == P1
assert blind_y*H - G1 - G2 == P2
else:
assert witness.σ == 1
assert blind_y*H + G1 + G2 == P1
assert blind_y*H == P2
boolean_check = make_ring_sig(Ei, [P1, P2], blind_y, int(witness.σ))
assert verify_ring_sig(Ei, boolean_check, [P1, P2])
return Proof(R, s, boolean_check), public
def make_ring_sig(Ei, public_keys, secret, j):
H, S = gens[Ei][-1], Scalar[Ei]
assert len(public_keys) == 2
assert secret*H == public_keys[j]
assert j in (0, 1)
k0 = int(S.random_element())
R0 = k0*H
c1 = hash_scalar(Ei, R0.xy())
s1 = int(S.random_element())
R1 = s1*H - c1*public_keys[(j + 1) % 2]
c0 = hash_scalar(Ei, R1.xy())
s0 = k0 + c0*secret
if j == 1:
c0 = c1
s0, s1 = s1, s0
proof = RingProof(c0, s0, s1)
return proof
def verify_ring_sig(Ei, proof, public_keys):
H = gens[Ei][-1]
S = Scalar[Ei]
assert len(public_keys) == 2
R1 = proof.s0*H - proof.c0*public_keys[0]
c1 = hash_scalar(Ei, R1.xy())
R2 = proof.s1*H - c1*public_keys[1]
c2 = hash_scalar(Ei, R2.xy())
return c2 == proof.c0
def verify_proof(Ei, proof, public):
G1, G2, G3, G4, H = gens[Ei]
S = Scalar[Ei]
R, s = proof.R, proof.s
c = hash_scalar(Ei, R.transcript())
if (s.v1 * G1 +
s.v2 * G2 +
s.v3 * G3 +
s.v4 * G4 +
s.r * H
!=
R.v1 + R.v2 + R.v3 + R.v4 + R.r + c*public.C
):
return False
# Now we want to prove that
# D = v1 G1 + v2 G2 + b H
# or
# D = v3 G1 + v4 G2 + b H
# We do this by checking:
# X = (v1 - v2)G + b_X H
# Y = σ G + b_Y H
# D = xy G + v2 G + b_D H
# σ ∈ {0, 1}
# X = (v1 - v2)G + b_X H
if (s.v3 * G1 - s.v1 * G1 +
s.v4 * G2 - s.v2 * G2 +
s.blind_x * H
!=
R.v3_G1 - R.v1 + R.v4_G2 - R.v2 + R.blind_x + c*public.X
):
return False
# Y = σ G + b_Y H
if (s.σ * G1 + s.σ * G2 + s.blind_y * H
!=
R.σ_G1 + R.σ_G2 + R.blind_y + c*public.Y
):
return False
# Z = v1 G1 + v2 G2
if (s.v1 * G1 + s.v2 * G2 + s.blind_z * H
!=
R.v1 + R.v2 + R.blind_z + c*public.Z
):
return False
# Inner product verification. We select either P1 or P2
# prove D1 = x1 y1 G1 + b1 H
if (s.σ*(s.v3 - s.v1)*G1 + s.σ*(s.v4 - s.v2)*G2 + s.txy*H
!=
c**2*public.XY + c*R.C1 + R.C0
):
return False
# check D is correct
if public.D != public.XY + public.Z:
return False
# boolean check proof for s
P1 = public.Y
P2 = public.Y - G1 - G2
if not verify_ring_sig(Ei, proof.boolean_check, [P1, P2]):
return False
return True
def hash_scalar(Ei, values):
S = Scalar[Ei]
hasher = hashlib.sha256()
for value in values:
hasher.update(str(value).encode())
return S(int(hasher.hexdigest(), 16))
# Test proving system
def test_proof():
G1, G2, G3, G4, H = gens[E1]
S = Scalar[E1]
P1, P2 = [E[E2].random_point() for _ in range(2)]
(P1_x, P1_y), (P2_x, P2_y) = P1.xy(), P2.xy()
r, b = [S.random_element() for _ in range(2)]
C = hash_nodes(E1, P1, P2, r)
# σ = 0 for P1, or σ = 1 for P2
σ = S(1)
D = hash_point(E1, P2, b)
proof, public = make_proof(
E1,
ProofWitness(
P1_x,
P1_y,
P2_x,
P2_y,
r,
b,
σ
)
)
public.C = C
public.D = D
assert verify_proof(E1, proof, public)
# Now try the other side too
σ = S(0)
D = hash_point(E1, P1, b)
proof, public = make_proof(
E1,
ProofWitness(
P1_x,
P1_y,
P2_x,
P2_y,
r,
b,
σ
)
)
public.C = C
public.D = D
assert verify_proof(E1, proof, public)
# Test the ring sigs too
secret = int(S.random_element())
P1 = secret*H
P2 = E[E1].random_point()
proof = make_ring_sig(E1, [P1, P2], secret, 0)
assert verify_ring_sig(E1, proof, [P1, P2])
# Also try in reverse
P1, P2 = P2, P1
proof = make_ring_sig(E1, [P1, P2], secret, 1)
assert verify_ring_sig(E1, proof, [P1, P2])
# Ring sigs is our boolean proof for σ
σ = S(0)
b = S.random_element()
# Verifier only has P
# We prove that σ ∈ {0, 1}
P = σ*G1 + b*H
# They can only make a ring signature on H
# if σ is 0 or 1
# P1 = P represents σ = 0
P1 = P
# P2 = P - G1 represents σ = 1
P2 = P - G1
proof = make_ring_sig(E1, [P1, P2], b, 0)
assert verify_ring_sig(E1, proof, [P1, P2])
# Also try σ = 1
σ = S(1)
P = σ*G1 + b*H
P1 = P
P2 = P - G1
proof = make_ring_sig(E1, [P1, P2], b, 1)
assert verify_ring_sig(E1, proof, [P1, P2])