Decision tree training.

This commit is contained in:
Marcel Keller
2022-11-09 11:21:34 +11:00
parent 90336561f8
commit cd25c2e9f1
187 changed files with 2357 additions and 329 deletions

View File

@@ -235,6 +235,9 @@ public:
template <class T>
static void ands(T& processor, const vector<int>& args) { processor.ands(args); }
template <class T>
static void andrsvec(T& processor, const vector<int>& args)
{ processor.andrsvec(args); }
template <class T>
static void xors(T& processor, const vector<int>& args) { processor.xors(args); }
template <class T>
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }

View File

@@ -1,5 +1,16 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.4 (Nov 9, 2022)
- Decision tree learning
- Optimized oblivious shuffle in Rep3
- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC
- Optimized element-vector AND in SemiBin
- Optimized input protocol in Shamir-based protocols
- Square-root ORAM (@Quitlox)
- Improved ORAM in binary circuits
- UTF-8 outputs
## 0.3.3 (Aug 25, 2022)
- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate

4
CONFIG
View File

@@ -67,8 +67,11 @@ endif
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
LDLIBS += -lboost_system -lssl -lcrypto
CFLAGS += -I./local/include
ifeq ($(USE_NTL),1)
CFLAGS += -DUSE_NTL
LDLIBS := -lntl $(LDLIBS)
@@ -100,5 +103,4 @@ ifeq ($(USE_KOS),1)
CFLAGS += -DUSE_KOS
else
CFLAGS += -std=c++17
LDLIBS += -llibOTe -lcryptoTools
endif

View File

@@ -13,6 +13,7 @@ import Compiler.instructions as spdz
import Compiler.tools as tools
import collections
import itertools
import math
class SecretBitsAF(base.RegisterArgFormat):
reg_type = 'sb'
@@ -50,6 +51,7 @@ opcodes = dict(
INPUTBVEC = 0x247,
SPLIT = 0x248,
CONVCBIT2S = 0x249,
ANDRSVEC = 0x24a,
XORCBI = 0x210,
BITDECC = 0x211,
NOTCB = 0x212,
@@ -155,6 +157,52 @@ class andrs(BinaryVectorInstruction):
def add_usage(self, req_node):
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
req_node.increment(('bit', 'mixed'),
sum(int(math.ceil(x / 64)) for x in self.args[::4]))
class andrsvec(base.VarArgsInstruction, base.Mergeable,
base.DynFormatInstruction):
""" Constant-vector AND of secret bit registers (vectorized version).
:param: total number of arguments to follow (int)
:param: number of arguments to follow for one operation /
operation vector size plus three (int)
:param: vector size (int)
:param: result vector (sbit)
:param: (repeat)...
:param: constant operand (sbits)
:param: vector operand
:param: (repeat)...
:param: (repeat from number of arguments to follow for one operation)...
"""
code = opcodes['ANDRSVEC']
def __init__(self, *args, **kwargs):
super(andrsvec, self).__init__(*args, **kwargs)
for i, n in self.bases(iter(self.args)):
size = self.args[i + 1]
for x in self.args[i + 2:i + n]:
assert x.n == size
@classmethod
def dynamic_arg_format(cls, args):
yield 'int'
for i, n in cls.bases(args):
yield 'int'
n_args = (n - 3) // 2
assert n_args > 0
for j in range(n_args):
yield 'sbw'
for j in range(n_args + 1):
yield 'sb'
yield 'int'
def add_usage(self, req_node):
for i, n in self.bases(iter(self.args)):
size = self.args[i + 1]
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
req_node.increment(('bit', 'mixed'), size)
class ands(BinaryVectorInstruction):
""" Bitwise AND of secret bit register vector.
@@ -605,6 +653,7 @@ class inputbvec(base.DoNotEliminateInstruction, base.VarArgsInstruction,
for i, n in cls.bases(args):
yield 'int'
yield 'p'
assert n > 3
for j in range(n - 3):
yield 'sbw'
yield 'int'

View File

@@ -652,7 +652,7 @@ class sbitvec(_vec, _bit):
You can access the rows by member :py:obj:`v` and the columns by calling
:py:obj:`elements`.
There are three ways to create an instance:
There are four ways to create an instance:
1. By transposition::
@@ -685,6 +685,11 @@ class sbitvec(_vec, _bit):
This should output::
[1, 0, 1]
4. Private input::
x = sbitvec.get_type(32).get_input_from(player)
"""
bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v)))
is_clear = False
@@ -904,6 +909,34 @@ class sbitvec(_vec, _bit):
def __mul__(self, other):
if isinstance(other, int):
return self.from_vec(x * other for x in self.v)
if isinstance(other, sbitvec):
if len(other.v) == 1:
other = other.v[0]
elif len(self.v) == 1:
self, other = other, self.v[0]
else:
raise CompilerError('no operand of lenght 1: %d/%d',
(len(self.v), len(other.v)))
if not isinstance(other, sbits):
return NotImplemented
ops = []
for x in self.v:
if not util.is_zero(x):
assert x.n == other.n
ops.append(x)
if ops:
prods = [sbits.get_type(other.n)() for i in ops]
inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops)
res = []
i = 0
for x in self.v:
if util.is_zero(x):
res.append(0)
else:
res.append(prods[i])
i += 1
return sbitvec.from_vec(res)
__rmul__ = __mul__
def __add__(self, other):
return self.from_vec(x + y for x, y in zip(self.v, other))
def bit_and(self, other):
@@ -945,6 +978,13 @@ class sbitvec(_vec, _bit):
else:
res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v])
return res
def demux(self):
if len(self) == 1:
return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]])
a = sbitvec.from_vec(self.v[:len(self) // 2]).demux()
b = sbitvec.from_vec(self.v[len(self) // 2:]).demux()
prod = [a * bb for bb in b.v]
return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod)))
class bit(object):
n = 1
@@ -1243,20 +1283,19 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
return other * self.v[0]
elif isinstance(other, sbitfixvec):
return NotImplemented
_, other_bits = self.expand(other, False)
my_bits, other_bits = self.expand(other, False)
matrix = []
m = float('inf')
for x in itertools.chain(self.v, other_bits):
for x in itertools.chain(my_bits, other_bits):
try:
m = min(m, x.n)
except:
pass
if m == 1:
op = operator.mul
else:
op = operator.and_
matrix = []
for i, b in enumerate(other_bits):
matrix.append([op(x, b) for x in self.v[:len(self.v)-i]])
if m == 1:
matrix.append([x * b for x in my_bits[:len(self.v)-i]])
else:
matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v)
v = sbitint.wallace_tree_from_matrix(matrix)
return self.from_vec(v[:len(self.v)])
__rmul__ = __mul__
@@ -1366,7 +1405,7 @@ class sbitfix(_fix):
cls.set_precision(f, k)
return cls._new(cls.int_type(other), k, f)
class sbitfixvec(_fix):
class sbitfixvec(_fix, _vec):
""" Vector of fixed-point numbers for parallel binary computation.
Use :py:obj:`set_precision()` to change the precision.

View File

@@ -261,6 +261,7 @@ class Merger:
instructions = self.instructions
merge_nodes = self.open_nodes
depths = self.depths
self.req_num = defaultdict(lambda: 0)
if not merge_nodes:
return 0
@@ -281,6 +282,7 @@ class Merger:
print('Merging %d %s in round %d/%d' % \
(len(merge), t.__name__, i, len(merges)))
self.do_merge(merge)
self.req_num[t.__name__, 'round'] += 1
preorder = None
@@ -530,7 +532,9 @@ class Merger:
can_eliminate_defs = True
for reg in inst.get_def():
for dup in reg.duplicates:
if not dup.can_eliminate:
if not (dup.can_eliminate and reduce(
operator.and_,
(x.can_eliminate for x in dup.vector), True)):
can_eliminate_defs = False
break
# remove if instruction has result that isn't used

View File

@@ -137,8 +137,6 @@ def sha3_256(x):
0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7
0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067
Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only
implemented for computation modulo a power of two.
"""
global Keccak_f

View File

@@ -1,5 +1,6 @@
from Compiler.path_oram import *
from Compiler.oram import *
from Compiler.path_oram import PathORAM, XOR
from Compiler.util import bit_compose
def first_diff(a_bits, b_bits):

View File

@@ -125,6 +125,13 @@ class Compiler:
default=defaults.binary,
help="bit length of sint in binary circuit (default: 0 for arithmetic)",
)
parser.add_option(
"-G",
"--garbled-circuit",
dest="garbled",
action="store_true",
help="compile for binary circuits only (default: false)",
)
parser.add_option(
"-F",
"--field",

504
Compiler/decision_tree.py Normal file
View File

@@ -0,0 +1,504 @@
from Compiler.types import *
from Compiler.sorting import *
from Compiler.library import *
from Compiler import util, oram
from itertools import accumulate
import math
debug = False
debug_split = False
debug_layers = False
max_leaves = None
def get_type(x):
if isinstance(x, (Array, SubMultiArray)):
return x.value_type
elif isinstance(x, (tuple, list)):
x = x[0] + x[-1]
if util.is_constant(x):
return cint
else:
return type(x)
else:
return type(x)
def PrefixSum(x):
return x.get_vector().prefix_sum()
def PrefixSumR(x):
tmp = get_type(x).Array(len(x))
tmp.assign_vector(x)
break_point()
tmp[:] = tmp.get_reverse_vector().prefix_sum()
break_point()
return tmp.get_reverse_vector()
def PrefixSum_inv(x):
tmp = get_type(x).Array(len(x) + 1)
tmp.assign_vector(x, base=1)
tmp[0] = 0
return tmp.get_vector(size=len(x), base=1) - tmp.get_vector(size=len(x))
def PrefixSumR_inv(x):
tmp = get_type(x).Array(len(x) + 1)
tmp.assign_vector(x)
tmp[-1] = 0
return tmp.get_vector(size=len(x)) - tmp.get_vector(base=1, size=len(x))
class SortPerm:
def __init__(self, x):
B = sint.Matrix(len(x), 2)
B.set_column(0, 1 - x.get_vector())
B.set_column(1, x.get_vector())
self.perm = Array.create_from(dest_comp(B))
def apply(self, x):
res = Array.create_from(x)
reveal_sort(self.perm, res, False)
return res
def unapply(self, x):
res = Array.create_from(x)
reveal_sort(self.perm, res, True)
return res
def Sort(keys, *to_sort, n_bits=None, time=False):
if time:
start_timer(1)
for k in keys:
assert len(k) == len(keys[0])
n_bits = n_bits or [None] * len(keys)
bs = Matrix.create_from(
sum([k.get_vector().bit_decompose(nb)
for k, nb in reversed(list(zip(keys, n_bits)))], []))
res = Matrix.create_from(to_sort)
res = res.transpose()
if time:
start_timer(11)
print_ln('sort')
radix_sort_from_matrix(bs, res)
if time:
stop_timer(11)
stop_timer(1)
return res.transpose()
def VectMax(key, *data):
def reducer(x, y):
b = x[0] > y[0]
return [b.if_else(xx, yy) for xx, yy in zip(x, y)]
if debug:
key = list(key)
data = [list(x) for x in data]
print_ln('vect max key=%s data=%s', util.reveal(key), util.reveal(data))
return util.tree_reduce(reducer, zip(key, *data))[1:]
def GroupSum(g, x):
assert len(g) == len(x)
p = PrefixSumR(x) * g
pi = SortPerm(g.get_vector().bit_not())
p1 = pi.apply(p)
s1 = PrefixSumR_inv(p1)
d1 = PrefixSum_inv(s1)
d = pi.unapply(d1) * g
return PrefixSum(d)
def GroupPrefixSum(g, x):
assert len(g) == len(x)
s = get_type(x).Array(len(x) + 1)
s[0] = 0
s.assign_vector(PrefixSum(x), base=1)
q = get_type(s).Array(len(x))
q.assign_vector(s.get_vector(size=len(x)) * g)
return s.get_vector(size=len(x), base=1) - GroupSum(g, q)
def GroupMax(g, keys, *x):
if debug:
print_ln('group max input g=%s keys=%s x=%s', util.reveal(g),
util.reveal(keys), util.reveal(x))
assert len(keys) == len(g)
for xx in x:
assert len(xx) == len(g)
n = len(g)
m = int(math.ceil(math.log(n, 2)))
keys = Array.create_from(keys)
x = [Array.create_from(xx) for xx in x]
g_new = Array.create_from(g)
g_old = g_new.same_shape()
for d in range(m):
w = 2 ** d
g_old[:] = g_new[:]
break_point()
vsize = n - w
g_new.assign_vector(g_old.get_vector(size=vsize).bit_or(
g_old.get_vector(size=vsize, base=w)), base=w)
b = keys.get_vector(size=vsize) > keys.get_vector(size=vsize, base=w)
for xx in [keys] + x:
a = b.if_else(xx.get_vector(size=vsize),
xx.get_vector(size=vsize, base=w))
xx.assign_vector(g_old.get_vector(size=vsize, base=w).if_else(
xx.get_vector(size=vsize, base=w), a), base=w)
break_point()
if debug:
print_ln('group max w=%s b=%s a=%s keys=%s x=%s g=%s', w, b.reveal(),
util.reveal(a), util.reveal(keys),
util.reveal(x), g_new.reveal())
t = sint.Array(len(g))
t[-1] = 1
t.assign_vector(g.get_vector(size=n - 1, base=1))
if debug:
print_ln('group max end g=%s t=%s keys=%s x=%s', util.reveal(g),
util.reveal(t), util.reveal(keys), util.reveal(x))
return [GroupSum(g, t[:] * xx) for xx in [keys] + x]
def ModifiedGini(g, y, debug=False):
assert len(g) == len(y)
y = [y.get_vector().bit_not(), y]
u = [GroupPrefixSum(g, yy) for yy in y]
s = [GroupSum(g, yy) for yy in y]
w = [ss - uu for ss, uu in zip(s, u)]
us = sum(u)
ws = sum(w)
uqs = u[0] ** 2 + u[1] ** 2
wqs = w[0] ** 2 + w[1] ** 2
res = sfix(uqs) / us + sfix(wqs) / ws
if debug:
print_ln('u0=%s', util.reveal(u[0]))
print_ln('u0=%s', util.reveal(u[1]))
print_ln('us=%s', util.reveal(us))
print_ln('w0=%s', util.reveal(w[0]))
print_ln('w1=%s', util.reveal(w[1]))
print_ln('ws=%s', util.reveal(ws))
print_ln('p=%s', util.reveal(p))
print_ln('q=%s', util.reveal(q))
print_ln('g=%s y=%s s=%s',
util.reveal(g), util.reveal(y),
util.reveal(s))
if debug:
print_ln('gini %s %s', str(res), util.reveal(res))
return res
MIN_VALUE = -10000
def FormatLayer(h, g, *a):
return CropLayer(h, *FormatLayer_without_crop(g, *a))
def FormatLayer_without_crop(g, *a):
for x in a:
assert len(x) == len(g)
v = [g.if_else(aa, 0) for aa in a]
v = Sort([g.bit_not()], *v, n_bits=[1])
return v
def CropLayer(k, *v):
if max_leaves:
n = min(2 ** k, max_leaves)
else:
n = 2 ** k
return [vv[:min(n, len(vv))] for vv in v]
def TrainLeafNodes(h, g, y, NID):
assert len(g) == len(y)
assert len(g) == len(NID)
Label = GroupSum(g, y.bit_not()) < GroupSum(g, y)
return FormatLayer(h, g, NID, Label)
def GroupSame(g, y):
assert len(g) == len(y)
s = GroupSum(g, [sint(1)] * len(g))
s0 = GroupSum(g, y.bit_not())
s1 = GroupSum(g, y)
if debug_split:
print_ln('group same g=%s', util.reveal(g))
print_ln('group same y=%s', util.reveal(y))
return (s == s0).bit_or(s == s1)
def GroupFirstOne(g, b):
assert len(g) == len(b)
s = GroupPrefixSum(g, b)
return s * b == 1
class TreeTrainer:
""" Decision tree training by `Hamada et al.`_
:param x: sample data (by attribute, list or
:py:obj:`~Compiler.types.Matrix`)
:param y: binary labels (list or sint vector)
:param h: height (int)
:param binary: binary attributes instead of continuous
:param attr_lengths: attribute description for mixed data
(list of 0/1 for continuous/binary)
:param n_threads: number of threads (default: single thread)
.. _`Hamada et al.`: https://arxiv.org/abs/2112.12906
"""
def ApplyTests(self, x, AID, Threshold):
m = len(x)
n = len(AID)
assert len(AID) == len(Threshold)
for xx in x:
assert len(xx) == len(AID)
e = sint.Matrix(m, n)
AID = Array.create_from(AID)
@for_range_multithread(self.n_threads, 1, m)
def _(j):
e[j][:] = AID[:] == j
xx = sum(x[j] * e[j] for j in range(m))
if debug:
print_ln('apply e=%s xx=%s', util.reveal(e), util.reveal(xx))
return 2 * xx < Threshold
def AttributeWiseTestSelection(self, g, x, y, time=False, debug=False):
assert len(g) == len(x)
assert len(g) == len(y)
if time:
start_timer(2)
s = ModifiedGini(g, y, debug=debug)
if time:
stop_timer(2)
if debug:
print_ln('gini %s', s.reveal())
xx = x
t = get_type(x).Array(len(x))
t[-1] = MIN_VALUE
t.assign_vector(xx.get_vector(size=len(x) - 1) + \
xx.get_vector(size=len(x) - 1, base=1))
gg = g
p = sint.Array(len(x))
p[-1] = 1
p.assign_vector(gg.get_vector(base=1, size=len(x) - 1).bit_or(
xx.get_vector(size=len(x) - 1) == \
xx.get_vector(size=len(x) - 1, base=1)))
break_point()
if debug:
print_ln('attribute t=%s p=%s', util.reveal(t), util.reveal(p))
s = p[:].if_else(MIN_VALUE, s)
t = p[:].if_else(MIN_VALUE, t[:])
if debug:
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
if time:
start_timer(3)
s, t = GroupMax(gg, s, t)
if time:
stop_timer(3)
if debug:
print_ln('attribute s=%s t=%s', util.reveal(s), util.reveal(t))
return t, s
def GlobalTestSelection(self, x, y, g):
assert len(y) == len(g)
for xx in x:
assert(len(xx) == len(g))
m = len(x)
n = len(y)
u, t = [get_type(x).Matrix(m, n) for i in range(2)]
v = get_type(y).Matrix(m, n)
s = sfix.Matrix(m, n)
@for_range_multithread(self.n_threads, 1, m)
def _(j):
single = not self.n_threads or self.n_threads == 1
print_ln('run %s', j)
@if_e(self.attr_lengths[j])
def _():
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
n_bits=[util.log2(n), 1], time=single)
@else_
def _():
u[j][:], v[j][:] = Sort((PrefixSum(g), x[j]), x[j], y,
n_bits=[util.log2(n), None],
time=single)
if self.debug_threading:
print_ln('global sort %s %s %s', j, util.reveal(u[j]),
util.reveal(v[j]))
t[j][:], s[j][:] = self.AttributeWiseTestSelection(
g, u[j], v[j], time=single, debug=self.debug_selection)
if self.debug_threading:
print_ln('global attribute %s %s %s', j, util.reveal(t[j]),
util.reveal(s[j]))
n = len(g)
a, tt = [sint.Array(n) for i in range(2)]
if self.debug_threading:
print_ln('global s=%s', util.reveal(s))
if self.debug_gini:
print_ln('Gini indices ' + ' '.join(str(i) + ':%s' for i in range(m)),
*(ss[0].reveal() for ss in s))
start_timer(4)
a[:], tt[:] = VectMax((s[j][:] for j in range(m)), range(m),
(t[j][:] for j in range(m)))
stop_timer(4)
return a[:], tt[:]
def TrainInternalNodes(self, k, x, y, g, NID):
assert len(g) == len(y)
for xx in x:
assert len(xx) == len(g)
AID, Threshold = self.GlobalTestSelection(x, y, g)
s = GroupSame(g[:], y[:])
if debug or debug_split:
print_ln('AID=%s', util.reveal(AID))
print_ln('Threshold=%s', util.reveal(Threshold))
print_ln('GroupSame=%s', util.reveal(s))
AID, Threshold = s.if_else(0, AID), s.if_else(MIN_VALUE, Threshold)
b = self.ApplyTests(x, AID, Threshold)
return FormatLayer_without_crop(g[:], NID, AID, Threshold), b
@method_block
def train_layer(self, k):
x = self.x
y = self.y
g = self.g
NID = self.NID
layer_matrix = self.layer_matrix
self.layer_matrix[k], b = \
self.TrainInternalNodes(k, x, y, g, NID)
if debug:
print_ln('internal %s %s',
util.reveal(layer_matrix[k]), util.reveal(b))
if debug_layers:
print_ln('layer %s:', k)
for name, data in zip(('NID', 'AID', 'Thr'), layer_matrix[k]):
print_ln(' %s: %s', name, data.reveal())
NID[:] = 2 ** k * b + NID
b_not = b.bit_not()
if debug:
print_ln('b_not=%s', b_not.reveal())
g[:] = GroupFirstOne(g, b_not) + GroupFirstOne(g, b)
y[:], g[:], NID[:], *xx = Sort([b], y, g, NID, *x, n_bits=[1])
for i, xxx in enumerate(xx):
x[i] = xxx
def __init__(self, x, y, h, binary=False, attr_lengths=None,
n_threads=None):
assert not (binary and attr_lengths)
if binary:
attr_lengths = [1] * len(x)
else:
attr_lengths = attr_lengths or ([0] * len(x))
for l in attr_lengths:
assert l in (0, 1)
self.attr_lengths = Array.create_from(regint(attr_lengths))
Array.check_indices = False
Matrix.disable_index_checks()
for xx in x:
assert len(xx) == len(y)
n = len(y)
self.g = sint.Array(n)
self.g.assign_all(0)
self.g[0] = 1
self.NID = sint.Array(n)
self.NID.assign_all(1)
self.y = Array.create_from(y)
self.x = Matrix.create_from(x)
self.layer_matrix = sint.Tensor([h, 3, n])
self.n_threads = n_threads
self.debug_selection = False
self.debug_threading = False
self.debug_gini = True
def train(self):
""" Train and return decision tree. """
h = len(self.layer_matrix)
@for_range(h)
def _(k):
self.train_layer(k)
return self.get_tree(h)
def train_with_testing(self, *test_set):
""" Train decision tree and test against test data.
:param y: binary labels (list or sint vector)
:param x: sample data (by attribute, list or
:py:obj:`~Compiler.types.Matrix`)
:returns: tree
"""
for k in range(len(self.layer_matrix)):
self.train_layer(k)
tree = self.get_tree(k + 1)
output_decision_tree(tree)
test_decision_tree('train', tree, self.y, self.x,
n_threads=self.n_threads)
if test_set:
test_decision_tree('test', tree, *test_set,
n_threads=self.n_threads)
return tree
def get_tree(self, h):
Layer = [None] * (h + 1)
for k in range(h):
Layer[k] = CropLayer(k, *self.layer_matrix[k])
Layer[h] = TrainLeafNodes(h, self.g[:], self.y[:], self.NID)
return Layer
def DecisionTreeTraining(x, y, h, binary=False):
return TreeTrainer(x, y, h, binary=binary).train()
def output_decision_tree(layers):
""" Print decision tree output by :py:class:`TreeTrainer`. """
print_ln('full model %s', util.reveal(layers))
for i, layer in enumerate(layers[:-1]):
print_ln('level %s:', i)
for j, x in enumerate(('NID', 'AID', 'Thr')):
print_ln(' %s: %s', x, util.reveal(layer[j]))
print_ln('leaves:')
for j, x in enumerate(('NID', 'result')):
print_ln(' %s: %s', x, util.reveal(layers[-1][j]))
def pick(bits, x):
if len(bits) == 1:
return bits[0] * x[0]
else:
try:
return x[0].dot_product(bits, x)
except:
return sum(aa * bb for aa, bb in zip(bits, x))
def run_decision_tree(layers, data):
""" Run decision tree against sample data.
:param layers: tree output by :py:class:`TreeTrainer`
:param data: sample data (:py:class:`~Compiler.types.Array`)
:returns: binary label
"""
h = len(layers) - 1
index = 1
for k, layer in enumerate(layers[:-1]):
assert len(layer) == 3
for x in layer:
assert len(x) <= 2 ** k
bits = layer[0].equal(index, k)
threshold = pick(bits, layer[2])
key_index = pick(bits, layer[1])
if key_index.is_clear:
key = data[key_index]
else:
key = pick(
oram.demux(key_index.bit_decompose(util.log2(len(data)))), data)
child = 2 * key < threshold
index += child * 2 ** k
bits = layers[h][0].equal(index, h)
return pick(bits, layers[h][1])
def test_decision_tree(name, layers, y, x, n_threads=None):
start_timer(100)
n = len(y)
x = x.transpose().reveal()
y = y.reveal()
guess = regint.Array(n)
truth = regint.Array(n)
correct = regint.Array(2)
parts = regint.Array(2)
layers = [Matrix.create_from(util.reveal(layer)) for layer in layers]
@for_range_multithread(n_threads, 1, n)
def _(i):
guess[i] = run_decision_tree([[part[:] for part in layer]
for layer in layers], x[i]).reveal()
truth[i] = y[i].reveal()
@for_range(n)
def _(i):
parts[truth[i]] += 1
c = (guess[i].bit_xor(truth[i]).bit_not())
correct[truth[i]] += c
print_ln('%s for height %s: %s/%s (%s/%s, %s/%s)', name, len(layers) - 1,
sum(correct), n, correct[0], parts[0], correct[1], parts[1])
stop_timer(100)

View File

@@ -311,6 +311,7 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
@instructions_base.ret_cisc
def Pow2(a, l, kappa):
comparison.program.curr_tape.require_bit_length(l - 1)
m = int(ceil(log(l, 2)))
t = BitDec(a, m, m, kappa)
return Pow2_from_bits(t)

View File

@@ -614,6 +614,18 @@ class submr(base.SubBase):
code = base.opcodes['SUBMR']
arg_format = ['sw','c','s']
@base.vectorize
class prefixsums(base.Instruction):
""" Prefix sum.
:param: result (sint)
:param: input (sint)
"""
__slots__ = []
code = base.opcodes['PREFIXSUMS']
arg_format = ['sw','s']
@base.gf2n
@base.vectorize
class mulc(base.MulBase):
@@ -2301,6 +2313,7 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction,
yield 'int'
for i, n in self.bases(args):
yield 's' + field + 'w'
assert n > 2
for j in range(n - 2):
yield 's' + field
yield 'int'

View File

@@ -80,6 +80,7 @@ opcodes = dict(
SUBSI = 0x2A,
SUBCFI = 0x2B,
SUBSFI = 0x2C,
PREFIXSUMS = 0x2D,
# Multiplication/division
MULC = 0x30,
MULM = 0x31,
@@ -702,10 +703,16 @@ class ClearIntAF(RegisterArgFormat):
reg_type = RegType.ClearInt
class IntArgFormat(ArgFormat):
n_bits = 32
@classmethod
def check(cls, arg):
if not isinstance(arg, int) and not arg is None:
raise ArgumentError(arg, 'Expected an integer-valued argument')
if not arg is None:
if not isinstance(arg, int):
raise ArgumentError(arg, 'Expected an integer-valued argument')
if arg >= 2 ** cls.n_bits or arg < -2 ** cls.n_bits:
raise ArgumentError(
arg, 'Immediate value outside of %d-bit range' % cls.n_bits)
@classmethod
def encode(cls, arg):
@@ -718,6 +725,8 @@ class IntArgFormat(ArgFormat):
return str(self.i)
class LongArgFormat(IntArgFormat):
n_bits = 64
@classmethod
def encode(cls, arg):
return list(struct.pack('>Q', arg))
@@ -729,8 +738,6 @@ class ImmediateModpAF(IntArgFormat):
@classmethod
def check(cls, arg):
super(ImmediateModpAF, cls).check(arg)
if arg >= 2**32 or arg < -2**32:
raise ArgumentError(arg, 'Immediate value outside of 32-bit range')
class ImmediateGF2NAF(IntArgFormat):
@classmethod

View File

@@ -139,7 +139,7 @@ def print_str_if(cond, ss, *args):
""" Print string conditionally. See :py:func:`print_ln_if` for details. """
if util.is_constant(cond):
if cond:
print_ln(ss, *args)
print_str(ss, *args)
else:
subs = ss.split('%s')
assert len(subs) == len(args) + 1
@@ -1021,9 +1021,11 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
def f(i):
state = tuplify(initializer())
start_block = get_block()
j = i * n_parallel
one = regint(1)
for k in range(n_parallel):
j = i * n_parallel + k
state = reducer(tuplify(loop_body(j)), state)
j += one
if n_parallel > 1 and start_block != get_block():
print('WARNING: parallelization broken '
'by control flow instruction')

View File

@@ -73,8 +73,13 @@ from functools import reduce
def log_e(x):
return mpc_math.log_fx(x, math.e)
use_mux = False
def exp(x):
return mpc_math.pow_fx(math.e, x)
if use_mux:
return mpc_math.mux_exp(math.e, x)
else:
return mpc_math.pow_fx(math.e, x)
def get_limit(x):
exp_limit = 2 ** (x.k - x.f - 1)
@@ -164,13 +169,16 @@ def softmax(x):
return softmax_from_exp(exp_for_softmax(x)[0])
def exp_for_softmax(x):
m = util.max(x) - get_limit(x[0]) + 1 + math.log(len(x), 2)
m = util.max(x) - get_limit(x[0]) + math.log(len(x))
mv = m.expand_to_vector(len(x))
try:
x = x.get_vector()
except AttributeError:
x = sfix(x)
return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m
if use_mux:
return exp(x - mv), m
else:
return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m
def softmax_from_exp(x):
return x / sum(x)
@@ -2002,6 +2010,9 @@ class Optimizer:
return res
def __init__(self, report_loss=None):
if get_program().options.binary:
raise CompilerError(
'machine learning code not compatible with binary circuits')
self.tol = 0.000
self.report_loss = report_loss
self.X_by_label = None

View File

@@ -8,6 +8,8 @@ This has to imported explicitly.
import math
import operator
from functools import reduce
from Compiler import floatingpoint
from Compiler import types
from Compiler import comparison
@@ -398,6 +400,36 @@ def exp2_fx(a, zero_output=False, as19=False):
return s.if_else(1 / g, g)
def mux_exp(x, y, block_size=8):
assert util.is_constant_float(x)
from Compiler.GC.types import sbitvec, sbits
bits = sbitvec.from_vec(y.v.bit_decompose(y.k, maybe_mixed=True)).v
sign = bits[-1]
m = math.log(2 ** (y.k - y.f - 1), x)
del bits[int(math.ceil(math.log(m, 2))) + y.f:]
parts = []
for i in range(0, len(bits), block_size):
one_hot = sbitvec.from_vec(bits[i:i + block_size]).demux().v
exp = []
try:
for j in range(len(one_hot)):
exp.append(types.cfix.int_rep(x ** (j * 2 ** (i - y.f)), y.f))
except OverflowError:
pass
exp = list(filter(lambda x: x < 2 ** (y.k - 1), exp))
bin_part = [0] * max(x.bit_length() for x in exp)
for j in range(len(bin_part)):
for k, (a, b) in enumerate(zip(one_hot, exp)):
bin_part[j] ^= a if util.bit_decompose(b, len(bin_part))[j] \
else 0
if util.is_zero(bin_part[j]):
bin_part[j] = sbits.get_type(y.size)(0)
if i == 0:
bin_part[j] = sign.if_else(0, bin_part[j])
parts.append(y._new(y.int_type(sbitvec.from_vec(bin_part))))
return util.tree_reduce(operator.mul, parts)
@types.vectorize
@instructions_base.sfix_cisc
def log2_fx(x, use_division=True):

View File

@@ -32,6 +32,8 @@ class NonLinear:
return shift_two(a, m)
prog = program.Program.prog
if prog.use_trunc_pr:
if not prog.options.ring:
prog.curr_tape.require_bit_length(k + prog.security)
if signed and prog.use_trunc_pr != -1:
a += (1 << (k - 1))
res = sint()

View File

@@ -1034,8 +1034,9 @@ def get_n_threads_for_tree(size):
class TreeORAM(AbstractORAM):
""" Tree ORAM. """
def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \
def __init__(self, size, value_type=None, value_length=1, entry_size=None, \
bucket_oram=TrivialORAM, init_rounds=-1):
value_type = value_type or sint
print('create oram of size', size)
self.bucket_oram = bucket_oram
# heuristic bucket size
@@ -1722,6 +1723,8 @@ def OptimalORAM(size,*args,**kwargs):
:param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` /
:py:class:`sfix`
"""
if not util.is_constant(size):
raise CompilerError('ORAM size has be a compile-time constant')
if get_program().options.binary:
return BinaryORAM(size, *args, **kwargs)
if optimal_threshold is None:
@@ -1772,6 +1775,12 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty):
def test_oram(oram_type, N, value_type=sint, iterations=100):
stop_grind()
oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0)
test_oram_initialized(oram, iterations)
return oram
def test_oram_initialized(oram, iterations=100):
N = oram.size
value_type = oram.value_type
value_type = value_type.get_type(32)
index_type = value_type.get_type(log2(N))
start_grind()

View File

@@ -29,6 +29,7 @@ data_types = dict(
bit=2,
inverse=3,
dabit=4,
mixed=5,
)
field_types = dict(
@@ -45,6 +46,7 @@ class defaults:
ring = 0
field = 0
binary = 0
garbled = False
prime = None
galois = 40
budget = 100000
@@ -150,10 +152,11 @@ class Program(object):
gc.ldmsd,
gc.stmsd,
gc.stmsdci,
gc.xors,
gc.andrs,
gc.ands,
gc.inputb,
gc.inputbvec,
gc.reveal,
]
self.use_trunc_pr = False
""" Setting whether to use special probabilistic truncation. """
@@ -350,7 +353,8 @@ class Program(object):
print("Writing to", sch_filename)
sch_file.write(str(self.max_par_tapes()) + "\n")
sch_file.write(str(len(nonempty_tapes)) + "\n")
sch_file.write(" ".join(tape.name for tape in nonempty_tapes) + "\n")
sch_file.write(" ".join("%s:%d" % (tape.name, len(tape))
for tape in nonempty_tapes) + "\n")
sch_file.write("1 0\n")
sch_file.write("0\n")
sch_file.write(" ".join(sys.argv) + "\n")
@@ -506,7 +510,8 @@ class Program(object):
self.set_security(security)
def optimize_for_gc(self):
pass
import Compiler.GC.instructions as gc
self.to_merge += [gc.xors]
def get_tape_counter(self):
res = self.tape_counter
@@ -686,6 +691,7 @@ class Tape:
self.purged = False
self.n_rounds = 0
self.n_to_merge = 0
self.rounds = Tape.ReqNum()
self.warn_about_mem = parent.program.warn_about_mem[-1]
def __len__(self):
@@ -750,6 +756,7 @@ class Tape:
inst.add_usage(req_node)
req_node.num["all", "round"] += self.n_rounds
req_node.num["all", "inv"] += self.n_to_merge
req_node.num += self.rounds
def expand_cisc(self):
new_instructions = []
@@ -796,7 +803,14 @@ class Tape:
self.name = name
self.outfile = self.program.programs_dir + "/Bytecode/" + self.name + ".bc"
def __len__(self):
if self.purged:
return self.size
else:
return sum(len(block) for block in self.basicblocks)
def purge(self):
self.size = len(self)
for block in self.basicblocks:
block.purge()
self._is_empty = len(self.basicblocks) == 0
@@ -865,6 +879,8 @@ class Tape:
numrounds = merger.longest_paths_merge()
block.n_rounds = numrounds
block.n_to_merge = len(merger.open_nodes)
if options.verbose:
block.rounds = merger.req_num
if merger.counter and self.program.verbose:
print(
"Block requires",
@@ -1113,7 +1129,8 @@ class Tape:
__rmul__ = __mul__
def set_all(self, value):
if value == float("inf") and self["all", "inv"] > 0:
if Program.prog.options.verbose and \
value == float("inf") and self["all", "inv"] > 0:
print("Going to unknown from %s" % self)
res = Tape.ReqNum()
for i in self:
@@ -1142,6 +1159,8 @@ class Tape:
res = []
for req, num in self.items():
domain = t(req[0])
if num < 0:
num = float('inf')
n = "%12.0f" % num
if req[1] == "input":
res += ["%s %s inputs from player %d" % (n, domain, req[2])]

View File

@@ -3,12 +3,7 @@ from Compiler import types, library, instructions
def dest_comp(B):
Bt = B.transpose()
Bt_flat = Bt.get_vector()
St_flat = Bt.value_type.Array(len(Bt_flat))
St_flat.assign(Bt_flat)
@library.for_range(len(St_flat) - 1)
def _(i):
St_flat[i + 1] = St_flat[i + 1] + St_flat[i]
St_flat = Bt.get_vector().prefix_sum()
Tt_flat = Bt.get_vector() * St_flat.get_vector()
Tt = types.Matrix(*Bt.sizes, B.value_type)
Tt.assign_vector(Tt_flat)
@@ -37,8 +32,14 @@ def radix_sort(k, D, n_bits=None, signed=True):
bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits))
if signed and len(bs) > 1:
bs[-1][:] = bs[-1][:].bit_not()
B = types.sint.Matrix(len(k), 2)
h = types.Array.create_from(types.sint(types.regint.inc(len(k))))
radix_sort_from_matrix(bs, D)
def radix_sort_from_matrix(bs, D):
n = len(D)
for b in bs:
assert(len(b) == n)
B = types.sint.Matrix(n, 2)
h = types.Array.create_from(types.sint(types.regint.inc(n)))
@library.for_range(len(bs))
def _(i):
b = bs[i]

View File

@@ -10,9 +10,7 @@ from Compiler.GC.types import cbit, sbit, sbitint, sbits
from Compiler.program import Program
from Compiler.types import (Array, MemValue, MultiArray, _clear, _secret, cint,
regint, sint, sintbit)
from oram import demux_array, get_n_threads
program = Program.prog
from Compiler.oram import demux_array, get_n_threads
# Adds messages on completion of heavy computation steps
debug = False
@@ -44,6 +42,13 @@ B = TypeVar("B", sintbit, sbit)
class SqrtOram(Generic[T, B]):
"""Oblivious RAM using the "Square-Root" algorithm.
:param MultiArray data: The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array).
:param sint value_type: The secret type to use, defaults to sint.
:param int k: Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
:param int period: Leave at None, this parameter is used to recursively pass down the top-level period.
"""
# TODO: Preferably this is an Array of vectors, but this is currently not supported
# One should regard these structures as Arrays where an entry may hold more
# than one value (which is a nice property to have when using the ORAM in
@@ -69,14 +74,6 @@ class SqrtOram(Generic[T, B]):
t: cint
def __init__(self, data: T | MultiArray, entry_length: int = 1, value_type: Type[T] = sint, k: int = 0, period: int | None = None, initialize: bool = True, empty_data=False) -> None:
"""Initialize a new Oblivious RAM using the "Square-Root" algorithm.
Args:
data (MultiArray): The data with which to initialize the ORAM. One may provide a MultiArray such that every "block" can hold multiple elements (an Array).
value_type (sint): The secret type to use, defaults to sint.
k (int): Leave at 0, this parameter is used to recursively pass down the depth of this ORAM.
period (int): Leave at None, this parameter is used to recursively pass down the top-level period.
"""
global debug, allow_memory_allocation
# Correctly initialize the shuffle (memory) depending on the type of data
@@ -103,6 +100,7 @@ class SqrtOram(Generic[T, B]):
self.index_size = util.log2(self.n) + 1 # +1 because signed
self.index_type = value_type.get_type(self.index_size)
self.entry_length = entry_length
self.size = self.n
if debug:
lib.print_ln(
@@ -632,6 +630,7 @@ class RecursivePositionMap(PositionMap[T, B], SqrtOram[T, B]):
# The item at logical_address
# will be in block with index h (block.<h>)
# at position l in block.data (block.data<l>)
program = Program.prog
h = MemValue(self.value_type.bit_compose(sbits.get_type(program.bit_length)(
logical_address).right_shift(pack_log, program.bit_length)))
l = self.value_type.bit_compose(sbits(logical_address) & (pack - 1))

View File

@@ -749,7 +749,14 @@ class _register(Tape.Register, _number, _structure):
self.mov(res[i], self)
return res
class _clear(_register):
class _arithmetic_register(_register):
""" Arithmetic circuit type. """
def __init__(self, *args, **kwargs):
if program.options.garbled:
raise CompilerError('functionality only available in arithmetic circuits')
super(_arithmetic_register, self).__init__(*args, **kwargs)
class _clear(_arithmetic_register):
""" Clear domain-dependent type. """
__slots__ = []
mov = staticmethod(movc)
@@ -1085,6 +1092,8 @@ class cint(_clear, _int):
def __ne__(self, other):
return 1 - (self == other)
equal = lambda self, other, *args, **kwargs: self.__eq__(other)
def __lshift__(self, other):
""" Clear left shift.
@@ -1836,7 +1845,7 @@ class longint:
res += x.bit_decompose(64)
return res[:bit_length]
class _secret(_register, _secret_structure):
class _secret(_arithmetic_register, _secret_structure):
__slots__ = []
mov = staticmethod(set_instruction_type(movs))
@@ -2682,6 +2691,15 @@ class sint(_secret, _int):
comparison.Trunc(res, tmp, 2 * k, k, kappa, True)
return res
@vectorize
def int_mod(self, other, bit_length=None):
""" Secret integer modulo.
:param other: sint
:param bit_length: bit length of input (default: global bit length)
"""
return self - other * self.int_div(other, bit_length=bit_length)
def trunc_zeros(self, n_zeros, bit_length=None, signed=True):
bit_length = bit_length or program.bit_length
return comparison.TruncZeros(self, bit_length, n_zeros, signed)
@@ -2808,6 +2826,13 @@ class sint(_secret, _int):
res = res.get_vector()
return res
@vectorize
def prefix_sum(self):
""" Prefix sum. """
res = sint()
prefixsums(res, self)
return res
class sintbit(sint):
""" :py:class:`sint` holding a bit, supporting binary operations
(``&, |, ^``). """
@@ -3940,6 +3965,8 @@ class _single(_number, _secret_structure):
:param n: number of inputs (int)
:param client_id: regint
:param size: vector size (default 1)
:returns: list of length ``n``
"""
sint_inputs = cls.int_type.receive_from_client(n, client_id,
message_type)
@@ -3977,6 +4004,8 @@ class _single(_number, _secret_structure):
def conv(cls, other):
if isinstance(other, cls):
return other
elif isinstance(other, (list, tuple)):
return type(other)(cls.conv(x) for x in other)
else:
try:
return cls.from_sint(other)
@@ -4216,7 +4245,7 @@ class _fix(_single):
if isinstance(other, _fix) and (cls.k, cls.f) == (other.k, other.f):
return other
else:
return cls(other)
return super(_fix, cls).conv(other)
@classmethod
def _new(cls, other, k=None, f=None):
@@ -4524,6 +4553,9 @@ class sfix(_fix):
return self._new(self.v.secure_permute(*args, **kwargs),
k=self.k, f=self.f)
def prefix_sum(self):
return self._new(self.v.prefix_sum(), k=self.k, f=self.f)
class unreduced_sfix(_single):
int_type = sint
@@ -5271,6 +5303,8 @@ class Array(_vectorizable):
a[:] += b[:]
"""
check_indices = True
@classmethod
def create_from(cls, l):
""" Convert Python iterator or vector to array. Basic type will be taken
@@ -5283,7 +5317,9 @@ class Array(_vectorizable):
"""
if isinstance(l, cls):
return l
res = l.same_shape()
res[:] = l[:]
return res
if isinstance(l, _number):
tmp = l
t = type(l)
@@ -5304,7 +5340,6 @@ class Array(_vectorizable):
self.debug = debug
self.creator_tape = program.curr_tape
self.sink = None
self.check_indices = True
if alloc:
self.alloc()
@@ -5435,7 +5470,10 @@ class Array(_vectorizable):
return self.value_type.load_mem(address)
def _store(self, value, address):
self.value_type.conv(value).store_in_mem(address)
tmp = self.value_type.conv(value)
if not isinstance(tmp, _vec) and tmp.size != self.value_type.mem_size():
raise CompilerError('size mismatch in array assignment')
tmp.store_in_mem(address)
def __len__(self):
return self.length
@@ -5506,6 +5544,12 @@ class Array(_vectorizable):
get_part_vector = get_vector
def get_reverse_vector(self):
""" Return vector with content in reverse order. """
size = self.length
address = regint.inc(size, size - 1, -1)
return self.value_type.load_mem(self.address + address, size=size)
def get_part(self, base, size):
""" Part array.
@@ -5605,7 +5649,6 @@ class Array(_vectorizable):
""" Vector subtraction.
:param other: vector or container of same length and type that supports operations with type of this array """
assert len(self) == len(other)
return self.get_vector() - other
def __mul__(self, value):
@@ -5668,7 +5711,7 @@ class Array(_vectorizable):
""" Reveal the whole array.
:returns: Array of relevant clear type. """
return Array.create_from(x.reveal() for x in self)
return Array.create_from(self.get_vector().reveal())
def reveal_list(self):
""" Reveal as list. """
@@ -6367,13 +6410,15 @@ class SubMultiArray(_vectorizable):
res = Matrix(self.sizes[1], self.sizes[0], self.value_type)
library.break_point()
if self.value_type.n_elements() == 1:
@library.for_range_opt(self.sizes[0])
def _(j):
res.set_column(j, self[j][:])
nr = self.sizes[1]
nc = self.sizes[0]
a = regint.inc(nr * nc, 0, nr, 1, nc)
b = regint.inc(nr * nc, 0, 1, nc)
res[:] = self.value_type.load_mem(self.address + a + b)
else:
@library.for_range_opt(self.sizes[1])
@library.for_range_opt(self.sizes[1], budget=100)
def _(i):
@library.for_range_opt(self.sizes[0])
@library.for_range_opt(self.sizes[0], budget=100)
def _(j):
res[i][j] = self[j][i]
library.break_point()
@@ -6424,7 +6469,7 @@ class SubMultiArray(_vectorizable):
def randomize(self, *args):
""" Randomize according to data type. """
if self.total_size() < program.options.budget:
if self.total_size() < program.budget:
self.assign_vector(
self.value_type.get_random(*args, size=self.total_size()))
else:
@@ -6432,6 +6477,12 @@ class SubMultiArray(_vectorizable):
def _(i):
self[i].randomize(*args)
def reveal(self):
""" Reveal to :py:obj:`MultiArray` of same shape. """
res = MultiArray(self.sizes, self.value_type.clear_type)
res[:] = self.get_vector().reveal()
return res
def reveal_list(self):
""" Reveal as list. """
return list(self.get_vector().reveal())
@@ -6542,7 +6593,7 @@ class Matrix(MultiArray):
@staticmethod
def create_from(rows):
rows = list(rows)
if isinstance(rows[0], (list, tuple)):
if isinstance(rows[0], (list, tuple, Array)):
t = type(rows[0][0])
else:
t = type(rows[0])

View File

@@ -22,4 +22,5 @@ int main()
generate_mac_keys<Share<P256Element::Scalar>>(key, 2, prefix);
make_mult_triples<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix);
make_inverse<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix);
P256Element::finish();
}

View File

@@ -14,7 +14,14 @@ void P256Element::init()
curve = EC_GROUP_new_by_curve_name(NID_secp256k1);
assert(curve != 0);
auto modulus = EC_GROUP_get0_order(curve);
Scalar::init_field(BN_bn2dec(modulus), false);
auto mod = BN_bn2dec(modulus);
Scalar::init_field(mod, false);
free(mod);
}
void P256Element::finish()
{
EC_GROUP_free(curve);
}
P256Element::P256Element()
@@ -42,6 +49,11 @@ P256Element::P256Element(word other) :
BN_free(exp);
}
P256Element::~P256Element()
{
EC_POINT_free(point);
}
P256Element& P256Element::operator =(const P256Element& other)
{
assert(EC_POINT_copy(point, other.point) != 0);
@@ -99,7 +111,7 @@ bool P256Element::operator ==(const P256Element& other) const
return not cmp;
}
void P256Element::pack(octetStream& os) const
void P256Element::pack(octetStream& os, int) const
{
octet* buffer;
size_t length = EC_POINT_point2buf(curve, point,
@@ -107,9 +119,10 @@ void P256Element::pack(octetStream& os) const
assert(length != 0);
os.store_int(length, 8);
os.append(buffer, length);
free(buffer);
}
void P256Element::unpack(octetStream& os)
void P256Element::unpack(octetStream& os, int)
{
size_t length = os.get_int(8);
assert(

View File

@@ -32,11 +32,13 @@ public:
static string type_string() { return "P256"; }
static void init();
static void finish();
P256Element();
P256Element(const P256Element& other);
P256Element(const Scalar& other);
P256Element(word other);
~P256Element();
P256Element& operator=(const P256Element& other);
@@ -58,8 +60,8 @@ public:
bool is_zero() { return *this == P256Element(); }
void add(octetStream& os) { *this += os.get<P256Element>(); }
void pack(octetStream& os) const;
void unpack(octetStream& os);
void pack(octetStream& os, int = -1) const;
void unpack(octetStream& os, int = -1);
octetStream hash(size_t n_bytes) const;

View File

@@ -64,4 +64,5 @@ int main(int argc, const char** argv)
pShare::MAC_Check::teardown();
Share<P256Element>::MAC_Check::teardown();
P256Element::finish();
}

View File

@@ -30,6 +30,8 @@
#include "GC/ThreadMaster.hpp"
#include "GC/Secret.hpp"
#include "Machines/ShamirMachine.hpp"
#include "Machines/MalRep.hpp"
#include "Machines/Rep.hpp"
#include <assert.h>
@@ -69,4 +71,5 @@ void run(int argc, const char** argv)
preprocessing(tuples, n_tuples, sk, proc, opts);
// check(tuples, sk, {}, P);
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
P256Element::finish();
}

View File

@@ -140,4 +140,5 @@ void run(int argc, const char** argv)
pShare::MAC_Check::teardown();
T<P256Element>::MAC_Check::teardown();
P256Element::finish();
}

View File

@@ -15,7 +15,7 @@ make bankers-bonus-client.x
./compile.py bankers_bonus 1
Scripts/setup-ssl.sh <nparties>
Scripts/setup-clients.sh 3
Scripts/<protocol>.sh bankers_bonus-1 &
PLAYERS=<nparties> Scripts/<protocol>.sh bankers_bonus-1 &
./bankers-bonus-client.x 0 <nparties> 100 0 &
./bankers-bonus-client.x 1 <nparties> 200 0 &
./bankers-bonus-client.x 2 <nparties> 50 1

View File

@@ -116,6 +116,14 @@ void secure_init(T& setup, Player& P, U& machine,
ofstream file(filename);
os.output(file);
}
if (OnlineOptions::singleton.verbose)
{
cerr << "Ciphertext length: " << params.p0().numBits();
for (size_t i = 1; i < params.FFTD().size(); i++)
cerr << "+" << params.FFTD()[i].get_prime().numBits();
cerr << endl;
}
}
template <class FD>

View File

@@ -128,6 +128,7 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
bool ok=false;
int cnt=0;
(void) cnt;
while (!ok)
{ cnt++;
Stage_1(P,ciphertexts,c,pk);

View File

@@ -44,7 +44,8 @@ void BitAdder::add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summ
&supplies);
BitAdder().add(res, summands, start,
summands[0][0].size(), proc, T::default_length);
queues->wrap_up(job);
if (start)
queues->wrap_up(job);
}
else
add(res, summands, 0, res.size(), proc, length);

View File

@@ -6,12 +6,12 @@
#ifndef GC_BITPREPFILES_H_
#define GC_BITPREPFILES_H_
namespace GC
{
#include "ShiftableTripleBuffer.h"
#include "Processor/Data_Files.h"
namespace GC
{
template<class T>
class BitPrepFiles : public ShiftableTripleBuffer<T>, public Sub_Data_Files<T>
{

View File

@@ -11,11 +11,13 @@
#include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "GC/NoShare.h"
#include "GC/Processor.h"
#include "Math/gf2nlong.h"
#include "Tools/SwitchableOutput.h"
#include "Processor/DummyProtocol.h"
#include "Processor/Instruction.h"
#include "Protocols/FakePrep.h"
#include "Protocols/FakeMC.h"
#include "Protocols/FakeProtocol.h"
@@ -85,6 +87,11 @@ public:
{ processor.andrs(args); }
static void ands(GC::Processor<FakeSecret>& processor, const vector<int>& regs);
template <class T>
static void andrsvec(T&, const vector<int>&)
{ throw runtime_error("andrsvec not implemented"); }
static void andm(GC::Processor<FakeSecret>& processor, const ::Instruction& instruction)
{ processor.andm(instruction); }
template <class T>
static void xors(GC::Processor<T>& processor, const vector<int>& regs)
{ processor.xors(regs); }
template <class T>

View File

@@ -64,6 +64,7 @@ enum
INPUTBVEC = 0x247,
SPLIT = 0x248,
CONVCBIT2S = 0x249,
ANDRSVEC = 0x24a,
// write to clear
CLEAR_WRITE = 0x210,
XORCBI = 0x210,

View File

@@ -47,7 +47,7 @@ public:
~Machine();
void load_schedule(const string& progname);
void load_program(const string& threadname, const string& filename);
size_t load_program(const string& threadname, const string& filename);
template<class U>
void reset(const U& program);

View File

@@ -35,12 +35,14 @@ Machine<T>::~Machine()
}
template<class T>
void Machine<T>::load_program(const string& threadname, const string& filename)
size_t Machine<T>::load_program(const string& threadname,
const string& filename)
{
(void)threadname;
progs.push_back({});
progs.back().parse_file(filename);
reset(progs.back());
return progs.back().size();
}
template<class T>

View File

@@ -18,6 +18,8 @@ using namespace std;
class NoMemory
{
public:
void resize_min(size_t, const char*) {}
};
namespace GC

View File

@@ -154,6 +154,7 @@ public:
static void xors(Processor<NoShare>&, const vector<int>&) { fail(); }
static void ands(Processor<NoShare>&, const vector<int>&) { fail(); }
static void andrs(Processor<NoShare>&, const vector<int>&) { fail(); }
static void andrsvec(Processor<NoShare>&, const vector<int>&) { fail(); }
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }

View File

@@ -8,6 +8,8 @@
#include "PersonalPrep.h"
#include "Protocols/ShuffleSacrifice.hpp"
namespace GC
{
@@ -36,7 +38,8 @@ void PersonalPrep<T>::buffer_personal_triples(size_t batch_size, ThreadQueues* q
PersonalTripleJob job(&triples, input_player);
int start = queues->distribute(job, batch_size);
buffer_personal_triples(triples, start, batch_size);
queues->wrap_up(job);
if (start)
queues->wrap_up(job);
}
else
buffer_personal_triples(triples, 0, batch_size);

View File

@@ -10,6 +10,7 @@
#include "Protocols/Replicated.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "ShareSecret.hpp"
namespace GC

View File

@@ -91,6 +91,7 @@ public:
void and_(const vector<int>& args, bool repeat);
void andrs(const vector<int>& args) { and_(args, true); }
void ands(const vector<int>& args) { and_(args, false); }
void andrsvec(const vector<int>& args);
void input(const vector<int>& args);
void inputb(typename T::Input& input, ProcessorBase& input_processor,

View File

@@ -15,6 +15,7 @@ using namespace std;
#include "GC/Program.h"
#include "Access.h"
#include "Processor/FixInput.h"
#include "Math/BitVec.h"
#include "GC/Machine.hpp"
#include "Processor/ProcessorBase.hpp"
@@ -205,9 +206,13 @@ template<class U>
void Processor<T>::mem_op(int n, Memory<U>& dest, const Memory<U>& source,
Integer dest_address, Integer source_address)
{
dest.check_index(dest_address + n - 1);
source.check_index(source_address + n - 1);
auto d = &dest[dest_address];
auto s = &source[source_address];
for (int i = 0; i < n; i++)
{
dest[dest_address + i] = source[source_address + i];
*d++ = *s++;
}
}
@@ -302,6 +307,40 @@ void Processor<T>::and_(const vector<int>& args, bool repeat)
}
}
template <class T>
void Processor<T>::andrsvec(const vector<int>& args)
{
int N_BITS = T::default_length;
auto it = args.begin();
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
int base = *(it + n_args);
assert(n_args <= N_BITS);
for (int i = 0; i < size; i += 1)
{
if (i % N_BITS == 0)
for (int j = 0; j < n_args; j++)
S.at(*(it + j) + i / N_BITS).resize_regs(
min(N_BITS, size - i));
T y;
y.get_regs().push_back(S.at(base + i / N_BITS).get_reg(i % N_BITS));
for (int j = 0; j < n_args; j++)
{
T x, tmp;
x.get_regs().push_back(
S.at(*(it + n_args + 1 + j) + i / N_BITS).get_reg(
i % N_BITS));
tmp.and_(1, x, y, false);
S.at(*(it + j) + i / N_BITS).get_reg(i % N_BITS) = tmp.get_reg(0);
}
}
it += 2 * n_args + 1;
}
}
template <class T>
void Processor<T>::input(const vector<int>& args)
{

View File

@@ -40,6 +40,8 @@ class Program
Program();
size_t size() const { return p.size(); }
// Read in a program
void parse_file(const string& filename);
void parse(const string& programe);

View File

@@ -98,6 +98,9 @@ public:
static void ands(Processor<U>& processor, const vector<int>& args)
{ T::ands(processor, args); }
template<class U>
static void andrsvec(Processor<U>& processor, const vector<int>& args)
{ T::andrsvec(processor, args); }
template<class U>
static void xors(Processor<U>& processor, const vector<int>& args)
{ T::xors(processor, args); }
template<class U>

36
GC/Semi.cpp Normal file
View File

@@ -0,0 +1,36 @@
/*
* Semi.cpp
*
*/
#include "Semi.h"
#include "SemiPrep.h"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/SemiInput.hpp"
#include "Protocols/Beaver.hpp"
namespace GC
{
void Semi::prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
bool repeat)
{
if (repeat and OnlineOptions::singleton.live_prep)
{
this->triples.push_back({{}});
auto& triple = this->triples.back();
triple = dynamic_cast<SemiPrep*>(prep)->get_mixed_triple(n);
for (int i = 0; i < 2; i++)
triple[1 + i] = triple[1 + i].mask(n);
triple[0] = triple[0].extend_bit().mask(n);
shares.push_back(y - triple[0]);
shares.push_back(x - triple[1]);
lengths.push_back(n);
}
else
prepare_mul(x, y, n);
}
} /* namespace GC */

31
GC/Semi.h Normal file
View File

@@ -0,0 +1,31 @@
/*
* Semi.h
*
*/
#ifndef GC_SEMI_H_
#define GC_SEMI_H_
#include "Protocols/Beaver.h"
#include "SemiSecret.h"
namespace GC
{
class Semi : public Beaver<SemiSecret>
{
typedef Beaver<SemiSecret> super;
public:
Semi(Player& P) :
super(P)
{
}
void prepare_mult(const SemiSecret& x, const SemiSecret& y, int n,
bool repeat);
};
} /* namespace GC */
#endif /* GC_SEMI_H_ */

View File

@@ -4,6 +4,7 @@
*/
#include "SemiPrep.h"
#include "Semi.h"
#include "ThreadMaster.h"
#include "OT/NPartyTripleGenerator.h"
#include "OT/BitDiagonal.h"
@@ -21,7 +22,7 @@ SemiPrep::SemiPrep(DataPositions& usage, bool) :
{
}
void SemiPrep::set_protocol(Beaver<SemiSecret>& protocol)
void SemiPrep::set_protocol(SemiSecret::Protocol& protocol)
{
if (triple_generator)
{
@@ -53,6 +54,9 @@ SemiPrep::~SemiPrep()
{
if (triple_generator)
delete triple_generator;
this->print_left("mixed triples", mixed_triples.size(),
SemiSecret::type_string(),
this->usage.files.at(DATA_GF2N).at(DATA_MIXED));
}
void SemiPrep::buffer_bits()
@@ -64,4 +68,25 @@ void SemiPrep::buffer_bits()
}
}
array<SemiSecret, 3> SemiPrep::get_mixed_triple(int n)
{
assert(n < 128);
if (mixed_triples.empty())
{
assert(this->triple_generator);
this->triple_generator->generateMixedTriples();
for (auto& x : this->triple_generator->mixedTriples)
{
this->mixed_triples.push_back({{x[0], x[1], x[2]}});
}
this->triple_generator->unlock();
}
this->count(DATA_MIXED);
auto res = mixed_triples.back();
mixed_triples.pop_back();
return res;
}
} /* namespace GC */

View File

@@ -25,11 +25,13 @@ class SemiPrep : public BufferPrep<SemiSecret>, ShiftableTripleBuffer<SemiSecret
SeededPRNG secure_prng;
vector<array<SemiSecret, 3>> mixed_triples;
public:
SemiPrep(DataPositions& usage, bool = true);
~SemiPrep();
void set_protocol(Beaver<SemiSecret>& protocol);
void set_protocol(SemiSecret::Protocol& protocol);
void buffer_triples();
void buffer_bits();
@@ -37,6 +39,8 @@ public:
void buffer_squares() { throw not_implemented(); }
void buffer_inverses() { throw not_implemented(); }
array<SemiSecret, 3> get_mixed_triple(int n);
void get(Dtype type, SemiSecret* data)
{
BufferPrep<SemiSecret>::get(type, data);

View File

@@ -19,6 +19,7 @@ namespace GC
class SemiPrep;
class DealerPrep;
class Semi;
template<class T, class V>
class SemiSecretBase : public V, public ShareSecret<T>
@@ -88,9 +89,13 @@ public:
typedef MC MAC_Check;
typedef SemiInput<This> Input;
typedef SemiPrep LivePrep;
typedef Semi Protocol;
static MC* new_mc(typename SemiShare<BitVec>::mac_key_type);
static void andrsvec(Processor<SemiSecret>& processor,
const vector<int>& args);
SemiSecret()
{
}

View File

@@ -8,6 +8,7 @@
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/DealerMC.h"
#include "SemiSecret.h"
#include "Semi.h"
namespace GC
{
@@ -60,6 +61,60 @@ void SemiSecretBase<T, V>::trans(Processor<T>& processor, int n_outputs,
}
}
inline
void SemiSecret::andrsvec(Processor<SemiSecret>& processor,
const vector<int>& args)
{
int N_BITS = default_length;
auto protocol = ShareThread<SemiSecret>::s().protocol;
assert(protocol);
protocol->init_mul();
auto it = args.begin();
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
it += n_args;
int base = *it++;
assert(n_args <= N_BITS);
for (int i = 0; i < size; i += N_BITS)
{
square64 square;
for (int j = 0; j < n_args; j++)
square.rows[j] = processor.S.at(*(it + j) + i / N_BITS).get();
int n_ops = min(N_BITS, size - i);
square.transpose(n_args, n_ops);
for (int j = 0; j < n_ops; j++)
{
long bit = processor.S.at(base + i / N_BITS).get_bit(j);
auto y_ext = SemiSecret(bit).extend_bit();
protocol->prepare_mult(square.rows[j], y_ext, n_args, true);
}
}
it += n_args;
}
protocol->exchange();
it = args.begin();
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
for (int i = 0; i < size; i += N_BITS)
{
int n_ops = min(N_BITS, size - i);
square64 square;
for (int j = 0; j < n_ops; j++)
square.rows[j] = protocol->finalize_mul(n_args).get();
square.transpose(n_ops, n_args);
for (int j = 0; j < n_args; j++)
processor.S.at(*(it + j) + i / N_BITS) = square.rows[j];
}
it += 2 * n_args + 1;
}
}
template<class T, class V>
void SemiSecretBase<T, V>::load_clear(int n, const Integer& x)
{

View File

@@ -6,8 +6,6 @@
#ifndef GC_SHAREPARTY_H_
#define GC_SHAREPARTY_H_
#include "Protocols/ReplicatedMC.h"
#include "Protocols/MaliciousRepMC.h"
#include "ShareSecret.h"
#include "Processor.h"
#include "Program.h"

View File

@@ -16,14 +16,12 @@
#include "Protocols/fake-stuff.h"
#include "ShareThread.hpp"
#include "RepPrep.hpp"
#include "ThreadMaster.hpp"
#include "Thread.hpp"
#include "ShareSecret.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/fake-stuff.hpp"
namespace GC

View File

@@ -63,6 +63,7 @@ public:
static void ands(Processor<U>& processor, const vector<int>& args)
{ and_(processor, args, false); }
static void and_(Processor<U>& processor, const vector<int>& args, bool repeat);
static void andrsvec(Processor<U>& processor, const vector<int>& args);
static void xors(Processor<U>& processor, const vector<int>& args);
static void inputb(Processor<U>& processor, const vector<int>& args)
{ inputb(processor, processor, args); }

View File

@@ -8,16 +8,12 @@
#include "ShareSecret.h"
#include "MaliciousRepSecret.h"
#include "Protocols/MaliciousRepMC.h"
#include "ShareThread.h"
#include "Thread.h"
#include "square64.h"
#include "Protocols/Share.h"
#include "Protocols/ReplicatedMC.hpp"
#include "Protocols/Beaver.hpp"
#include "ShareParty.h"
#include "ShareThread.hpp"
#include "Thread.hpp"
@@ -288,6 +284,12 @@ void ShareSecret<U>::and_(
ShareThread<U>::s().and_(processor, args, repeat);
}
template<class U>
void ShareSecret<U>::andrsvec(Processor<U>& processor, const vector<int>& args)
{
ShareThread<U>::s().andrsvec(processor, args);
}
template<class U>
void ShareSecret<U>::xors(Processor<U>& processor, const vector<int>& args)
{

View File

@@ -7,11 +7,7 @@
#define GC_SHARETHREAD_H_
#include "Thread.h"
#include "MaliciousRepSecret.h"
#include "RepPrep.h"
#include "SemiHonestRepPrep.h"
#include "Processor/Data_Files.h"
#include "Protocols/ReplicatedInput.h"
#include <array>
@@ -45,6 +41,7 @@ public:
void check();
void and_(Processor<T>& processor, const vector<int>& args, bool repeat);
void andrsvec(Processor<T>& processor, const vector<int>& args);
void xors(Processor<T>& processor, const vector<int>& args);
};

View File

@@ -107,7 +107,7 @@ void ShareThread<T>::and_(Processor<T>& processor,
else
processor.S[right + j].mask(y_ext, n);
processor.S[left + j].mask(x_ext, n);
protocol->prepare_mul(x_ext, y_ext, n);
protocol->prepare_mult(x_ext, y_ext, n, repeat);
}
}
@@ -127,6 +127,53 @@ void ShareThread<T>::and_(Processor<T>& processor,
}
}
template<class T>
void ShareThread<T>::andrsvec(Processor<T>& processor, const vector<int>& args)
{
int N_BITS = T::default_length;
auto& protocol = this->protocol;
assert(protocol);
protocol->init_mul();
auto it = args.begin();
T x_ext, y_ext;
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
it += n_args;
int base = *it++;
assert(n_args <= N_BITS);
for (int i = 0; i < size; i += N_BITS)
{
int n_ops = min(N_BITS, size - i);
for (int j = 0; j < n_args; j++)
{
processor.S.at(*(it + j) + i / N_BITS).mask(x_ext, n_ops);
processor.S.at(base + i / N_BITS).mask(y_ext, n_ops);
protocol->prepare_mul(x_ext, y_ext, n_ops);
}
}
it += n_args;
}
protocol->exchange();
it = args.begin();
while (it < args.end())
{
int n_args = (*it++ - 3) / 2;
int size = *it++;
for (int i = 0; i < size; i += N_BITS)
{
int n_ops = min(N_BITS, size - i);
for (int j = 0; j < n_args; j++)
protocol->finalize_mul(n_ops).mask(
processor.S.at(*(it + j) + i / N_BITS), n_ops);
}
it += 2 * n_args + 1;
}
}
template<class T>
void ShareThread<T>::xors(Processor<T>& processor, const vector<int>& args)
{

View File

@@ -68,6 +68,7 @@ void ThreadMaster<T>::run()
P = new PlainPlayer(N, "main");
machine.load_schedule(progname);
machine.reset(machine.progs[0], memory);
for (int i = 0; i < machine.nthreads; i++)
threads.push_back(new_thread(i));

View File

@@ -8,7 +8,7 @@
#include "TinierSharePrep.h"
#include "PersonalPrep.h"
#include "PersonalPrep.hpp"
namespace GC
{

View File

@@ -46,7 +46,7 @@ public:
sizes.reserve(n);
}
void prepare_open(const T& secret)
void prepare_open(const T& secret, int = -1)
{
for (auto& part : secret.get_regs())
part_MC.prepare_open(part);

View File

@@ -6,6 +6,8 @@
#include "TinierSharePrep.h"
#include "Protocols/MascotPrep.hpp"
#include "Protocols/ShuffleSacrifice.hpp"
#include "Protocols/MalRepRingPrep.hpp"
namespace GC
{

View File

@@ -45,6 +45,7 @@
X(NOTS, processor.nots(INST)) \
X(NOTCB, processor.notcb(INST)) \
X(ANDRS, T::andrs(PROC, EXTRA)) \
X(ANDRSVEC, T::andrsvec(PROC, EXTRA)) \
X(ANDS, T::ands(PROC, EXTRA)) \
X(ANDM, T::andm(PROC, instruction)) \
X(ADDCB, C0 = PC1 + PC2) \

View File

@@ -1,19 +1,17 @@
CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License)
Copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
All rights reserved. CSIRO is willing to grant you a licence to this MP-SPDZ sofware on the following terms, except where otherwise indicated for third party material.
Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
* Neither the name of CSIRO nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission of CSIRO.
EXCEPT AS EXPRESSLY STATED IN THIS AGREEMENT AND TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE IS PROVIDED "AS-IS". CSIRO MAKES NO REPRESENTATIONS, WARRANTIES OR CONDITIONS OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY REPRESENTATIONS, WARRANTIES OR CONDITIONS REGARDING THE CONTENTS OR ACCURACY OF THE SOFTWARE, OR OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, THE ABSENCE OF LATENT OR OTHER DEFECTS, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE.
TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL CSIRO BE LIABLE ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, IN AN ACTION FOR BREACH OF CONTRACT, NEGLIGENCE OR OTHERWISE) FOR ANY CLAIM, LOSS, DAMAGES OR OTHER LIABILITY HOWSOEVER INCURRED. WITHOUT LIMITING THE SCOPE OF THE PREVIOUS SENTENCE THE EXCLUSION OF LIABILITY SHALL INCLUDE: LOSS OF PRODUCTION OR OPERATION TIME, LOSS, DAMAGE OR CORRUPTION OF DATA OR RECORDS; OR LOSS OF ANTICIPATED SAVINGS, OPPORTUNITY, REVENUE, PROFIT OR GOODWILL, OR OTHER ECONOMIC LOSS; OR ANY SPECIAL, INCIDENTAL, INDIRECT, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES, ARISING OUT OF OR IN CONNECTION WITH THIS AGREEMENT, ACCESS OF THE SOFTWARE OR ANY OTHER DEALINGS WITH THE SOFTWARE, EVEN IF CSIRO HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH CLAIM, LOSS, DAMAGES OR OTHER LIABILITY.
APPLICABLE LEGISLATION SUCH AS THE AUSTRALIAN CONSUMER LAW MAY APPLY REPRESENTATIONS, WARRANTIES, OR CONDITIONS, OR IMPOSES OBLIGATIONS OR LIABILITY ON CSIRO THAT CANNOT BE EXCLUDED, RESTRICTED OR MODIFIED TO THE FULL EXTENT SET OUT IN THE EXPRESS TERMS OF THIS CLAUSE ABOVE "CONSUMER GUARANTEES". TO THE EXTENT THAT SUCH CONSUMER GUARANTEES CONTINUE TO APPLY, THEN TO THE FULL EXTENT PERMITTED BY THE APPLICABLE LEGISLATION, THE LIABILITY OF CSIRO UNDER THE RELEVANT CONSUMER GUARANTEE IS LIMITED (WHERE PERMITTED AT CSIRO'S OPTION) TO ONE OF FOLLOWING REMEDIES OR SUBSTANTIALLY EQUIVALENT REMEDIES:
(a) THE REPLACEMENT OF THE SOFTWARE, THE SUPPLY OF EQUIVALENT SOFTWARE, OR SUPPLYING RELEVANT SERVICES AGAIN;
(b) THE REPAIR OF THE SOFTWARE;
(c) THE PAYMENT OF THE COST OF REPLACING THE SOFTWARE, OF ACQUIRING EQUIVALENT SOFTWARE, HAVING THE RELEVANT SERVICES SUPPLIED AGAIN, OR HAVING THE SOFTWARE REPAIRED.
IN THIS CLAUSE, CSIRO INCLUDES ANY THIRD PARTY AUTHOR OR OWNER OF ANY PART OF THE SOFTWARE OR MATERIAL DISTRIBUTED WITH IT. CSIRO MAY ENFORCE ANY RIGHTS ON BEHALF OF THE RELEVANT THIRD PARTY.
Third Party Components
The following third party components are distributed with the Software. You agree to comply with the licence terms for these components as part of accessing the Software. Other third party software may also be identified in separate files distributed with the Software.
The Software is copyright (c) 2022, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230.
CSIRO grants you a licence to the Software on the terms of the BSD 3-Clause Licence.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
The following third party components are distributed with the Software.
___________________________________________________________________
SPDZ-2 [https://github.com/bristolcrypto/SPDZ-2]
Copyright (c) 2018, The University of Bristol

View File

@@ -9,5 +9,7 @@
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/MaliciousRepPrep.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/Beaver.hpp"
#include "Rep.hpp"
#endif /* MACHINES_MALREP_HPP_ */

View File

@@ -4,7 +4,8 @@
*/
#include "Protocols/MalRepRingPrep.h"
#include "Protocols/ReplicatedPrep2k.h"
#include "Protocols/SemiRep3Prep.h"
#include "GC/SemiHonestRepPrep.h"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
@@ -12,6 +13,8 @@
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/Spdz2kPrep.hpp"
#include "Protocols/ReplicatedMC.hpp"
#include "Protocols/Rep3Shuffler.hpp"
#include "Math/Z2k.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/RepPrep.hpp"

View File

@@ -15,8 +15,14 @@
#include "Protocols/DealerMC.hpp"
#include "Protocols/DealerMatrixPrep.hpp"
#include "Protocols/Beaver.hpp"
#include "Semi.hpp"
#include "Protocols/SemiInput.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/SemiMC.hpp"
#include "GC/DealerPrep.h"
#include "GC/SemiPrep.h"
#include "GC/SemiSecret.hpp"
int main(int argc, const char** argv)
{

View File

@@ -17,6 +17,7 @@
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/FakeShare.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/MAC_Check_Base.hpp"
int main(int argc, const char** argv)
{

View File

@@ -7,12 +7,14 @@
#include "GC/ShareParty.hpp"
#include "GC/ShareSecret.hpp"
#include "GC/MaliciousRepSecret.h"
#include "GC/RepPrep.h"
#include "GC/Machine.hpp"
#include "GC/Processor.hpp"
#include "GC/Program.hpp"
#include "GC/Thread.hpp"
#include "GC/ThreadMaster.hpp"
#include "GC/RepPrep.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MaliciousRepMC.hpp"

View File

@@ -9,6 +9,7 @@
#include "Math/gfp.hpp"
#include "Processor/FieldMachine.hpp"
#include "Processor/OfflineMachine.hpp"
#include "Protocols/MascotPrep.hpp"
int main(int argc, const char** argv)
{

View File

@@ -9,6 +9,8 @@
#include "Processor/Machine.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/MalRepRingPrep.hpp"
#include "Protocols/ReplicatedPrep.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Math/gfp.hpp"
#include "Math/Z2k.hpp"

View File

@@ -5,8 +5,11 @@
#include "GC/PostSacriBin.h"
#include "GC/PostSacriSecret.h"
#include "GC/RepPrep.h"
#include "GC/ShareParty.hpp"
#include "GC/RepPrep.hpp"
#include "Protocols/MaliciousRepMC.hpp"
int main(int argc, const char** argv)
{

View File

@@ -7,6 +7,7 @@
#include "BMR/RealProgramParty.hpp"
#include "Machines/SPDZ.hpp"
#include "Protocols/MascotPrep.hpp"
int main(int argc, const char** argv)
{

View File

@@ -4,6 +4,7 @@
*/
#include "GC/ShareParty.h"
#include "GC/SemiHonestRepPrep.h"
#include "GC/ShareParty.hpp"
#include "GC/ShareSecret.hpp"
@@ -12,6 +13,7 @@
#include "GC/Program.hpp"
#include "GC/Thread.hpp"
#include "GC/ThreadMaster.hpp"
#include "GC/RepPrep.hpp"
#include "Processor/Instruction.hpp"
#include "Protocols/MaliciousRepMC.hpp"

View File

@@ -4,7 +4,6 @@
*/
#include "Protocols/Rep3Share2k.h"
#include "Protocols/ReplicatedPrep2k.h"
#include "Processor/RingOptions.h"
#include "Math/Integer.h"
#include "Machines/RepRing.hpp"

View File

@@ -13,10 +13,10 @@
#include "Math/gf2n.h"
#include "Tools/ezOptionParser.h"
#include "GC/MaliciousCcdSecret.h"
#include "GC/SemiHonestRepPrep.h"
#include "Processor/FieldMachine.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/Share.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/SpdzWise.hpp"
@@ -30,6 +30,7 @@
#include "GC/RepPrep.hpp"
#include "GC/ThreadMaster.hpp"
#include "Math/gfp.hpp"
#include "MalRep.hpp"
int main(int argc, const char** argv)
{

View File

@@ -11,10 +11,10 @@
#include "Protocols/MalRepRingPrep.h"
#include "Processor/RingOptions.h"
#include "GC/MaliciousCcdSecret.h"
#include "GC/SemiHonestRepPrep.h"
#include "Processor/RingMachine.hpp"
#include "Protocols/Replicated.hpp"
#include "Protocols/MaliciousRepMC.hpp"
#include "Protocols/Share.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/SpdzWise.hpp"
@@ -32,6 +32,7 @@
#include "GC/ShareSecret.hpp"
#include "GC/RepPrep.hpp"
#include "GC/ThreadMaster.hpp"
#include "MalRep.hpp"
int main(int argc, const char** argv)
{

View File

@@ -12,6 +12,7 @@
#include "Math/gf2n.h"
#include "GC/CcdSecret.h"
#include "GC/MaliciousCcdSecret.h"
#include "GC/SemiHonestRepPrep.h"
#include "Protocols/Share.hpp"
#include "Protocols/SpdzWise.hpp"

View File

@@ -25,6 +25,7 @@
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/Beaver.hpp"
#include "Protocols/MascotPrep.hpp"
#include "Protocols/MalRepRingPrep.hpp"
int main(int argc, const char** argv)
{

View File

@@ -12,7 +12,7 @@ PROCESSOR = $(patsubst %.cpp,%.o,$(wildcard Processor/*.cpp))
FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols/CowGearOptions.o
GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR)
GC_SEMI = GC/SemiPrep.o GC/square64.o
GC_SEMI = GC/SemiPrep.o GC/square64.o GC/Semi.o
OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) $(LIBSIMPLEOT)
OT_EXE = ot.x ot-offline.x
@@ -40,6 +40,17 @@ LIBSIMPLEOT_ASM = deps/SimpleOT/libsimpleot.a
LIBSIMPLEOT += $(LIBSIMPLEOT_ASM)
endif
STATIC_OTE = local/lib/liblibOTe.a
SHARED_OTE = local/lib/liblibOTe.so
ifeq ($(USE_KOS), 0)
ifeq ($(USE_SHARED_OTE), 1)
OT += $(SHARED_OTE) local/lib/libcryptoTools.so
else
OT += $(STATIC_OTE) local/lib/libcryptoTools.a
endif
endif
# used for dependency generation
OBJS = $(BMR) $(FHEOBJS) $(TINYOTOFFLINE) $(YAO) $(COMPLETE) $(patsubst %.cpp,%.o,$(wildcard Machines/*.cpp Utils/*.cpp))
DEPS := $(wildcard */*.d */*/*.d)
@@ -106,6 +117,7 @@ endif
tldr: libote
$(MAKE) mascot-party.x
mkdir Player-Data 2> /dev/null; true
ifeq ($(ARM), 1)
Tools/intrinsics.h: deps/simde/simde
@@ -130,8 +142,8 @@ $(SHAREDLIB): $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o
$(FHEOFFLINE): $(FHEOBJS) $(SHAREDLIB)
$(CXX) $(CFLAGS) -shared -o $@ $^ $(LDLIBS)
static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
static/%.x: Machines/%.o $(LIBRELEASE) $(LIBSIMPLEOT) local/lib/libcryptoTools.a local/lib/liblibOTe.a
$(CXX) -o $@ $(CFLAGS) $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(LIBRELEASE) -llibOTe -lcryptoTools $(LIBSIMPLEOT) $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
static/%.x: ECDSA/%.o ECDSA/P256Element.o $(VMOBJS) $(OT) $(LIBSIMPLEOT)
$(CXX) $(CFLAGS) -o $@ $^ -Wl,-Map=$<.map -Wl,-Bstatic -static-libgcc -static-libstdc++ $(BOOST) $(LDLIBS) -Wl,-Bdynamic -ldl
@@ -201,13 +213,13 @@ replicated-field-party.x: GC/square64.o
brain-party.x: GC/square64.o
malicious-rep-bin-party.x: GC/square64.o
ps-rep-bin-party.x: GC/PostSacriBin.o
semi-bin-party.x: $(OT) GC/SemiPrep.o GC/square64.o
semi-bin-party.x: $(OT) $(GC_SEMI)
tiny-party.x: $(OT)
tinier-party.x: $(OT)
spdz2k-party.x: $(TINIER) $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
static/spdz2k-party.x: $(patsubst %.cpp,%.o,$(wildcard Machines/SPDZ2*.cpp))
semi-party.x: $(OT) GC/SemiPrep.o GC/square64.o
semi2k-party.x: $(OT) GC/SemiPrep.o GC/square64.o
semi-party.x: $(OT) $(GC_SEMI)
semi2k-party.x: $(OT) $(GC_SEMI)
hemi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
temi-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
soho-party.x: $(FHEOFFLINE) $(GC_SEMI) $(OT)
@@ -232,15 +244,15 @@ malicious-rep-ring-party.x: Protocols/MalRepRingOptions.o
sy-rep-ring-party.x: Protocols/MalRepRingOptions.o
rep4-ring-party.x: GC/Rep4Secret.o
no-party.x: Protocols/ShareInterface.o
semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) GC/SemiPrep.o
semi-ecdsa-party.x: $(OT) $(LIBSIMPLEOT) $(GC_SEMI)
mascot-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
fake-spdz-ecdsa-party.x: $(OT) $(LIBSIMPLEOT)
emulate.x: GC/FakeSecret.o
semi-bmr-party.x: GC/SemiPrep.o $(OT)
semi-bmr-party.x: $(GC_SEMI) $(OT)
real-bmr-party.x: $(OT)
paper-example.x: $(VM) $(OT) $(FHEOFFLINE)
binary-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o
mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o GC/SemiPrep.o GC/AtlasSecret.o Machines/Tinier.o
binary-example.x: $(VM) $(OT) GC/PostSacriBin.o $(GC_SEMI) GC/AtlasSecret.o
mixed-example.x: $(VM) $(OT) GC/PostSacriBin.o $(GC_SEMI) GC/AtlasSecret.o Machines/Tinier.o
l2h-example.x: $(VM) $(OT) Machines/Tinier.o
he-example.x: $(FHEOFFLINE)
mascot-offline.x: $(VM) $(TINIER)
@@ -272,14 +284,15 @@ OT/BaseOT.o: deps/SimplestOT_C/ref10/Makefile
deps/SimplestOT_C/ref10/Makefile:
git submodule update --init deps/SimplestOT_C || git clone https://github.com/mkskeller/SimplestOT_C deps/SimplestOT_C
cd deps/SimplestOT_C/ref10; cmake .
cd deps/SimplestOT_C/ref10; PATH=$(CURDIR)/local/bin:$(PATH) cmake .
.PHONY: Programs/Circuits
Programs/Circuits:
git submodule update --init Programs/Circuits
.PHONY: mpir-setup mpir-global mpir
mpir-setup:
.PHONY: mpir-setup mpir-global
mpir-setup: deps/mpir/Makefile
deps/mpir/Makefile:
git submodule update --init deps/mpir || git clone https://github.com/wbhart/mpir deps/mpir
cd deps/mpir; \
autoreconf -i; \
@@ -292,35 +305,45 @@ mpir-global: mpir-setup
$(MAKE) -C deps/mpir
sudo $(MAKE) -C deps/mpir install
mpir: mpir-setup
mpir: local/lib/libmpirxx.so
local/lib/libmpirxx.so: deps/mpir/Makefile
cd deps/mpir; \
./configure --enable-cxx --prefix=$(CURDIR)/local
$(MAKE) -C deps/mpir install
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
-echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine
deps/libOTe/libOTe:
git submodule update --init --recursive deps/libOTe
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
-echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine
git submodule update --init --recursive deps/libOTe || git clone --recurse-submodules https://github.com/mkskeller/softspoken-implementation deps/libOTe
boost: deps/libOTe/libOTe
cd deps/libOTe; \
python3 build.py --setup --boost --install=$(CURDIR)/local
OTE_OPTS = -DENABLE_SOFTSPOKEN_OT=ON -DCMAKE_CXX_COMPILER=$(CXX) -DCMAKE_INSTALL_LIBDIR=lib
ifeq ($(USE_SHARED_OTE), 1)
OTE = $(SHARED_OTE)
else
OTE = $(STATIC_OTE)
endif
libote:
rm $(STATIC_OTE) $(SHARED_OTE)* 2>/dev/null; true
$(MAKE) $(OTE)
local/lib/libcryptoTools.a: $(STATIC_OTE)
local/lib/libcryptoTools.so: $(SHARED_OTE)
OT/OTExtensionWithMatrix.o: $(OTE)
ifeq ($(ARM), 1)
libote: deps/libOTe/libOTe
local/lib/liblibOTe.a: deps/libOTe/libOTe
cd deps/libOTe; \
PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 -DENABLE_AVX=OFF -DENABLE_SSE=OFF $(OTE_OPTS)
else
libote: deps/libOTe/libOTe
local/lib/liblibOTe.a: deps/libOTe/libOTe
cd deps/libOTe; \
PATH="$(CURDIR)/local/bin:$(PATH)" python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 $(OTE_OPTS)
endif
libote-shared: deps/libOTe/libOTe
$(SHARED_OTE): deps/libOTe/libOTe
cd deps/libOTe; \
python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=1 $(OTE_OPTS)

View File

@@ -69,6 +69,8 @@ public:
{
if (n == -1)
pack(os);
else if (n == 1)
os.store_int<1>(this->a & 1);
else
os.store_int(super::mask(n).get(), DIV_CEIL(n, 8));
}
@@ -77,6 +79,8 @@ public:
{
if (n == -1)
unpack(os);
else if (n == 1)
this->a = os.get_int<1>();
else
this->a = os.get_int(DIV_CEIL(n, 8));
}

View File

@@ -4,6 +4,7 @@
*/
#include "Math/Square.h"
#include "Math/BitVec.h"
template<class U>
Square<U>& Square<U>::sub(const Square<U>& other)
@@ -40,6 +41,16 @@ void Square<U>::bit_sub(const BitVector& bits, int start)
}
}
template<>
inline
void Square<BitVec>::bit_sub(const BitVector& bits, int start)
{
for (int i = 0; i < BitVec::length(); i++)
{
rows[i] -= bits.get_portion<BitVec>(start + i);
}
}
template<class U>
void Square<U>::conditional_add(BitVector& conditions,
Square<U>& other, int offset)

View File

@@ -20,10 +20,10 @@
using namespace std;
#ifndef MAX_MOD_SZ
#if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 10
#if defined(GFP_MOD_SZ) and GFP_MOD_SZ > 11
#define MAX_MOD_SZ GFP_MOD_SZ
#else
#define MAX_MOD_SZ 10
#define MAX_MOD_SZ 11
#endif
#endif

View File

@@ -16,7 +16,8 @@ enum Dtype
DATA_BIT,
DATA_INVERSE,
DATA_DABIT,
N_DTYPE
DATA_MIXED,
N_DTYPE,
};
#endif /* MATH_FIELD_TYPES_H_ */

View File

@@ -70,20 +70,6 @@ inline void mpn_add_fixed_n<2>(mp_limb_t* res, const mp_limb_t* x, const mp_limb
);
}
template <>
inline void mpn_add_fixed_n<3>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{
memcpy(res, y, 3 * sizeof(mp_limb_t));
__asm__ (
"add %3, %0 \n"
"adc %4, %1 \n"
"adc %5, %2 \n"
: "+&r"(res[0]), "+&r"(res[1]), "+r"(res[2])
: "rm"(x[0]), "rm"(x[1]), "rm"(x[2])
: "cc"
);
}
template <>
inline void mpn_add_fixed_n<4>(mp_limb_t* res, const mp_limb_t* x, const mp_limb_t* y)
{

View File

@@ -26,7 +26,7 @@ inline void short_memcpy(void* out, void* in, size_t n_bytes)
X(1) X(2) X(3) X(4) X(5) X(6) X(7) X(8)
#undef X
default:
throw invalid_length("length outside range");
throw invalid_length("length outside range: " + to_string(n_bytes));
}
}

View File

@@ -68,7 +68,7 @@ public:
void set_receiver_inputs(const BitVector& new_inputs)
{
if ((int)new_inputs.size() != nOT)
throw invalid_length();
throw invalid_length("BaseOT");
receiver_inputs = new_inputs;
}

View File

@@ -127,6 +127,9 @@ public:
vector< U, aligned_allocator<U, 32> > squares;
typename U::RowType& operator[](int i)
{ return squares[i / U::n_rows()].rows[i % U::n_rows()]; }
size_t vertical_size();
void resize_vertical(int length)

View File

@@ -19,7 +19,7 @@ template <class U>
bool Matrix<U>::operator==(Matrix<U>& other)
{
if (squares.size() != other.squares.size())
throw invalid_length();
throw invalid_length("Matrix");
for (size_t i = 0; i < squares.size(); i++)
if (not(squares[i] == other.squares[i]))
return false;
@@ -109,7 +109,7 @@ template <class U>
Slice<U>& Slice<U>::rsub(Slice<U>& other)
{
if (bm.squares.size() < other.end)
throw invalid_length();
throw invalid_length("rsub");
for (size_t i = other.start; i < other.end; i++)
bm.squares[i].rsub(other.bm.squares[i]);
return *this;

View File

@@ -18,6 +18,8 @@ class MamaRectangle
typename T::Square squares[N];
public:
typedef GC::NoValue RowType;
static int n_rows() { return T::Square::n_rows(); }
static int n_columns() { return T::Square::n_columns(); }
static int n_row_bytes() { return T::Square::n_row_bytes(); }

View File

@@ -6,6 +6,7 @@
#include "Tools/random.h"
#include "Tools/time-func.h"
#include "Processor/InputTuple.h"
#include "Protocols/dabit.h"
#include "OT/OTTripleSetup.h"
#include "OT/MascotParams.h"
@@ -98,7 +99,8 @@ public:
vector<PlainTriple<open_type, N_AMPLIFY>> preampTriples;
vector<array<open_type, 3>> plainTriples;
vector<open_type> plainBits;
vector<dabit<T>> plainBits;
vector<array<open_type, 3>> mixedTriples;
typename T::MAC_Check* MC;
@@ -114,6 +116,7 @@ public:
void plainTripleRound(int k = 0);
void generatePlainBits();
void generateMixedTriples();
void run_multipliers(MultJob job);

View File

@@ -489,7 +489,8 @@ void OTTripleGenerator<T>::generatePlainBits()
machine.set_passive();
machine.output = false;
int n = multiple_minimum(nPreampTriplesPerLoop, T::open_type::size_in_bits());
int n = multiple_minimum(100 * nPreampTriplesPerLoop,
T::open_type::size_in_bits());
valueBits.resize(1);
valueBits[0].resize(n);
@@ -500,16 +501,52 @@ void OTTripleGenerator<T>::generatePlainBits()
wait_for_multipliers();
plainBits.clear();
typename T::open_type two = 2;
for (int j = 0; j < n; j++)
{
if (j % T::open_type::size_in_bits() < T::open_type::length())
{
plainBits.push_back(valueBits[0].get_bit(j));
plainBits.back() += ot_multipliers[0]->c_output[j] * 2;
bool b = valueBits[0].get_bit(j);
plainBits.push_back({b, b});
plainBits.back().first += ot_multipliers[0]->c_output[j] * two;
}
}
}
template<class T>
void OTTripleGenerator<T>::generateMixedTriples()
{
assert(ot_multipliers.size() == 1);
machine.set_passive();
machine.output = false;
int n = multiple_minimum(100 * nPreampTriplesPerLoop,
T::open_type::size_in_bits());
valueBits.resize(2);
valueBits[0].resize(n);
valueBits[0].randomize(share_prg);
valueBits[1].resize(n * T::open_type::N_BITS);
valueBits[1].randomize(share_prg);
signal_multipliers(DATA_MIXED);
wait_for_multipliers();
mixedTriples.clear();
for (int j = 0; j < n; j++)
{
auto a = valueBits[0].get_bit(j);
auto b = valueBits[1].template get_portion<typename T::open_type>(j);
auto c = a ? b : typename T::open_type();
for (auto& x : ot_multipliers)
c += x->c_output[j];
mixedTriples.push_back({{a, b, c}});
}
}
template<class U>
void OTTripleGenerator<U>::plainTripleRound(int k)
{

View File

@@ -188,7 +188,7 @@ template <class T>
void OTCorrelator<U>::reduce_squares(unsigned int nTriples, vector<T>& output, int start)
{
if (receiverOutputMatrix.squares.size() < nTriples + start)
throw invalid_length();
throw invalid_length("reduce_squares");
output.resize(nTriples);
for (unsigned int j = 0; j < nTriples; j++)
{

View File

@@ -9,7 +9,10 @@
#ifndef USE_KOS
#include "Networking/PlayerCtSocket.h"
osuCrypto::IOService OTExtensionWithMatrix::ios;
#include <libOTe/TwoChooseOne/SoftSpokenOT/TwoOneMalicious.h>
#include <cryptoTools/Network/IOService.h>
osuCrypto::IOService ot_extension_ios;
#endif
#include "OTCorrelator.hpp"
@@ -112,7 +115,7 @@ void OTExtensionWithMatrix::extend(int nOTs_requested, const BitVector& newRecei
resize(nOTs_requested);
if (not channel)
channel = new osuCrypto::Channel(ios, new PlayerCtSocket(*player));
channel = new osuCrypto::Channel(ot_extension_ios, new PlayerCtSocket(*player));
if (player->my_num())
{

View File

@@ -11,8 +11,9 @@
#include "Math/gf2n.h"
#ifndef USE_KOS
#include <libOTe/TwoChooseOne/SoftSpokenOT/TwoOneMalicious.h>
#include <cryptoTools/Network/IOService.h>
namespace osuCrypto {
class Channel;
}
#endif
template <class U>
@@ -57,7 +58,6 @@ class OTExtensionWithMatrix : public OTCorrelator<BitMatrix>
int nsubloops;
#ifndef USE_KOS
static osuCrypto::IOService ios;
osuCrypto::Channel* channel;
#endif

View File

@@ -59,6 +59,7 @@ protected:
void multiplyForTriples();
virtual void multiplyForBits();
virtual void multiplyForMixed();
virtual void multiplyForInputs(MultJob job) = 0;
virtual void after_correlation() = 0;
@@ -174,6 +175,7 @@ class SemiMultiplier : public OTMultiplier<T>
}
void multiplyForBits();
void multiplyForMixed();
void after_correlation();

View File

@@ -128,6 +128,9 @@ void OTMultiplier<T>::multiply()
case DATA_TRIPLE:
multiplyForTriples();
break;
case DATA_MIXED:
multiplyForMixed();
break;
default:
throw not_implemented();
}
@@ -188,6 +191,55 @@ void SemiMultiplier<T>::multiplyForBits()
this->outbox.push({});
}
template<class T>
void SemiMultiplier<T>::multiplyForMixed()
{
auto& rot_ext = this->rot_ext;
typedef Square<BitVec> X;
OTCorrelator<Matrix<X>> otCorrelator(
this->generator.players[this->thread_num], BOTH, true);
BitVector aBits = this->generator.valueBits[0];
rot_ext.extend_correlated(aBits);
auto& baseSenderOutputs = otCorrelator.matrices;
auto& baseReceiverOutput = otCorrelator.senderOutputMatrices[0];
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
if (this->generator.get_player().num_players() == 2)
{
c_output.clear();
for (unsigned j = 0; j < aBits.size(); j++)
{
this->generator.valueBits[1].set_portion(j,
BitVec(baseSenderOutputs[0][j] ^ baseSenderOutputs[1][j]));
c_output.push_back(baseReceiverOutput[j] ^ baseSenderOutputs[0][j]);
}
this->outbox.push({});
return;
}
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
baseReceiverOutput);
otCorrelator.correlate(0, otCorrelator.receiverOutputMatrix.squares.size(),
this->generator.valueBits[1], false, -1);
c_output.clear();
for (unsigned j = 0; j < aBits.size(); j++)
{
c_output.push_back(
otCorrelator.receiverOutputMatrix[j]
^ otCorrelator.senderOutputMatrices[0][j]);
}
this->outbox.push({});
}
template<class W>
void OTMultiplier<W>::multiplyForTriples()
{
@@ -592,3 +644,9 @@ void OTMultiplier<T>::multiplyForBits()
{
throw runtime_error("bit generation not implemented in this case");
}
template<class T>
void OTMultiplier<T>::multiplyForMixed()
{
throw runtime_error("mixed generation not implemented in this case");
}

View File

@@ -67,6 +67,14 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
string threadname;
for (int i=0; i<nprogs; i++)
{ inpf >> threadname;
size_t split = threadname.find(":");
long expected = -1;
if (split != string::npos)
{
expected = atoi(threadname.substr(split + 1).c_str());
threadname = threadname.substr(0, split);
}
string filename = "Programs/Bytecode/" + threadname + ".bc";
bc_filenames.push_back(filename);
if (load_bytecode)
@@ -74,8 +82,11 @@ void BaseMachine::load_schedule(const string& progname, bool load_bytecode)
#ifdef DEBUG_FILES
cerr << "Loading program " << i << " from " << filename << endl;
#endif
load_program(threadname, filename);
long size = load_program(threadname, filename);
if (expected >= 0 and expected != size)
throw runtime_error("broken bytecode file");
}
}
for (auto i : {1, 0, 0})
@@ -99,7 +110,8 @@ void BaseMachine::print_compiler()
cerr << "Compiler: " << compiler << endl;
}
void BaseMachine::load_program(const string& threadname, const string& filename)
size_t BaseMachine::load_program(const string& threadname,
const string& filename)
{
(void)threadname;
(void)filename;

View File

@@ -31,7 +31,8 @@ protected:
string domain;
string relevant_opts;
virtual void load_program(const string& threadname, const string& filename);
virtual size_t load_program(const string& threadname,
const string& filename);
public:
static thread_local int thread_num;

View File

@@ -7,8 +7,7 @@
#include "Protocols/dabit.h"
#include "Math/Setup.h"
#include "GC/BitPrepFiles.h"
#include "Protocols/MascotPrep.hpp"
#include "Tools/benchmarking.h"
template<class T>
Preprocessing<T>* Preprocessing<T>::get_live_prep(SubProcessor<T>* proc,
@@ -44,6 +43,20 @@ Preprocessing<T>* Preprocessing<T>::get_new(
BaseMachine::thread_num);
}
template<class T>
T Preprocessing<T>::get_random_from_inputs(int nplayers)
{
T res;
for (int j = 0; j < nplayers; j++)
{
T tmp;
typename T::open_type _;
this->get_input_no_count(tmp, _, j);
res += tmp;
}
return res;
}
template<class T>
Sub_Data_Files<T>::Sub_Data_Files(const Names& N, DataPositions& usage,
int thread_num) :

View File

@@ -84,6 +84,7 @@ enum
SUBSI = 0x2A,
SUBCFI = 0x2B,
SUBSFI = 0x2C,
PREFIXSUMS = 0x2D,
// Multiplication/division/other arithmetic
MULC = 0x30,
MULM = 0x31,

Some files were not shown because too many files have changed in this diff Show More