Merge FlexDense with Dense

This commit is contained in:
Hidde L
2025-10-21 15:55:20 +02:00
parent 9ac22e3650
commit ba2b440712

View File

@@ -725,350 +725,6 @@ class DenseBase(Layer):
N = len(batch)
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address)
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
@multithread(self.n_threads, self.d_in)
def _(base, size):
mp = B.direct_trans_mul(A, reduce=False,
indices=(regint.inc(size, base),
batch.get_vector(),
regint.inc(N),
regint.inc(self.d_out)))
tmp.assign_part_vector(mp, base)
progress('nabla W (matmul)')
@multithread(self.n_threads, self.d_in * self.d_out,
max_size=get_program().budget)
def _(base, size):
self.nabla_W.assign_vector(
tmp.get_vector(base, size).reduce_after_mul(), base=base)
if self.print_random_update:
print_ln('backward %s', self)
i = regint.get_random(64) % self.d_in
j = regint.get_random(64) % self.d_out
print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s',
str(self.nabla_W), i, j, tmp[i][j].v.reveal(),
self.nabla_W[i][j].reveal(),
A.get_column(j).reveal(),
B.get_column_by_row_indices(
batch.get_vector(), i).reveal())
print_ln('batch=%s B=%s', batch,
[self.X[bi][0][i].reveal() for bi in batch])
progress('nabla W')
self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector()
for k in range(N))
for j in range(self.d)))
progress('nabla b')
if self.debug_output:
print_ln('dense nabla Y %s', self.nabla_Y.reveal_nested())
print_ln('dense W %s', self.W.reveal_nested())
print_ln('dense nabla X %s', self.nabla_X.reveal_nested())
if self.debug:
limit = N * self.debug
@for_range_opt(self.d_in)
def _(i):
@for_range_opt(self.d_out)
def _(j):
to_check = self.nabla_W[i][j].reveal()
check = sum(to_check > limit) + sum(to_check < -limit)
@if_(check)
def _():
print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check)
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
for k in range(N)])
print_ln('X %s', [self.X[k][0][i].reveal()
for k in range(N)])
@for_range_opt(self.d_out)
def _(j):
to_check = self.nabla_b[j].reveal()
check = sum(to_check > limit) + sum(to_check < -limit)
@if_(check)
def _():
print_ln('nabla b %s %s: %s', j, len(self.b), to_check)
print_ln('Y %s', [f_schur_Y[k][0][j].reveal()
for k in range(N)])
@for_range_opt(len(batch))
def _(i):
to_check = self.nabla_X[i].get_vector().reveal()
check = sum(to_check > limit) + sum(to_check < -limit)
@if_(check)
def _():
print_ln('X %s %s', i, self.X[i].reveal_nested())
print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested())
class Dense(DenseBase):
""" Fixed-point dense (matrix multiplication) layer.
:param N: number of examples
:param d_in: input dimension
:param d_out: output dimension
"""
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
if activation == 'id':
self.activation_layer = None
elif activation == 'relu':
self.activation_layer = Relu([N, d, d_out])
elif activation == 'square':
self.activation_layer = Square([N, d, d_out])
else:
raise CompilerError('activation not supported: %s', activation)
self.N = N
self.d_in = d_in
self.d_out = d_out
self.d = d
self.activation = activation
self.X = Tensor([N, d, d_in], sfix)
self.Y = Tensor([N, d, d_out], sfix)
self.W = Tensor([d_in, d_out], sfix)
self.b = sfix.Array(d_out)
back_N = min(N, self.back_batch_size)
self.nabla_Y = Tensor([back_N, d, d_out], sfix)
self.nabla_X = Tensor([back_N, d, d_in], sfix)
self.nabla_W = Tensor([d_in, d_out], sfix)
self.nabla_b = sfix.Array(d_out)
self.debug = debug
l = self.activation_layer
if l:
self.f_input = l.X
l.Y = self.Y
l.nabla_Y = self.nabla_Y
else:
self.f_input = self.Y
def __repr__(self):
return '%s(%s, %s, %s, activation=%s)' % \
(type(self).__name__, self.N, self.d_in,
self.d_out, repr(self.activation))
def reset(self):
d_in = self.d_in
d_out = self.d_out
r = math.sqrt(6.0 / (d_in + d_out))
print('Initializing dense weights in [%f,%f]' % (-r, r))
self.W.randomize(-r, r, n_threads=self.n_threads)
self.b.assign_all(0)
def input_from(self, player, **kwargs):
self.W.input_from(player, **kwargs)
if self.input_bias:
self.b.input_from(player, **kwargs)
def compute_f_input(self, batch):
N = len(batch)
assert self.d == 1
if self.input_bias:
prod = MultiArray([N, self.d, self.d_out], sfix)
else:
prod = self.f_input
max_size = get_program().budget
@multithread(self.n_threads, N, max_size)
def _(base, size):
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
prod.assign_part_vector(
X_sub.direct_mul(self.W, indices=(
batch.get_vector(base, size), regint.inc(self.d_in),
regint.inc(self.d_in), regint.inc(self.d_out))), base)
if self.input_bias:
if self.d_out == 1:
@multithread(self.n_threads, N)
def _(base, size):
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
self.f_input.assign_vector(v, base)
else:
@for_range_multithread(self.n_threads, 100, N)
def _(i):
v = prod[i].get_vector() + self.b.get_vector()
self.f_input[i].assign_vector(v)
progress('f input')
def _forward(self, batch=None):
if batch is None:
batch = regint.Array(self.N)
batch.assign(regint.inc(self.N))
self.compute_f_input(batch=batch)
if self.activation_layer:
self.activation_layer.forward(batch, training=True)
if self.debug_output:
print_ln('dense X %s', self.X.reveal_nested())
print_ln('dense W %s', self.W.reveal_nested())
print_ln('dense b %s', self.b.reveal_nested())
print_ln('dense Y %s', self.Y.reveal_nested())
if self.debug:
limit = self.debug
@for_range_opt(len(batch))
def _(i):
@for_range_opt(self.d_out)
def _(j):
to_check = self.Y[i][0][j].reveal()
check = to_check > limit
@if_(check)
def _():
print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
print_ln('X %s', self.X[i].reveal_nested())
print_ln('W %s',
[self.W[k][j].reveal() for k in range(self.d_in)])
def backward(self, compute_nabla_X=True, batch=None):
N = len(batch)
d = self.d
d_out = self.d_out
X = self.X
Y = self.Y
W = self.W
b = self.b
nabla_X = self.nabla_X
nabla_Y = self.nabla_Y
nabla_W = self.nabla_W
nabla_b = self.nabla_b
if self.activation_layer:
self.activation_layer.backward(batch)
f_schur_Y = self.activation_layer.nabla_X
else:
f_schur_Y = nabla_Y
if compute_nabla_X:
nabla_X.alloc()
@multithread(self.n_threads, N)
def _(base, size):
B = sfix.Matrix(N, d_out, address=f_schur_Y.address)
nabla_X.assign_part_vector(
B.direct_mul_trans(W, indices=(regint.inc(size, base),
regint.inc(self.d_out),
regint.inc(self.d_out),
regint.inc(self.d_in))),
base)
if self.print_random_update:
print_ln('backward %s', self)
index = regint.get_random(64) % self.nabla_X.total_size()
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
index, self.nabla_X.to_array()[index].reveal())
progress('nabla X')
self.backward_params(f_schur_Y, batch=batch)
class FlexDense(Dense):
""" Fixed-point dense (matrix multiplication) layer with flexible number of dimensions.
Behaves like torch's nn.Linear which loops over the additional dimensions.
:param N: number of examples
:param d_in: input dimension
:param d_out: output dimension
"""
def compute_f_input(self, batch):
N = len(batch)
prod = MultiArray([N, self.d, self.d_out], sfix)
# flattened_array version
result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address)
max_size = get_program().budget
# X is stored at full dataset indices, batch specifies which samples to use
X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
# Precompute batch_d_indices for all N*d elements
# For each sample in batch, expand to d consecutive indices
batch_d_indices = regint.Array(N * self.d)
@for_range(N)
def _(i):
actual_sample = batch[i]
batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d)
# @for_range(self.d)
# def _(d_idx):
# batch_d_indices[i * self.d + d_idx] = actual_sample * self.d + d_idx
@multithread(self.n_threads, N * self.d, max_size)
def _(base, size):
result_matrix.assign_part_vector(
X_sub.direct_mul(self.W, indices=(
batch_d_indices.get_vector(base, size), regint.inc(self.d_in),
regint.inc(self.d_in), regint.inc(self.d_out))), base)
if self.input_bias:
if self.d_out == 1:
@multithread(self.n_threads, N)
def _(base, size):
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
self.f_input.assign_vector(v, base)
else:
@for_range_multithread(self.n_threads, 100, [N, self.d])
def _(i, j):
v = prod[i][j].get_vector() + self.b.get_vector()
# print_ln("running bias %s %s %s", i, j, v.reveal())
self.f_input[i][j].assign_vector(v)
# print_ln("FlexDense f_inpu full %s", self.f_input.reveal())
progress('f input')
def backward(self, compute_nabla_X=True, batch=None):
N = len(batch)
d = self.d
d_out = self.d_out
X = self.X
Y = self.Y
W = self.W
b = self.b
nabla_X = self.nabla_X
nabla_Y = self.nabla_Y
nabla_W = self.nabla_W
nabla_b = self.nabla_b
if self.activation_layer:
self.activation_layer.backward(batch)
f_schur_Y = self.activation_layer.nabla_X
else:
f_schur_Y = nabla_Y
if compute_nabla_X:
nabla_X.alloc()
# flattened matrix version
result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address)
# Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices
@multithread(self.n_threads, N * self.d)
def _(base, size):
X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
offset = regint.inc(size, base=base)
result_matrix.assign_part_vector(
X_sub.direct_mul_trans(self.W, indices=(
offset, regint.inc(self.d_out),
regint.inc(self.d_out), regint.inc(self.d_in))), base)
if self.print_random_update:
print_ln('backward %s', self)
index = regint.get_random(64) % self.nabla_X.total_size()
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
index, self.nabla_X.to_array()[index].reveal())
progress('nabla X')
self.backward_params(f_schur_Y, batch=batch)
def backward_params(self, f_schur_Y, batch):
print("backward params flexdense")
N = len(batch)
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
# tmp.assign_all(0)
# A (f_schur_Y/nabla_Y) is stored at sequential batch indices [0, 1, ..., N-1]
A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
# B (X) is stored at the full dataset size, not just the batch
@@ -1082,7 +738,7 @@ class FlexDense(Dense):
@for_range(N)
def _(i):
# batch[i] gives the actual sample index in the full dataset
# For FlexDense with d>1, we need to map to flattened indices
# For Dense with d>1, we need to map to flattened indices
actual_sample_idx = batch[i]
@for_range(self.d)
def _(d_idx):
@@ -1162,6 +818,182 @@ class FlexDense(Dense):
print_ln('X %s %s', i, self.X[i].reveal_nested())
print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested())
class Dense(DenseBase):
""" Fixed-point dense (matrix multiplication) layer.
Supports inputs of size [N, d, d_in] to map to [N, d, d_out]. If d > 1, the layer
behaves like torch's nn.Linear which loops over the additional d
:param N: number of examples
:param d_in: input dimension
:param d_out: output dimension
:param d: (optional) extra dimension
"""
def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False):
if activation == 'id':
self.activation_layer = None
elif activation == 'relu':
self.activation_layer = Relu([N, d, d_out])
elif activation == 'square':
self.activation_layer = Square([N, d, d_out])
else:
raise CompilerError('activation not supported: %s', activation)
self.N = N
self.d_in = d_in
self.d_out = d_out
self.d = d
self.activation = activation
self.X = Tensor([N, d, d_in], sfix)
self.Y = Tensor([N, d, d_out], sfix)
self.W = Tensor([d_in, d_out], sfix)
self.b = sfix.Array(d_out)
back_N = min(N, self.back_batch_size)
self.nabla_Y = Tensor([back_N, d, d_out], sfix)
self.nabla_X = Tensor([back_N, d, d_in], sfix)
self.nabla_W = Tensor([d_in, d_out], sfix)
self.nabla_b = sfix.Array(d_out)
self.debug = debug
l = self.activation_layer
if l:
self.f_input = l.X
l.Y = self.Y
l.nabla_Y = self.nabla_Y
else:
self.f_input = self.Y
def __repr__(self):
return '%s(%s, %s, %s, activation=%s)' % \
(type(self).__name__, self.N, self.d_in,
self.d_out, repr(self.activation))
def reset(self):
d_in = self.d_in
d_out = self.d_out
r = math.sqrt(6.0 / (d_in + d_out))
print('Initializing dense weights in [%f,%f]' % (-r, r))
self.W.randomize(-r, r, n_threads=self.n_threads)
self.b.assign_all(0)
def input_from(self, player, **kwargs):
self.W.input_from(player, **kwargs)
if self.input_bias:
self.b.input_from(player, **kwargs)
def compute_f_input(self, batch):
N = len(batch)
prod = MultiArray([N, self.d, self.d_out], sfix)
# flattened_array version
result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address)
max_size = get_program().budget
# X is stored at full dataset indices, batch specifies which samples to use
X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
# Precompute batch_d_indices for all N*d elements
# For each sample in batch, expand to d consecutive indices
batch_d_indices = regint.Array(N * self.d)
@for_range(N)
def _(i):
actual_sample = batch[i]
batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d)
@multithread(self.n_threads, N * self.d, max_size)
def _(base, size):
result_matrix.assign_part_vector(
X_sub.direct_mul(self.W, indices=(
batch_d_indices.get_vector(base, size), regint.inc(self.d_in),
regint.inc(self.d_in), regint.inc(self.d_out))), base)
if self.input_bias:
if self.d_out == 1:
@multithread(self.n_threads, N)
def _(base, size):
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
self.f_input.assign_vector(v, base)
else:
@for_range_multithread(self.n_threads, 100, [N, self.d])
def _(i, j):
v = prod[i][j].get_vector() + self.b.get_vector()
self.f_input[i][j].assign_vector(v)
progress('f input')
def _forward(self, batch=None):
if batch is None:
batch = regint.Array(self.N)
batch.assign(regint.inc(self.N))
self.compute_f_input(batch=batch)
if self.activation_layer:
self.activation_layer.forward(batch, training=True)
if self.debug_output:
print_ln('dense X %s', self.X.reveal_nested())
print_ln('dense W %s', self.W.reveal_nested())
print_ln('dense b %s', self.b.reveal_nested())
print_ln('dense Y %s', self.Y.reveal_nested())
if self.debug:
limit = self.debug
@for_range_opt(len(batch))
def _(i):
@for_range_opt(self.d_out)
def _(j):
to_check = self.Y[i][0][j].reveal()
check = to_check > limit
@if_(check)
def _():
print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check)
print_ln('X %s', self.X[i].reveal_nested())
print_ln('W %s',
[self.W[k][j].reveal() for k in range(self.d_in)])
def backward(self, compute_nabla_X=True, batch=None):
N = len(batch)
d = self.d
d_out = self.d_out
X = self.X
Y = self.Y
W = self.W
b = self.b
nabla_X = self.nabla_X
nabla_Y = self.nabla_Y
nabla_W = self.nabla_W
nabla_b = self.nabla_b
if self.activation_layer:
self.activation_layer.backward(batch)
f_schur_Y = self.activation_layer.nabla_X
else:
f_schur_Y = nabla_Y
if compute_nabla_X:
nabla_X.alloc()
# flattened matrix version
result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address)
# Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices
@multithread(self.n_threads, N * self.d)
def _(base, size):
X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
offset = regint.inc(size, base=base)
result_matrix.assign_part_vector(
X_sub.direct_mul_trans(self.W, indices=(
offset, regint.inc(self.d_out),
regint.inc(self.d_out), regint.inc(self.d_in))), base)
if self.print_random_update:
print_ln('backward %s', self)
index = regint.get_random(64) % self.nabla_X.total_size()
print_ln('%s nabla_X at %s: %s', str(self.nabla_X),
index, self.nabla_X.to_array()[index].reveal())
progress('nabla X')
self.backward_params(f_schur_Y, batch=batch)
class QuantizedDense(DenseBase):
@@ -3108,7 +2940,7 @@ class BertIntermediate(BertBase):
input_shape = [n_examples, seq_len, hidden_size]
output_shape = [n_examples, seq_len, intermediate_size]
super(BertIntermediate, self).__init__(input_shape, output_shape)
self.dense = FlexDense(n_examples, hidden_size, intermediate_size, seq_len)
self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len)
self.activation = Gelu([n_examples, seq_len, intermediate_size])
@@ -3150,7 +2982,7 @@ class BertOutput(BertBase):
self.input_shape = input_shape
print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx)
super(BertOutput, self).__init__(input_shape, output_shape)
self.dense = FlexDense(n_examples, intermediate_size, hidden_size, seq_len)
self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len)
self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx)
self.dropout = FlexDropout([n_examples, seq_len, hidden_size], alpha=dropout)
@@ -3241,9 +3073,9 @@ class MultiHeadAttention(BertBase):
self.seq_len = seq_len
self.hidden_size = hidden_size
self.wq = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len)
self.wk = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len)
self.wv = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len)
self.wq = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len)
self.wk = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len)
self.wv = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len)
self.dropout = FlexDropout([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], alpha=dropout) # I think? # TODO: DROPOUT?
self.output = BertOutput(internal_shape, hidden_size, hidden_size, seq_len, dropout, layernorm_eps, rsqrt_approx)
@@ -4621,7 +4453,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
if name == 'Linear':
assert mul(input_shape[1:]) == item.in_features
assert item.bias is not None
layers.append(FlexDense(input_shape[0], item.in_features,
layers.append(Dense(input_shape[0], item.in_features,
item.out_features))
if input_via is not None:
shapes = [x.shape for x in (layers[-1].W, layers[-1].b)]