fix flexdense

This commit is contained in:
Hidde L
2025-10-21 15:45:07 +02:00
parent 2162b79b73
commit 9ac22e3650

View File

@@ -978,31 +978,28 @@ class FlexDense(Dense):
# flattened_array version
result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address)
max_size = get_program().budget // self.d_out
max_size = get_program().budget
# for now we assume that batch size is total size
# assert N == self.N
# batch contains the indices of the batches in self.N, we want to expand to have self.d too
# batch_with_d =
# X is stored at full dataset indices, batch specifies which samples to use
X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
# Precompute batch_d_indices for all N*d elements
# For each sample in batch, expand to d consecutive indices
batch_d_indices = regint.Array(N * self.d)
@for_range(N)
def _(i):
actual_sample = batch[i]
batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d)
# @for_range(self.d)
# def _(d_idx):
# batch_d_indices[i * self.d + d_idx] = actual_sample * self.d + d_idx
# we are going to assume the batch is continuous
batch_0 = MemValue(batch[0])
@multithread(self.n_threads, N * self.d, max_size)
def _(base, size):
batch_offset = batch_0 * self.d
X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
offset = regint.inc(size, base=base + batch_offset)
# array_offset = regint.Array(size)
# array_offset.assign_all(batch_offset)
# print_ln("array offset %s", array_offset.reveal())
# offset[:] += array_offset[:]
# print_ln("total offset %s %s", batch_offset, offset)
result_matrix.assign_part_vector(
X_sub.direct_mul(self.W, indices=(
offset, regint.inc(self.d_in),
batch_d_indices.get_vector(base, size), regint.inc(self.d_in),
regint.inc(self.d_in), regint.inc(self.d_out))), base)
# print_ln("result matrix done")
if self.input_bias:
if self.d_out == 1:
@@ -1044,14 +1041,12 @@ class FlexDense(Dense):
nabla_X.alloc()
# flattened matrix version
max_size = get_program().budget // self.d_in
result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address)
batch_0 = MemValue(batch[0])
@multithread(self.n_threads, N * self.d, max_size)
# Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices
@multithread(self.n_threads, N * self.d)
def _(base, size):
batch_offset = batch_0 * self.d
X_sub = sfix.Matrix(self.N * self.d, self.d_out, address=f_schur_Y.address)
offset = regint.inc(size, base=base + batch_offset)
X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
offset = regint.inc(size, base=base)
result_matrix.assign_part_vector(
X_sub.direct_mul_trans(self.W, indices=(
@@ -1074,14 +1069,28 @@ class FlexDense(Dense):
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
# tmp.assign_all(0)
# A (f_schur_Y/nabla_Y) is stored at sequential batch indices [0, 1, ..., N-1]
A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address)
# B (X) is stored at the full dataset size, not just the batch
B = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address)
@multithread(self.n_threads, self.d_in)
def _(base, size):
# For A: use sequential indices [0, 1, ..., N*d-1]
# For B: use actual batch indices expanded for d dimension
batch_d_indices = regint.Array(N * self.d)
@for_range(N)
def _(i):
# batch[i] gives the actual sample index in the full dataset
# For FlexDense with d>1, we need to map to flattened indices
actual_sample_idx = batch[i]
@for_range(self.d)
def _(d_idx):
batch_d_indices[i * self.d + d_idx] = actual_sample_idx * self.d + d_idx
mp = B.direct_trans_mul(A, reduce=False,
indices=(regint.inc(size, base),
regint.inc(N * self.d), # Not sure
batch_d_indices.get_vector(),
regint.inc(N * self.d),
regint.inc(self.d_out)))
@@ -4612,7 +4621,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
if name == 'Linear':
assert mul(input_shape[1:]) == item.in_features
assert item.bias is not None
layers.append(Dense(input_shape[0], item.in_features,
layers.append(FlexDense(input_shape[0], item.in_features,
item.out_features))
if input_via is not None:
shapes = [x.shape for x in (layers[-1].W, layers[-1].b)]