mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
816 lines
29 KiB
Python
816 lines
29 KiB
Python
import mpc_math, math
|
|
|
|
from Compiler.types import *
|
|
from Compiler.types import _unreduced_squant
|
|
from Compiler.library import *
|
|
from functools import reduce
|
|
|
|
def log_e(x):
|
|
return mpc_math.log_fx(x, math.e)
|
|
|
|
def exp(x):
|
|
return mpc_math.pow_fx(math.e, x)
|
|
|
|
def sanitize(x, raw, lower, upper):
|
|
exp_limit = 2 ** (x.k - x.f - 1)
|
|
limit = math.log(exp_limit)
|
|
if get_program().options.ring:
|
|
res = raw
|
|
else:
|
|
res = (x > limit).if_else(upper, raw)
|
|
return (x < -limit).if_else(lower, res)
|
|
|
|
def sigmoid(x):
|
|
return sigmoid_from_e_x(x, exp(-x))
|
|
|
|
def sigmoid_from_e_x(x, e_x):
|
|
return sanitize(x, 1 / (1 + e_x), 0, 1)
|
|
|
|
def sigmoid_prime(x):
|
|
sx = sigmoid(x)
|
|
return sx * (1 - sx)
|
|
|
|
def lse_0_from_e_x(x, e_x):
|
|
return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0)
|
|
|
|
def lse_0(x):
|
|
return lse_0_from_e_x(x, exp(x))
|
|
|
|
def relu_prime(x):
|
|
return (0 <= x)
|
|
|
|
def relu(x):
|
|
return (0 < x).if_else(x, 0)
|
|
|
|
def progress(x):
|
|
return
|
|
print_ln(x)
|
|
time()
|
|
|
|
def set_n_threads(n_threads):
|
|
Layer.n_threads = n_threads
|
|
Optimizer.n_threads = n_threads
|
|
|
|
class Layer:
|
|
n_threads = 1
|
|
|
|
class Output(Layer):
|
|
def __init__(self, N, debug=False):
|
|
self.N = N
|
|
self.X = sfix.Array(N)
|
|
self.Y = sfix.Array(N)
|
|
self.nabla_X = sfix.Array(N)
|
|
self.l = MemValue(sfix(-1))
|
|
self.e_x = sfix.Array(N)
|
|
self.debug = debug
|
|
self.weights = cint.Array(N)
|
|
self.weights.assign_all(1)
|
|
self.weight_total = N
|
|
|
|
nablas = lambda self: ()
|
|
thetas = lambda self: ()
|
|
reset = lambda self: None
|
|
|
|
def divisor(self, divisor, size):
|
|
return cfix(1.0 / divisor, size=size)
|
|
|
|
def forward(self, N=None):
|
|
N = N or self.N
|
|
lse = sfix.Array(N)
|
|
@multithread(self.n_threads, N)
|
|
def _(base, size):
|
|
x = self.X.get_vector(base, size)
|
|
y = self.Y.get_vector(base, size)
|
|
e_x = exp(-x)
|
|
self.e_x.assign(e_x, base)
|
|
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
|
|
e_x = self.e_x.get_vector(0, N)
|
|
self.l.write(sum(lse) * \
|
|
self.divisor(self.N, 1))
|
|
|
|
def backward(self):
|
|
@multithread(self.n_threads, self.N)
|
|
def _(base, size):
|
|
diff = sigmoid_from_e_x(self.X.get_vector(base, size),
|
|
self.e_x.get_vector(base, size)) - \
|
|
self.Y.get_vector(base, size)
|
|
assert sfix.f == cfix.f
|
|
diff *= self.weights.get_vector(base, size)
|
|
self.nabla_X.assign(diff * self.divisor(self.weight_total, size), \
|
|
base)
|
|
# @for_range_opt(len(diff))
|
|
# def _(i):
|
|
# self.nabla_X[i] = self.nabla_X[i] * self.weights[i]
|
|
if self.debug:
|
|
a = cfix.Array(len(diff))
|
|
a.assign(diff.reveal())
|
|
@for_range(len(diff))
|
|
def _(i):
|
|
x = a[i]
|
|
print_ln_if((x < -1.001) + (x > 1.001), 'sigmoid')
|
|
#print_ln('%s', x)
|
|
|
|
def set_weights(self, weights):
|
|
self.weights.assign(weights)
|
|
self.weight_total = sum(weights)
|
|
|
|
class DenseBase(Layer):
|
|
thetas = lambda self: (self.W, self.b)
|
|
nablas = lambda self: (self.nabla_W, self.nabla_b)
|
|
|
|
def backward_params(self, f_schur_Y):
|
|
N = self.N
|
|
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
|
|
|
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
|
|
def _(j, k):
|
|
assert self.d == 1
|
|
a = [f_schur_Y[i][0][k] for i in range(N)]
|
|
b = [self.X[i][0][j] for i in range(N)]
|
|
tmp[j][k] = sfix.unreduced_dot_product(a, b)
|
|
|
|
if self.d_in * self.d_out < 100000:
|
|
print('reduce at once')
|
|
@multithread(self.n_threads, self.d_in * self.d_out)
|
|
def _(base, size):
|
|
self.nabla_W.assign_vector(
|
|
tmp.get_vector(base, size).reduce_after_mul(), base=base)
|
|
else:
|
|
@for_range_opt(self.d_in)
|
|
def _(i):
|
|
self.nabla_W[i] = tmp[i].get_vector().reduce_after_mul()
|
|
|
|
self.nabla_b.assign(sum(sum(f_schur_Y[k][j][i] for k in range(N))
|
|
for j in range(self.d)) for i in range(self.d_out))
|
|
|
|
progress('nabla W/b')
|
|
|
|
class Dense(DenseBase):
|
|
def __init__(self, N, d_in, d_out, d=1, activation='id'):
|
|
self.activation = activation
|
|
if activation == 'id':
|
|
self.f = lambda x: x
|
|
elif activation == 'relu':
|
|
self.f = relu
|
|
self.f_prime = relu_prime
|
|
elif activation == 'sigmoid':
|
|
self.f = sigmoid
|
|
self.f_prime = sigmoid_prime
|
|
|
|
self.N = N
|
|
self.d_in = d_in
|
|
self.d_out = d_out
|
|
self.d = d
|
|
|
|
self.X = MultiArray([N, d, d_in], sfix)
|
|
self.Y = MultiArray([N, d, d_out], sfix)
|
|
self.W = sfix.Matrix(d_in, d_out)
|
|
self.b = sfix.Array(d_out)
|
|
|
|
self.reset()
|
|
|
|
self.nabla_Y = MultiArray([N, d, d_out], sfix)
|
|
self.nabla_X = MultiArray([N, d, d_in], sfix)
|
|
self.nabla_W = sfix.Matrix(d_in, d_out)
|
|
self.nabla_W.assign_all(0)
|
|
self.nabla_b = sfix.Array(d_out)
|
|
|
|
self.f_input = MultiArray([N, d, d_out], sfix)
|
|
|
|
def reset(self):
|
|
d_in = self.d_in
|
|
d_out = self.d_out
|
|
r = math.sqrt(6.0 / (d_in + d_out))
|
|
@for_range(d_in)
|
|
def _(i):
|
|
@for_range(d_out)
|
|
def _(j):
|
|
self.W[i][j] = sfix.get_random(-r, r)
|
|
self.b.assign_all(0)
|
|
|
|
def compute_f_input(self):
|
|
prod = MultiArray([self.N, self.d, self.d_out], sfix)
|
|
@for_range_opt_multithread(self.n_threads, self.N)
|
|
def _(i):
|
|
self.X[i].plain_mul(self.W, res=prod[i])
|
|
|
|
@for_range_opt_multithread(self.n_threads, self.N)
|
|
def _(i):
|
|
@for_range_opt(self.d)
|
|
def _(j):
|
|
v = prod[i][j].get_vector() + self.b.get_vector()
|
|
self.f_input[i][j].assign(v)
|
|
progress('f input')
|
|
|
|
def forward(self):
|
|
self.compute_f_input()
|
|
self.Y.assign_vector(self.f(self.f_input.get_vector()))
|
|
|
|
def backward(self, compute_nabla_X=True):
|
|
N = self.N
|
|
d = self.d
|
|
d_out = self.d_out
|
|
X = self.X
|
|
Y = self.Y
|
|
W = self.W
|
|
b = self.b
|
|
nabla_X = self.nabla_X
|
|
nabla_Y = self.nabla_Y
|
|
nabla_W = self.nabla_W
|
|
nabla_b = self.nabla_b
|
|
|
|
if self.activation == 'id':
|
|
f_schur_Y = nabla_Y
|
|
else:
|
|
f_prime_bit = MultiArray([N, d, d_out], sint)
|
|
f_schur_Y = MultiArray([N, d, d_out], sfix)
|
|
|
|
self.compute_f_input()
|
|
f_prime_bit.assign_vector(self.f_prime(self.f_input.get_vector()))
|
|
|
|
progress('f prime')
|
|
|
|
@for_range_opt(N)
|
|
def _(i):
|
|
f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i])
|
|
|
|
progress('f prime schur Y')
|
|
|
|
if compute_nabla_X:
|
|
@for_range_opt(N)
|
|
def _(i):
|
|
if self.activation == 'id':
|
|
nabla_X[i] = nabla_Y[i].mul_trans(W)
|
|
else:
|
|
nabla_X[i] = nabla_Y[i].schur(f_prime_bit[i]).mul_trans(W)
|
|
|
|
progress('nabla X')
|
|
|
|
self.backward_params(f_schur_Y)
|
|
|
|
class QuantizedDense(DenseBase):
|
|
def __init__(self, N, d_in, d_out):
|
|
self.N = N
|
|
self.d_in = d_in
|
|
self.d_out = d_out
|
|
self.d = 1
|
|
self.H = math.sqrt(1.5 / (d_in + d_out))
|
|
|
|
self.W = sfix.Matrix(d_in, d_out)
|
|
self.nabla_W = self.W.same_shape()
|
|
self.T = sint.Matrix(d_in, d_out)
|
|
self.b = sfix.Array(d_out)
|
|
self.nabla_b = self.b.same_shape()
|
|
|
|
self.X = MultiArray([N, 1, d_in], sfix)
|
|
self.Y = MultiArray([N, 1, d_out], sfix)
|
|
self.nabla_Y = self.Y.same_shape()
|
|
|
|
def reset(self):
|
|
@for_range(self.d_in)
|
|
def _(i):
|
|
@for_range(self.d_out)
|
|
def _(j):
|
|
self.W[i][j] = sfix.get_random(-1, 1)
|
|
self.b.assign_all(0)
|
|
|
|
def forward(self):
|
|
@for_range_opt(self.d_in)
|
|
def _(i):
|
|
@for_range_opt(self.d_out)
|
|
def _(j):
|
|
over = self.W[i][j] > 0.5
|
|
under = self.W[i][j] < -0.5
|
|
self.T[i][j] = over.if_else(1, under.if_else(-1, 0))
|
|
over = self.W[i][j] > 1
|
|
under = self.W[i][j] < -1
|
|
self.W[i][j] = over.if_else(1, under.if_else(-1, self.W[i][j]))
|
|
@for_range_opt(self.N)
|
|
def _(i):
|
|
assert self.d_out == 1
|
|
self.Y[i][0][0] = self.b[0] + self.H * sfix._new(
|
|
sint.dot_product([self.T[j][0] for j in range(self.d_in)],
|
|
[self.X[i][0][j].v for j in range(self.d_in)]))
|
|
|
|
def backward(self, compute_nabla_X=False):
|
|
assert not compute_nabla_X
|
|
self.backward_params(self.nabla_Y)
|
|
|
|
class Dropout:
|
|
def __init__(self, N, d1, d2=1):
|
|
self.N = N
|
|
self.d1 = d1
|
|
self.d2 = d2
|
|
self.X = MultiArray([N, d1, d2], sfix)
|
|
self.Y = MultiArray([N, d1, d2], sfix)
|
|
self.nabla_Y = MultiArray([N, d1, d2], sfix)
|
|
self.nabla_X = MultiArray([N, d1, d2], sfix)
|
|
self.alpha = 0.5
|
|
self.B = MultiArray([N, d1, d2], sint)
|
|
|
|
def forward(self):
|
|
assert self.alpha == 0.5
|
|
@for_range(self.N)
|
|
def _(i):
|
|
@for_range(self.d1)
|
|
def _(j):
|
|
@for_range(self.d2)
|
|
def _(k):
|
|
self.B[i][j][k] = sint.get_random_bit()
|
|
self.Y = self.X.schur(self.B)
|
|
|
|
def backward(self):
|
|
self.nabla_X = self.nabla_Y.schur(self.B)
|
|
|
|
class QuantBase(object):
|
|
n_threads = 1
|
|
|
|
@staticmethod
|
|
def new_squant():
|
|
class _(squant):
|
|
@classmethod
|
|
def get_input_from(cls, player, size=None):
|
|
return cls._new(sint.get_input_from(player, size=size))
|
|
return _
|
|
|
|
def __init__(self, input_shape, output_shape):
|
|
self.input_shape = input_shape
|
|
self.output_shape = output_shape
|
|
|
|
self.input_squant = self.new_squant()
|
|
self.output_squant = self.new_squant()
|
|
|
|
self.X = MultiArray(input_shape, self.input_squant)
|
|
self.Y = MultiArray(output_shape, self.output_squant)
|
|
|
|
def temp_shape(self):
|
|
return [0]
|
|
|
|
class QuantConvBase(QuantBase):
|
|
fewer_rounds = True
|
|
temp_weights = None
|
|
temp_inputs = None
|
|
|
|
@classmethod
|
|
def init_temp(cls, layers):
|
|
size = 0
|
|
for layer in layers:
|
|
size = max(size, reduce(operator.mul, layer.temp_shape()))
|
|
cls.temp_weights = sfix.Array(size)
|
|
cls.temp_inputs = sfix.Array(size)
|
|
|
|
def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride):
|
|
super(QuantConvBase, self).__init__(input_shape, output_shape)
|
|
|
|
self.weight_shape = weight_shape
|
|
self.bias_shape = bias_shape
|
|
self.stride = stride
|
|
|
|
self.weight_squant = self.new_squant()
|
|
self.bias_squant = self.new_squant()
|
|
|
|
self.weights = MultiArray(weight_shape, self.weight_squant)
|
|
self.bias = Array(output_shape[-1], self.bias_squant)
|
|
|
|
self.unreduced = MultiArray(self.output_shape, sint,
|
|
address=self.Y.address)
|
|
|
|
assert(weight_shape[-1] == input_shape[-1])
|
|
assert(bias_shape[0] == output_shape[-1])
|
|
assert(len(bias_shape) == 1)
|
|
assert(len(input_shape) == 4)
|
|
assert(len(output_shape) == 4)
|
|
assert(len(weight_shape) == 4)
|
|
|
|
def input_from(self, player):
|
|
for s in self.input_squant, self.weight_squant, self.bias_squant, self.output_squant:
|
|
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
|
self.weights.input_from(player, budget=100000)
|
|
self.bias.input_from(player)
|
|
print('WARNING: assuming that bias quantization parameters are correct')
|
|
|
|
self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params)
|
|
|
|
def dot_product(self, iv, wv, out_y, out_x, out_c):
|
|
bias = self.bias[out_c]
|
|
acc = squant.unreduced_dot_product(iv, wv)
|
|
acc.v += bias.v
|
|
acc.res_params = self.output_squant.params
|
|
#self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul()
|
|
self.unreduced[0][out_y][out_x][out_c] = acc.v
|
|
|
|
def reduction(self):
|
|
unreduced = self.unreduced
|
|
n_summands = self.n_summands()
|
|
start_timer(2)
|
|
n_outputs = reduce(operator.mul, self.output_shape)
|
|
if n_outputs % self.n_threads == 0:
|
|
n_per_thread = n_outputs // self.n_threads
|
|
@for_range_opt_multithread(self.n_threads, self.n_threads)
|
|
def _(i):
|
|
res = _unreduced_squant(
|
|
sint.load_mem(unreduced.address + i * n_per_thread,
|
|
size=n_per_thread),
|
|
(self.input_squant.params, self.weight_squant.params),
|
|
self.output_squant.params,
|
|
n_summands).reduce_after_mul()
|
|
res.store_in_mem(self.Y.address + i * n_per_thread)
|
|
else:
|
|
@for_range_opt_multithread(self.n_threads, self.output_shape[1])
|
|
def _(out_y):
|
|
self.Y[0][out_y].assign_vector(_unreduced_squant(
|
|
unreduced[0][out_y].get_vector(),
|
|
(self.input_squant.params, self.weight_squant.params),
|
|
self.output_squant.params,
|
|
n_summands).reduce_after_mul())
|
|
stop_timer(2)
|
|
|
|
def temp_shape(self):
|
|
return list(self.output_shape[1:]) + [self.n_summands()]
|
|
|
|
def prepare_temp(self):
|
|
shape = self.temp_shape()
|
|
inputs = MultiArray(shape, self.input_squant,
|
|
address=self.temp_inputs)
|
|
weights = MultiArray(shape, self.weight_squant,
|
|
address=self.temp_weights)
|
|
return inputs, weights
|
|
|
|
class QuantConv2d(QuantConvBase):
|
|
def n_summands(self):
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
|
return weights_h * weights_w * n_channels_in
|
|
|
|
def forward(self, N=1):
|
|
assert(N == 1)
|
|
assert(self.weight_shape[0] == self.output_shape[-1])
|
|
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
|
_, output_h, output_w, n_channels_out = self.output_shape
|
|
|
|
stride_h, stride_w = self.stride
|
|
padding_h, padding_w = (weights_h // 2, weights_w // 2)
|
|
|
|
if self.fewer_rounds:
|
|
inputs, weights = self.prepare_temp()
|
|
|
|
@for_range_opt_multithread(self.n_threads,
|
|
[output_h, output_w, n_channels_out])
|
|
def _(out_y, out_x, out_c):
|
|
in_x_origin = (out_x * stride_w) - padding_w
|
|
in_y_origin = (out_y * stride_h) - padding_h
|
|
iv = []
|
|
wv = []
|
|
for filter_y in range(weights_h):
|
|
in_y = in_y_origin + filter_y
|
|
inside_y = (0 <= in_y) * (in_y < inputs_h)
|
|
for filter_x in range(weights_w):
|
|
in_x = in_x_origin + filter_x
|
|
inside_x = (0 <= in_x) * (in_x < inputs_w)
|
|
inside = inside_y * inside_x
|
|
if inside is 0:
|
|
continue
|
|
for in_c in range(n_channels_in):
|
|
iv += [self.X[0][in_y * inside_y]
|
|
[in_x * inside_x][in_c]]
|
|
wv += [self.weights[out_c][filter_y][filter_x][in_c]]
|
|
wv[-1] *= inside
|
|
if self.fewer_rounds:
|
|
inputs[out_y][out_x][out_c].assign(iv)
|
|
weights[out_y][out_x][out_c].assign(wv)
|
|
else:
|
|
self.dot_product(iv, wv, out_y, out_x, out_c)
|
|
|
|
if self.fewer_rounds:
|
|
@for_range_opt_multithread(self.n_threads,
|
|
list(self.output_shape[1:]))
|
|
def _(out_y, out_x, out_c):
|
|
self.dot_product(inputs[out_y][out_x][out_c],
|
|
weights[out_y][out_x][out_c],
|
|
out_y, out_x, out_c)
|
|
|
|
self.reduction()
|
|
|
|
class QuantDepthwiseConv2d(QuantConvBase):
|
|
def n_summands(self):
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
return weights_h * weights_w
|
|
|
|
def forward(self, N=1):
|
|
assert(N == 1)
|
|
assert(self.weight_shape[-1] == self.output_shape[-1])
|
|
assert(self.input_shape[-1] == self.output_shape[-1])
|
|
|
|
_, weights_h, weights_w, _ = self.weight_shape
|
|
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
|
_, output_h, output_w, n_channels_out = self.output_shape
|
|
|
|
stride_h, stride_w = self.stride
|
|
padding_h, padding_w = (weights_h // 2, weights_w // 2)
|
|
|
|
depth_multiplier = 1
|
|
|
|
if self.fewer_rounds:
|
|
inputs, weights = self.prepare_temp()
|
|
|
|
@for_range_opt_multithread(self.n_threads,
|
|
[output_h, output_w, n_channels_in])
|
|
def _(out_y, out_x, in_c):
|
|
for m in range(depth_multiplier):
|
|
oc = m + in_c * depth_multiplier
|
|
in_x_origin = (out_x * stride_w) - padding_w
|
|
in_y_origin = (out_y * stride_h) - padding_h
|
|
iv = []
|
|
wv = []
|
|
for filter_y in range(weights_h):
|
|
for filter_x in range(weights_w):
|
|
in_x = in_x_origin + filter_x
|
|
in_y = in_y_origin + filter_y
|
|
inside = (0 <= in_x) * (in_x < inputs_w) * \
|
|
(0 <= in_y) * (in_y < inputs_h)
|
|
if inside is 0:
|
|
continue
|
|
iv += [self.X[0][in_y][in_x][in_c]]
|
|
wv += [self.weights[0][filter_y][filter_x][oc]]
|
|
wv[-1] *= inside
|
|
if self.fewer_rounds:
|
|
inputs[out_y][out_x][oc].assign(iv)
|
|
weights[out_y][out_x][oc].assign(wv)
|
|
else:
|
|
self.dot_product(iv, wv, out_y, out_x, oc)
|
|
|
|
if self.fewer_rounds:
|
|
@for_range_opt_multithread(self.n_threads,
|
|
list(self.output_shape[1:]))
|
|
def _(out_y, out_x, out_c):
|
|
self.dot_product(inputs[out_y][out_x][out_c],
|
|
weights[out_y][out_x][out_c],
|
|
out_y, out_x, out_c)
|
|
|
|
self.reduction()
|
|
|
|
class QuantAveragePool2d(QuantBase):
|
|
def __init__(self, input_shape, output_shape, filter_size):
|
|
super(QuantAveragePool2d, self).__init__(input_shape, output_shape)
|
|
self.filter_size = filter_size
|
|
|
|
def input_from(self, player):
|
|
print('WARNING: assuming that input and output quantization parameters are the same')
|
|
for s in self.input_squant, self.output_squant:
|
|
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
|
|
|
def forward(self, N=1):
|
|
assert(N == 1)
|
|
|
|
_, input_h, input_w, n_channels_in = self.input_shape
|
|
_, output_h, output_w, n_channels_out = self.output_shape
|
|
|
|
n = input_h * input_w
|
|
print('divisor: ', n)
|
|
|
|
assert output_h == output_w == 1
|
|
assert n_channels_in == n_channels_out
|
|
|
|
padding_h, padding_w = (0, 0)
|
|
stride_h, stride_w = (2, 2)
|
|
filter_h, filter_w = self.filter_size
|
|
|
|
@for_range_opt(output_h)
|
|
def _(out_y):
|
|
@for_range_opt(output_w)
|
|
def _(out_x):
|
|
@for_range_opt(n_channels_in)
|
|
def _(c):
|
|
in_x_origin = (out_x * stride_w) - padding_w
|
|
in_y_origin = (out_y * stride_h) - padding_h
|
|
fxs = (-in_x_origin).max(0)
|
|
#fxe = min(filter_w, input_w - in_x_origin)
|
|
fys = (-in_y_origin).max(0)
|
|
#fye = min(filter_h, input_h - in_y_origin)
|
|
acc = 0
|
|
#fc = 0
|
|
for i in range(filter_h):
|
|
filter_y = fys + i
|
|
for j in range(filter_w):
|
|
filter_x = fxs + j
|
|
in_x = in_x_origin + filter_x
|
|
in_y = in_y_origin + filter_y
|
|
acc += self.X[0][in_y][in_x][c].v
|
|
#fc += 1
|
|
logn = int(math.log(n, 2))
|
|
acc = (acc + n // 2)
|
|
if 2 ** logn == n:
|
|
acc = acc.round(self.output_squant.params.k + logn,
|
|
logn, nearest=True)
|
|
else:
|
|
acc = acc.int_div(sint(n),
|
|
self.output_squant.params.k + logn)
|
|
#acc = min(255, max(0, acc))
|
|
self.Y[0][out_y][out_x][c] = self.output_squant._new(acc)
|
|
|
|
class QuantReshape(QuantBase):
|
|
def __init__(self, input_shape, _, output_shape):
|
|
super(QuantReshape, self).__init__(input_shape, output_shape)
|
|
|
|
def input_from(self, player):
|
|
print('WARNING: assuming that input and output quantization parameters are the same')
|
|
_ = self.new_squant()
|
|
for s in self.input_squant, _, self.output_squant:
|
|
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
|
for i in range(2):
|
|
sint.get_input_from(player)
|
|
|
|
def forward(self, N=1):
|
|
assert(N == 1)
|
|
# reshaping is implicit
|
|
self.Y.assign(self.X)
|
|
|
|
class QuantSoftmax(QuantBase):
|
|
def input_from(self, player):
|
|
print('WARNING: assuming that input and output quantization parameters are the same')
|
|
for s in self.input_squant, self.output_squant:
|
|
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
|
|
|
def forward(self, N=1):
|
|
assert(N == 1)
|
|
assert(len(self.input_shape) == 2)
|
|
|
|
# just print the best
|
|
def comp(left, right):
|
|
c = left[1].v.greater_than(right[1].v, self.input_squant.params.k)
|
|
#print_ln('comp %s %s %s', c.reveal(), left[1].v.reveal(), right[1].v.reveal())
|
|
return [c.if_else(x, y) for x, y in zip(left, right)]
|
|
print_ln('guess: %s', util.tree_reduce(comp, list(enumerate(self.X[0])))[0].reveal())
|
|
|
|
class Optimizer:
|
|
n_threads = Layer.n_threads
|
|
|
|
def forward(self, N):
|
|
for j in range(len(self.layers) - 1):
|
|
self.layers[j].forward()
|
|
self.layers[j + 1].X.assign(self.layers[j].Y)
|
|
self.layers[-1].forward(N)
|
|
|
|
def backward(self):
|
|
for j in range(1, len(self.layers)):
|
|
self.layers[-j].backward()
|
|
self.layers[-j - 1].nabla_Y.assign(self.layers[-j].nabla_X)
|
|
self.layers[0].backward(compute_nabla_X=False)
|
|
|
|
def run(self):
|
|
i = MemValue(0)
|
|
@do_while
|
|
def _():
|
|
if self.X_by_label is not None:
|
|
N = self.layers[0].N
|
|
assert self.layers[-1].N == N
|
|
assert N % 2 == 0
|
|
n = N // 2
|
|
@for_range(n)
|
|
def _(i):
|
|
self.layers[-1].Y[i] = 0
|
|
self.layers[-1].Y[i + n] = 1
|
|
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
|
|
self.X_by_label) / n))
|
|
print('%d runs per epoch' % n_per_epoch)
|
|
indices_by_label = []
|
|
for label, X in enumerate(self.X_by_label):
|
|
indices = regint.Array(n * n_per_epoch)
|
|
indices_by_label.append(indices)
|
|
indices.assign(i % len(X) for i in range(len(indices)))
|
|
indices.shuffle()
|
|
@for_range(n_per_epoch)
|
|
def _(j):
|
|
j = MemValue(j)
|
|
for label, X in enumerate(self.X_by_label):
|
|
indices = indices_by_label[label]
|
|
@for_range_multithread(self.n_threads, 1, n)
|
|
def _(i):
|
|
idx = indices[i + j * n_per_epoch]
|
|
self.layers[0].X[i + label * n] = X[idx]
|
|
self.forward(None)
|
|
self.backward()
|
|
self.update(i)
|
|
else:
|
|
self.forward(None)
|
|
self.backward()
|
|
self.update(i)
|
|
loss = self.layers[-1].l
|
|
if self.report_loss:
|
|
print_ln('loss after epoch %s: %s', i, loss.reveal())
|
|
else:
|
|
print_ln('done with epoch %s', i)
|
|
time()
|
|
i.iadd(1)
|
|
res = (i < self.n_epochs)
|
|
if self.tol > 0:
|
|
res *= (1 - (loss >= 0) * (loss < self.tol)).reveal()
|
|
return res
|
|
print_ln('finished after %s epochs', i)
|
|
|
|
class Adam(Optimizer):
|
|
def __init__(self, layers, n_epochs):
|
|
self.alpha = .001
|
|
self.beta1 = 0.9
|
|
self.beta2 = 0.999
|
|
self.epsilon = 10 ** -8
|
|
self.n_epochs = n_epochs
|
|
|
|
self.layers = layers
|
|
self.ms = []
|
|
self.vs = []
|
|
self.gs = []
|
|
self.thetas = []
|
|
for layer in layers:
|
|
for nabla in layer.nablas():
|
|
self.gs.append(nabla)
|
|
for x in self.ms, self.vs:
|
|
x.append(nabla.same_shape())
|
|
for theta in layer.thetas():
|
|
self.thetas.append(theta)
|
|
|
|
self.mhat_factors = Array(n_epochs, sfix)
|
|
self.vhat_factors = Array(n_epochs, sfix)
|
|
|
|
for i in range(n_epochs):
|
|
for factors, beta in ((self.mhat_factors, self.beta1),
|
|
(self.vhat_factors, self.beta2)):
|
|
factors[i] = 1. / (1 - beta ** (i + 1))
|
|
|
|
def update(self, i_epoch):
|
|
for m, v, g, theta in zip(self.ms, self.vs, self.gs, self.thetas):
|
|
@for_range_opt(len(m))
|
|
def _(k):
|
|
m[k] = self.beta1 * m[k] + (1 - self.beta1) * g[k]
|
|
v[k] = self.beta2 * v[k] + (1 - self.beta2) * g[k] ** 2
|
|
mhat = m[k] * self.mhat_factors[i_epoch]
|
|
vhat = v[k] * self.vhat_factors[i_epoch]
|
|
theta[k] = theta[k] - self.alpha * mhat / \
|
|
mpc_math.sqrt(vhat) + self.epsilon
|
|
|
|
class SGD(Optimizer):
|
|
def __init__(self, layers, n_epochs, debug=False, report_loss=False):
|
|
self.momentum = 0.9
|
|
self.layers = layers
|
|
self.n_epochs = n_epochs
|
|
self.thetas = []
|
|
self.nablas = []
|
|
self.delta_thetas = []
|
|
for layer in layers:
|
|
self.nablas.extend(layer.nablas())
|
|
self.thetas.extend(layer.thetas())
|
|
for theta in layer.thetas():
|
|
self.delta_thetas.append(theta.same_shape())
|
|
self.gamma = MemValue(sfix(0.01))
|
|
self.debug = debug
|
|
self.report_loss = report_loss
|
|
self.tol = 0.000
|
|
self.X_by_label = None
|
|
|
|
def reset(self, X_by_label=None):
|
|
self.X_by_label = X_by_label
|
|
for y in self.delta_thetas:
|
|
y.assign_all(0)
|
|
for layer in self.layers:
|
|
layer.reset()
|
|
|
|
def update(self, i_epoch):
|
|
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
|
|
self.delta_thetas):
|
|
@for_range_opt_multithread(self.n_threads, len(nabla))
|
|
def _(k):
|
|
old = delta_theta[k]
|
|
if isinstance(old, Array):
|
|
old = old.get_vector()
|
|
red_old = self.momentum * old
|
|
new = self.gamma * nabla[k]
|
|
diff = red_old - new
|
|
delta_theta[k] = diff
|
|
theta[k] = theta[k] + delta_theta[k]
|
|
if self.debug:
|
|
for x, name in (old, 'old'), (red_old, 'red_old'), \
|
|
(new, 'new'), (diff, 'diff'):
|
|
x = x.reveal()
|
|
print_ln_if((x > 1000) + (x < -1000),
|
|
name + ': %s %s %s %s',
|
|
*[y.v.reveal() for y in (old, red_old, \
|
|
new, diff)])
|
|
if self.debug:
|
|
d = delta_theta.get_vector().reveal()
|
|
a = cfix.Array(len(d.v))
|
|
a.assign(d)
|
|
@for_range(len(a))
|
|
def _(i):
|
|
x = a[i]
|
|
print_ln_if((x > 1000) + (x < -1000),
|
|
'update len=%d' % len(nabla))
|
|
a.assign(nabla.get_vector().reveal())
|
|
@for_range(len(a))
|
|
def _(i):
|
|
x = a[i]
|
|
print_ln_if((x > 1000) + (x < -1000),
|
|
'nabla len=%d' % len(nabla))
|
|
self.gamma.imul(1 - 10 ** - 6)
|