[research/mpc] inner product over mpc, test msm

This commit is contained in:
ertosns
2023-09-21 16:14:48 +03:00
parent 44514a3474
commit d35942e566
2 changed files with 72 additions and 0 deletions

View File

@@ -0,0 +1,36 @@
load('share.sage')
load('beaver.sage')
import random
import numpy as np
party0_val = [1,2]
party1_val = [2,4]
source = Source(p)
party0_random = [1,1]
party1_random = [1,1]
# additive share distribution, and communication of private values
# party 0 shares
alpha1_l = [AuthenticatedShare(party0_random[i], source, 0) for i in range(2)]
beta1_l = [AuthenticatedShare(party1_random[i], source, 1) for i in range(2)]
# party 1 shares
alpha2_l = [AuthenticatedShare(party0_val[i] - party0_random[i], source, 0) for i in range(2)]
beta2_l = [AuthenticatedShare(party1_val[i] - party1_random[i], source, 1) for i in range(2)]
# party 0 c
a1b1_l = [MultiplicationAuthenticatedShares(alpha1, beta1, source.triplet(0), 0) for alpha1, beta1 in zip(alpha1_l, beta1_l)]
# party 1 c
a2b2_l = [MultiplicationAuthenticatedShares(alpha2, beta2, source.triplet(1), 1) for alpha2, beta2 in zip(alpha2_l, beta2_l)]
# party 0 de
for a1b1 in a1b1_l:
print('a1b1: d/e: {}/{}'.format(a1b1.d, a1b1.e))
# party 1 de
for a2b2 in a2b2_l:
print('a2b2: d/e: {}/{}'.format(a2b2.d, a2b2.e))
lhs_l = [a1b1.mul(a2b2.d, a2b2.e) for a1b1, a2b2 in zip(a1b1_l, a2b2_l)]
rhs_l = [a2b2.mul(a1b1.d, a1b1.e) for a1b1, a2b2 in zip(a1b1_l, a2b2_l)]
res = [lhs.authenticated_open(rhs) for lhs, rhs in zip(lhs_l, rhs_l)]
assert (sum(res) == np.dot(party0_val,party1_val)), 'mul: {}, expected mul: {}'.format(res, party0_val*party1_val)

View File

@@ -0,0 +1,36 @@
load('beaver.sage')
import random
N = 2
source = Source(p)
points = [CurvePoint.random() for _ in range(0, N)]
lhs_points = [pt - CurvePoint.random() for pt in points]
rhs_points = [point - lhs for (point, lhs) in zip(points, lhs_points)]
lhs_points_shares = [ECAuthenticatedShare(pt) for pt in lhs_points]
rhs_points_shares = [ECAuthenticatedShare(pt) for pt in rhs_points]
assert [lhs_pt_share.authenticated_open(rhs_pt_share) for lhs_pt_share, rhs_pt_share in zip(lhs_points_shares, rhs_points_shares)] == points
scalars = [random.randint(0,p) for i in range(0, N)]
lhs_scalars = [s - random.randint(0,p) for s in scalars]
rhs_scalars = [s - lhs for (s, lhs) in zip(scalars, lhs_scalars)]
lhs_scalars_shares = [AuthenticatedShare(s, source, 0) for s in lhs_scalars]
rhs_scalars_shares = [AuthenticatedShare(s, source, 1) for s in rhs_scalars]
assert [lhs_scalar_share.authenticated_open(rhs_scalar_share) for lhs_scalar_share, rhs_scalar_share in zip(lhs_scalars_shares, rhs_scalars_shares)] == scalars
lhs_msm = MSM(lhs_points_shares, lhs_scalars_shares, source, 0)
rhs_msm = MSM(rhs_points_shares, rhs_scalars_shares, source, 1)
#
lhs_de = [[point_scalar.d.copy(), point_scalar.e.copy()] for point_scalar in lhs_msm.point_scalars]
rhs_de = [[point_scalar.d.copy(), point_scalar.e.copy()] for point_scalar in rhs_msm.point_scalars]
res = []
lhs = lhs_msm.msm(rhs_de)
rhs = rhs_msm.msm(lhs_de)
result = sum([lhs_pt_scalar.authenticated_open(rhs_pt_scalar) for lhs_pt_scalar, rhs_pt_scalar in zip (lhs_msm.point_scalars , rhs_msm.point_scalars)])
res = lhs.authenticated_open(rhs)
assert result == res
assert result == sum([p*s for p, s in zip(points, scalars)]), 'res: {}, expected: {}'.format(res, sum([p*s for p, s in zip(points, scalars)]))