mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Merge FlexDense with Dense
This commit is contained in:
534
Compiler/ml.py
534
Compiler/ml.py
@@ -725,350 +725,6 @@ class DenseBase(Layer):
|
||||
N = len(batch)
|
||||
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
||||
|
||||
A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address)
|
||||
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
|
||||
@multithread(self.n_threads, self.d_in)
|
||||
def _(base, size):
|
||||
mp = B.direct_trans_mul(A, reduce=False,
|
||||
indices=(regint.inc(size, base),
|
||||
batch.get_vector(),
|
||||
regint.inc(N),
|
||||
regint.inc(self.d_out)))
|
||||
tmp.assign_part_vector(mp, base)
|
||||
|
||||
progress('nabla W (matmul)')
|
||||
|
||||
@multithread(self.n_threads, self.d_in * self.d_out,
|
||||
max_size=get_program().budget)
|
||||
def _(base, size):
|
||||
self.nabla_W.assign_vector(
|
||||
tmp.get_vector(base, size).reduce_after_mul(), base=base)
|
||||
|
||||
if self.print_random_update:
|
||||
print_ln('backward %s', self)
|
||||
i = regint.get_random(64) % self.d_in
|
||||
j = regint.get_random(64) % self.d_out
|
||||
print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s',
|
||||
str(self.nabla_W), i, j, tmp[i][j].v.reveal(),
|
||||
self.nabla_W[i][j].reveal(),
|
||||
A.get_column(j).reveal(),
|
||||
B.get_column_by_row_indices(
|
||||
batch.get_vector(), i).reveal())
|
||||
print_ln('batch=%s B=%s', batch,
|
||||
[self.X[bi][0][i].reveal() for bi in batch])
|
||||
|
||||
progress('nabla W')
|
||||
|
||||
self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector()
|
||||
for k in range(N))
|
||||
for j in range(self.d)))
|
||||
|
||||
progress('nabla b')
|
||||
|
||||
if self.debug_output:
|
||||
print_ln('dense nabla Y %s', self.nabla_Y.reveal_nested())
|
||||
print_ln('dense W %s', self.W.reveal_nested())
|
||||
print_ln('dense nabla X %s', self.nabla_X.reveal_nested())
|
||||
if self.debug:
|
||||
limit = N * self.debug
|
||||
@for_range_opt(self.d_in)
|
||||
def _(i):
|
||||
@for_range_opt(self.d_out)
|
||||
def _(j):
|
||||
to_check = self.nabla_W[i][j].reveal()
|
||||
check = sum(to_check > limit) + sum(to_check < -limit)
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check)
|
||||
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
|
||||
for k in range(N)])
|
||||
print_ln('X %s', [self.X[k][0][i].reveal()
|
||||
for k in range(N)])
|
||||
@for_range_opt(self.d_out)
|
||||
def _(j):
|
||||
to_check = self.nabla_b[j].reveal()
|
||||
check = sum(to_check > limit) + sum(to_check < -limit)
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('nabla b %s %s: %s', j, len(self.b), to_check)
|
||||
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
|
||||
for k in range(N)])
|
||||
@for_range_opt(len(batch))
|
||||
def _(i):
|
||||
to_check = self.nabla_X[i].get_vector().reveal()
|
||||
check = sum(to_check > limit) + sum(to_check < -limit)
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('X %s %s', i, self.X[i].reveal_nested())
|
||||
print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested())
|
||||
|
||||
class Dense(DenseBase):
|
||||
""" Fixed-point dense (matrix multiplication) layer.
|
||||
|
||||
:param N: number of examples
|
||||
:param d_in: input dimension
|
||||
:param d_out: output dimension
|
||||
"""
|
||||
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
|
||||
if activation == 'id':
|
||||
self.activation_layer = None
|
||||
elif activation == 'relu':
|
||||
self.activation_layer = Relu([N, d, d_out])
|
||||
elif activation == 'square':
|
||||
self.activation_layer = Square([N, d, d_out])
|
||||
else:
|
||||
raise CompilerError('activation not supported: %s', activation)
|
||||
|
||||
self.N = N
|
||||
self.d_in = d_in
|
||||
self.d_out = d_out
|
||||
self.d = d
|
||||
self.activation = activation
|
||||
|
||||
self.X = Tensor([N, d, d_in], sfix)
|
||||
self.Y = Tensor([N, d, d_out], sfix)
|
||||
self.W = Tensor([d_in, d_out], sfix)
|
||||
self.b = sfix.Array(d_out)
|
||||
|
||||
back_N = min(N, self.back_batch_size)
|
||||
self.nabla_Y = Tensor([back_N, d, d_out], sfix)
|
||||
self.nabla_X = Tensor([back_N, d, d_in], sfix)
|
||||
self.nabla_W = Tensor([d_in, d_out], sfix)
|
||||
self.nabla_b = sfix.Array(d_out)
|
||||
|
||||
self.debug = debug
|
||||
|
||||
l = self.activation_layer
|
||||
if l:
|
||||
self.f_input = l.X
|
||||
l.Y = self.Y
|
||||
l.nabla_Y = self.nabla_Y
|
||||
else:
|
||||
self.f_input = self.Y
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s, %s, activation=%s)' % \
|
||||
(type(self).__name__, self.N, self.d_in,
|
||||
self.d_out, repr(self.activation))
|
||||
|
||||
def reset(self):
|
||||
d_in = self.d_in
|
||||
d_out = self.d_out
|
||||
r = math.sqrt(6.0 / (d_in + d_out))
|
||||
print('Initializing dense weights in [%f,%f]' % (-r, r))
|
||||
self.W.randomize(-r, r, n_threads=self.n_threads)
|
||||
self.b.assign_all(0)
|
||||
|
||||
def input_from(self, player, **kwargs):
|
||||
self.W.input_from(player, **kwargs)
|
||||
if self.input_bias:
|
||||
self.b.input_from(player, **kwargs)
|
||||
|
||||
def compute_f_input(self, batch):
|
||||
N = len(batch)
|
||||
assert self.d == 1
|
||||
if self.input_bias:
|
||||
prod = MultiArray([N, self.d, self.d_out], sfix)
|
||||
else:
|
||||
prod = self.f_input
|
||||
max_size = get_program().budget
|
||||
@multithread(self.n_threads, N, max_size)
|
||||
def _(base, size):
|
||||
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
prod.assign_part_vector(
|
||||
X_sub.direct_mul(self.W, indices=(
|
||||
batch.get_vector(base, size), regint.inc(self.d_in),
|
||||
regint.inc(self.d_in), regint.inc(self.d_out))), base)
|
||||
|
||||
if self.input_bias:
|
||||
if self.d_out == 1:
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
|
||||
self.f_input.assign_vector(v, base)
|
||||
else:
|
||||
@for_range_multithread(self.n_threads, 100, N)
|
||||
def _(i):
|
||||
v = prod[i].get_vector() + self.b.get_vector()
|
||||
self.f_input[i].assign_vector(v)
|
||||
progress('f input')
|
||||
|
||||
def _forward(self, batch=None):
|
||||
if batch is None:
|
||||
batch = regint.Array(self.N)
|
||||
batch.assign(regint.inc(self.N))
|
||||
self.compute_f_input(batch=batch)
|
||||
if self.activation_layer:
|
||||
self.activation_layer.forward(batch, training=True)
|
||||
if self.debug_output:
|
||||
print_ln('dense X %s', self.X.reveal_nested())
|
||||
print_ln('dense W %s', self.W.reveal_nested())
|
||||
print_ln('dense b %s', self.b.reveal_nested())
|
||||
print_ln('dense Y %s', self.Y.reveal_nested())
|
||||
if self.debug:
|
||||
limit = self.debug
|
||||
@for_range_opt(len(batch))
|
||||
def _(i):
|
||||
@for_range_opt(self.d_out)
|
||||
def _(j):
|
||||
to_check = self.Y[i][0][j].reveal()
|
||||
check = to_check > limit
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
|
||||
print_ln('X %s', self.X[i].reveal_nested())
|
||||
print_ln('W %s',
|
||||
[self.W[k][j].reveal() for k in range(self.d_in)])
|
||||
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
N = len(batch)
|
||||
d = self.d
|
||||
d_out = self.d_out
|
||||
X = self.X
|
||||
Y = self.Y
|
||||
W = self.W
|
||||
b = self.b
|
||||
nabla_X = self.nabla_X
|
||||
nabla_Y = self.nabla_Y
|
||||
nabla_W = self.nabla_W
|
||||
nabla_b = self.nabla_b
|
||||
|
||||
if self.activation_layer:
|
||||
self.activation_layer.backward(batch)
|
||||
f_schur_Y = self.activation_layer.nabla_X
|
||||
else:
|
||||
f_schur_Y = nabla_Y
|
||||
|
||||
if compute_nabla_X:
|
||||
nabla_X.alloc()
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
B = sfix.Matrix(N, d_out, address=f_schur_Y.address)
|
||||
nabla_X.assign_part_vector(
|
||||
B.direct_mul_trans(W, indices=(regint.inc(size, base),
|
||||
regint.inc(self.d_out),
|
||||
regint.inc(self.d_out),
|
||||
regint.inc(self.d_in))),
|
||||
base)
|
||||
|
||||
if self.print_random_update:
|
||||
print_ln('backward %s', self)
|
||||
index = regint.get_random(64) % self.nabla_X.total_size()
|
||||
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
|
||||
index, self.nabla_X.to_array()[index].reveal())
|
||||
|
||||
progress('nabla X')
|
||||
|
||||
self.backward_params(f_schur_Y, batch=batch)
|
||||
|
||||
|
||||
class FlexDense(Dense):
|
||||
""" Fixed-point dense (matrix multiplication) layer with flexible number of dimensions.
|
||||
Behaves like torch's nn.Linear which loops over the additional dimensions.
|
||||
|
||||
:param N: number of examples
|
||||
:param d_in: input dimension
|
||||
:param d_out: output dimension
|
||||
"""
|
||||
|
||||
def compute_f_input(self, batch):
|
||||
N = len(batch)
|
||||
prod = MultiArray([N, self.d, self.d_out], sfix)
|
||||
|
||||
# flattened_array version
|
||||
result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address)
|
||||
max_size = get_program().budget
|
||||
|
||||
# X is stored at full dataset indices, batch specifies which samples to use
|
||||
X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
|
||||
|
||||
# Precompute batch_d_indices for all N*d elements
|
||||
# For each sample in batch, expand to d consecutive indices
|
||||
batch_d_indices = regint.Array(N * self.d)
|
||||
@for_range(N)
|
||||
def _(i):
|
||||
actual_sample = batch[i]
|
||||
batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d)
|
||||
# @for_range(self.d)
|
||||
# def _(d_idx):
|
||||
# batch_d_indices[i * self.d + d_idx] = actual_sample * self.d + d_idx
|
||||
|
||||
@multithread(self.n_threads, N * self.d, max_size)
|
||||
def _(base, size):
|
||||
result_matrix.assign_part_vector(
|
||||
X_sub.direct_mul(self.W, indices=(
|
||||
batch_d_indices.get_vector(base, size), regint.inc(self.d_in),
|
||||
regint.inc(self.d_in), regint.inc(self.d_out))), base)
|
||||
|
||||
if self.input_bias:
|
||||
if self.d_out == 1:
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
|
||||
self.f_input.assign_vector(v, base)
|
||||
else:
|
||||
@for_range_multithread(self.n_threads, 100, [N, self.d])
|
||||
def _(i, j):
|
||||
v = prod[i][j].get_vector() + self.b.get_vector()
|
||||
# print_ln("running bias %s %s %s", i, j, v.reveal())
|
||||
self.f_input[i][j].assign_vector(v)
|
||||
|
||||
# print_ln("FlexDense f_inpu full %s", self.f_input.reveal())
|
||||
|
||||
progress('f input')
|
||||
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
N = len(batch)
|
||||
d = self.d
|
||||
d_out = self.d_out
|
||||
X = self.X
|
||||
Y = self.Y
|
||||
W = self.W
|
||||
b = self.b
|
||||
nabla_X = self.nabla_X
|
||||
nabla_Y = self.nabla_Y
|
||||
nabla_W = self.nabla_W
|
||||
nabla_b = self.nabla_b
|
||||
|
||||
if self.activation_layer:
|
||||
self.activation_layer.backward(batch)
|
||||
f_schur_Y = self.activation_layer.nabla_X
|
||||
else:
|
||||
f_schur_Y = nabla_Y
|
||||
|
||||
if compute_nabla_X:
|
||||
nabla_X.alloc()
|
||||
|
||||
# flattened matrix version
|
||||
result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address)
|
||||
# Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices
|
||||
@multithread(self.n_threads, N * self.d)
|
||||
def _(base, size):
|
||||
X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
|
||||
offset = regint.inc(size, base=base)
|
||||
|
||||
result_matrix.assign_part_vector(
|
||||
X_sub.direct_mul_trans(self.W, indices=(
|
||||
offset, regint.inc(self.d_out),
|
||||
regint.inc(self.d_out), regint.inc(self.d_in))), base)
|
||||
|
||||
if self.print_random_update:
|
||||
print_ln('backward %s', self)
|
||||
index = regint.get_random(64) % self.nabla_X.total_size()
|
||||
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
|
||||
index, self.nabla_X.to_array()[index].reveal())
|
||||
|
||||
progress('nabla X')
|
||||
|
||||
self.backward_params(f_schur_Y, batch=batch)
|
||||
|
||||
def backward_params(self, f_schur_Y, batch):
|
||||
print("backward params flexdense")
|
||||
N = len(batch)
|
||||
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
||||
# tmp.assign_all(0)
|
||||
|
||||
# A (f_schur_Y/nabla_Y) is stored at sequential batch indices [0, 1, ..., N-1]
|
||||
A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
|
||||
# B (X) is stored at the full dataset size, not just the batch
|
||||
@@ -1082,7 +738,7 @@ class FlexDense(Dense):
|
||||
@for_range(N)
|
||||
def _(i):
|
||||
# batch[i] gives the actual sample index in the full dataset
|
||||
# For FlexDense with d>1, we need to map to flattened indices
|
||||
# For Dense with d>1, we need to map to flattened indices
|
||||
actual_sample_idx = batch[i]
|
||||
@for_range(self.d)
|
||||
def _(d_idx):
|
||||
@@ -1162,6 +818,182 @@ class FlexDense(Dense):
|
||||
print_ln('X %s %s', i, self.X[i].reveal_nested())
|
||||
print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested())
|
||||
|
||||
class Dense(DenseBase):
|
||||
""" Fixed-point dense (matrix multiplication) layer.
|
||||
Supports inputs of size [N, d, d_in] to map to [N, d, d_out]. If d > 1, the layer
|
||||
behaves like torch's nn.Linear which loops over the additional d
|
||||
|
||||
:param N: number of examples
|
||||
:param d_in: input dimension
|
||||
:param d_out: output dimension
|
||||
:param d: (optional) extra dimension
|
||||
"""
|
||||
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
|
||||
if activation == 'id':
|
||||
self.activation_layer = None
|
||||
elif activation == 'relu':
|
||||
self.activation_layer = Relu([N, d, d_out])
|
||||
elif activation == 'square':
|
||||
self.activation_layer = Square([N, d, d_out])
|
||||
else:
|
||||
raise CompilerError('activation not supported: %s', activation)
|
||||
|
||||
self.N = N
|
||||
self.d_in = d_in
|
||||
self.d_out = d_out
|
||||
self.d = d
|
||||
self.activation = activation
|
||||
|
||||
self.X = Tensor([N, d, d_in], sfix)
|
||||
self.Y = Tensor([N, d, d_out], sfix)
|
||||
self.W = Tensor([d_in, d_out], sfix)
|
||||
self.b = sfix.Array(d_out)
|
||||
|
||||
back_N = min(N, self.back_batch_size)
|
||||
self.nabla_Y = Tensor([back_N, d, d_out], sfix)
|
||||
self.nabla_X = Tensor([back_N, d, d_in], sfix)
|
||||
self.nabla_W = Tensor([d_in, d_out], sfix)
|
||||
self.nabla_b = sfix.Array(d_out)
|
||||
|
||||
self.debug = debug
|
||||
|
||||
l = self.activation_layer
|
||||
if l:
|
||||
self.f_input = l.X
|
||||
l.Y = self.Y
|
||||
l.nabla_Y = self.nabla_Y
|
||||
else:
|
||||
self.f_input = self.Y
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s, %s, activation=%s)' % \
|
||||
(type(self).__name__, self.N, self.d_in,
|
||||
self.d_out, repr(self.activation))
|
||||
|
||||
def reset(self):
|
||||
d_in = self.d_in
|
||||
d_out = self.d_out
|
||||
r = math.sqrt(6.0 / (d_in + d_out))
|
||||
print('Initializing dense weights in [%f,%f]' % (-r, r))
|
||||
self.W.randomize(-r, r, n_threads=self.n_threads)
|
||||
self.b.assign_all(0)
|
||||
|
||||
def input_from(self, player, **kwargs):
|
||||
self.W.input_from(player, **kwargs)
|
||||
if self.input_bias:
|
||||
self.b.input_from(player, **kwargs)
|
||||
|
||||
def compute_f_input(self, batch):
|
||||
N = len(batch)
|
||||
prod = MultiArray([N, self.d, self.d_out], sfix)
|
||||
|
||||
# flattened_array version
|
||||
result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address)
|
||||
max_size = get_program().budget
|
||||
|
||||
# X is stored at full dataset indices, batch specifies which samples to use
|
||||
X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
|
||||
|
||||
# Precompute batch_d_indices for all N*d elements
|
||||
# For each sample in batch, expand to d consecutive indices
|
||||
batch_d_indices = regint.Array(N * self.d)
|
||||
@for_range(N)
|
||||
def _(i):
|
||||
actual_sample = batch[i]
|
||||
batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d)
|
||||
|
||||
@multithread(self.n_threads, N * self.d, max_size)
|
||||
def _(base, size):
|
||||
result_matrix.assign_part_vector(
|
||||
X_sub.direct_mul(self.W, indices=(
|
||||
batch_d_indices.get_vector(base, size), regint.inc(self.d_in),
|
||||
regint.inc(self.d_in), regint.inc(self.d_out))), base)
|
||||
|
||||
if self.input_bias:
|
||||
if self.d_out == 1:
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
|
||||
self.f_input.assign_vector(v, base)
|
||||
else:
|
||||
@for_range_multithread(self.n_threads, 100, [N, self.d])
|
||||
def _(i, j):
|
||||
v = prod[i][j].get_vector() + self.b.get_vector()
|
||||
self.f_input[i][j].assign_vector(v)
|
||||
|
||||
progress('f input')
|
||||
|
||||
def _forward(self, batch=None):
|
||||
if batch is None:
|
||||
batch = regint.Array(self.N)
|
||||
batch.assign(regint.inc(self.N))
|
||||
self.compute_f_input(batch=batch)
|
||||
if self.activation_layer:
|
||||
self.activation_layer.forward(batch, training=True)
|
||||
if self.debug_output:
|
||||
print_ln('dense X %s', self.X.reveal_nested())
|
||||
print_ln('dense W %s', self.W.reveal_nested())
|
||||
print_ln('dense b %s', self.b.reveal_nested())
|
||||
print_ln('dense Y %s', self.Y.reveal_nested())
|
||||
if self.debug:
|
||||
limit = self.debug
|
||||
@for_range_opt(len(batch))
|
||||
def _(i):
|
||||
@for_range_opt(self.d_out)
|
||||
def _(j):
|
||||
to_check = self.Y[i][0][j].reveal()
|
||||
check = to_check > limit
|
||||
@if_(check)
|
||||
def _():
|
||||
print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
|
||||
print_ln('X %s', self.X[i].reveal_nested())
|
||||
print_ln('W %s',
|
||||
[self.W[k][j].reveal() for k in range(self.d_in)])
|
||||
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
N = len(batch)
|
||||
d = self.d
|
||||
d_out = self.d_out
|
||||
X = self.X
|
||||
Y = self.Y
|
||||
W = self.W
|
||||
b = self.b
|
||||
nabla_X = self.nabla_X
|
||||
nabla_Y = self.nabla_Y
|
||||
nabla_W = self.nabla_W
|
||||
nabla_b = self.nabla_b
|
||||
|
||||
if self.activation_layer:
|
||||
self.activation_layer.backward(batch)
|
||||
f_schur_Y = self.activation_layer.nabla_X
|
||||
else:
|
||||
f_schur_Y = nabla_Y
|
||||
|
||||
if compute_nabla_X:
|
||||
nabla_X.alloc()
|
||||
|
||||
# flattened matrix version
|
||||
result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address)
|
||||
# Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices
|
||||
@multithread(self.n_threads, N * self.d)
|
||||
def _(base, size):
|
||||
X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
|
||||
offset = regint.inc(size, base=base)
|
||||
|
||||
result_matrix.assign_part_vector(
|
||||
X_sub.direct_mul_trans(self.W, indices=(
|
||||
offset, regint.inc(self.d_out),
|
||||
regint.inc(self.d_out), regint.inc(self.d_in))), base)
|
||||
|
||||
if self.print_random_update:
|
||||
print_ln('backward %s', self)
|
||||
index = regint.get_random(64) % self.nabla_X.total_size()
|
||||
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
|
||||
index, self.nabla_X.to_array()[index].reveal())
|
||||
|
||||
progress('nabla X')
|
||||
|
||||
self.backward_params(f_schur_Y, batch=batch)
|
||||
|
||||
|
||||
class QuantizedDense(DenseBase):
|
||||
@@ -3108,7 +2940,7 @@ class BertIntermediate(BertBase):
|
||||
input_shape = [n_examples, seq_len, hidden_size]
|
||||
output_shape = [n_examples, seq_len, intermediate_size]
|
||||
super(BertIntermediate, self).__init__(input_shape, output_shape)
|
||||
self.dense = FlexDense(n_examples, hidden_size, intermediate_size, seq_len)
|
||||
self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len)
|
||||
self.activation = Gelu([n_examples, seq_len, intermediate_size])
|
||||
|
||||
|
||||
@@ -3150,7 +2982,7 @@ class BertOutput(BertBase):
|
||||
self.input_shape = input_shape
|
||||
print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx)
|
||||
super(BertOutput, self).__init__(input_shape, output_shape)
|
||||
self.dense = FlexDense(n_examples, intermediate_size, hidden_size, seq_len)
|
||||
self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len)
|
||||
self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx)
|
||||
self.dropout = FlexDropout([n_examples, seq_len, hidden_size], alpha=dropout)
|
||||
|
||||
@@ -3241,9 +3073,9 @@ class MultiHeadAttention(BertBase):
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.wq = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len)
|
||||
self.wk = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len)
|
||||
self.wv = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len)
|
||||
self.wq = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len)
|
||||
self.wk = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len)
|
||||
self.wv = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len)
|
||||
self.dropout = FlexDropout([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], alpha=dropout) # I think? # TODO: DROPOUT?
|
||||
|
||||
self.output = BertOutput(internal_shape, hidden_size, hidden_size, seq_len, dropout, layernorm_eps, rsqrt_approx)
|
||||
@@ -4621,7 +4453,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
|
||||
if name == 'Linear':
|
||||
assert mul(input_shape[1:]) == item.in_features
|
||||
assert item.bias is not None
|
||||
layers.append(FlexDense(input_shape[0], item.in_features,
|
||||
layers.append(Dense(input_shape[0], item.in_features,
|
||||
item.out_features))
|
||||
if input_via is not None:
|
||||
shapes = [x.shape for x in (layers[-1].W, layers[-1].b)]
|
||||
|
||||
Reference in New Issue
Block a user