mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
2265 lines
86 KiB
Python
2265 lines
86 KiB
Python
"""
|
|
This module contains machine learning functionality. It is work in
|
|
progress, so you must expect things to change. The only tested
|
|
functionality for training is using consective layers.
|
|
This includes logistic regression. It can be run as
|
|
follows::
|
|
|
|
sgd = ml.SGD([ml.Dense(n_examples, n_features, 1),
|
|
ml.Output(n_examples, approx=True)], n_epochs,
|
|
report_loss=True)
|
|
sgd.layers[0].X.input_from(0)
|
|
sgd.layers[1].Y.input_from(1)
|
|
sgd.reset()
|
|
sgd.run()
|
|
|
|
This loads measurements from party 0 and labels (0/1) from party
|
|
1. After running, the model is stored in :py:obj:`sgd.layers[0].W` and
|
|
:py:obj:`sgd.layers[1].b`. The :py:obj:`approx` parameter determines
|
|
whether to use an approximate sigmoid function. Setting it to 5 uses
|
|
a five-piece approximation instead of a three-piece one.
|
|
|
|
A simple network for MNIST using two dense layers can be trained as
|
|
follows::
|
|
|
|
sgd = ml.SGD([ml.Dense(60000, 784, 128, activation='relu'),
|
|
ml.Dense(60000, 128, 10),
|
|
ml.MultiOutput(60000, 10)], n_epochs,
|
|
report_loss=True)
|
|
sgd.layers[0].X.input_from(0)
|
|
sgd.layers[1].Y.input_from(1)
|
|
sgd.reset()
|
|
sgd.run()
|
|
|
|
See `this repository <https://github.com/csiro-mlai/mnist-mpc>`_
|
|
for scripts importing MNIST training data and further examples.
|
|
|
|
Inference can be run as follows::
|
|
|
|
data = sfix.Matrix(n_test, n_features)
|
|
data.input_from(0)
|
|
res = sgd.eval(data)
|
|
print_ln('Results: %s', [x.reveal() for x in res])
|
|
|
|
For inference/classification, this module offers the layers necessary
|
|
for neural networks such as DenseNet, ResNet, and SqueezeNet. A
|
|
minimal example using input from player 0 and model from player 1
|
|
looks as follows::
|
|
|
|
graph = Optimizer()
|
|
graph.layers = layers
|
|
layers[0].X.input_from(0)
|
|
for layer in layers:
|
|
layer.input_from(1)
|
|
graph.forward(1)
|
|
res = layers[-1].Y
|
|
|
|
See the `readme <https://github.com/data61/MP-SPDZ/#tensorflow-inference>`_ for
|
|
an example of how to run MP-SPDZ on TensorFlow graphs.
|
|
"""
|
|
|
|
import math
|
|
import re
|
|
|
|
from Compiler import mpc_math, util
|
|
from Compiler.types import *
|
|
from Compiler.types import _unreduced_squant
|
|
from Compiler.library import *
|
|
from Compiler.util import is_zero, tree_reduce
|
|
from Compiler.comparison import CarryOutRawLE
|
|
from Compiler.GC.types import sbitint
|
|
from functools import reduce
|
|
|
|
def log_e(x):
|
|
return mpc_math.log_fx(x, math.e)
|
|
|
|
def exp(x):
|
|
return mpc_math.pow_fx(math.e, x)
|
|
|
|
def get_limit(x):
|
|
exp_limit = 2 ** (x.k - x.f - 1)
|
|
return math.log(exp_limit)
|
|
|
|
def sanitize(x, raw, lower, upper):
|
|
limit = get_limit(x)
|
|
res = (x > limit).if_else(upper, raw)
|
|
return (x < -limit).if_else(lower, res)
|
|
|
|
def sigmoid(x):
|
|
""" Sigmoid function.
|
|
|
|
:param x: sfix """
|
|
return sigmoid_from_e_x(x, exp(-x))
|
|
|
|
def sigmoid_from_e_x(x, e_x):
|
|
return sanitize(x, 1 / (1 + e_x), 0, 1)
|
|
|
|
def sigmoid_prime(x):
|
|
""" Sigmoid derivative.
|
|
|
|
:param x: sfix """
|
|
sx = sigmoid(x)
|
|
return sx * (1 - sx)
|
|
|
|
@vectorize
|
|
def approx_sigmoid(x, n=3):
|
|
""" Piece-wise approximate sigmoid as in
|
|
`Dahl et al. <https://arxiv.org/abs/1810.08130>`_
|
|
|
|
:param x: input
|
|
:param n: number of pieces, 3 (default) or 5
|
|
"""
|
|
if n == 5:
|
|
cuts = [-5, -2.5, 2.5, 5]
|
|
le = [0] + [x <= cut for cut in cuts] + [1]
|
|
select = [le[i + 1] - le[i] for i in range(5)]
|
|
outputs = [cfix(10 ** -4),
|
|
0.02776 * x + 0.145,
|
|
0.17 * x + 0.5,
|
|
0.02776 * x + 0.85498,
|
|
cfix(1 - 10 ** -4)]
|
|
return sum(a * b for a, b in zip(select, outputs))
|
|
else:
|
|
a = x < -0.5
|
|
b = x > 0.5
|
|
return a.if_else(0, b.if_else(1, 0.5 + x))
|
|
|
|
def lse_0_from_e_x(x, e_x):
|
|
return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0)
|
|
|
|
def lse_0(x):
|
|
return lse_0_from_e_x(x, exp(x))
|
|
|
|
def approx_lse_0(x, n=3):
|
|
assert n != 5
|
|
a = x < -0.5
|
|
b = x > 0.5
|
|
return a.if_else(0, b.if_else(x, 0.5 * (x + 0.5) ** 2)) - x
|
|
|
|
def relu_prime(x):
|
|
""" ReLU derivative. """
|
|
return (0 <= x)
|
|
|
|
def relu(x):
|
|
""" ReLU function (maximum of input and zero). """
|
|
return (0 < x).if_else(x, 0)
|
|
|
|
def argmax(x):
|
|
""" Compute index of maximum element.
|
|
|
|
:param x: iterable
|
|
:returns: sint
|
|
"""
|
|
def op(a, b):
|
|
comp = (a[1] > b[1])
|
|
return comp.if_else(a[0], b[0]), comp.if_else(a[1], b[1])
|
|
return tree_reduce(op, enumerate(x))[0]
|
|
|
|
report_progress = False
|
|
|
|
def progress(x):
|
|
if report_progress:
|
|
print_ln(x)
|
|
time()
|
|
|
|
def set_n_threads(n_threads):
|
|
Layer.n_threads = n_threads
|
|
Optimizer.n_threads = n_threads
|
|
|
|
def _no_mem_warnings(function):
|
|
def wrapper(*args, **kwargs):
|
|
get_program().warn_about_mem.append(False)
|
|
res = function(*args, **kwargs)
|
|
get_program().warn_about_mem.pop()
|
|
return res
|
|
return wrapper
|
|
|
|
class Tensor(MultiArray):
|
|
def __init__(self, *args, **kwargs):
|
|
kwargs['alloc'] = False
|
|
super(Tensor, self).__init__(*args, **kwargs)
|
|
|
|
def input_from(self, *args, **kwargs):
|
|
self.alloc()
|
|
super(Tensor, self).input_from(*args, **kwargs)
|
|
|
|
def __getitem__(self, *args):
|
|
self.alloc()
|
|
return super(Tensor, self).__getitem__(*args)
|
|
|
|
def assign_vector(self, *args):
|
|
self.alloc()
|
|
return super(Tensor, self).assign_vector(*args)
|
|
|
|
def assign_vector_by_indices(self, *args):
|
|
self.alloc()
|
|
return super(Tensor, self).assign_vector_by_indices(*args)
|
|
|
|
class Layer:
|
|
n_threads = 1
|
|
inputs = []
|
|
input_bias = True
|
|
thetas = lambda self: ()
|
|
debug_output = False
|
|
back_batch_size = 128
|
|
|
|
@property
|
|
def shape(self):
|
|
return list(self._Y.sizes)
|
|
|
|
@property
|
|
def X(self):
|
|
self._X.alloc()
|
|
return self._X
|
|
|
|
@X.setter
|
|
def X(self, value):
|
|
self._X = value
|
|
|
|
@property
|
|
def Y(self):
|
|
self._Y.alloc()
|
|
return self._Y
|
|
|
|
@Y.setter
|
|
def Y(self, value):
|
|
self._Y = value
|
|
|
|
def forward(self, batch=None, training=None):
|
|
if batch is None:
|
|
batch = Array.create_from(regint(0))
|
|
self._forward(batch)
|
|
|
|
def __str__(self):
|
|
return type(self).__name__ + str(self._Y.sizes)
|
|
|
|
class NoVariableLayer(Layer):
|
|
input_from = lambda *args, **kwargs: None
|
|
output_weights = lambda *args: None
|
|
|
|
nablas = lambda self: ()
|
|
reset = lambda self: None
|
|
|
|
class Output(NoVariableLayer):
|
|
""" Fixed-point logistic regression output layer.
|
|
|
|
:param N: number of examples
|
|
:param approx: :py:obj:`False` (default) or parameter for :py:obj:`approx_sigmoid`
|
|
"""
|
|
n_outputs = 2
|
|
|
|
@classmethod
|
|
def from_args(cls, N, program):
|
|
res = cls(N, approx='approx' in program.args)
|
|
res.compute_loss = not 'no_loss' in program.args
|
|
return res
|
|
|
|
def __init__(self, N, debug=False, approx=False):
|
|
self.N = N
|
|
self.X = sfix.Array(N)
|
|
self.Y = sfix.Array(N)
|
|
self.nabla_X = sfix.Array(N)
|
|
self.l = MemValue(sfix(-1))
|
|
self.e_x = sfix.Array(N)
|
|
self.debug = debug
|
|
self.weights = None
|
|
self.approx = approx
|
|
self.compute_loss = True
|
|
|
|
def divisor(self, divisor, size):
|
|
return cfix(1.0 / divisor, size=size)
|
|
|
|
def _forward(self, batch):
|
|
if self.approx == 5:
|
|
self.l.write(999)
|
|
return
|
|
N = len(batch)
|
|
lse = sfix.Array(N)
|
|
@multithread(self.n_threads, N)
|
|
def _(base, size):
|
|
x = self.X.get_vector(base, size)
|
|
y = self.Y.get(batch.get_vector(base, size))
|
|
if self.approx:
|
|
if self.compute_loss:
|
|
lse.assign(approx_lse_0(x, self.approx) + x * (1 - y), base)
|
|
return
|
|
e_x = exp(-x)
|
|
self.e_x.assign(e_x, base)
|
|
if self.compute_loss:
|
|
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
|
|
self.l.write(sum(lse) * \
|
|
self.divisor(N, 1))
|
|
|
|
def eval(self, size, base=0):
|
|
if self.approx:
|
|
return approx_sigmoid(self.X.get_vector(base, size), self.approx)
|
|
else:
|
|
return sigmoid_from_e_x(self.X.get_vector(base, size),
|
|
self.e_x.get_vector(base, size))
|
|
|
|
def backward(self, batch):
|
|
N = len(batch)
|
|
@multithread(self.n_threads, N)
|
|
def _(base, size):
|
|
diff = self.eval(size, base) - \
|
|
self.Y.get(batch.get_vector(base, size))
|
|
assert sfix.f == cfix.f
|
|
if self.weights is not None:
|
|
assert N == len(self.weights)
|
|
diff *= self.weights.get_vector(base, size)
|
|
assert self.weight_total == N
|
|
self.nabla_X.assign(diff, base)
|
|
# @for_range_opt(len(diff))
|
|
# def _(i):
|
|
# self.nabla_X[i] = self.nabla_X[i] * self.weights[i]
|
|
if self.debug_output:
|
|
print_ln('sigmoid X %s', self.X.reveal_nested())
|
|
print_ln('sigmoid nabla %s', self.nabla_X.reveal_nested())
|
|
print_ln('batch %s', batch.reveal_nested())
|
|
|
|
def set_weights(self, weights):
|
|
self.weights = cfix.Array(len(weights))
|
|
self.weights.assign(weights)
|
|
self.weight_total = sum(weights)
|
|
|
|
def average_loss(self, N):
|
|
return self.l.reveal()
|
|
|
|
def reveal_correctness(self, n=None, Y=None, debug=False):
|
|
if n is None:
|
|
n = self.X.sizes[0]
|
|
if Y is None:
|
|
Y = self.Y
|
|
n_correct = MemValue(0)
|
|
n_printed = MemValue(0)
|
|
@for_range_opt(n)
|
|
def _(i):
|
|
truth = Y[i].reveal()
|
|
b = self.X[i].reveal()
|
|
if debug:
|
|
nabla = self.nabla_X[i].reveal()
|
|
guess = b > 0
|
|
correct = truth == guess
|
|
n_correct.iadd(correct)
|
|
if debug:
|
|
to_print = (1 - correct) * (n_printed < 10)
|
|
n_printed.iadd(to_print)
|
|
print_ln_if(to_print, '%s: %s %s %s %s',
|
|
i, truth, guess, b, nabla)
|
|
return n_correct
|
|
|
|
class MultiOutputBase(NoVariableLayer):
|
|
def __init__(self, N, d_out, approx=False, debug=False):
|
|
self.X = sfix.Matrix(N, d_out)
|
|
self.Y = sint.Matrix(N, d_out)
|
|
self.nabla_X = sfix.Matrix(N, d_out)
|
|
self.l = MemValue(sfix(-1))
|
|
self.losses = sfix.Array(N)
|
|
self.approx = None
|
|
self.N = N
|
|
self.d_out = d_out
|
|
self.compute_loss = True
|
|
|
|
def eval(self, N):
|
|
d_out = self.X.sizes[1]
|
|
res = sfix.Matrix(N, d_out)
|
|
res.assign_vector(self.X.get_part_vector(0, N))
|
|
return res
|
|
|
|
def average_loss(self, N):
|
|
return sum(self.losses.get_vector(0, N)).reveal() / N
|
|
|
|
def reveal_correctness(self, n=None, Y=None, debug=False):
|
|
if n is None:
|
|
n = self.X.sizes[0]
|
|
if Y is None:
|
|
Y = self.Y
|
|
n_printed = MemValue(0)
|
|
assert n <= len(self.X)
|
|
assert n <= len(Y)
|
|
Y.address = MemValue.if_necessary(Y.address)
|
|
@map_sum(None if debug else self.n_threads, None, n, 1, regint)
|
|
def _(i):
|
|
a = Y[i].reveal_list()
|
|
b = self.X[i].reveal_list()
|
|
if debug:
|
|
loss = self.losses[i].reveal()
|
|
exp = self.get_extra_debugging(i)
|
|
nabla = self.nabla_X[i].reveal_list()
|
|
truth = argmax(a)
|
|
guess = argmax(b)
|
|
correct = truth == guess
|
|
if debug:
|
|
to_print = (1 - correct) * (n_printed < 10)
|
|
n_printed.iadd(to_print)
|
|
print_ln_if(to_print, '%s: %s %s %s %s %s %s',
|
|
i, truth, guess, loss, b, exp, nabla)
|
|
return correct
|
|
return _()
|
|
|
|
@property
|
|
def n_outputs(self):
|
|
return self.d_out
|
|
|
|
def get_extra_debugging(self, i):
|
|
return ''
|
|
|
|
@staticmethod
|
|
def from_args(program, N, n_output):
|
|
if 'relu_out' in program.args:
|
|
res = ReluMultiOutput(N, n_output)
|
|
else:
|
|
res = MultiOutput(N, n_output, approx='approx' in program.args)
|
|
res.cheaper_loss = 'mse' in program.args
|
|
res.compute_loss = not 'no_loss' in program.args
|
|
for arg in program.args:
|
|
m = re.match('approx=(.*)', arg)
|
|
if m:
|
|
res.approx = float(m.group(1))
|
|
return res
|
|
|
|
class MultiOutput(MultiOutputBase):
|
|
"""
|
|
Output layer for multi-class classification with softmax and cross entropy.
|
|
|
|
:param N: number of examples
|
|
:param d_out: number of classes
|
|
:param approx: use ReLU division instead of softmax for the loss
|
|
"""
|
|
def __init__(self, N, d_out, approx=False, debug=False):
|
|
MultiOutputBase.__init__(self, N, d_out)
|
|
self.exp = sfix.Matrix(N, d_out)
|
|
self.approx = approx
|
|
self.positives = sint.Matrix(N, d_out)
|
|
self.relus = sfix.Matrix(N, d_out)
|
|
self.cheaper_loss = False
|
|
self.debug = debug
|
|
self.true_X = sfix.Array(N)
|
|
|
|
def _forward(self, batch):
|
|
N = len(batch)
|
|
d_out = self.X.sizes[1]
|
|
tmp = self.losses
|
|
@for_range_opt_multithread(self.n_threads, N)
|
|
def _(i):
|
|
if self.approx:
|
|
if self.cheaper_loss or isinstance(self.approx, float):
|
|
limit = 0
|
|
else:
|
|
limit = 0.1
|
|
positives = self.X[i].get_vector() > limit
|
|
relus = positives.if_else(self.X[i].get_vector(), 0)
|
|
self.positives[i].assign_vector(positives)
|
|
self.relus[i].assign_vector(relus)
|
|
if self.compute_loss:
|
|
if self.cheaper_loss:
|
|
s = sum(relus)
|
|
tmp[i] = sum((self.Y[batch[i]][j] * s - relus[j]) ** 2
|
|
for j in range(d_out)) / s ** 2 * 0.5
|
|
else:
|
|
div = relus / sum(relus).expand_to_vector(d_out)
|
|
self.losses[i] = -sfix.dot_product(
|
|
self.Y[batch[i]].get_vector(), log_e(div))
|
|
else:
|
|
m = util.max(self.X[i])
|
|
mv = m.expand_to_vector(d_out)
|
|
x = self.X[i].get_vector()
|
|
e = (x - mv > -get_limit(x)).if_else(exp(x - mv), 0)
|
|
self.exp[i].assign_vector(e)
|
|
if self.compute_loss:
|
|
true_X = sfix.dot_product(self.Y[batch[i]], self.X[i])
|
|
tmp[i] = m + log_e(sum(e)) - true_X
|
|
self.true_X[i] = true_X
|
|
self.l.write(sum(tmp.get_vector(0, N)) / N)
|
|
|
|
def eval(self, N):
|
|
d_out = self.X.sizes[1]
|
|
res = sfix.Matrix(N, d_out)
|
|
if self.approx:
|
|
@for_range_opt_multithread(self.n_threads, N)
|
|
def _(i):
|
|
relus = (self.X[i].get_vector() > 0).if_else(
|
|
self.X[i].get_vector(), 0)
|
|
res[i].assign_vector(relus / sum(relus).expand_to_vector(d_out))
|
|
return res
|
|
@for_range_opt_multithread(self.n_threads, N)
|
|
def _(i):
|
|
x = self.X[i].get_vector() - \
|
|
util.max(self.X[i].get_vector()).expand_to_vector(d_out)
|
|
e = exp(x)
|
|
res[i].assign_vector(e / sum(e).expand_to_vector(d_out))
|
|
return res
|
|
|
|
def backward(self, batch):
|
|
d_out = self.X.sizes[1]
|
|
if self.approx:
|
|
@for_range_opt_multithread(self.n_threads, len(batch))
|
|
def _(i):
|
|
if self.cheaper_loss:
|
|
s = sum(self.relus[i])
|
|
ss = s * s * s
|
|
inv = 1 / ss
|
|
@for_range_opt(d_out)
|
|
def _(j):
|
|
res = 0
|
|
for k in range(d_out):
|
|
relu = self.relus[i][k]
|
|
summand = relu - self.Y[batch[i]][k] * s
|
|
summand *= (sfix.from_sint(j == k) - relu)
|
|
res += summand
|
|
fallback = -self.Y[batch[i]][j]
|
|
res *= inv
|
|
self.nabla_X[i][j] = self.positives[i][j].if_else(res, fallback)
|
|
return
|
|
relus = self.relus[i].get_vector()
|
|
if isinstance(self.approx, float):
|
|
relus += self.approx
|
|
positives = self.positives[i].get_vector()
|
|
inv = (1 / sum(relus)).expand_to_vector(d_out)
|
|
truths = self.Y[batch[i]].get_vector()
|
|
raw = truths / relus - inv
|
|
self.nabla_X[i] = -positives.if_else(raw, truths)
|
|
self.maybe_debug_backward(batch)
|
|
return
|
|
@for_range_opt_multithread(self.n_threads, len(batch))
|
|
def _(i):
|
|
for j in range(d_out):
|
|
dividend = self.exp[i][j]
|
|
divisor = sum(self.exp[i])
|
|
div = (divisor > 0.1).if_else(dividend / divisor, 0)
|
|
self.nabla_X[i][j] = (-self.Y[batch[i]][j] + div)
|
|
self.maybe_debug_backward(batch)
|
|
|
|
def maybe_debug_backward(self, batch):
|
|
if self.debug:
|
|
@for_range(len(batch))
|
|
def _(i):
|
|
check = 0
|
|
for j in range(self.X.sizes[1]):
|
|
to_check = self.nabla_X[i][j].reveal()
|
|
check += (to_check > len(batch)) + (to_check < -len(batch))
|
|
print_ln_if(check, 'X %s', self.X[i].reveal_nested())
|
|
print_ln_if(check, 'exp %s', self.exp[i].reveal_nested())
|
|
print_ln_if(check, 'nabla X %s',
|
|
self.nabla_X[i].reveal_nested())
|
|
|
|
def get_extra_debugging(self, i):
|
|
if self.approx:
|
|
return self.relus[i].reveal_list()
|
|
else:
|
|
return self.exp[i].reveal_list()
|
|
|
|
class ReluMultiOutput(MultiOutputBase):
|
|
"""
|
|
Output layer for multi-class classification with back-propagation
|
|
based on ReLU division.
|
|
|
|
:param N: number of examples
|
|
:param d_out: number of classes
|
|
"""
|
|
def forward(self, batch, training=None):
|
|
self.l.write(999)
|
|
|
|
def backward(self, batch):
|
|
N = len(batch)
|
|
d_out = self.X.sizes[1]
|
|
relus = sfix.Matrix(N, d_out)
|
|
@for_range_opt_multithread(self.n_threads, len(batch))
|
|
def _(i):
|
|
positives = self.X[i].get_vector() > 0
|
|
relus = positives.if_else(self.X[i].get_vector(), 0)
|
|
s = sum(relus)
|
|
inv = 1 / s
|
|
prod = relus * inv
|
|
res = prod - self.Y[batch[i]].get_vector()
|
|
self.nabla_X[i].assign_vector(res)
|
|
|
|
class DenseBase(Layer):
|
|
thetas = lambda self: (self.W, self.b)
|
|
nablas = lambda self: (self.nabla_W, self.nabla_b)
|
|
|
|
def output_weights(self):
|
|
print_ln('%s', self.W.reveal_nested())
|
|
print_ln('%s', self.b.reveal_nested())
|
|
|
|
def backward_params(self, f_schur_Y, batch):
|
|
N = len(batch)
|
|
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
|
|
|
@multithread(self.n_threads, self.d_in)
|
|
def _(base, size):
|
|
A = sfix.Matrix(self.N, self.d_out, address=f_schur_Y.address)
|
|
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
|
mp = B.direct_trans_mul(A, reduce=False,
|
|
indices=(regint.inc(size, base),
|
|
batch.get_vector(),
|
|
regint.inc(N),
|
|
regint.inc(self.d_out)))
|
|
tmp.assign_part_vector(mp, base)
|
|
|
|
progress('nabla W (matmul)')
|
|
|
|
if self.d_in * self.d_out < 200000:
|
|
print('reduce at once')
|
|
@multithread(self.n_threads, self.d_in * self.d_out)
|
|
def _(base, size):
|
|
self.nabla_W.assign_vector(
|
|
tmp.get_vector(base, size).reduce_after_mul(), base=base)
|
|
else:
|
|
@for_range_opt(self.d_in)
|
|
def _(i):
|
|
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul()
|
|
|
|
progress('nabla W')
|
|
|
|
self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector()
|
|
for k in range(N))
|
|
for j in range(self.d)))
|
|
|
|
progress('nabla b')
|
|
|
|
if self.debug_output:
|
|
print_ln('dense nabla Y %s', self.nabla_Y.reveal_nested())
|
|
print_ln('dense W %s', self.W.reveal_nested())
|
|
print_ln('dense nabla X %s', self.nabla_X.reveal_nested())
|
|
if self.debug:
|
|
limit = N * self.debug
|
|
@for_range_opt(self.d_in)
|
|
def _(i):
|
|
@for_range_opt(self.d_out)
|
|
def _(j):
|
|
to_check = self.nabla_W[i][j].reveal()
|
|
check = sum(to_check > limit) + sum(to_check < -limit)
|
|
@if_(check)
|
|
def _():
|
|
print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check)
|
|
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
|
|
for k in range(N)])
|
|
print_ln('X %s', [self.X[k][0][i].reveal()
|
|
for k in range(N)])
|
|
@for_range_opt(self.d_out)
|
|
def _(j):
|
|
to_check = self.nabla_b[j].reveal()
|
|
check = sum(to_check > limit) + sum(to_check < -limit)
|
|
@if_(check)
|
|
def _():
|
|
print_ln('nabla b %s %s: %s', j, len(self.b), to_check)
|
|
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
|
|
for k in range(N)])
|
|
@for_range_opt(len(batch))
|
|
def _(i):
|
|
to_check = self.nabla_X[i].get_vector().reveal()
|
|
check = sum(to_check > limit) + sum(to_check < -limit)
|
|
@if_(check)
|
|
def _():
|
|
print_ln('X %s %s', i, self.X[i].reveal_nested())
|
|
print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested())
|
|
|
|
class Dense(DenseBase):
|
|
""" Fixed-point dense (matrix multiplication) layer.
|
|
|
|
:param N: number of examples
|
|
:param d_in: input dimension
|
|
:param d_out: output dimension
|
|
"""
|
|
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
|
|
if activation == 'id':
|
|
self.activation_layer = None
|
|
elif activation == 'relu':
|
|
self.activation_layer = Relu([N, d, d_out])
|
|
elif activation == 'square':
|
|
self.activation_layer = Square([N, d, d_out])
|
|
else:
|
|
raise CompilerError('activation not supported: %s', activation)
|
|
|
|
self.N = N
|
|
self.d_in = d_in
|
|
self.d_out = d_out
|
|
self.d = d
|
|
|
|
self.X = MultiArray([N, d, d_in], sfix)
|
|
self.Y = MultiArray([N, d, d_out], sfix)
|
|
self.W = Tensor([d_in, d_out], sfix)
|
|
self.b = sfix.Array(d_out)
|
|
|
|
back_N = min(N, self.back_batch_size)
|
|
self.nabla_Y = MultiArray([back_N, d, d_out], sfix)
|
|
self.nabla_X = MultiArray([back_N, d, d_in], sfix)
|
|
self.nabla_W = sfix.Matrix(d_in, d_out)
|
|
self.nabla_b = sfix.Array(d_out)
|
|
|
|
self.debug = debug
|
|
|
|
l = self.activation_layer
|
|
if l:
|
|
self.f_input = l.X
|
|
l.Y = self.Y
|
|
l.nabla_Y = self.nabla_Y
|
|
else:
|
|
self.f_input = self.Y
|
|
|
|
def reset(self):
|
|
d_in = self.d_in
|
|
d_out = self.d_out
|
|
r = math.sqrt(6.0 / (d_in + d_out))
|
|
print('Initializing dense weights in [%f,%f]' % (-r, r))
|
|
self.W.assign_vector(sfix.get_random(-r, r, size=self.W.total_size()))
|
|
self.b.assign_all(0)
|
|
|
|
def input_from(self, player, raw=False):
|
|
self.W.input_from(player, raw=raw)
|
|
if self.input_bias:
|
|
self.b.input_from(player, raw=raw)
|
|
|
|
def compute_f_input(self, batch):
|
|
N = len(batch)
|
|
assert self.d == 1
|
|
if self.input_bias:
|
|
prod = MultiArray([N, self.d, self.d_out], sfix)
|
|
else:
|
|
prod = self.f_input
|
|
max_size = program.Program.prog.budget // self.d_out
|
|
@multithread(self.n_threads, N, max_size)
|
|
def _(base, size):
|
|
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
|
prod.assign_part_vector(
|
|
X_sub.direct_mul(self.W, indices=(
|
|
batch.get_vector(base, size), regint.inc(self.d_in),
|
|
regint.inc(self.d_in), regint.inc(self.d_out))), base)
|
|
|
|
if self.input_bias:
|
|
if self.d_out == 1:
|
|
@multithread(self.n_threads, N)
|
|
def _(base, size):
|
|
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
|
|
self.f_input.assign_vector(v, base)
|
|
else:
|
|
@for_range_multithread(self.n_threads, 100, N)
|
|
def _(i):
|
|
v = prod[i].get_vector() + self.b.get_vector()
|
|
self.f_input[i].assign_vector(v)
|
|
progress('f input')
|
|
|
|
def _forward(self, batch=None):
|
|
if batch is None:
|
|
batch = regint.Array(self.N)
|
|
batch.assign(regint.inc(self.N))
|
|
self.compute_f_input(batch=batch)
|
|
if self.activation_layer:
|
|
self.activation_layer.forward(batch)
|
|
if self.debug_output:
|
|
print_ln('dense X %s', self.X.reveal_nested())
|
|
print_ln('dense W %s', self.W.reveal_nested())
|
|
print_ln('dense b %s', self.b.reveal_nested())
|
|
print_ln('dense Y %s', self.Y.reveal_nested())
|
|
if self.debug:
|
|
limit = self.debug
|
|
@for_range_opt(len(batch))
|
|
def _(i):
|
|
@for_range_opt(self.d_out)
|
|
def _(j):
|
|
to_check = self.Y[i][0][j].reveal()
|
|
check = to_check > limit
|
|
@if_(check)
|
|
def _():
|
|
print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
|
|
print_ln('X %s', self.X[i].reveal_nested())
|
|
print_ln('W %s',
|
|
[self.W[k][j].reveal() for k in range(self.d_in)])
|
|
|
|
def backward(self, compute_nabla_X=True, batch=None):
|
|
N = len(batch)
|
|
d = self.d
|
|
d_out = self.d_out
|
|
X = self.X
|
|
Y = self.Y
|
|
W = self.W
|
|
b = self.b
|
|
nabla_X = self.nabla_X
|
|
nabla_Y = self.nabla_Y
|
|
nabla_W = self.nabla_W
|
|
nabla_b = self.nabla_b
|
|
|
|
if self.activation_layer:
|
|
self.activation_layer.backward(batch)
|
|
f_schur_Y = self.activation_layer.nabla_X
|
|
else:
|
|
f_schur_Y = nabla_Y
|
|
|
|
if compute_nabla_X:
|
|
@multithread(self.n_threads, N)
|
|
def _(base, size):
|
|
B = sfix.Matrix(N, d_out, address=f_schur_Y.address)
|
|
nabla_X.assign_part_vector(
|
|
B.direct_mul_trans(W, indices=(regint.inc(size, base),
|
|
regint.inc(self.d_out),
|
|
regint.inc(self.d_out),
|
|
regint.inc(self.d_in))),
|
|
base)
|
|
|
|
progress('nabla X')
|
|
|
|
self.backward_params(f_schur_Y, batch=batch)
|
|
|
|
class QuantizedDense(DenseBase):
|
|
def __init__(self, N, d_in, d_out):
|
|
self.N = N
|
|
self.d_in = d_in
|
|
self.d_out = d_out
|
|
self.d = 1
|
|
self.H = math.sqrt(1.5 / (d_in + d_out))
|
|
|
|
self.W = sfix.Matrix(d_in, d_out)
|
|
self.nabla_W = self.W.same_shape()
|
|
self.T = sint.Matrix(d_in, d_out)
|
|
self.b = sfix.Array(d_out)
|
|
self.nabla_b = self.b.same_shape()
|
|
|
|
self.X = MultiArray([N, 1, d_in], sfix)
|
|
self.Y = MultiArray([N, 1, d_out], sfix)
|
|
self.nabla_Y = self.Y.same_shape()
|
|
|
|
def reset(self):
|
|
@for_range(self.d_in)
|
|
def _(i):
|
|
@for_range(self.d_out)
|
|
def _(j):
|
|
self.W[i][j] = sfix.get_random(-1, 1)
|
|
self.b.assign_all(0)
|
|
|
|
def _forward(self):
|
|
@for_range_opt(self.d_in)
|
|
def _(i):
|
|
@for_range_opt(self.d_out)
|
|
def _(j):
|
|
over = self.W[i][j] > 0.5
|
|
under = self.W[i][j] < -0.5
|
|
self.T[i][j] = over.if_else(1, under.if_else(-1, 0))
|
|
over = self.W[i][j] > 1
|
|
under = self.W[i][j] < -1
|
|
self.W[i][j] = over.if_else(1, under.if_else(-1, self.W[i][j]))
|
|
@for_range_opt(self.N)
|
|
def _(i):
|
|
assert self.d_out == 1
|
|
self.Y[i][0][0] = self.b[0] + self.H * sfix._new(
|
|
sint.dot_product([self.T[j][0] for j in range(self.d_in)],
|
|
[self.X[i][0][j].v for j in range(self.d_in)]))
|
|
|
|
def backward(self, compute_nabla_X=False):
|
|
assert not compute_nabla_X
|
|
self.backward_params(self.nabla_Y)
|
|
|
|
class Dropout(NoVariableLayer):
|
|
""" Dropout layer.
|
|
|
|
:param N: number of examples
|
|
:param d1: total dimension
|
|
:param alpha: probability (power of two)
|
|
"""
|
|
def __init__(self, N, d1, d2=1, alpha=0.5):
|
|
self.N = N
|
|
self.d1 = d1
|
|
self.d2 = d2
|
|
self.X = MultiArray([N, d1, d2], sfix)
|
|
self.Y = MultiArray([N, d1, d2], sfix)
|
|
self.nabla_Y = MultiArray([N, d1, d2], sfix)
|
|
self.nabla_X = MultiArray([N, d1, d2], sfix)
|
|
self.alpha = alpha
|
|
self.B = MultiArray([N, d1, d2], sint)
|
|
|
|
def forward(self, batch, training=False):
|
|
if training:
|
|
n_bits = -math.log(self.alpha, 2)
|
|
assert n_bits == int(n_bits)
|
|
n_bits = int(n_bits)
|
|
@for_range_opt_multithread(self.n_threads, len(batch))
|
|
def _(i):
|
|
size = self.d1 * self.d2
|
|
self.B[i].assign_vector(util.tree_reduce(
|
|
util.or_op, (sint.get_random_bit(size=size)
|
|
for i in range(n_bits))))
|
|
@for_range_opt_multithread(self.n_threads, len(batch))
|
|
def _(i):
|
|
self.Y[i].assign_vector(1 / (1 - self.alpha) *
|
|
self.X[batch[i]].get_vector() * self.B[i].get_vector())
|
|
else:
|
|
@for_range(len(batch))
|
|
def _(i):
|
|
self.Y[i] = self.X[batch[i]]
|
|
if self.debug_output:
|
|
print_ln('dropout X %s', self.X.reveal_nested())
|
|
print_ln('dropout Y %s', self.Y.reveal_nested())
|
|
|
|
def backward(self, compute_nabla_X=True, batch=None):
|
|
if compute_nabla_X:
|
|
@for_range_opt_multithread(self.n_threads, len(batch))
|
|
def _(i):
|
|
self.nabla_X[batch[i]].assign_vector(
|
|
self.nabla_Y[i].get_vector() * self.B[i].get_vector())
|
|
if self.debug_output:
|
|
print_ln('dropout nabla_Y %s', self.nabla_Y.reveal_nested())
|
|
print_ln('dropout nabla_X %s', self.nabla_X.reveal_nested())
|
|
|
|
class ElementWiseLayer(NoVariableLayer):
|
|
def __init__(self, shape, inputs=None):
|
|
self.X = Tensor(shape, sfix)
|
|
self.Y = Tensor(shape, sfix)
|
|
backward_shape = list(shape)
|
|
backward_shape[0] = min(shape[0], self.back_batch_size)
|
|
self.nabla_X = Tensor(backward_shape, sfix)
|
|
self.nabla_Y = Tensor(backward_shape, sfix)
|
|
self.inputs = inputs
|
|
|
|
def _forward(self, batch=[0]):
|
|
n_per_item = reduce(operator.mul, self.X.sizes[1:])
|
|
@multithread(self.n_threads, len(batch), max(1, 1000 // n_per_item))
|
|
def _(base, size):
|
|
self.Y.assign_part_vector(self.f_part(base, size), base)
|
|
|
|
if self.debug_output:
|
|
name = self
|
|
@for_range(len(batch))
|
|
def _(i):
|
|
print_ln('%s X %s %s', name, i, self.X[i].reveal_nested())
|
|
print_ln('%s Y %s %s', name, i, self.Y[i].reveal_nested())
|
|
|
|
def backward(self, batch):
|
|
f_prime_bit = MultiArray(self.X.sizes, self.prime_type)
|
|
n_elements = len(batch) * reduce(operator.mul, f_prime_bit.sizes[1:])
|
|
|
|
@multithread(self.n_threads, n_elements)
|
|
def _(base, size):
|
|
f_prime_bit.assign_vector(self.f_prime_part(base, size), base)
|
|
|
|
progress('f prime')
|
|
|
|
@multithread(self.n_threads, n_elements)
|
|
def _(base, size):
|
|
self.nabla_X.assign_vector(self.nabla_Y.get_vector(base, size) *
|
|
f_prime_bit.get_vector(base, size),
|
|
base)
|
|
|
|
progress('f prime schur Y')
|
|
|
|
if self.debug_output:
|
|
name = self
|
|
@for_range(len(batch))
|
|
def _(i):
|
|
print_ln('%s X %s %s', name, i, self.X[i].reveal_nested())
|
|
print_ln('%s f_prime %s %s', name, i, f_prime_bit[i].reveal_nested())
|
|
print_ln('%s nabla Y %s %s', name, i, self.nabla_Y[i].reveal_nested())
|
|
print_ln('%s nabla X %s %s', name, i, self.nabla_X[i].reveal_nested())
|
|
|
|
class Relu(ElementWiseLayer):
|
|
""" Fixed-point ReLU layer.
|
|
|
|
:param shape: input/output shape (tuple/list of int)
|
|
"""
|
|
f = staticmethod(relu)
|
|
f_prime = staticmethod(relu_prime)
|
|
prime_type = sint
|
|
comparisons = None
|
|
|
|
def __init__(self, shape, inputs=None):
|
|
super(Relu, self).__init__(shape)
|
|
self.comparisons = MultiArray(shape, sint)
|
|
|
|
def f_part(self, base, size):
|
|
x = self.X.get_part_vector(base, size)
|
|
c = x > 0
|
|
self.comparisons.assign_part_vector(c, base)
|
|
return c.if_else(x, 0)
|
|
|
|
def f_prime_part(self, base, size):
|
|
return self.comparisons.get_vector(base, size)
|
|
|
|
class Square(ElementWiseLayer):
|
|
""" Fixed-point square layer.
|
|
|
|
:param shape: input/output shape (tuple/list of int)
|
|
"""
|
|
f = staticmethod(lambda x: x ** 2)
|
|
f_prime = staticmethod(lambda x: cfix(2, size=x.size) * x)
|
|
prime_type = sfix
|
|
|
|
class MaxPool(NoVariableLayer):
|
|
""" Fixed-point MaxPool layer.
|
|
|
|
:param shape: input shape (tuple/list of four int)
|
|
:param strides: strides (tuple/list of four int, first and last must be 1)
|
|
:param ksize: kernel size (tuple/list of four int, first and last must be 1)
|
|
:param padding: :py:obj:`'VALID'` (default) or :py:obj:`'SAME'`
|
|
"""
|
|
def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1),
|
|
padding='VALID'):
|
|
assert len(shape) == 4
|
|
for x in strides, ksize:
|
|
for i in 0, 3:
|
|
assert x[i] == 1
|
|
self.X = Tensor(shape, sfix)
|
|
if padding == 'SAME':
|
|
output_shape = [int(math.ceil(shape[i] / strides[i])) for i in range(4)]
|
|
else:
|
|
output_shape = [(shape[i] - ksize[i]) // strides[i] + 1 for i in range(4)]
|
|
self.Y = Tensor(output_shape, sfix)
|
|
self.strides = strides
|
|
self.ksize = ksize
|
|
self.nabla_X = Tensor(shape, sfix)
|
|
self.nabla_Y = Tensor(output_shape, sfix)
|
|
self.N = shape[0]
|
|
self.comparisons = MultiArray([self.N, self.X.sizes[3],
|
|
ksize[1] * ksize[2]], sint)
|
|
|
|
def _forward(self, batch):
|
|
def process(pool, bi, k, i, j):
|
|
def m(a, b):
|
|
c = a[0] > b[0]
|
|
l = [c * x for x in a[1]]
|
|
l += [(1 - c) * x for x in b[1]]
|
|
return c.if_else(a[0], b[0]), l
|
|
red = util.tree_reduce(m, [(x[0], [1]) for x in pool])
|
|
self.Y[bi][i][j][k] = red[0]
|
|
for i, x in enumerate(red[1]):
|
|
self.comparisons[bi][k][i] = x
|
|
self.traverse(batch, process)
|
|
|
|
def backward(self, compute_nabla_X=True, batch=None):
|
|
if compute_nabla_X:
|
|
self.nabla_X.alloc()
|
|
def process(pool, bi, k, i, j):
|
|
for (x, h_in, w_in, h, w), c in zip(pool,
|
|
self.comparisons[bi][k]):
|
|
hh = h * h_in
|
|
ww = w * w_in
|
|
self.nabla_X[bi][hh][ww][k] = \
|
|
util.if_else(h_in * w_in, c * self.nabla_Y[bi][i][j][k],
|
|
self.nabla_X[bi][hh][ww][k])
|
|
self.traverse(batch, process)
|
|
|
|
def traverse(self, batch, process):
|
|
need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] >
|
|
self.X.sizes[i] for i in range(4)]
|
|
@for_range_opt_multithread(self.n_threads,
|
|
[len(batch), self.X.sizes[3]])
|
|
def _(l, k):
|
|
bi = batch[l]
|
|
@for_range_opt(self.Y.sizes[1])
|
|
def _(i):
|
|
h_base = self.strides[1] * i
|
|
@for_range_opt(self.Y.sizes[2])
|
|
def _(j):
|
|
w_base = self.strides[2] * j
|
|
pool = []
|
|
for ii in range(self.ksize[1]):
|
|
h = h_base + ii
|
|
if need_padding[1]:
|
|
h_in = h < self.X.sizes[1]
|
|
else:
|
|
h_in = True
|
|
for jj in range(self.ksize[2]):
|
|
w = w_base + jj
|
|
if need_padding[2]:
|
|
w_in = w < self.X.sizes[2]
|
|
else:
|
|
w_in = True
|
|
if not is_zero(h_in * w_in):
|
|
pool.append([h_in * w_in * self.X[bi][h_in * h]
|
|
[w_in * w][k], h_in, w_in, h, w])
|
|
process(pool, bi, k, i, j)
|
|
|
|
|
|
class Argmax(NoVariableLayer):
|
|
""" Fixed-point Argmax layer.
|
|
|
|
:param shape: input shape (tuple/list of two int)
|
|
"""
|
|
def __init__(self, shape):
|
|
assert len(shape) == 2
|
|
self.X = MultiArray(shape, sfix)
|
|
self.Y = Array(shape[0], sint)
|
|
|
|
def _forward(self, batch=[0]):
|
|
assert len(batch) == 1
|
|
self.Y[batch[0]] = argmax(self.X[batch[0]])
|
|
|
|
class Concat(NoVariableLayer):
|
|
""" Fixed-point concatentation layer.
|
|
|
|
:param inputs: two input layers (tuple/list)
|
|
:param dimension: dimension for concatenation (must be 3)
|
|
"""
|
|
def __init__(self, inputs, dimension):
|
|
self.inputs = inputs
|
|
self.dimension = dimension
|
|
shapes = [inp.shape for inp in inputs]
|
|
assert dimension == 3
|
|
assert len(shapes) == 2
|
|
assert len(shapes[0]) == len(shapes[1])
|
|
shape = []
|
|
for i in range(len(shapes[0])):
|
|
if i == dimension:
|
|
shape.append(shapes[0][i] + shapes[1][i])
|
|
else:
|
|
assert shapes[0][i] == shapes[1][i]
|
|
shape.append(shapes[0][i])
|
|
self.Y = Tensor(shape, sfix)
|
|
|
|
def _forward(self, batch=[0]):
|
|
assert len(batch) == 1
|
|
@for_range_multithread(self.n_threads, 1, self.Y.sizes[1:3])
|
|
def _(i, j):
|
|
X = [x.Y[batch[0]] for x in self.inputs]
|
|
self.Y[batch[0]][i][j].assign_vector(X[0][i][j].get_vector())
|
|
self.Y[batch[0]][i][j].assign_part_vector(
|
|
X[1][i][j].get_vector(),
|
|
len(X[0][i][j]))
|
|
|
|
class Add(NoVariableLayer):
|
|
""" Fixed-point addition layer.
|
|
|
|
:param inputs: two input layers with same shape (tuple/list)
|
|
"""
|
|
def __init__(self, inputs):
|
|
assert len(inputs) > 1
|
|
shape = inputs[0].shape
|
|
for inp in inputs:
|
|
assert inp.shape == shape
|
|
self.Y = Tensor(shape, sfix)
|
|
self.inputs = inputs
|
|
|
|
def _forward(self, batch=[0]):
|
|
assert len(batch) == 1
|
|
@multithread(self.n_threads, self.Y[0].total_size())
|
|
def _(base, size):
|
|
tmp = sum(inp.Y[batch[0]].get_vector(base, size)
|
|
for inp in self.inputs)
|
|
self.Y[batch[0]].assign_vector(tmp, base)
|
|
|
|
class FusedBatchNorm(Layer):
|
|
""" Fixed-point fused batch normalization layer.
|
|
|
|
:param shape: input/output shape (tuple/list of four int)
|
|
"""
|
|
def __init__(self, shape, inputs=None):
|
|
assert len(shape) == 4
|
|
self.X = Tensor(shape, sfix)
|
|
self.Y = Tensor(shape, sfix)
|
|
self.weights = sfix.Array(shape[3])
|
|
self.bias = sfix.Array(shape[3])
|
|
self.inputs = inputs
|
|
|
|
def input_from(self, player, raw=False):
|
|
self.weights.input_from(player, raw=raw)
|
|
self.bias.input_from(player, raw=raw)
|
|
tmp = sfix.Array(len(self.bias))
|
|
tmp.input_from(player, raw=raw)
|
|
tmp.input_from(player, raw=raw)
|
|
|
|
def _forward(self, batch=[0]):
|
|
assert len(batch) == 1
|
|
@for_range_opt_multithread(self.n_threads, self.X.sizes[1:3])
|
|
def _(i, j):
|
|
self.Y[batch[0]][i][j].assign_vector(
|
|
self.X[batch[0]][i][j].get_vector() * self.weights.get_vector()
|
|
+ self.bias.get_vector())
|
|
|
|
class QuantBase(object):
|
|
bias_before_reduction = True
|
|
|
|
@staticmethod
|
|
def new_squant():
|
|
class _(squant):
|
|
@classmethod
|
|
def get_params_from(cls, player):
|
|
cls.set_params(sfloat.get_input_from(player),
|
|
sint.get_input_from(player))
|
|
@classmethod
|
|
def get_input_from(cls, player, size=None):
|
|
return cls._new(sint.get_input_from(player, size=size))
|
|
return _
|
|
|
|
def const_div(self, acc, n):
|
|
logn = int(math.log(n, 2))
|
|
acc = (acc + n // 2)
|
|
if 2 ** logn == n:
|
|
acc = acc.round(self.output_squant.params.k + logn, logn, nearest=True)
|
|
else:
|
|
acc = acc.int_div(sint(n), self.output_squant.params.k + logn)
|
|
return acc
|
|
|
|
class FixBase:
|
|
bias_before_reduction = False
|
|
|
|
@staticmethod
|
|
def new_squant():
|
|
class _(sfix):
|
|
params = None
|
|
return _
|
|
|
|
def input_params_from(self, player):
|
|
pass
|
|
|
|
def const_div(self, acc, n):
|
|
return (sfix._new(acc) * self.output_squant(1 / n)).v
|
|
|
|
class BaseLayer(Layer):
|
|
def __init__(self, input_shape, output_shape, inputs=None):
|
|
self.input_shape = input_shape
|
|
self.output_shape = output_shape
|
|
|
|
self.input_squant = self.new_squant()
|
|
self.output_squant = self.new_squant()
|
|
|
|
self.X = Tensor(input_shape, self.input_squant)
|
|
self.Y = Tensor(output_shape, self.output_squant)
|
|
|
|
back_shapes = list(input_shape), list(output_shape)
|
|
for x in back_shapes:
|
|
x[0] = min(x[0], self.back_batch_size)
|
|
|
|
self.nabla_X = MultiArray(back_shapes[0], self.input_squant)
|
|
self.nabla_Y = MultiArray(back_shapes[1], self.output_squant)
|
|
self.inputs = inputs
|
|
|
|
def temp_shape(self):
|
|
return [0]
|
|
|
|
@property
|
|
def N(self):
|
|
return self.input_shape[0]
|
|
|
|
class ConvBase(BaseLayer):
|
|
fewer_rounds = True
|
|
use_conv2ds = True
|
|
temp_weights = None
|
|
temp_inputs = None
|
|
thetas = lambda self: (self.weights, self.bias)
|
|
nablas = lambda self: (self.nabla_weights, self.nabla_bias)
|
|
|
|
@classmethod
|
|
def init_temp(cls, layers):
|
|
size = 0
|
|
for layer in layers:
|
|
size = max(size, reduce(operator.mul, layer.temp_shape()))
|
|
cls.temp_weights = sfix.Array(size)
|
|
cls.temp_inputs = sfix.Array(size)
|
|
|
|
def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride,
|
|
padding='SAME', tf_weight_format=False, inputs=None):
|
|
super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs)
|
|
|
|
self.weight_shape = weight_shape
|
|
self.bias_shape = bias_shape
|
|
self.stride = stride
|
|
self.tf_weight_format = tf_weight_format
|
|
if padding == 'SAME':
|
|
# https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn
|
|
self.padding = []
|
|
for i in 1, 2:
|
|
s = stride[i - 1]
|
|
assert output_shape[i] >= input_shape[i] // s
|
|
if tf_weight_format:
|
|
w = weight_shape[i - 1]
|
|
else:
|
|
w = weight_shape[i]
|
|
if (input_shape[i] % stride[1] == 0):
|
|
pad_total = max(w - s, 0)
|
|
else:
|
|
pad_total = max(w - (input_shape[i] % s), 0)
|
|
self.padding.append(pad_total // 2)
|
|
elif padding == 'VALID':
|
|
self.padding = [0, 0]
|
|
else:
|
|
self.padding = padding
|
|
|
|
self.weight_squant = self.new_squant()
|
|
self.bias_squant = self.new_squant()
|
|
|
|
self.weights = Tensor(weight_shape, self.weight_squant)
|
|
self.bias = Array(output_shape[-1], self.bias_squant)
|
|
|
|
self.nabla_weights = Tensor(weight_shape, self.weight_squant)
|
|
self.nabla_bias = Array(output_shape[-1], self.bias_squant)
|
|
|
|
self.unreduced = Tensor(self.output_shape, sint, address=self.Y.address)
|
|
|
|
if tf_weight_format:
|
|
weight_in = weight_shape[2]
|
|
else:
|
|
weight_in = weight_shape[3]
|
|
assert(weight_in == input_shape[-1])
|
|
assert(bias_shape[0] == output_shape[-1])
|
|
assert(len(bias_shape) == 1)
|
|
assert(len(input_shape) == 4)
|
|
assert(len(output_shape) == 4)
|
|
assert(len(weight_shape) == 4)
|
|
|
|
def input_from(self, player, raw=False):
|
|
self.input_params_from(player)
|
|
self.weights.input_from(player, budget=100000, raw=raw)
|
|
if self.input_bias:
|
|
self.bias.input_from(player, raw=raw)
|
|
|
|
def output_weights(self):
|
|
print_ln('%s', self.weights.reveal_nested())
|
|
print_ln('%s', self.bias.reveal_nested())
|
|
|
|
def dot_product(self, iv, wv, out_y, out_x, out_c):
|
|
bias = self.bias[out_c]
|
|
acc = self.output_squant.unreduced_dot_product(iv, wv)
|
|
acc.v += bias.v
|
|
acc.res_params = self.output_squant.params
|
|
#self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul()
|
|
self.unreduced[0][out_y][out_x][out_c] = acc.v
|
|
|
|
def reduction(self, batch_length=1):
|
|
unreduced = self.unreduced
|
|
n_summands = self.n_summands()
|
|
#start_timer(2)
|
|
n_outputs = batch_length * reduce(operator.mul, self.output_shape[1:])
|
|
@multithread(self.n_threads, n_outputs,
|
|
1000 if sfix.round_nearest else 10 ** 6)
|
|
def _(base, n_per_thread):
|
|
res = self.input_squant().unreduced(
|
|
sint.load_mem(unreduced.address + base,
|
|
size=n_per_thread),
|
|
self.weight_squant(),
|
|
self.output_squant.params,
|
|
n_summands).reduce_after_mul()
|
|
res.store_in_mem(self.Y.address + base)
|
|
#stop_timer(2)
|
|
|
|
def temp_shape(self):
|
|
return list(self.output_shape[1:]) + [self.n_summands()]
|
|
|
|
def prepare_temp(self):
|
|
shape = self.temp_shape()
|
|
inputs = MultiArray(shape, self.input_squant,
|
|
address=self.temp_inputs)
|
|
weights = MultiArray(shape, self.weight_squant,
|
|
address=self.temp_weights)
|
|
return inputs, weights
|
|
|
|
class Conv2d(ConvBase):
|
|
def n_summands(self):
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
|
return weights_h * weights_w * n_channels_in
|
|
|
|
def _forward(self, batch):
|
|
if self.tf_weight_format:
|
|
assert(self.weight_shape[3] == self.output_shape[-1])
|
|
weights_h, weights_w, _, _ = self.weight_shape
|
|
else:
|
|
assert(self.weight_shape[0] == self.output_shape[-1])
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
|
_, output_h, output_w, n_channels_out = self.output_shape
|
|
|
|
stride_h, stride_w = self.stride
|
|
padding_h, padding_w = self.padding
|
|
|
|
if self.use_conv2ds:
|
|
n_parts = max(1, round(self.n_threads / n_channels_out))
|
|
while len(batch) % n_parts != 0:
|
|
n_parts -= 1
|
|
print('Convolution in %d parts' % n_parts)
|
|
part_size = len(batch) // n_parts
|
|
@for_range_multithread(self.n_threads, 1, [n_parts, n_channels_out])
|
|
def _(i, j):
|
|
inputs = self.X.get_slice_vector(
|
|
batch.get_part(i * part_size, part_size))
|
|
if self.tf_weight_format:
|
|
weights = self.weights.get_vector_by_indices(None, None, None, j)
|
|
else:
|
|
weights = self.weights.get_part_vector(j)
|
|
inputs = inputs.pre_mul()
|
|
weights = weights.pre_mul()
|
|
res = sint(size = output_h * output_w * part_size)
|
|
conv2ds(res, inputs, weights, output_h, output_w,
|
|
inputs_h, inputs_w, weights_h, weights_w,
|
|
stride_h, stride_w, n_channels_in, padding_h, padding_w,
|
|
part_size)
|
|
if self.bias_before_reduction:
|
|
res += self.bias.expand_to_vector(j, res.size).v
|
|
else:
|
|
res += self.bias.expand_to_vector(j, res.size).v << \
|
|
self.input_squant.f
|
|
addresses = regint.inc(res.size,
|
|
self.unreduced[i * part_size].address + j,
|
|
n_channels_out)
|
|
res.store_in_mem(addresses)
|
|
self.reduction(len(batch))
|
|
if self.debug_output:
|
|
print_ln('%s weights %s', self, self.weights.reveal_nested())
|
|
print_ln('%s bias %s', self, self.bias.reveal_nested())
|
|
@for_range(len(batch))
|
|
def _(i):
|
|
print_ln('%s X %s %s', self, i, self.X[batch[i]].reveal_nested())
|
|
print_ln('%s Y %s %s', self, i, self.Y[i].reveal_nested())
|
|
return
|
|
else:
|
|
assert len(batch) == 1
|
|
if self.fewer_rounds:
|
|
inputs, weights = self.prepare_temp()
|
|
|
|
@for_range_opt_multithread(self.n_threads,
|
|
[output_h, output_w, n_channels_out])
|
|
def _(out_y, out_x, out_c):
|
|
in_x_origin = (out_x * stride_w) - padding_w
|
|
in_y_origin = (out_y * stride_h) - padding_h
|
|
iv = []
|
|
wv = []
|
|
for filter_y in range(weights_h):
|
|
in_y = in_y_origin + filter_y
|
|
inside_y = (0 <= in_y) * (in_y < inputs_h)
|
|
for filter_x in range(weights_w):
|
|
in_x = in_x_origin + filter_x
|
|
inside_x = (0 <= in_x) * (in_x < inputs_w)
|
|
inside = inside_y * inside_x
|
|
if is_zero(inside):
|
|
continue
|
|
for in_c in range(n_channels_in):
|
|
iv += [self.X[0][in_y * inside_y]
|
|
[in_x * inside_x][in_c]]
|
|
wv += [self.weights[out_c][filter_y][filter_x][in_c]]
|
|
wv[-1] *= inside
|
|
if self.fewer_rounds:
|
|
inputs[out_y][out_x][out_c].assign(iv)
|
|
weights[out_y][out_x][out_c].assign(wv)
|
|
else:
|
|
self.dot_product(iv, wv, out_y, out_x, out_c)
|
|
|
|
if self.fewer_rounds:
|
|
@for_range_opt_multithread(self.n_threads,
|
|
list(self.output_shape[1:]))
|
|
def _(out_y, out_x, out_c):
|
|
self.dot_product(inputs[out_y][out_x][out_c],
|
|
weights[out_y][out_x][out_c],
|
|
out_y, out_x, out_c)
|
|
|
|
self.reduction()
|
|
|
|
class QuantConvBase(QuantBase):
|
|
def input_params_from(self, player):
|
|
for s in self.input_squant, self.weight_squant, self.bias_squant, self.output_squant:
|
|
s.get_params_from(player)
|
|
print('WARNING: assuming that bias quantization parameters are correct')
|
|
self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params)
|
|
|
|
class QuantConv2d(QuantConvBase, Conv2d):
|
|
pass
|
|
|
|
class FixConv2d(Conv2d, FixBase):
|
|
""" Fixed-point 2D convolution layer.
|
|
|
|
:param input_shape: input shape (tuple/list of four int)
|
|
:param weight_shape: weight shape (tuple/list of four int)
|
|
:param bias_shape: bias shape (tuple/list of one int)
|
|
:param output_shape: output shape (tuple/list of four int)
|
|
:param stride: stride (tuple/list of two int)
|
|
:param padding: :py:obj:`'SAME'` (default), :py:obj:`'VALID'`, or tuple/list of two int
|
|
:param tf_weight_format: weight shape format is (height, width, input channels, output channels) instead of the default (output channels, height, width, input channels)
|
|
"""
|
|
|
|
def reset(self):
|
|
assert not self.tf_weight_format
|
|
kernel_size = self.weight_shape[1] * self.weight_shape[2]
|
|
r = math.sqrt(6.0 / (kernel_size * sum(self.weight_shape[::3])))
|
|
print('Initializing convolution weights in [%f,%f]' % (-r, r))
|
|
self.weights.assign_vector(
|
|
sfix.get_random(-r, r, size=self.weights.total_size()))
|
|
self.bias.assign_all(0)
|
|
|
|
def backward(self, compute_nabla_X=True, batch=None):
|
|
assert self.use_conv2ds
|
|
|
|
assert not self.tf_weight_format
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
|
_, output_h, output_w, n_channels_out = self.output_shape
|
|
|
|
stride_h, stride_w = self.stride
|
|
padding_h, padding_w = self.padding
|
|
|
|
N = len(batch)
|
|
|
|
self.nabla_bias.assign_all(0)
|
|
|
|
@for_range(N)
|
|
def _(i):
|
|
self.nabla_bias.assign_vector(
|
|
self.nabla_bias.get_vector() + sum(sum(
|
|
self.nabla_Y[i][j][k].get_vector() for k in range(output_w))
|
|
for j in range(output_h)))
|
|
|
|
input_size = inputs_h * inputs_w * N
|
|
batch_repeat = regint.Matrix(N, inputs_h * inputs_w)
|
|
batch_repeat.assign_vector(batch.get(
|
|
regint.inc(input_size, 0, 1, 1, N)) *
|
|
reduce(operator.mul, self.input_shape[1:]))
|
|
|
|
@for_range_opt_multithread(self.n_threads, [n_channels_in, n_channels_out])
|
|
def _(i, j):
|
|
a = regint.inc(input_size, self.X.address + i, n_channels_in, N,
|
|
inputs_h * inputs_w)
|
|
inputs = sfix.load_mem(batch_repeat.get_vector() + a).pre_mul()
|
|
b = regint.inc(N * output_w * output_h, self.nabla_Y.address + j, n_channels_out, N)
|
|
rep_out = regint.inc(output_h * output_w * N, 0, 1, 1, N) * \
|
|
reduce(operator.mul, self.output_shape[1:])
|
|
nabla_outputs = sfix.load_mem(rep_out + b).pre_mul()
|
|
res = sint(size = weights_h * weights_w)
|
|
conv2ds(res, inputs, nabla_outputs, weights_h, weights_w, inputs_h,
|
|
inputs_w, output_h, output_w, -stride_h, -stride_w, N,
|
|
padding_h, padding_w, 1)
|
|
reduced = unreduced_sfix._new(res).reduce_after_mul()
|
|
self.nabla_weights.assign_vector_by_indices(reduced, j, None, None, i)
|
|
|
|
if compute_nabla_X:
|
|
assert tuple(self.padding) == (0, 0)
|
|
assert tuple(self.stride) == (1, 1)
|
|
reverse_weights = MultiArray(
|
|
[n_channels_in, weights_h, weights_w, n_channels_out], sfix)
|
|
@for_range(n_channels_out)
|
|
def _(i):
|
|
@for_range(weights_h)
|
|
def _(j):
|
|
@for_range(weights_w)
|
|
def _(k):
|
|
@for_range(n_channels_in)
|
|
def _(l):
|
|
reverse_weights[l][weights_h-j-1][k][i] = \
|
|
self.weights[i][j][weights_w-k-1][l]
|
|
padded_w = inputs_w + 2 * padding_w
|
|
padded_h = inputs_h + 2 * padding_h
|
|
if padding_h or padding_w:
|
|
output = MultiArray(
|
|
[N, padded_h, padded_w, n_channels_in], sfix)
|
|
else:
|
|
output = self.nabla_X
|
|
@for_range_opt_multithread(self.n_threads,
|
|
[N, n_channels_in])
|
|
def _(i, j):
|
|
res = sint(size = (padded_w * padded_h))
|
|
conv2ds(res, self.nabla_Y[i].get_vector().pre_mul(),
|
|
reverse_weights[j].get_vector().pre_mul(),
|
|
padded_h, padded_w, output_h, output_w,
|
|
weights_h, weights_w, 1, 1, n_channels_out,
|
|
weights_h - 1, weights_w - 1, 1)
|
|
output.assign_vector_by_indices(
|
|
unreduced_sfix._new(res).reduce_after_mul(),
|
|
i, None, None, j)
|
|
if padding_h or padding_w:
|
|
@for_range(N)
|
|
def _(i):
|
|
@for_range(inputs_h)
|
|
def _(j):
|
|
@for_range(inputs_w)
|
|
def _(k):
|
|
self.nabla_X[i][j][k].assign_vector(
|
|
output[i][j][k].get_vector())
|
|
|
|
if self.debug_output:
|
|
@for_range(len(batch))
|
|
def _(i):
|
|
print_ln('%s X %s %s', self, i, list(self.X[i].reveal_nested()))
|
|
print_ln('%s nabla Y %s %s', self, i, list(self.nabla_Y[i].reveal_nested()))
|
|
if compute_nabla_X:
|
|
print_ln('%s nabla X %s %s', self, i, self.nabla_X[batch[i]].reveal_nested())
|
|
print_ln('%s nabla weights %s', self,
|
|
(self.nabla_weights.reveal_nested()))
|
|
print_ln('%s weights %s', self, (self.weights.reveal_nested()))
|
|
print_ln('%s nabla b %s', self, (self.nabla_bias.reveal_nested()))
|
|
print_ln('%s bias %s', self, (self.bias.reveal_nested()))
|
|
|
|
class QuantDepthwiseConv2d(QuantConvBase, Conv2d):
|
|
def n_summands(self):
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
return weights_h * weights_w
|
|
|
|
def _forward(self, batch):
|
|
assert len(batch) == 1
|
|
assert(self.weight_shape[-1] == self.output_shape[-1])
|
|
assert(self.input_shape[-1] == self.output_shape[-1])
|
|
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
|
_, output_h, output_w, n_channels_out = self.output_shape
|
|
|
|
stride_h, stride_w = self.stride
|
|
padding_h, padding_w = self.padding
|
|
|
|
depth_multiplier = 1
|
|
|
|
if self.use_conv2ds:
|
|
assert depth_multiplier == 1
|
|
assert self.weight_shape[0] == 1
|
|
@for_range_opt_multithread(self.n_threads, n_channels_in)
|
|
def _(j):
|
|
inputs = self.X.get_vector_by_indices(0, None, None, j)
|
|
assert not self.tf_weight_format
|
|
weights = self.weights.get_vector_by_indices(0, None, None,
|
|
j)
|
|
inputs = inputs.pre_mul()
|
|
weights = weights.pre_mul()
|
|
res = sint(size = output_h * output_w)
|
|
conv2ds(res, inputs, weights, output_h, output_w,
|
|
inputs_h, inputs_w, weights_h, weights_w,
|
|
stride_h, stride_w, 1, padding_h, padding_w, 1)
|
|
res += self.bias.expand_to_vector(j, res.size).v
|
|
self.unreduced.assign_vector_by_indices(res, 0, None, None, j)
|
|
self.reduction()
|
|
return
|
|
else:
|
|
if self.fewer_rounds:
|
|
inputs, weights = self.prepare_temp()
|
|
|
|
@for_range_opt_multithread(self.n_threads,
|
|
[output_h, output_w, n_channels_in])
|
|
def _(out_y, out_x, in_c):
|
|
for m in range(depth_multiplier):
|
|
oc = m + in_c * depth_multiplier
|
|
in_x_origin = (out_x * stride_w) - padding_w
|
|
in_y_origin = (out_y * stride_h) - padding_h
|
|
iv = []
|
|
wv = []
|
|
for filter_y in range(weights_h):
|
|
for filter_x in range(weights_w):
|
|
in_x = in_x_origin + filter_x
|
|
in_y = in_y_origin + filter_y
|
|
inside = (0 <= in_x) * (in_x < inputs_w) * \
|
|
(0 <= in_y) * (in_y < inputs_h)
|
|
if is_zero(inside):
|
|
continue
|
|
iv += [self.X[0][in_y][in_x][in_c]]
|
|
wv += [self.weights[0][filter_y][filter_x][oc]]
|
|
wv[-1] *= inside
|
|
if self.fewer_rounds:
|
|
inputs[out_y][out_x][oc].assign(iv)
|
|
weights[out_y][out_x][oc].assign(wv)
|
|
else:
|
|
self.dot_product(iv, wv, out_y, out_x, oc)
|
|
|
|
if self.fewer_rounds:
|
|
@for_range_opt_multithread(self.n_threads,
|
|
list(self.output_shape[1:]))
|
|
def _(out_y, out_x, out_c):
|
|
self.dot_product(inputs[out_y][out_x][out_c],
|
|
weights[out_y][out_x][out_c],
|
|
out_y, out_x, out_c)
|
|
|
|
self.reduction()
|
|
|
|
class AveragePool2d(BaseLayer):
|
|
def __init__(self, input_shape, output_shape, filter_size, strides=(1, 1)):
|
|
super(AveragePool2d, self).__init__(input_shape, output_shape)
|
|
self.filter_size = filter_size
|
|
self.strides = strides
|
|
for i in (0, 1):
|
|
if strides[i] == 1:
|
|
assert output_shape[1+i] == 1
|
|
assert filter_size[i] == input_shape[1+i]
|
|
else:
|
|
assert strides[i] == filter_size[i]
|
|
assert output_shape[1+i] * strides[i] == input_shape[1+i]
|
|
|
|
def input_from(self, player, raw=False):
|
|
self.input_params_from(player)
|
|
|
|
def _forward(self, batch=[0]):
|
|
assert len(batch) == 1
|
|
|
|
_, input_h, input_w, n_channels_in = self.input_shape
|
|
_, output_h, output_w, n_channels_out = self.output_shape
|
|
|
|
assert n_channels_in == n_channels_out
|
|
|
|
padding_h, padding_w = (0, 0)
|
|
stride_h, stride_w = self.strides
|
|
filter_h, filter_w = self.filter_size
|
|
n = filter_h * filter_w
|
|
print('divisor: ', n)
|
|
|
|
@for_range_opt_multithread(self.n_threads,
|
|
[output_h, output_w, n_channels_in])
|
|
def _(out_y, out_x, c):
|
|
in_x_origin = (out_x * stride_w) - padding_w
|
|
in_y_origin = (out_y * stride_h) - padding_h
|
|
fxs = util.max(-in_x_origin, 0)
|
|
#fxe = min(filter_w, input_w - in_x_origin)
|
|
fys = util.max(-in_y_origin, 0)
|
|
#fye = min(filter_h, input_h - in_y_origin)
|
|
acc = 0
|
|
#fc = 0
|
|
for i in range(filter_h):
|
|
filter_y = fys + i
|
|
for j in range(filter_w):
|
|
filter_x = fxs + j
|
|
in_x = in_x_origin + filter_x
|
|
in_y = in_y_origin + filter_y
|
|
acc += self.X[0][in_y][in_x][c].v
|
|
#fc += 1
|
|
acc = self.const_div(acc, n)
|
|
self.Y[0][out_y][out_x][c] = self.output_squant._new(acc)
|
|
|
|
class QuantAveragePool2d(QuantBase, AveragePool2d):
|
|
def input_params_from(self, player):
|
|
print('WARNING: assuming that input and output quantization parameters are the same')
|
|
for s in self.input_squant, self.output_squant:
|
|
s.get_params_from(player)
|
|
|
|
class FixAveragePool2d(FixBase, AveragePool2d):
|
|
""" Fixed-point 2D AvgPool layer.
|
|
|
|
:param input_shape: input shape (tuple/list of four int)
|
|
:param output_shape: output shape (tuple/list of four int)
|
|
:param filter_size: filter size (tuple/list of two int)
|
|
:param strides: strides (tuple/list of two int)
|
|
"""
|
|
|
|
class QuantReshape(QuantBase, BaseLayer):
|
|
def __init__(self, input_shape, _, output_shape):
|
|
super(QuantReshape, self).__init__(input_shape, output_shape)
|
|
|
|
def input_from(self, player):
|
|
print('WARNING: assuming that input and output quantization parameters are the same')
|
|
_ = self.new_squant()
|
|
for s in self.input_squant, _, self.output_squant:
|
|
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
|
for i in range(2):
|
|
sint.get_input_from(player)
|
|
|
|
def _forward(self, batch):
|
|
assert len(batch) == 1
|
|
# reshaping is implicit
|
|
self.Y.assign(self.X)
|
|
|
|
class QuantSoftmax(QuantBase, BaseLayer):
|
|
def input_from(self, player):
|
|
print('WARNING: assuming that input and output quantization parameters are the same')
|
|
for s in self.input_squant, self.output_squant:
|
|
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
|
|
|
def _forward(self, batch):
|
|
assert len(batch) == 1
|
|
assert(len(self.input_shape) == 2)
|
|
|
|
# just print the best
|
|
def comp(left, right):
|
|
c = left[1].v.greater_than(right[1].v, self.input_squant.params.k)
|
|
#print_ln('comp %s %s %s', c.reveal(), left[1].v.reveal(), right[1].v.reveal())
|
|
return [c.if_else(x, y) for x, y in zip(left, right)]
|
|
print_ln('guess: %s', util.tree_reduce(comp, list(enumerate(self.X[0])))[0].reveal())
|
|
|
|
class Optimizer:
|
|
""" Base class for graphs of layers. """
|
|
n_threads = Layer.n_threads
|
|
always_shuffle = True
|
|
time_layers = False
|
|
revealing_correctness = False
|
|
|
|
@staticmethod
|
|
def from_args(program, layers):
|
|
if 'adam' in program.args or 'adamapprox' in program.args:
|
|
return Adam(layers, 1, approx='adamapprox' in program.args)
|
|
elif 'amsgrad' in program.args:
|
|
return Adam(layers, approx=True, amsgrad=True)
|
|
elif 'quotient' in program.args:
|
|
return Adam(layers, approx=True, amsgrad=True, normalize=True)
|
|
else:
|
|
return SGD(layers, 1)
|
|
|
|
def __init__(self, report_loss=None):
|
|
self.tol = 0.000
|
|
if report_loss is None:
|
|
self.report_loss = self.layers[-1].compute_loss
|
|
else:
|
|
self.report_loss = report_loss
|
|
self.X_by_label = None
|
|
self.print_update_average = False
|
|
self.print_losses = False
|
|
self.print_loss_reduction = False
|
|
self.i_epoch = MemValue(0)
|
|
self.stopped_on_loss = MemValue(0)
|
|
|
|
@property
|
|
def layers(self):
|
|
""" Get all layers. """
|
|
return self._layers
|
|
|
|
@layers.setter
|
|
def layers(self, layers):
|
|
""" Construct linear graph from list of layers. """
|
|
self._layers = layers
|
|
prev = None
|
|
for layer in layers:
|
|
if not layer.inputs and prev is not None:
|
|
layer.inputs = [prev]
|
|
prev = layer
|
|
|
|
def set_layers_with_inputs(self, layers):
|
|
""" Construct graph from :py:obj:`inputs` members of list of layers. """
|
|
self._layers = layers
|
|
used = set([None])
|
|
for layer in reversed(layers):
|
|
layer.last_used = list(filter(lambda x: x not in used, layer.inputs))
|
|
used.update(layer.inputs)
|
|
|
|
def reset(self):
|
|
""" Initialize weights. """
|
|
for layer in self.layers:
|
|
layer.reset()
|
|
self.i_epoch.write(0)
|
|
self.stopped_on_loss.write(0)
|
|
|
|
def batch_for(self, layer, batch):
|
|
if layer in (self.layers[0], self.layers[-1]):
|
|
return batch
|
|
else:
|
|
batch = regint.Array(len(batch))
|
|
batch.assign(regint.inc(len(batch)))
|
|
return batch
|
|
|
|
@_no_mem_warnings
|
|
def forward(self, N=None, batch=None, keep_intermediate=True,
|
|
model_from=None, training=False):
|
|
""" Compute graph.
|
|
|
|
:param N: batch size (used if batch not given)
|
|
:param batch: indices for computation (:py:class:`~Compiler.types.Array` or list)
|
|
:param keep_intermediate: do not free memory of intermediate results after use
|
|
"""
|
|
if batch is None:
|
|
batch = regint.Array(N)
|
|
batch.assign(regint.inc(N))
|
|
for i, layer in enumerate(self.layers):
|
|
if layer.inputs and len(layer.inputs) == 1 and layer.inputs[0] is not None:
|
|
layer._X.address = layer.inputs[0].Y.address
|
|
layer.Y.alloc()
|
|
if model_from is not None:
|
|
layer.input_from(model_from)
|
|
break_point()
|
|
if self.time_layers:
|
|
start_timer(100 + i)
|
|
layer.forward(batch=self.batch_for(layer, batch), training=training)
|
|
if self.time_layers:
|
|
stop_timer(100 + i)
|
|
break_point()
|
|
if not keep_intermediate:
|
|
for l in layer.last_used:
|
|
l.Y.delete()
|
|
for theta in layer.thetas():
|
|
theta.delete()
|
|
|
|
@_no_mem_warnings
|
|
def eval(self, data):
|
|
""" Compute evaluation after training. """
|
|
N = len(data)
|
|
self.layers[0].X.assign(data)
|
|
self.forward(N)
|
|
return self.layers[-1].eval(N)
|
|
|
|
@_no_mem_warnings
|
|
def backward(self, batch):
|
|
""" Compute backward propagation. """
|
|
for i, layer in reversed(list(enumerate(self.layers))):
|
|
assert len(batch) <= layer.back_batch_size
|
|
if self.time_layers:
|
|
start_timer(200 + i)
|
|
if not layer.inputs:
|
|
layer.backward(compute_nabla_X=False,
|
|
batch=self.batch_for(layer, batch))
|
|
else:
|
|
layer.backward(batch=self.batch_for(layer, batch))
|
|
if len(layer.inputs) == 1:
|
|
layer.inputs[0].nabla_Y.address = \
|
|
layer.nabla_X.address
|
|
if self.time_layers:
|
|
stop_timer(200 + i)
|
|
|
|
@_no_mem_warnings
|
|
def run(self, batch_size=None, stop_on_loss=0):
|
|
""" Run training.
|
|
|
|
:param batch_size: batch size (defaults to example size of first layer)
|
|
"""
|
|
if self.n_epochs == 0:
|
|
return
|
|
if batch_size is not None:
|
|
N = batch_size
|
|
else:
|
|
N = self.layers[0].N
|
|
i = self.i_epoch
|
|
n_iterations = MemValue(0)
|
|
self.n_correct = MemValue(0)
|
|
@for_range(self.n_epochs)
|
|
def _(_):
|
|
if self.X_by_label is None:
|
|
self.X_by_label = [[None] * self.layers[0].N]
|
|
assert len(self.X_by_label) in (1, 2)
|
|
assert N % len(self.X_by_label) == 0
|
|
n = N // len(self.X_by_label)
|
|
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
|
|
self.X_by_label) / n))
|
|
print('%d runs per epoch' % n_per_epoch)
|
|
indices_by_label = []
|
|
for label, X in enumerate(self.X_by_label):
|
|
indices = regint.Array(n * n_per_epoch)
|
|
indices_by_label.append(indices)
|
|
indices.assign(regint.inc(len(indices), 0, 1, 1, len(X)))
|
|
if self.always_shuffle or n_per_epoch > 1:
|
|
indices.shuffle()
|
|
loss_sum = MemValue(sfix(0))
|
|
self.n_correct.write(0)
|
|
@for_range(n_per_epoch)
|
|
def _(j):
|
|
n_iterations.iadd(1)
|
|
batch = regint.Array(N)
|
|
for label, X in enumerate(self.X_by_label):
|
|
indices = indices_by_label[label]
|
|
batch.assign(indices.get_vector(j * n, n) +
|
|
regint(label * len(self.X_by_label[0]), size=n),
|
|
label * n)
|
|
self.forward(batch=batch, training=True)
|
|
self.backward(batch=batch)
|
|
self.update(i, batch=batch)
|
|
loss_sum.iadd(self.layers[-1].l)
|
|
if self.print_loss_reduction:
|
|
before = self.layers[-1].average_loss(N)
|
|
self.forward(batch=batch)
|
|
after = self.layers[-1].average_loss(N)
|
|
print_ln('loss reduction in batch %s: %s (%s - %s)', j,
|
|
before - after, before, after)
|
|
elif self.print_losses:
|
|
print_str('\rloss in batch %s: %s/%s', j,
|
|
self.layers[-1].average_loss(N),
|
|
loss_sum.reveal() / (j + 1))
|
|
if self.revealing_correctness:
|
|
part_truth = self.layers[-1].Y.same_shape()
|
|
part_truth.assign_vector(
|
|
self.layers[-1].Y.get_slice_vector(batch))
|
|
self.n_correct.iadd(
|
|
self.layers[-1].reveal_correctness(batch_size, part_truth))
|
|
if stop_on_loss:
|
|
loss = self.layers[-1].average_loss(N)
|
|
res = (loss < stop_on_loss) * (loss >= -1)
|
|
self.stopped_on_loss.write(1 - res)
|
|
return res
|
|
if self.print_losses:
|
|
print_ln()
|
|
if self.report_loss and self.layers[-1].compute_loss and self.layers[-1].approx != 5:
|
|
print_ln('loss in epoch %s: %s', i,
|
|
(loss_sum.reveal() * cfix(1 / n_per_epoch)))
|
|
else:
|
|
print_ln('done with epoch %s', i)
|
|
time()
|
|
i.iadd(1)
|
|
res = True
|
|
if self.tol > 0:
|
|
res *= (1 - (loss >= 0) * (loss < self.tol)).reveal()
|
|
return res
|
|
|
|
def reveal_correctness(self, data, truth, batch_size):
|
|
training_data = self.layers[0].X.address
|
|
training_truth = self.layers[-1].Y.address
|
|
self.layers[0].X.address = data.address
|
|
self.layers[-1].Y.address = truth.address
|
|
N = data.sizes[0]
|
|
batch = regint.Array(batch_size)
|
|
n_correct = MemValue(0)
|
|
loss = MemValue(sfix(0))
|
|
def f(start, batch_size):
|
|
batch.assign_vector(regint.inc(batch_size, start))
|
|
self.forward(batch=batch)
|
|
part_truth = truth.get_part(start, batch_size)
|
|
n_correct.iadd(
|
|
self.layers[-1].reveal_correctness(batch_size, part_truth))
|
|
loss.iadd(self.layers[-1].l * batch_size)
|
|
@for_range(N // batch_size)
|
|
def _(i):
|
|
start = i * batch_size
|
|
f(start, batch_size)
|
|
batch_size = N % batch_size
|
|
if batch_size:
|
|
start = N - batch_size
|
|
f(start, batch_size)
|
|
self.layers[0].X.address = training_data
|
|
self.layers[-1].Y.address = training_truth
|
|
loss = loss.reveal()
|
|
if cfix.f < 31:
|
|
loss = cfix._new(loss.v << (31 - cfix.f), k=63, f=31)
|
|
return n_correct, loss / N
|
|
|
|
@_no_mem_warnings
|
|
def run_by_args(self, program, n_runs, batch_size, test_X, test_Y,
|
|
acc_batch_size=None):
|
|
if acc_batch_size is None:
|
|
acc_batch_size = batch_size
|
|
depreciation = None
|
|
for arg in program.args:
|
|
m = re.match('rate(.*)', arg)
|
|
if m:
|
|
self.gamma = MemValue(cfix(float(m.group(1))))
|
|
m = re.match('dep(.*)', arg)
|
|
if m:
|
|
depreciation = float(m.group(1))
|
|
if 'nomom' in program.args:
|
|
self.momentum = 0
|
|
self.print_losses = 'print_losses' in program.args
|
|
self.time_layers = 'time_layers' in program.args
|
|
self.revealing_correctness = not 'no_acc' in program.args
|
|
self.layers[-1].compute_loss = not 'no_loss' in program.args
|
|
model_input = 'model_input' in program.args
|
|
acc_first = model_input and not 'train_first' in program.args
|
|
if model_input:
|
|
for layer in self.layers:
|
|
layer.input_from(0)
|
|
else:
|
|
self.reset()
|
|
if 'one_iter' in program.args:
|
|
self.output_weights()
|
|
print_ln('loss')
|
|
print_ln('%s', self.eval(
|
|
self.layers[0].X.get_part(0, batch_size)).reveal_nested())
|
|
for layer in self.layers:
|
|
print_ln('%s', layer.X.get_part(0, batch_size).reveal_nested())
|
|
print_ln('%s', self.layers[-1].Y.get_part(0, batch_size).reveal_nested())
|
|
batch = Array.create_from(regint.inc(batch_size))
|
|
self.forward(batch=batch, training=True)
|
|
self.backward(batch=batch)
|
|
self.update(0, batch=batch)
|
|
print_ln('loss %s', self.layers[-1].l.reveal())
|
|
self.output_weights()
|
|
return
|
|
@for_range(n_runs)
|
|
def _(i):
|
|
if not acc_first:
|
|
start_timer(1)
|
|
self.run(batch_size,
|
|
stop_on_loss=0 if 'no_loss' in program.args else 100)
|
|
stop_timer(1)
|
|
if 'no_acc' in program.args:
|
|
return
|
|
N = self.layers[0].X.sizes[0]
|
|
n_trained = (N + batch_size - 1) // batch_size * batch_size
|
|
print_ln('train_acc: %s (%s/%s)',
|
|
cfix(self.n_correct, k=63, f=31) / n_trained,
|
|
self.n_correct, n_trained)
|
|
n_test = len(test_Y)
|
|
n_correct, loss = self.reveal_correctness(test_X, test_Y, acc_batch_size)
|
|
print_ln('test loss: %s', loss)
|
|
print_ln('acc: %s (%s/%s)', cfix(n_correct, k=63, f=31) / n_test,
|
|
n_correct, n_test)
|
|
if acc_first:
|
|
start_timer(1)
|
|
self.run(batch_size)
|
|
stop_timer(1)
|
|
else:
|
|
@if_(util.or_op(self.stopped_on_loss, n_correct <
|
|
int(n_test // self.layers[-1].n_outputs * 1.2)))
|
|
def _():
|
|
self.gamma.imul(.5)
|
|
self.reset()
|
|
print_ln('reset after reducing learning rate to %s',
|
|
self.gamma)
|
|
if depreciation:
|
|
self.gamma.imul(depreciation)
|
|
print_ln('reducing learning rate to %s', self.gamma)
|
|
if 'model_output' in program.args:
|
|
self.output_weights()
|
|
|
|
def output_weights(self):
|
|
print_float_precision(max(6, sfix.f // 3))
|
|
for layer in self.layers:
|
|
layer.output_weights()
|
|
|
|
class Adam(Optimizer):
|
|
""" Adam/AMSgrad optimizer.
|
|
|
|
:param layers: layers of linear graph
|
|
:param approx: use approximation for inverse square root (bool)
|
|
:param amsgrad: use AMSgrad (bool)
|
|
"""
|
|
def __init__(self, layers, n_epochs=1, approx=False, amsgrad=False,
|
|
normalize=False):
|
|
self.gamma = MemValue(cfix(.001))
|
|
self.beta1 = 0.9
|
|
self.beta2 = 0.999
|
|
self.beta1_power = MemValue(cfix(1))
|
|
self.beta2_power = MemValue(cfix(1))
|
|
self.epsilon = max(2 ** -((sfix.k - sfix.f - 8) / (1 + approx)), 10 ** -8)
|
|
self.n_epochs = n_epochs
|
|
self.approx = approx
|
|
self.amsgrad = amsgrad
|
|
self.normalize = normalize
|
|
if amsgrad:
|
|
print_str('Using AMSgrad ')
|
|
else:
|
|
print_str('Using Adam ')
|
|
if approx:
|
|
print_ln('with inverse square root approximation')
|
|
else:
|
|
print_ln('with more precise inverse square root')
|
|
if normalize:
|
|
print_ln('Normalize gradient')
|
|
|
|
self.layers = layers
|
|
self.ms = []
|
|
self.vs = []
|
|
self.gs = []
|
|
self.thetas = []
|
|
self.vhats = []
|
|
for layer in layers:
|
|
for nabla in layer.nablas():
|
|
self.gs.append(nabla)
|
|
for x in self.ms, self.vs:
|
|
x.append(nabla.same_shape())
|
|
if amsgrad:
|
|
self.vhats.append(nabla.same_shape())
|
|
for theta in layer.thetas():
|
|
self.thetas.append(theta)
|
|
|
|
super(Adam, self).__init__()
|
|
|
|
def update(self, i_epoch, batch):
|
|
self.beta1_power *= self.beta1
|
|
self.beta2_power *= self.beta2
|
|
m_factor = MemValue(1 / (1 - self.beta1_power))
|
|
v_factor = MemValue(1 / (1 - self.beta2_power))
|
|
for i_layer, (m, v, g, theta) in enumerate(zip(self.ms, self.vs,
|
|
self.gs, self.thetas)):
|
|
if self.normalize:
|
|
abs_g = g.same_shape()
|
|
@multithread(self.n_threads, g.total_size())
|
|
def _(base, size):
|
|
abs_g.assign_vector(abs(g.get_vector(base, size)), base)
|
|
max_g = tree_reduce_multithread(self.n_threads,
|
|
util.max, abs_g.get_vector())
|
|
scale = MemValue(sfix._new(library.AppRcr(
|
|
max_g.v, max_g.k, max_g.f, simplex_flag=True)))
|
|
@multithread(self.n_threads, m.total_size())
|
|
def _(base, size):
|
|
m_part = m.get_vector(base, size)
|
|
v_part = v.get_vector(base, size)
|
|
g_part = g.get_vector(base, size)
|
|
if self.normalize:
|
|
g_part *= scale.expand_to_vector(size)
|
|
m_part = self.beta1 * m_part + (1 - self.beta1) * g_part
|
|
v_part = self.beta2 * v_part + (1 - self.beta2) * g_part ** 2
|
|
m.assign_vector(m_part, base)
|
|
v.assign_vector(v_part, base)
|
|
if self.amsgrad:
|
|
vhat = self.vhats [i_layer].get_vector(base, size)
|
|
vhat = util.max(vhat, v_part)
|
|
self.vhats[i_layer].assign_vector(vhat, base)
|
|
diff = self.gamma.expand_to_vector(size) * m_part
|
|
else:
|
|
mhat = m_part * m_factor.expand_to_vector(size)
|
|
vhat = v_part * v_factor.expand_to_vector(size)
|
|
diff = self.gamma.expand_to_vector(size) * mhat
|
|
if self.approx:
|
|
diff *= mpc_math.InvertSqrt(vhat + self.epsilon ** 2)
|
|
else:
|
|
diff /= mpc_math.sqrt(vhat) + self.epsilon
|
|
theta.assign_vector(theta.get_vector(base, size) - diff, base)
|
|
|
|
class SGD(Optimizer):
|
|
""" Stochastic gradient descent.
|
|
|
|
:param layers: layers of linear graph
|
|
:param n_epochs: number of epochs for training
|
|
:param report_loss: disclose and print loss
|
|
"""
|
|
def __init__(self, layers, n_epochs, debug=False, report_loss=None):
|
|
self.momentum = 0.9
|
|
self.layers = layers
|
|
self.n_epochs = n_epochs
|
|
self.thetas = []
|
|
self.nablas = []
|
|
self.delta_thetas = []
|
|
for layer in layers:
|
|
self.nablas.extend(layer.nablas())
|
|
self.thetas.extend(layer.thetas())
|
|
for theta in layer.thetas():
|
|
self.delta_thetas.append(theta.same_shape())
|
|
self.gamma = MemValue(cfix(0.01))
|
|
self.debug = debug
|
|
super(SGD, self).__init__(report_loss)
|
|
|
|
@_no_mem_warnings
|
|
def reset(self, X_by_label=None):
|
|
""" Reset layer parameters.
|
|
|
|
:param X_by_label: if given, set training data by public labels for balancing
|
|
"""
|
|
self.X_by_label = X_by_label
|
|
if X_by_label is not None:
|
|
for label, X in enumerate(X_by_label):
|
|
@for_range_multithread(self.n_threads, 1, len(X))
|
|
def _(i):
|
|
j = i + label * len(X_by_label[0])
|
|
self.layers[0].X[j] = X[i]
|
|
self.layers[-1].Y[j] = label
|
|
for y in self.delta_thetas:
|
|
y.assign_all(0)
|
|
super(SGD, self).reset()
|
|
|
|
def update(self, i_epoch, batch):
|
|
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
|
|
self.delta_thetas):
|
|
@multithread(self.n_threads, nabla.total_size())
|
|
def _(base, size):
|
|
old = delta_theta.get_vector(base, size)
|
|
red_old = self.momentum * old
|
|
rate = self.gamma.expand_to_vector(size)
|
|
nabla_vector = nabla.get_vector(base, size)
|
|
log_batch_size = math.log(len(batch), 2)
|
|
# divide by len(batch) by truncation
|
|
# increased rate if len(batch) is not a power of two
|
|
pre_trunc = nabla_vector.v * rate.v
|
|
k = nabla_vector.k + rate.k
|
|
m = rate.f + int(log_batch_size)
|
|
v = pre_trunc.round(k, m, signed=True,
|
|
nearest=sfix.round_nearest)
|
|
new = nabla_vector._new(v)
|
|
diff = red_old - new
|
|
delta_theta.assign_vector(diff, base)
|
|
theta.assign_vector(theta.get_vector(base, size) +
|
|
delta_theta.get_vector(base, size), base)
|
|
if self.print_update_average:
|
|
vec = abs(delta_theta.get_vector().reveal())
|
|
print_ln('update average: %s (%s)',
|
|
sum(vec) * cfix(1 / len(vec)), len(vec))
|
|
if self.debug:
|
|
limit = int(self.debug)
|
|
d = delta_theta.get_vector().reveal()
|
|
aa = [cfix.Array(len(d.v)) for i in range(3)]
|
|
a = aa[0]
|
|
a.assign(d)
|
|
@for_range(len(a))
|
|
def _(i):
|
|
x = a[i]
|
|
print_ln_if((x > limit) + (x < -limit),
|
|
'update epoch=%s %s index=%s %s',
|
|
i_epoch.read(), str(delta_theta), i, x)
|
|
a = aa[1]
|
|
a.assign(nabla.get_vector().reveal())
|
|
@for_range(len(a))
|
|
def _(i):
|
|
x = a[i]
|
|
print_ln_if((x > len(batch) * limit) + (x < -len(batch) * limit),
|
|
'nabla epoch=%s %s index=%s %s',
|
|
i_epoch.read(), str(nabla), i, x)
|
|
a = aa[2]
|
|
a.assign(theta.get_vector().reveal())
|
|
@for_range(len(a))
|
|
def _(i):
|
|
x = a[i]
|
|
print_ln_if((x > limit) + (x < -limit),
|
|
'theta epoch=%s %s index=%s %s',
|
|
i_epoch.read(), str(theta), i, x)
|
|
index = regint.get_random(64) % len(a)
|
|
print_ln('%s at %s: nabla=%s update=%s theta=%s', str(theta), index,
|
|
aa[1][index], aa[0][index], aa[2][index])
|
|
self.gamma.imul(1 - 10 ** - 6)
|