mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
fix flexdense
This commit is contained in:
@@ -978,31 +978,28 @@ class FlexDense(Dense):
|
||||
|
||||
# flattened_array version
|
||||
result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address)
|
||||
max_size = get_program().budget // self.d_out
|
||||
max_size = get_program().budget
|
||||
|
||||
# for now we assume that batch size is total size
|
||||
# assert N == self.N
|
||||
# batch contains the indices of the batches in self.N, we want to expand to have self.d too
|
||||
# batch_with_d =
|
||||
# X is stored at full dataset indices, batch specifies which samples to use
|
||||
X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
|
||||
|
||||
# Precompute batch_d_indices for all N*d elements
|
||||
# For each sample in batch, expand to d consecutive indices
|
||||
batch_d_indices = regint.Array(N * self.d)
|
||||
@for_range(N)
|
||||
def _(i):
|
||||
actual_sample = batch[i]
|
||||
batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d)
|
||||
# @for_range(self.d)
|
||||
# def _(d_idx):
|
||||
# batch_d_indices[i * self.d + d_idx] = actual_sample * self.d + d_idx
|
||||
|
||||
# we are going to assume the batch is continuous
|
||||
batch_0 = MemValue(batch[0])
|
||||
@multithread(self.n_threads, N * self.d, max_size)
|
||||
def _(base, size):
|
||||
batch_offset = batch_0 * self.d
|
||||
X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
|
||||
offset = regint.inc(size, base=base + batch_offset)
|
||||
# array_offset = regint.Array(size)
|
||||
# array_offset.assign_all(batch_offset)
|
||||
# print_ln("array offset %s", array_offset.reveal())
|
||||
# offset[:] += array_offset[:]
|
||||
# print_ln("total offset %s %s", batch_offset, offset)
|
||||
|
||||
result_matrix.assign_part_vector(
|
||||
X_sub.direct_mul(self.W, indices=(
|
||||
offset, regint.inc(self.d_in),
|
||||
batch_d_indices.get_vector(base, size), regint.inc(self.d_in),
|
||||
regint.inc(self.d_in), regint.inc(self.d_out))), base)
|
||||
# print_ln("result matrix done")
|
||||
|
||||
if self.input_bias:
|
||||
if self.d_out == 1:
|
||||
@@ -1044,14 +1041,12 @@ class FlexDense(Dense):
|
||||
nabla_X.alloc()
|
||||
|
||||
# flattened matrix version
|
||||
max_size = get_program().budget // self.d_in
|
||||
result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address)
|
||||
batch_0 = MemValue(batch[0])
|
||||
@multithread(self.n_threads, N * self.d, max_size)
|
||||
# Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices
|
||||
@multithread(self.n_threads, N * self.d)
|
||||
def _(base, size):
|
||||
batch_offset = batch_0 * self.d
|
||||
X_sub = sfix.Matrix(self.N * self.d, self.d_out, address=f_schur_Y.address)
|
||||
offset = regint.inc(size, base=base + batch_offset)
|
||||
X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
|
||||
offset = regint.inc(size, base=base)
|
||||
|
||||
result_matrix.assign_part_vector(
|
||||
X_sub.direct_mul_trans(self.W, indices=(
|
||||
@@ -1074,14 +1069,28 @@ class FlexDense(Dense):
|
||||
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
||||
# tmp.assign_all(0)
|
||||
|
||||
# A (f_schur_Y/nabla_Y) is stored at sequential batch indices [0, 1, ..., N-1]
|
||||
A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
|
||||
# B (X) is stored at the full dataset size, not just the batch
|
||||
B = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
|
||||
|
||||
@multithread(self.n_threads, self.d_in)
|
||||
def _(base, size):
|
||||
# For A: use sequential indices [0, 1, ..., N*d-1]
|
||||
# For B: use actual batch indices expanded for d dimension
|
||||
batch_d_indices = regint.Array(N * self.d)
|
||||
@for_range(N)
|
||||
def _(i):
|
||||
# batch[i] gives the actual sample index in the full dataset
|
||||
# For FlexDense with d>1, we need to map to flattened indices
|
||||
actual_sample_idx = batch[i]
|
||||
@for_range(self.d)
|
||||
def _(d_idx):
|
||||
batch_d_indices[i * self.d + d_idx] = actual_sample_idx * self.d + d_idx
|
||||
|
||||
mp = B.direct_trans_mul(A, reduce=False,
|
||||
indices=(regint.inc(size, base),
|
||||
regint.inc(N * self.d), # Not sure
|
||||
batch_d_indices.get_vector(),
|
||||
regint.inc(N * self.d),
|
||||
regint.inc(self.d_out)))
|
||||
|
||||
@@ -4612,7 +4621,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
|
||||
if name == 'Linear':
|
||||
assert mul(input_shape[1:]) == item.in_features
|
||||
assert item.bias is not None
|
||||
layers.append(Dense(input_shape[0], item.in_features,
|
||||
layers.append(FlexDense(input_shape[0], item.in_features,
|
||||
item.out_features))
|
||||
if input_via is not None:
|
||||
shapes = [x.shape for x in (layers[-1].W, layers[-1].b)]
|
||||
|
||||
Reference in New Issue
Block a user