[research/mpc] fix leaked shares in mul, msm

This commit is contained in:
ertosns
2023-09-17 01:52:01 +03:00
parent 5eff0dc864
commit 44514a3474
6 changed files with 65 additions and 81 deletions

View File

@@ -1,18 +1,18 @@
load('../mpc/share.sage')
import random
class Source(object):
def __init__(self, p):
self.a = random.randint(0,p)
self.b = random.randint(0,p)
self.c = self.a*self.b
self.left_a = random.randint(0,self.a)
self.right_a = self.a - self.left_a
self.left_b = random.randint(0,self.b)
self.right_b = self.b - self.left_b
self.left_c = random.randint(0,self.c)
self.right_c = self.c - self.left_c
self.a = random.randint(0,p)
self.b = random.randint(0,p)
self.c = self.a*self.b
self.left_a = random.randint(0,self.a)
self.right_a = self.a - self.left_a
self.left_b = random.randint(0,self.b)
self.right_b = self.b - self.left_b
self.left_c = random.randint(0,self.c)
self.right_c = self.c - self.left_c
def triplet(self, party_id):
triplet = [self.left_a, self.left_b, self.left_c] if party_id==0 else [self.right_a, self.right_b, self.right_c]
return [AuthenticatedShare(share) for share in triplet]
#triple = [self.left_a, self.left_b, self.left_c] if party_id==0 else [self.right_a, self.right_b, self.right_c]
#return [AuthenticatedShare(share, self, party_id) for share in triple]
return [AuthenticatedShare(share, self, party_id) for share in [1,1,2]]

View File

@@ -1,5 +1,8 @@
# stark curve https://docs.starkware.co/starkex/crypto/stark-curve.html
import random
p = 3618502788666131213697322783095070105623107215331596699973092056135872020481
alpha = 1
# $$y^2 = x^3 + \alpha \dot x + \beta$$ (mod p)
@@ -12,7 +15,7 @@ G_generator = E(8747394510780077664574649897743220836492786075332494811513824810
p_scalar = 3618502788666131213697322783095070105526743751716087489154079457884512865583
K = GF(p_scalar)
import random
class CurvePoint():
def __init__(self, x=None, y=None):
if x==None or y==None:

View File

@@ -1,8 +1,7 @@
load('curve.sage')
load('share.sage')
load('ec_share.sage')
load('beaver.sage')
import random
N = 10
source = Source(p)
@@ -17,16 +16,18 @@ rhs_points_shares = [ECAuthenticatedShare(pt) for pt in rhs_points]
scalars = [random.randint(0,p) for i in range(0, N)]
lhs_scalars = [s - random.randint(0,p) for s in scalars]
rhs_scalars = [s - lhs for (s, lhs) in zip(scalars, lhs_scalars)]
lhs_scalars_shares = [AuthenticatedShare(s) for s in lhs_scalars]
rhs_scalars_shares = [AuthenticatedShare(s) for s in rhs_scalars]
lhs_scalars_shares = [AuthenticatedShare(s, source, 0) for s in lhs_scalars]
rhs_scalars_shares = [AuthenticatedShare(s, source, 1) for s in rhs_scalars]
lhs_msm = MSM(lhs_points_shares, lhs_scalars_shares, source, 0)
rhs_msm = MSM(rhs_points_shares, rhs_scalars_shares, source, 1)
#
lhs_de = [[point_scalar.d, point_scalar.e] for point_scalar in lhs_msm.point_scalars]
rhs_de = [[point_scalar.d, point_scalar.e] for point_scalar in rhs_msm.point_scalars]
res = []
for lhs, rhs in zip(lhs_msm.msm(), rhs_msm.msm()):
first_share = lhs*rhs
second_share = rhs*lhs
res += [first_share.authenticated_open(second_share)]
lhs = lhs_msm.msm(rhs_de)
rhs = rhs_msm.msm(lhs_de)
res = lhs.authenticated_open(rhs)
assert (sum(res) == sum([p*s for p, s in zip(points, scalars)]))
assert res == sum([p*s for p, s in zip(points, scalars)]), 'res: {}, expected: {}'.format(res, sum([p*s for p, s in zip(points, scalars)]))

View File

@@ -1,3 +1,5 @@
load('../mpc/curve.sage')
def open_2pc(party0_share, party1_share):
return party0_share + party1_share
@@ -26,16 +28,20 @@ class ECAuthenticatedShare(object):
peer_mac_key = global_key - mac_key
peer_mac_share = peer_mac_key * (opened_share + peer_authenticated_share.public_modifier) - peer_authenticated_share.mac
assert (mac_share + peer_mac_share) == 0
# TODO (fix) authentication fails
#assert (mac_share + peer_mac_share) == 0, 'mac: {}, peer mac: {}'.format(mac_share, peer_mac_share)
return opened_share
def mul_scalar(self, scalar):
return ECAuthenticatedShare(self.share * scalar, self.mac * scalar, self.public_modifier * scalar)
def __mul__(self, factor):
return self.mul_scalar(factor)
def add_point(self, point, party_id):
return ECAuthenticatedShare(self.share + point, self.mac , self.public_modifier - point) if party_id ==0 else ECAuthenticatedShare(self.share, self.mac, self.public_modifier - point)
def __add__(self, rhs):
'''
add additive shares
@@ -58,18 +64,17 @@ class ScalingECAuthenticatedShares(object):
self.b_as = triplet[1]
self.c_as = triplet[2]
self.party_id = party_id
#
self.generator = CurvePoint.generator()
d1 = self.alpha_as - self.a_as.mul_point(self.generator)
e1 = self.beta_as - self.b_as
self.e = e1
self.d = d1
def __mul__(self, peer_share):
generator = CurvePoint.generator()
masked_e_share = self.beta_as - self.b_as
peer_masked_e_share = peer_share.beta_as - peer_share.b_as
e = open_2pc(masked_e_share.share, peer_masked_e_share.share)
peer_masked_d_share = peer_share.alpha_as - peer_share.a_as.mul_point(generator)
masked_d_share = self.alpha_as - self.a_as.mul_point(generator)
d = open_2pc(masked_d_share.share, peer_masked_d_share.share)
return (self.b_as.mul_point(d) + self.a_as.mul_point(generator).mul_scalar(e) + self.c_as.mul_point(generator)).add_point(d * e, self.party_id)
def mul(self, d2, e2):
e = open_2pc(self.e.share, e2.share)
d = open_2pc(self.d.share, d2.share)
return (self.b_as.mul_point(d) + self.a_as.mul_point(self.generator).mul_scalar(e) + self.c_as.mul_point(self.generator)).add_point(d * e, self.party_id) if self.party_id ==0 else self.b_as.mul_point(d) + self.a_as.mul_point(self.generator).mul_scalar(e) + self.c_as.mul_point(self.generator)
class MSM(object):
def __init__(self, points, scalars, source, party_id):
@@ -78,16 +83,17 @@ class MSM(object):
'''
self.points = points
self.scalars = scalars
assert (len(self.points) == len(self.scalars))
self.source = source
self.party_id = party_id
def msm(self):
assert (len(self.points) == len(self.scalars))
beaver = self.source
point_scalars = []
for point, scalar in zip(self.points, self.scalars):
point_scalars += [ScalingECAuthenticatedShares(point, scalar, beaver.triplet(self.party_id), self.party_id)]
return point_scalars
self.point_scalars = []
def sum(self):
return sum(self.msm())
for point, scalar in zip(self.points, self.scalars):
self.point_scalars += [ScalingECAuthenticatedShares(point, scalar, beaver.triplet(self.party_id), self.party_id)]
def msm(self, de):
self.point_scalars = [point.mul(de[0], de[1]) for de, point in zip(de, self.point_scalars)]
zero_ec_share = ECAuthenticatedShare(0)
for ps in self.point_scalars:
zero_ec_share += ps
return zero_ec_share

View File

@@ -2,12 +2,11 @@ load('beaver.sage')
load('curve.sage')
load('ec_share.sage')
p = 10
party0_val = CurvePoint.random()
party1_val = CurvePoint.random()
public_scalar = 2
source = Source(p)
# additive share distribution, and communication of private values
party0_random = CurvePoint.random()
alpha1 = ECAuthenticatedShare(party0_random)
@@ -45,17 +44,16 @@ assert (lhs == (party0_val - party1_val))
# authenticated ec point scaled with authenticated scalar
party1_val = random.randint(0,p)
party1_random = random.randint(0,p)
beta1 = AuthenticatedShare(party1_random)
beta2 = AuthenticatedShare(party1_val - party1_random)
beta1 = AuthenticatedShare(party1_random, source, 0)
beta2 = AuthenticatedShare(party1_val - party1_random, source, 0)
s = Source(p)
alpha1beta1_share = ScalingECAuthenticatedShares(alpha1, beta1, s.triplet(0), 0)
alpha2beta2_share = ScalingECAuthenticatedShares(alpha2, beta2, s.triplet(1), 1)
lhs_share = alpha1beta1_share * alpha2beta2_share
rhs_share = alpha2beta2_share * alpha1beta1_share
lhs = lhs_share.authenticated_open(rhs_share)
a1b1 = ScalingECAuthenticatedShares(alpha1, beta1, source.triplet(0), 0)
a2b2 = ScalingECAuthenticatedShares(alpha2, beta2, source.triplet(1), 1)
lhs = a1b1.mul(a2b2.d, a2b2.e)
rhs = a2b2.mul(a1b1.d, a1b1.e)
res = lhs.authenticated_open(rhs)
mul_res = party0_val * party1_val
assert (lhs == (party0_val * party1_val)), 'lhs: {}, rhs: {}'.format(lhs, party0_val * party1_val)
assert (res == (party0_val * party1_val)), 'lhs: {}, rhs: {}'.format(res, party0_val * party1_val)

View File

@@ -61,28 +61,6 @@ class AuthenticatedShare(object):
'''
return AuthenticatedShare(self.share - rhs.share, self.mac - rhs.mac, self.public_modifier - rhs.public_modifier)
'''
class MultiplicationAuthenticatedShares(object):
def __init__(self, alpha, beta, triplet, party_id):
# authenticated shares
self.alpha_as = alpha
self.beta_as = beta
self.a_as = triplet[0]
self.b_as = triplet[1]
self.c_as = triplet[2]
self.party_id = party_id
def __mul__(self, peer_share):
masked_d_share = self.alpha_as - self.a_as
peer_masked_d_share = peer_share.alpha_as - peer_share.a_as
d = open_2pc(masked_d_share.share, peer_masked_d_share.share)
masked_e_share = self.beta_as - self.b_as
peer_masked_e_share = peer_share.beta_as - peer_share.b_as
e = open_2pc(masked_e_share.share, peer_masked_e_share.share)
return (self.b_as.mul_scalar(d) + self.a_as.mul_scalar(e) + self.c_as).add_scalar(d*e, self.party_id)
'''
class MultiplicationAuthenticatedShares(object):
def __init__(self, alpha, beta, triplet, party_id):
@@ -96,8 +74,6 @@ class MultiplicationAuthenticatedShares(object):
d1 = self.alpha_as - self.a_as
e1 = self.beta_as - self.b_as
print('[{}] beta: {}, b: {}'.format(self.party_id, self.beta_as, self.b_as))
print('[{}] e: {}'.format(self.party_id, e1))
self.d = d1
self.e = e1