From f8139db805eb88bc14685f563d118271cba8b95d Mon Sep 17 00:00:00 2001 From: Hidde L Date: Tue, 14 Oct 2025 19:58:44 +0200 Subject: [PATCH 1/8] Add BERT layers and test script --- Compiler/ml.py | 1457 +++++++++++++++++++++++++++- Programs/Source/bert_inference.mpc | 389 ++++++++ 2 files changed, 1801 insertions(+), 45 deletions(-) create mode 100644 Programs/Source/bert_inference.mpc diff --git a/Compiler/ml.py b/Compiler/ml.py index 2c87a4c3..808721ee 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -290,7 +290,7 @@ class Layer: return type(self).__name__ + str(self._Y.shape) def __repr__(self): - return '%s(%s)' % (type(self).__name__, self._Y.shape) + return '%s(%s)' % (type(self).__name__, self.Y.shape) class NoVariableLayer(Layer): input_from = lambda *args, **kwargs: None @@ -467,7 +467,7 @@ class LinearOutput(OutputBase): class MultiOutputBase(NoVariableLayer): def __init__(self, N, d_out, approx=False, debug=False): self.X = sfix.Matrix(N, d_out) - self.Y = sint.Matrix(N, d_out) + self.Y = sfix.Matrix(N, d_out) self.nabla_X = sfix.Matrix(N, d_out) self.l = MemValue(sfix(-1)) self.losses = sfix.Array(N) @@ -962,6 +962,199 @@ class Dense(DenseBase): self.backward_params(f_schur_Y, batch=batch) + +class FlexDense(Dense): + """ Fixed-point dense (matrix multiplication) layer with flexible number of dimensions. + Behaves like torch's nn.Linear which loops over the additional dimensions. + + :param N: number of examples + :param d_in: input dimension + :param d_out: output dimension + """ + + def compute_f_input(self, batch): + N = len(batch) + prod = MultiArray([N, self.d, self.d_out], sfix) + + # flattened_array version + result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address) + max_size = get_program().budget // self.d_out + + # for now we assume that batch size is total size + # assert N == self.N + # batch contains the indices of the batches in self.N, we want to expand to have self.d too + # batch_with_d = + + # we are going to assume the batch is continuous + batch_0 = MemValue(batch[0]) + @multithread(self.n_threads, N * self.d, max_size) + def _(base, size): + batch_offset = batch_0 * self.d + X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) + offset = regint.inc(size, base=base + batch_offset) + # array_offset = regint.Array(size) + # array_offset.assign_all(batch_offset) + # print_ln("array offset %s", array_offset.reveal()) + # offset[:] += array_offset[:] + # print_ln("total offset %s %s", batch_offset, offset) + + result_matrix.assign_part_vector( + X_sub.direct_mul(self.W, indices=( + offset, regint.inc(self.d_in), + regint.inc(self.d_in), regint.inc(self.d_out))), base) + # print_ln("result matrix done") + + if self.input_bias: + if self.d_out == 1: + @multithread(self.n_threads, N) + def _(base, size): + v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) + self.f_input.assign_vector(v, base) + else: + @for_range_multithread(self.n_threads, 100, [N, self.d]) + def _(i, j): + v = prod[i][j].get_vector() + self.b.get_vector() + # print_ln("running bias %s %s %s", i, j, v.reveal()) + self.f_input[i][j].assign_vector(v) + + # print_ln("FlexDense f_inpu full %s", self.f_input.reveal()) + + progress('f input') + + def backward(self, compute_nabla_X=True, batch=None): + N = len(batch) + d = self.d + d_out = self.d_out + X = self.X + Y = self.Y + W = self.W + b = self.b + nabla_X = self.nabla_X + nabla_Y = self.nabla_Y + nabla_W = self.nabla_W + nabla_b = self.nabla_b + + if self.activation_layer: + self.activation_layer.backward(batch) + f_schur_Y = self.activation_layer.nabla_X + else: + f_schur_Y = nabla_Y + + if compute_nabla_X: + nabla_X.alloc() + + # flattened matrix version + max_size = get_program().budget // self.d_in + result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address) + batch_0 = MemValue(batch[0]) + @multithread(self.n_threads, N * self.d, max_size) + def _(base, size): + batch_offset = batch_0 * self.d + X_sub = sfix.Matrix(self.N * self.d, self.d_out, address=f_schur_Y.address) + offset = regint.inc(size, base=base + batch_offset) + + result_matrix.assign_part_vector( + X_sub.direct_mul_trans(self.W, indices=( + offset, regint.inc(self.d_out), + regint.inc(self.d_out), regint.inc(self.d_in))), base) + + if self.print_random_update: + print_ln('backward %s', self) + index = regint.get_random(64) % self.nabla_X.total_size() + print_ln('%s nabla_X at %s: %s', str(self.nabla_X), + index, self.nabla_X.to_array()[index].reveal()) + + progress('nabla X') + + self.backward_params(f_schur_Y, batch=batch) + + def backward_params(self, f_schur_Y, batch): + print("backward params flexdense") + N = len(batch) + tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) + # tmp.assign_all(0) + + A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) + B = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) + + @multithread(self.n_threads, self.d_in) + def _(base, size): + mp = B.direct_trans_mul(A, reduce=False, + indices=(regint.inc(size, base), + regint.inc(N * self.d), # Not sure + regint.inc(N * self.d), + regint.inc(self.d_out))) + + tmp.assign_part_vector(mp, base) + + progress('nabla W (matmul)') + + @multithread(self.n_threads, self.d_in * self.d_out, + max_size=get_program().budget) + def _(base, size): + self.nabla_W.assign_vector( + tmp.get_vector(base, size).reduce_after_mul(), base=base) + + if self.print_random_update: + print_ln('backward %s', self) + i = regint.get_random(64) % self.d_in + j = regint.get_random(64) % self.d_out + print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s', + str(self.nabla_W), i, j, tmp[i][j].v.reveal(), + self.nabla_W[i][j].reveal(), + A.get_column(j).reveal(), + B.get_column_by_row_indices( + batch.get_vector(), i).reveal()) + print_ln('batch=%s B=%s', batch, + [self.X[bi][0][i].reveal() for bi in batch]) + + progress('nabla W') + + self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector() + for k in range(N)) + for j in range(self.d))) + + progress('nabla b') + + if self.debug_output: + print_ln('dense nabla Y %s', self.nabla_Y.reveal_nested()) + print_ln('dense W %s', self.W.reveal_nested()) + print_ln('dense nabla X %s', self.nabla_X.reveal_nested()) + if self.debug: + limit = N * self.debug + @for_range_opt(self.d_in) + def _(i): + @for_range_opt(self.d_out) + def _(j): + to_check = self.nabla_W[i][j].reveal() + check = sum(to_check > limit) + sum(to_check < -limit) + @if_(check) + def _(): + print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check) + print_ln('Y %s', [f_schur_Y[k][0][j].reveal() + for k in range(N)]) + print_ln('X %s', [self.X[k][0][i].reveal() + for k in range(N)]) + @for_range_opt(self.d_out) + def _(j): + to_check = self.nabla_b[j].reveal() + check = sum(to_check > limit) + sum(to_check < -limit) + @if_(check) + def _(): + print_ln('nabla b %s %s: %s', j, len(self.b), to_check) + print_ln('Y %s', [f_schur_Y[k][0][j].reveal() + for k in range(N)]) + @for_range_opt(len(batch)) + def _(i): + to_check = self.nabla_X[i].get_vector().reveal() + check = sum(to_check > limit) + sum(to_check < -limit) + @if_(check) + def _(): + print_ln('X %s %s', i, self.X[i].reveal_nested()) + print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested()) + + + class QuantizedDense(DenseBase): def __init__(self, N, d_in, d_out): self.N = N @@ -1037,6 +1230,8 @@ class Dropout(NoVariableLayer): n_bits = -math.log(self.alpha, 2) assert n_bits == int(n_bits) n_bits = int(n_bits) + # self.B.assign_all(1) # TODO: temp disable for reproducibility + # self.alpha = 0.0 # TODO: temp disable for reproducibility @for_range_opt_multithread(self.n_threads, len(batch)) def _(i): size = self.d1 * self.d2 @@ -1065,6 +1260,61 @@ class Dropout(NoVariableLayer): print_ln('dropout nabla_Y %s', self.nabla_Y.reveal_nested()) print_ln('dropout nabla_X %s', self.nabla_X.reveal_nested()) +class FlexDropout(NoVariableLayer): + """ Dropout layer. + + :param N: number of examples + :param d1: total dimension + :param alpha: probability (power of two) + """ + def __init__(self, shape, alpha=0.5): + self.N = shape[0] + self.X = Tensor(shape, sfix) + self.Y = Tensor(shape, sfix) + self.nabla_Y = Tensor(shape, sfix) + self.nabla_X = Tensor(shape, sfix) + self.alpha = alpha + self.B = MultiArray(shape, sint) + + def __repr__(self): + return '%s(%s, alpha=%s)' % \ + (type(self).__name__, self.shape, self.alpha) + + def forward(self, batch, training=False): + if training: + n_bits = -math.log(self.alpha, 2) + assert n_bits == int(n_bits) + n_bits = int(n_bits) + self.B.assign_all(1) + self.alpha = 0.0 # TODO: temp disable for reproducibility + # @for_range_opt_multithread(self.n_threads, len(batch)) + # def _(i): + # size = reduce(operator.mul, self.shape[1:]) + # self.B[i].assign_vector(util.tree_reduce( + # util.or_op, (sint.get_random_bit(size=size) + # for i in range(n_bits)))) + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + self.Y[i].assign_vector(1 / (1 - self.alpha) * + self.X[batch[i]].get_vector() * self.B[i].get_vector()) + else: + @for_range(len(batch)) + def _(i): + self.Y[i] = self.X[batch[i]] + if self.debug_output: + print_ln('dropout X %s', self.X.reveal_nested()) + print_ln('dropout Y %s', self.Y.reveal_nested()) + + def backward(self, compute_nabla_X=True, batch=None): + if compute_nabla_X: + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + self.nabla_X[batch[i]].assign_vector( + self.nabla_Y[i].get_vector() * self.B[i].get_vector()) + if self.debug_output: + print_ln('dropout nabla_Y %s', self.nabla_Y.reveal_nested()) + print_ln('dropout nabla_X %s', self.nabla_X.reveal_nested()) + class ElementWiseLayer(NoVariableLayer): def __init__(self, shape, inputs=None): self.X = Tensor(shape, sfix) @@ -1148,6 +1398,149 @@ class Relu(ElementWiseLayer): def f_prime_part(self, base, size): return self.comparisons.get_vector(base, size) + +class Gelu(ElementWiseLayer): + """ Fixed-point GeLU layer. + Based on Dong et al., "PUMA: SECURE INFERENCE OF LLAMA-7B IN FIVE MINUTES" + + :param shape: input/output shape (tuple/list of int) + """ + prime_type = sfix + + def __init__(self, shape, inputs=None, approx=True): + super(Gelu, self).__init__(shape) + self.approx = approx + self.z0s = MultiArray(shape, sint) + self.z1s = MultiArray(shape, sint) + self.z2s = MultiArray(shape, sint) + + self.x2s = MultiArray(shape, sfix) + self.x3s = MultiArray(shape, sfix) + self.x4s = MultiArray(shape, sfix) + + self.poly_f_0_0 = -0.5054031199708174 + self.poly_f_0_1 = -0.42226581151983866 + self.poly_f_0_2 = -0.11807612951181953 + self.poly_f_0_3 = -0.011034134030615728 + + self.poly_f_1_0 = 0.008526321541038084 + self.poly_f_1_1 = 0.5 + self.poly_f_1_2 = 0.3603292692789629 + self.poly_f_1_4 = -0.037688200365904236 + self.poly_f_1_6 = 0.0018067462606141187 + + def f_part(self, base, size): + x = self.X.get_vector(base, size) + + if self.approx: + return self.compute_gelu_approx(x, base, size) + else: + return self.compute_gelu_sigmoid(x) + + def compute_gelu_approx(self, x, base, size): + # print_ln("GELU inputs %s", x.reveal()) + b0 = x < -4 + b1 = x < -1.95 + b2 = 3 < x + + z0 = b0 ^ b1 + z1 = b1 ^ b2 ^ 1 + z2 = b2 + + x1 = x + x2 = x ** 2 + x3 = x2 * x1 + x4 = x2 ** 2 + x6 = x3 ** 2 + + f_0 = self.poly_f_0_0 + self.poly_f_0_1 * x1 + self.poly_f_0_2 * x2 + self.poly_f_0_3 * x3 + f_1 = self.poly_f_1_0 + self.poly_f_1_1 * x1 + self.poly_f_1_2 * x2 + self.poly_f_1_4 * x4 + self.poly_f_1_6 * x6 + + self.z0s.assign_vector(z0, base) + self.z1s.assign_vector(z1, base) + self.z2s.assign_vector(z2, base) + + self.x2s.assign_vector(x2, base) + self.x3s.assign_vector(x3, base) + self.x4s.assign_vector(x6, base) + + return (z0 * f_0) + (z1 * f_1) + (z2 * x) + + def compute_gelu_sigmoid(self, x, base=0): + return x * sigmoid(0.071355 * (x ** 3) + 1.595769 * x) + + def tanh(self, x, base=0): + exp_2x = exp(2 * x) + + real_tanh = (exp_2x - 1) / (exp_2x + 1) + return real_tanh + + def f_prime_part(self, base, size): + if self.approx: + return self.f_prime_part_approx(base, size) + else: + return self.f_prime_part_sigmoid(base, size) + + def f_prime_part_approx(self, base, size): + # what we compute here is all derivatives + # we need to compute the derivative of the function at the point + poly_prime_f_0_0 = self.poly_f_0_1 + poly_prime_f_0_1 = self.poly_f_0_2 * 2 + poly_prime_f_0_2 = self.poly_f_0_3 * 3 + + poly_prime_f_1_0 = self.poly_f_1_1 + poly_prime_f_1_1 = self.poly_f_1_2 * 2 + poly_prime_f_1_3 = self.poly_f_1_4 * 4 + poly_prime_f_1_5 = self.poly_f_1_6 * 6 + + x1 = self.X.get_vector(base, size) + x2 = self.x2s.get_vector(base, size) + x3 = self.x3s.get_vector(base, size) + x4 = self.x4s.get_vector(base, size) + x5 = x4 * x1 + + f_0_prime = poly_prime_f_0_0 + poly_prime_f_0_1 * x1 + poly_prime_f_0_2 * x2 + f_1_prime = poly_prime_f_1_0 + poly_prime_f_1_1 * x1 + poly_prime_f_1_3 * x3 + poly_prime_f_1_5 * x5 + + z0 = self.z0s.get_vector(base, size) + z1 = self.z1s.get_vector(base, size) + z2 = self.z2s.get_vector(base, size) + + print_ln("gelu backward x %s res %s", x1.reveal()[:8], ((z0 * f_0_prime) + (z1 * f_1_prime) + z2).reveal()[:8]) + + return (z0 * f_0_prime) + (z1 * f_1_prime) + z2 + + def f_prime_part_sigmoid(self, base, size): + x = self.X.get_vector(base, size) + # return 0.5 * (1 + self.tanh(0.0356774 * x ** 3 + 0.797884 * x)) + print_ln("f sigmoid") + x3 = x ** 3 + exp_term = exp(1.59577 * x + 0.071355 * x3) + return (exp(3.19154* x + 0.14271 * x) + exp_term * (1 + 1.59577 * x + 0.214065 * x3)) / (1 + exp_term) ** 2 + + + +class Tanh(ElementWiseLayer): + + prime_type = sfix + + def __init__(self, shape, inputs=None): + super(Tanh, self).__init__(shape) + self.tanh_computations = MultiArray(shape, sfix) + + def f_part(self, base, size): + x = self.X.get_vector(base, size) + res = self.tanh(x) + self.tanh_computations.assign_vector(res, base) + return res + + def tanh(self, x): + return 2 * sigmoid(2 * x) - 1 + + def f_prime_part(self, base, size): + return 1 - self.tanh_computations.get_vector(base, size) ** 2 + + class Square(ElementWiseLayer): """ Fixed-point square layer. @@ -1353,6 +1746,8 @@ class Add(NoVariableLayer): self.Y = Tensor(shape, sfix) self.inputs = inputs + self.nabla_Y = Tensor(shape, sfix) + def _forward(self, batch=[0]): @multithread(self.n_threads, self.Y[0].total_size()) def _(base, size): @@ -1361,6 +1756,14 @@ class Add(NoVariableLayer): for inp in self.inputs) self.Y[bb].assign_vector(tmp, base) + def backward(self, compute_nabla_X=True, batch=None): + if compute_nabla_X: + for inp in self.inputs: + @multithread(self.n_threads, self.Y.total_size()) + def _(base, size): + inp.nabla_X.assign_vector(self.nabla_Y.get_vector(base, size), base) + + class FusedBatchNorm(Layer): """ Fixed-point fused batch normalization layer (inference only). @@ -1544,6 +1947,200 @@ class BatchNorm(Layer): for param in self.thetas() + (self.mu_hat, self.var_hat): param.reveal().binary_output() +class LayerNorm(Layer): # Changed class name + """ Fixed-point layer normalization layer. + For now assumes we only want to normalize over the last dimension + + :param shape: input/output shape (tuple/list of any number of int) + :param approx: use approximate square root + + """ + thetas = lambda self: (self.weights, self.bias) + nablas = lambda self: (self.nabla_weights, self.nabla_bias) + + def __init__(self, shape, approx=False, layernorm_eps=None, args=None): + if len(shape) == 2: + shape = [shape[0], 1, shape[1]] # Not sure why this extra dimension is added + tensors = (Tensor(shape, sfix) for i in range(4)) + self.X, self.Y, self.nabla_X, self.nabla_Y = tensors + self.epsilon = 2 ** (-sfix.f * 2 // 3 + 1) #if layernorm_eps is not None else sfix(layernorm_eps) + self.approx = approx + if approx: + print('Approximate square root inverse in layer normalization') + self.InvertSqrt = mpc_math.InvertSqrt + else: + print('Precise square root inverse in layer normalization') + self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x) + self.weights = sfix.Array(shape[-1]) + self.bias = sfix.Array(shape[-1]) + + batch_shape = [shape[0]] + list(self.X.sizes[1:-1]) + self.mu = sfix.Tensor(batch_shape) + self.var = sfix.Tensor(batch_shape) + self.nabla_weights = sfix.Array(shape[-1]) + self.nabla_bias = sfix.Array(shape[-1]) + + def __repr__(self): + return '%s(%s, approx=%s)' % \ + (type(self).__name__, self.X.sizes, self.approx) + + def reset(self): # Simplified reset method + self.bias.assign_all(0) + self.weights.assign_all(1) + + def _output(self, batch, mu, var): + batch_shape = [len(batch)] + list(self.X.sizes[1:-1]) + factor = sfix.Tensor(batch_shape) + + @multithread(self.n_threads, len(batch)) + def _(base, size): + factor.assign_part_vector( + self.InvertSqrt(var.get_part_vector(base, size) + self.epsilon), + base) + + @for_range_opt_multithread(self.n_threads, + batch_shape) + def _(*arg): + sel_X = self.X + sel_Y = self.Y + mu_sel = mu + fac_sel = factor + for ar in arg: + sel_X = sel_X[ar] + sel_Y = sel_Y[ar] + mu_sel = mu_sel[ar] + fac_sel = fac_sel[ar] + tmp = self.weights[:] * (sel_X[:] - mu_sel) * fac_sel # Removed self.mu reference + sel_Y[:] = self.bias[:] + tmp + + def forward(self, batch, training=False): + d = self.X.sizes[1] + d_in = self.X.sizes[2] + X_batch = MultiArray([len(batch), self.X.sizes[1], self.X.sizes[2]], sfix) + X_batch.assign_vector(self.X.get_slice_vector(batch)) + batch_shape = [len(batch)] + list(self.X.sizes[1:-1]) + + @for_range_opt_multithread(self.n_threads, batch_shape) + def _(*arg): + sel = X_batch + for ar in arg: + sel = sel[ar] + res = sum(sel[:]) + if len(arg) == 2: + self.mu[arg[0]][arg[1]] = res + else: + raise NotImplementedError("Only 3D tensors supported") + + @multithread(self.n_threads, self.mu.total_size()) + def _(base, size): + total_dim = self.X.sizes[-1] + self.mu.assign_vector( + self.mu.get_vector(base, size) / total_dim, base) + + @for_range_opt_multithread(self.n_threads, batch_shape) + def _(*arg): + sel = X_batch + mu_sel = self.mu + for ar in arg: + sel = sel[ar] + mu_sel = mu_sel[ar] + res = sum((sel[:] - mu_sel) ** 2) # Removed self.mu reference + if len(arg) == 2: + self.var[arg[0]][arg[1]] = res + else: + raise NotImplementedError("Only 3D tensors supported") + + @multithread(self.n_threads, self.var.total_size()) + def _(base, size): + total_dim = self.X.sizes[-1] + self.var.assign_vector( + self.var.get_vector(base, size) / (total_dim), + base) + self._output(batch, self.mu, self.var) # Simplified to always use current batch statistics + if self.print_random_update: + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.mu, self.var: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', + str(self.Y), i, j, k, self.X[i][j][k].reveal(), + self.Y[i][j][k].reveal()) + + def backward(self, batch, compute_nabla_X=True): + batch_shape = [len(batch)] + list(self.X.sizes[1:-1]) + factor = sfix.Tensor(batch_shape) + factor[:] = self.InvertSqrt(self.var[:] + self.epsilon) # TODO: Can we cache this? + + # Gradients wrt outputs stored in self.nabla_Y + norm = self.X.same_shape() # Gradient wrt scaled input (gamma * nabla_Y) + nYdf = self.X.same_shape() # Used for weights gradient computation + dNorm = self.X.same_shape() # only needed for nabla_X comp + # dNormNorm = self.X.same_shape() + + # print("layernorm back ", batch, nYdf.sizes, factor.sizes) + # print("batch shape", batch_shape, self.nabla_Y.sizes) + # print(self.nabla_bias.length, self.X.sizes) + + @for_range_opt_multithread(self.n_threads, batch_shape) + def _(*arg): + assert len(arg) == 2, "Only 3D tensors supported" + norm[arg[0]][arg[1]][:] = (self.X[batch[arg[0]]][arg[1]][:] - self.mu[arg[0]][arg[1]]) * factor[arg[0]][arg[1]] # norm_bti + nYdf[arg[0]][arg[1]][:] = self.nabla_Y[batch[arg[0]]][arg[1]][:] * norm[arg[0]][arg[1]][:] + + dNorm[arg[0]][arg[1]][:] = self.nabla_Y[batch[arg[0]]][arg[1]][:] * self.weights[:] # dnorm_i + # dNormNorm[arg[0]][arg[1]][:] = dNorm[arg[0]][arg[1]][:] * norm[arg[0]][arg[1]][:] + + # Sum over the appropriate axes for nabla_weights and nabla_bias + @map_sum_simple(self.n_threads, batch_shape, sfix, self.X.sizes[-1]) + def _(*arg): + sel = self.nabla_Y + for ar in arg: + sel = sel[ar] + return sel[:] + self.nabla_bias.assign(_()) + + @map_sum_simple(self.n_threads, batch_shape, sfix, self.X.sizes[-1]) + def _(*arg): + sel = nYdf + for ar in arg: + sel = sel[ar] + return sel[:] + self.nabla_weights.assign(_()) + + if compute_nabla_X: + # sum_dNorm = sfix.Array(self.X.sizes[-1]) + # sum_dNormNorm = sfix.Array(self.X.sizes[-1]) + # + # @map_sum_simple(self.n_threads, batch_shape, sfix, self.X.sizes[-1]) + # def _(*arg): + # sel = dNormNorm + # for ar in arg: + # sel = sel[ar] + # return sel[:] + # sum_dNormNorm.assign(_()) + + # print_ln("sum norm %s", sum_dNorm[:].reveal()) + # print_ln("sum norm norm %s", sum_dNormNorm[:].reveal()) + # print_ln("norm %s", norm[:].reveal()[:8]) # corr + # print_ln("dnorm %s", dNorm[:].reveal()[:8]) + + # Compute final gradient wrt input X + @for_range_opt_multithread(self.n_threads, batch_shape) + def _(*arg): + + if len(arg) == 2: + mean_dnorm = sum(dNorm[arg[0]][arg[1]]) / self.X.sizes[-1] + + self.nabla_X[arg[0]][arg[1]][:] = (dNorm[arg[0]][arg[1]][:] - mean_dnorm) * factor[arg[0]][arg[1]] + mean_dnormnorm = sum(self.nabla_X[arg[0]][arg[1]][:] * norm[arg[0]][arg[1]]) / self.X.sizes[-1] + self.nabla_X[arg[0]][arg[1]][:] -= norm[arg[0]][arg[1]] * mean_dnormnorm + + else: + raise NotImplementedError("Only 3D tensors supported") + # sel_nX[:] = sel_dnorm[:] - sum_dNorm - sel_norm[:] * sum_dNormNorm + # print_ln("layernorm result compute_x %s", sel_nX[:].reveal()) + class QuantBase(object): bias_before_reduction = True @@ -1615,8 +2212,16 @@ class ConvBase(BaseLayer): use_conv2ds = True temp_weights = None temp_inputs = None - thetas = lambda self: (self.weights, self.bias) - nablas = lambda self: (self.nabla_weights, self.nabla_bias) + def thetas(self): + if self.use_bias: + return self.weights, self.bias + else: + return tuple([self.weights]) + def nablas(self): + if self.use_bias: + return self.nabla_weights, self.nabla_bias + else: + return tuple([self.nabla_weights]) @classmethod def init_temp(cls, layers): @@ -1628,13 +2233,14 @@ class ConvBase(BaseLayer): def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, padding='SAME', tf_weight_format=False, inputs=None, - weight_type=None): + weight_type=None, bias=True): super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs) self.weight_shape = weight_shape self.bias_shape = bias_shape self.stride = stride self.tf_weight_format = tf_weight_format + self.use_bias = bias if padding == 'SAME': # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn self.padding = [] @@ -1665,10 +2271,14 @@ class ConvBase(BaseLayer): self.bias_squant = self.new_squant() self.weights = Tensor(weight_shape, self.weight_squant) - self.bias = Array(output_shape[-1], self.bias_squant) + if self.use_bias: + self.bias = Array(output_shape[-1], self.bias_squant) self.nabla_weights = Tensor(weight_shape, self.weight_squant) - self.nabla_bias = Array(output_shape[-1], self.bias_squant) + if self.use_bias: + self.nabla_bias = Array(output_shape[-1], self.bias_squant) + + self.unreduced = Tensor(self.output_shape, sint, address=self.Y.address) if tf_weight_format: weight_in = weight_shape[2] @@ -1687,10 +2297,6 @@ class ConvBase(BaseLayer): self.bias_shape, self.Y.sizes, self.stride, repr(self.padding), self.tf_weight_format) - @property - def unreduced(self): - return Tensor(self.output_shape, sint, address=self.Y.address) - def input_from(self, player, **kwargs): self.input_params_from(player) self.weights.input_from(player, budget=100000, **kwargs) @@ -1699,7 +2305,8 @@ class ConvBase(BaseLayer): def output_weights(self): self.weights.print_reveal_nested() - print_ln('%s', self.bias.reveal_nested()) + if self.use_bias: + print_ln('%s', self.bias.reveal_nested()) def reveal_parameters_to_binary(self): assert not self.tf_weight_format @@ -1711,15 +2318,21 @@ class ConvBase(BaseLayer): def _(j): part = self.weights.get_vector_by_indices(i, None, None, j) part.reveal().binary_output() - self.bias.reveal_to_binary_output() + if self.use_bias: + self.bias.reveal_to_binary_output() def dot_product(self, iv, wv, out_y, out_x, out_c): - bias = self.bias[out_c] - acc = self.output_squant.unreduced_dot_product(iv, wv) - acc.v += bias.v - acc.res_params = self.output_squant.params - #self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul() - self.unreduced[0][out_y][out_x][out_c] = acc.v + if self.use_bias: + bias = self.bias[out_c] + acc = self.output_squant.unreduced_dot_product(iv, wv) + acc.v += bias.v + acc.res_params = self.output_squant.params + #self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul() + self.unreduced[0][out_y][out_x][out_c] = acc.v + else: + acc = self.output_squant.unreduced_dot_product(iv, wv) + acc.res_params = self.output_squant.params + self.unreduced[0][out_y][out_x][out_c] = acc.v def reduction(self, batch_length=1): unreduced = self.unreduced @@ -1785,11 +2398,12 @@ class Conv2d(ConvBase): inputs_h, inputs_w, weights_h, weights_w, stride_h, stride_w, n_channels_in, padding_h, padding_w, part_size) - if self.bias_before_reduction: - res += self.bias.expand_to_vector(j, res.size).v - else: - res += self.bias.expand_to_vector(j, res.size).v << \ - self.weight_squant.f + if self.use_bias: + if self.bias_before_reduction: + res += self.bias.expand_to_vector(j, res.size).v + else: + res += self.bias.expand_to_vector(j, res.size).v << \ + self.weight_squant.f addresses = regint.inc(res.size, self.unreduced[i * part_size].address + j, n_channels_out) @@ -1797,7 +2411,8 @@ class Conv2d(ConvBase): self.reduction(len(batch)) if self.debug_output: print_ln('%s weights %s', self, self.weights.reveal_nested()) - print_ln('%s bias %s', self, self.bias.reveal_nested()) + if self.use_bias: + print_ln('%s bias %s', self, self.bias.reveal_nested()) @for_range(len(batch)) def _(i): print_ln('%s X %s %s', self, i, self.X[batch[i]].reveal_nested()) @@ -1873,7 +2488,8 @@ class FixConv2d(Conv2d, FixBase): r = math.sqrt(6.0 / (n_in + self.weight_shape[0])) print('Initializing convolution weights in [%f,%f]' % (-r, r)) self.weights.randomize(-r, r, n_threads=self.n_threads) - self.bias.assign_all(0) + if self.use_bias: + self.bias.assign_all(0) def backward(self, compute_nabla_X=True, batch=None): assert self.use_conv2ds @@ -1888,14 +2504,15 @@ class FixConv2d(Conv2d, FixBase): N = len(batch) - self.nabla_bias.assign_all(0) + if self.use_bias: + self.nabla_bias.assign_all(0) - @for_range(N) - def _(i): - self.nabla_bias.assign_vector( - self.nabla_bias.get_vector() + sum(sum( - self.nabla_Y[i][j][k].get_vector() for k in range(output_w)) - for j in range(output_h))) + @for_range(N) + def _(i): + self.nabla_bias.assign_vector( + self.nabla_bias.get_vector() + sum(sum( + self.nabla_Y[i][j][k].get_vector() for k in range(output_w)) + for j in range(output_h))) input_size = inputs_h * inputs_w * N batch_repeat = regint.Matrix(N, inputs_h * inputs_w) @@ -1975,8 +2592,9 @@ class FixConv2d(Conv2d, FixBase): print_ln('%s nabla weights %s', self, (self.nabla_weights.reveal_nested())) print_ln('%s weights %s', self, (self.weights.reveal_nested())) - print_ln('%s nabla b %s', self, (self.nabla_bias.reveal_nested())) - print_ln('%s bias %s', self, (self.bias.reveal_nested())) + if self.use_bias: + print_ln('%s nabla b %s', self, (self.nabla_bias.reveal_nested())) + print_ln('%s bias %s', self, (self.bias.reveal_nested())) class QuantDepthwiseConv2d(QuantConvBase, Conv2d): def n_summands(self): @@ -2109,7 +2727,7 @@ class AveragePool2d(BaseLayer): self.Y[0][out_y][out_x][c] = self.output_squant._new(acc) def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, - padding=0, **kwargs): + padding=0, bias=True, **kwargs): """ More convenient interface to :py:class:`FixConv2d`. :param input_shape: input shape (tuple/list of four int) @@ -2117,6 +2735,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, :param kernel_size: kernel size (int or tuple/list of two int) :param stride: stride (int or tuple/list of two int) :param padding: :py:obj:`'SAME'`, :py:obj:`'VALID'`, int, or tuple/list of two int + :param bias: whether layer has bias (bool) """ if isinstance(kernel_size, int): @@ -2130,7 +2749,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, padding = padding.upper() if isinstance(padding, str) \ else padding return FixConv2d(input_shape, weight_shape, (out_channels,), output_shape, - stride, padding, **kwargs) + stride, padding, bias=bias, **kwargs) def easyMaxPool(input_shape, kernel_size, stride=None, padding=0): """ More convenient interface to :py:class:`MaxPool`. @@ -2241,6 +2860,627 @@ class QuantSoftmax(QuantBase, BaseLayer): return [c.if_else(x, y) for x, y in zip(left, right)] print_ln('guess: %s', util.tree_reduce(comp, list(enumerate(self.X[0])))[0].reveal()) +class BertBase(BaseLayer, FixBase): + pass + +# class BertEmbedding(BertBase): # we dont do embedding + +class BertPooler(BertBase): + + thetas = lambda self: self.dense.thetas() + nablas = lambda self: self.dense.nablas() # refer to downstream layers? + + def __init__(self, n_examples, seq_len, hidden_state): + input_shape = [n_examples, seq_len, hidden_state] + output_shape = [n_examples, hidden_state] + super(BertPooler, self).__init__(input_shape, output_shape) + self.dense = Dense(n_examples, hidden_state, hidden_state) + self.activation = Tanh(output_shape) + + self.d_out = hidden_state + + + def _forward(self, batch): + # self.dense.X.address = self.X.address + self.activation.X.address = self.dense.Y.address + self.activation.Y.address = self.Y.address + + # grab the first repr? + # batch contains [n_batch, n_heads, n_dim] + @for_range(len(batch)) + def _(j): + print_ln("Pooling %s %s", j, batch[j]) + self.dense.X[j][:] = self.X[batch[j]][0][:] + + # if self.debug_output: + # print_ln("forward layer pooler.dense X %s", self.dense.X.reveal_nested()) + + self.dense.forward(batch) + # print_ln("LINEAR Layer weights after bertpooler.dense: %s", self.opt.layers[-2].W.reveal_nested()) + + self.activation._forward(batch) + # print_ln("LINEAR Layer weights after bertpooler.activation: %s", self.opt.layers[-2].W.reveal_nested()) + + def reset(self): + self.dense.reset() + self.activation.reset() + + def load_state_dict(self, state_dict, input_via): + import numpy + self.dense.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['dense.weight'], 0, 1)) + self.dense.b = sfix.input_tensor_via(input_via, state_dict['dense.bias']) + + def backward(self, compute_nabla_X=True, batch=None): + if batch is None: + batch = regint.Array(self.N) + batch.assign(regint.inc(self.N)) + + self.activation.nabla_X.alloc() + + self.activation.nabla_Y.address = self.nabla_Y.address + self.dense.nabla_Y.address = self.activation.nabla_X.address + self.dense.nabla_X.address = self.nabla_X.address # TODO: size mismatch here, but should be okay? cause rest 0s? + + self.activation.backward(batch) + self.dense.backward(compute_nabla_X, batch) + +class BertEncoder(BertBase): + + # I think this is unused? + + def __init__(self, n_examples, n_layers, d_model, n_heads, d_k, d_v, d_ff, dropout=0.1): + input_shape = [n_examples, d_model] + output_shape = [n_examples, d_model] + super(BertEncoder, self).__init__(input_shape, output_shape) + self.layers = [] + for _ in range(n_layers): + self.layers.append(BertLayer(n_examples, d_model, n_heads, d_k, d_v, d_ff, dropout)) + + for i in enumerate(1, len(self.layers)): + self.layers[i].X.address = self.layers[i - 1].Y.address + + self.layers[0].X.address = self.X.address + self.layers[-1].Y.address = self.Y.address + + def _forward(self, batch): + for layer in self.layers: + layer.forward(batch) + + def reset(self): + for layer in self.layers: + layer.reset() + + +class BertLayer(BertBase): + + thetas = lambda self: self.multi_head_attention.thetas() + self.intermediate.thetas() + self.output.thetas() #+ tuple(self.nabla_hidden_state) + nablas = lambda self: self.multi_head_attention.nablas() + self.intermediate.nablas() + self.output.nablas() #+ tuple(self.nabla_hidden_state) + + def __init__(self, n_examples, seq_len, hidden_state, intermediate_size, num_attention_heads, layernorm_eps, dropout=0.1, rsqrt_approx=True, batch_size=None): + input_shape = [n_examples, seq_len, hidden_state] + output_shape = [n_examples, seq_len, hidden_state] # TODO: we could make this batch_size + super(BertLayer, self).__init__(input_shape, output_shape) + + internal_shape = batch_size if batch_size is not None else n_examples + self.multi_head_attention = MultiHeadAttention(internal_shape, seq_len, hidden_state, num_attention_heads, dropout, layernorm_eps, rsqrt_approx) + self.intermediate = BertIntermediate(internal_shape, hidden_state, intermediate_size, seq_len) + self.output = BertOutput(internal_shape, intermediate_size, hidden_state, seq_len, dropout, layernorm_eps, rsqrt_approx) + + self.hidden_state = sfix.Tensor(input_shape) # TODO: Could also make this smaller + # self.nabla_hidden_state = sfix.Tensor(input_shape) + # self.nabla_hidden_state.alloc() + + # self.X.address = self.multi_head_attention.X.address + # self.Y.address = self.output.Y.address + + self.d_out = hidden_state + + print("Init BertLayer", input_shape, output_shape) + + def forward(self, batch, training=False): + if batch is None: + batch = Array.create_from(regint(0)) + + self.multi_head_attention._X.address = self.X.address + self.output.Y.address = self.Y.address + self.hidden_state.address = self.X.address + # self.multi_head_attention.Y.address = self.Y.address + + self.multi_head_attention.forward(batch, self.hidden_state, training) + # if self.debug_output: + # print_ln("our layer X %s %s", self.X[0][0][0].reveal(), self.output.X[0][0][0].reveal()) + + if self.debug_output: + print_ln("forward layer multi_head_attention %s %s", self.multi_head_attention.Y[0][1][0].reveal(), sum(sum(self.multi_head_attention.Y[0].reveal()))) + # print_ln("forward layer multi_head_attention full %s", self.multi_head_attention.Y.reveal()) + + print("Forward Attention") + + batch_inc = regint.Array(len(batch)) + batch_inc.assign(regint.inc(len(batch))) + self.intermediate.X.address = self.multi_head_attention.Y.address + self.intermediate.forward(batch_inc) + + if self.debug_output: + print_ln("forward layer intermediate %s %s %s", self.intermediate.Y.shape, self.intermediate.Y[0][1][0:20].reveal(), sum(sum(self.intermediate.Y[0].reveal()))) + + print_ln(" ") + + self.output.X.address = self.intermediate.Y.address + self.output.forward(batch_inc, self.multi_head_attention.Y, training) + # self.output.Y.address = self.output.X.address + + if self.debug_output: + print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) + # print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) + # print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) + + print_ln("our layer output %s %s %s %s", self.output.Y.address, len(self.Y[0].reveal()), self.output.Y[0][0][0:20].reveal(), sum(sum(self.output.Y[0].reveal()))) + # print_ln("shapes %s %s", self.Y.sizes, self.output.Y.sizes) + # print_ln("types %s %s %s %s %s %s", self.Y.value_type, self.output.Y.value_type, type(self.Y), type(self.output.Y), self, self.output) + + print("Forward BertLayer") + + def reset(self): + self.multi_head_attention.reset() + self.intermediate.reset() + self.output.reset() + + def load_state_dict(self, state_dict, input_via): + import numpy + # format of state_dict + # ['attention.self.query.weight', 'attention.self.query.bias', 'attention.self.key.weight', 'attention.self.key.bias', 'attention.self.value.weight', 'attention.self.value.bias', 'attention.output.dense.weight', 'attention.output.dense.bias', 'attention.output.LayerNorm.weight', 'attention.output.LayerNorm.bias', 'intermediate.dense.weight', 'intermediate.dense.bias', 'output.dense.weight', 'output.dense.bias', 'output.LayerNorm.weight', 'output.LayerNorm.bias'] + # set the values of the layers + self.multi_head_attention.wq.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['attention.self.query.weight'], 0, 1)) + self.multi_head_attention.wq.b = sfix.input_tensor_via(input_via, state_dict['attention.self.query.bias']) + self.multi_head_attention.wk.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['attention.self.key.weight'], 0, 1)) + self.multi_head_attention.wk.b = sfix.input_tensor_via(input_via, state_dict['attention.self.key.bias']) + self.multi_head_attention.wv.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['attention.self.value.weight'], 0, 1)) + self.multi_head_attention.wv.b = sfix.input_tensor_via(input_via, state_dict['attention.self.value.bias']) + + self.multi_head_attention.output.dense.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['attention.output.dense.weight'], 0, 1)) + self.multi_head_attention.output.dense.b = sfix.input_tensor_via(input_via, state_dict['attention.output.dense.bias']) + self.multi_head_attention.output.layer_norm.weights = sfix.input_tensor_via(input_via, state_dict['attention.output.LayerNorm.weight']) + self.multi_head_attention.output.layer_norm.bias = sfix.input_tensor_via(input_via, state_dict['attention.output.LayerNorm.bias']) + + self.intermediate.dense.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['intermediate.dense.weight'], 0, 1)) + self.intermediate.dense.b = sfix.input_tensor_via(input_via, state_dict['intermediate.dense.bias']) + + self.output.dense.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['output.dense.weight'], 0, 1)) + # print_ln("output.dense.W state_dict %s", self.output.dense.W[0][0].reveal()) + self.output.dense.b = sfix.input_tensor_via(input_via, state_dict['output.dense.bias']) + self.output.layer_norm.weights = sfix.input_tensor_via(input_via, state_dict['output.LayerNorm.weight']) + self.output.layer_norm.bias = sfix.input_tensor_via(input_via, state_dict['output.LayerNorm.bias']) + + def backward(self, compute_nabla_X=True, batch=None): + # layer.inputs[0].nabla_Y.address = \ + # layer.nabla_X.address + # assign nabla_X and Y + self.multi_head_attention.nabla_X.alloc() + self.intermediate.nabla_X.alloc() + self.output.nabla_X.alloc() + + self.output.nabla_Y.address = self.nabla_Y.address + self.intermediate.nabla_Y.address = self.output.nabla_X.address + self.multi_head_attention.nabla_Y.address = self.intermediate.nabla_X.address + # self.multi_head_attention.nabla_X.address = self.nabla_X.address + + nabla_y_multi_head_attention_from_layernorm = self.output.backward(True, batch) + # print_ln("Backward BertLayer.output.nabla_X %s", self.output.nabla_X.reveal_nested()[:8]) + self.intermediate.backward(True, batch) + + # residual, add it to Y because it gave the output of multihadattention to output + @multithread(self.n_threads, len(batch)) + def _(base, size): + self.multi_head_attention.nabla_Y.assign_part_vector( + self.multi_head_attention.nabla_Y.get_part_vector(base, size) + + nabla_y_multi_head_attention_from_layernorm.get_part_vector(base, size), base) + + if compute_nabla_X: + self.multi_head_attention.nabla_X.address = self.nabla_X.address + + nabla_y_hidden_state = self.multi_head_attention.backward(compute_nabla_X, batch) + + if compute_nabla_X: + print_ln("Bertlayer nabla_x %s %s", nabla_y_hidden_state.get_vector().reveal()[-8:], self.nabla_X.get_vector().reveal()[-8:]) + # and add hidden_state back to nabla_X, add to x because we gave x to multi_head_attention + @multithread(self.n_threads, len(batch)) + def _(base, size): + self.nabla_X.assign_part_vector( + self.nabla_X.get_part_vector(base, size) + + nabla_y_hidden_state.get_part_vector(base, size), base) + + +class BertIntermediate(BertBase): + + thetas = lambda self: self.dense.thetas() + nablas = lambda self: self.dense.nablas() + + def __init__(self, n_examples, hidden_size, intermediate_size, seq_len): + input_shape = [n_examples, seq_len, hidden_size] + output_shape = [n_examples, seq_len, intermediate_size] + super(BertIntermediate, self).__init__(input_shape, output_shape) + self.dense = FlexDense(n_examples, hidden_size, intermediate_size, seq_len) + self.activation = Gelu([n_examples, seq_len, intermediate_size]) + + + def forward(self, batch=None, training=None): + self.dense.X.address = self.X.address + self.activation.X.address = self.dense.Y.address + self.activation.Y.address = self.Y.address + + self.dense.forward(batch) + if self.debug_output: + print_ln("forward layer intermediate.dense %s", self.dense.Y[0][0][0:20].reveal()) + + self.activation._forward(batch) + + def reset(self): + self.dense.reset() + + def backward(self, compute_nabla_X=True, batch=None): + self.activation.nabla_X.alloc() + + # print_ln("Backward BertIntermediate.nabla_X %s", self.nabla_X.reveal_nested()[:8]) + + self.activation.nabla_Y.address = self.nabla_Y.address + self.dense.nabla_Y.address = self.activation.nabla_X.address + self.dense.nabla_X.address = self.nabla_X.address + + self.activation.backward(batch) + self.dense.backward(compute_nabla_X, batch) + + +class BertOutput(BertBase): + + thetas = lambda self: self.dense.thetas() + self.layer_norm.thetas() + nablas = lambda self: self.dense.nablas() + self.layer_norm.nablas() + + def __init__(self, n_examples, intermediate_size, hidden_size, seq_len, dropout=0.1, layernorm_eps=1e-12, rsqrt_approx=True): + input_shape = [n_examples, seq_len, intermediate_size] + output_shape = [n_examples, seq_len, hidden_size] + self.input_shape = input_shape + print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx) + super(BertOutput, self).__init__(input_shape, output_shape) + self.dense = FlexDense(n_examples, intermediate_size, hidden_size, seq_len) + self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx) + self.dropout = FlexDropout([n_examples, seq_len, hidden_size], alpha=dropout) + + + def forward(self, batch, input_tensor, training=False, input_tensor_batch=None): + # Because input_tensor might be the full training data shape + self.dense.X.address = self.X.address + self.dropout.X.address = self.dense.Y.address + self.layer_norm.X.address = self.dropout.Y.address + self.layer_norm.Y.address = self.Y.address + + self.dense.forward(batch) + if self.debug_output: + print_ln("forward layer output.dense %s", self.dense.Y[0][0][0:20].reveal()) + + self.dropout.forward(batch, training) + + if input_tensor_batch is not None: + input_tensor_batch_arr = MultiArray([len(batch), input_tensor.sizes[1], input_tensor.sizes[2]], sfix) + input_tensor_batch_arr.assign_vector(input_tensor.get_slice_vector(input_tensor_batch)) + @multithread(self.n_threads, len(batch)) + def _(base, size): + self.layer_norm.X.assign_part_vector( + self.layer_norm.X.get_part_vector(base, size) + + input_tensor_batch_arr.get_part_vector(base, size), base) + else: + @multithread(self.n_threads, len(batch)) + def _(base, size): + self.layer_norm.X.assign_part_vector( + self.layer_norm.X.get_part_vector(base, size) + + input_tensor.get_part_vector(base, size), base) + # if self.debug_output: + # print_ln("input tensor %s", input_tensor.reveal()) + + # self.layer_norm.X[:] += input_tensor[:] # TODO: is it maybe this addition since we take the last value? would be strange + + if self.debug_output: + print_ln("forward layer layer_norm_add %s", self.layer_norm.X[0][0][0:20].reveal()) + print_ln("") + self.layer_norm.forward(batch) + + + + def reset(self): + self.dense.reset() + + def backward(self, compute_nabla_X=True, batch=None): + self.layer_norm.nabla_X.alloc() + self.dropout.nabla_X.alloc() + + self.layer_norm.nabla_Y.address = self.nabla_Y.address + self.dropout.nabla_Y.address = self.layer_norm.nabla_X.address + self.dense.nabla_Y.address= self.dropout.nabla_X.address + self.dense.nabla_X.address = self.nabla_X.address + + # layer norm flows back to dropout but also to hidden_tensor... nabla hidden state? + + self.layer_norm.backward(batch, compute_nabla_X) + self.dropout.backward(compute_nabla_X, batch) + + if self.debug_output: + print_ln("backward layer dense x %s", self.dropout.nabla_X[0][0][0:20].reveal()) + + self.dense.backward(compute_nabla_X, batch) + + return self.layer_norm.nabla_X + +class MultiHeadAttention(BertBase): + + thetas = lambda self: self.wq.thetas() + self.wk.thetas() + self.wv.thetas() + self.output.thetas() + nablas = lambda self: self.wq.nablas() + self.wk.nablas() + self.wv.nablas() + self.output.nablas() + + def __init__(self, n_examples, seq_len, hidden_size, num_attention_heads, dropout=0.1, layernorm_eps=1e-12, rsqrt_approx=True, batch_size=None): + + # In the first layer the internal_shape is different from n_examples, afterwards it is the same + internal_shape = batch_size if batch_size is not None else n_examples + self.n_examples = internal_shape + + input_shape = [n_examples, seq_len, hidden_size] + output_shape = [internal_shape, seq_len, hidden_size] + super().__init__(input_shape, output_shape) + + print("Multheadattention", rsqrt_approx, input_shape, output_shape) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.hidden_size = hidden_size + self.seq_len = seq_len + + self.hidden_size = hidden_size + self.wq = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wk = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wv = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.dropout = FlexDropout([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], alpha=dropout) # I think? # TODO: DROPOUT? + + self.output = BertOutput(internal_shape, hidden_size, hidden_size, seq_len, dropout, layernorm_eps, rsqrt_approx) + self.context = sfix.Tensor([internal_shape, self.seq_len, hidden_size]) + self.nabla_context = sfix.Tensor([internal_shape, self.seq_len, hidden_size]) + + # self.context_nabla + + self.attention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix) + self.nabla_attention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix) + self.nabla_preattention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix) + + def forward(self, batch=None, hidden_state=None, training=None): + N = len(batch) + + # set up layers + dense_layers = [self.wq, self.wk, self.wv] + for layer in dense_layers: + layer.X.address = self.X.address + + self.output.X.address = self.context.address + self.output.Y.address = self.Y.address + + self.wq.forward(batch) + self.wk.forward(batch) + self.wv.forward(batch) + + inc_batch = regint.Array(N) + inc_batch.assign(regint.inc(N)) + + print_ln("post forward") + + if self.debug_output: + # print_ln('forward layer wq full %s', self.wq.X.reveal()) + print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal())) + print_ln('forward layer hidden_state %s', hidden_state[0][1][0:10].reveal()) + # print_ln('forward layer wv full %s', self.wv.Y.reveal()) + + # max_size = program.budget // self.attention_head_size + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) + def _(i, j): + # for j in range(self.num_attention_heads): + query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) # this is mem inefficient? + key_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + # print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:]) + + @for_range_opt(self.seq_len) + def _(k): + # for k in range(self.seq_len): + query_sub[k] = self.wq.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + key_sub[k] = self.wk.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + + # print_ln("query_sub %s %s", i, j) + res = query_sub.direct_mul_trans(key_sub) + self.attention_scores[i].assign_part_vector(res, j) + + if self.debug_output: + print_ln('forward layer attention_scores %s', self.attention_scores[0][0].reveal()) + # print_ln('forward layer attention_scores full %s', self.attention_scores.reveal()) + + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads, self.seq_len]) + def _(i, j, k): + self.attention_scores[i][j][k][:] = self.attention_scores[i][j][k][:] / math.sqrt(self.attention_head_size) + self.attention_scores[i][j][k][:] = softmax(self.attention_scores[i][j][k][:]) + + self.dropout.X.address = self.attention_scores.address + self.dropout.forward(batch=inc_batch, training=training) + + if self.debug_output: + print_ln('forward layer dropout full %s', self.dropout.Y.reveal()) + + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) + def _(i, j): + value_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + + @for_range_opt([self.seq_len]) + def _(k): + value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + # value_sub[k] = self.wv.Y[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] + + res = sfix.Matrix(self.seq_len, self.attention_head_size) + res.assign_vector(self.dropout.Y[i][j].direct_mul(value_sub)) + # res = self.dropout.Y[i][j].direct_mul(value_sub) + + @for_range_opt([self.seq_len]) + def _(k): + self.context[i][k].assign_part_vector(res[k], + j * self.attention_head_size + ) + # for k in range(self.seq_len): + # self.context[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] = res[k * self.attention_head_size:(k + 1) * self.attention_head_size] + + # How to transfer to forward? + + # missing half of the values ? + # print_ln('forward layer old_context %s', self.old_context[0].get_vector().reveal()) + if self.debug_output: + print_ln('forward layer multiheadattention before internal output %s', self.context[0][0][0:20].get_vector().reveal()) + + if self.debug_output: + print_ln('forward layer hidden_state %s', hidden_state[0][1][0:20].reveal()) + + self.output.forward(inc_batch, hidden_state, training, batch) + if self.debug_output: + print_ln('forward multiheadattention output %s', self.output.Y[0][0][0:20].reveal()) + print_ln("") + + # return context + + def reset(self): + self.wq.reset() + self.wk.reset() + self.wv.reset() + self.output.reset() + + def backward(self, compute_nabla_X=True, batch=None): + N = len(batch) + dense_layers = [self.wq, self.wk, self.wv] + for layer in dense_layers: + layer.nabla_Y.alloc() # we will fill them manually below + layer.nabla_X.alloc() # we have to add them up layer + + self.output.nabla_Y.address = self.nabla_Y.address + + self.output.nabla_X.alloc() + self.nabla_context.address = self.output.nabla_X.address + + self.nabla_attention_scores.address = self.dropout.nabla_X + + nabla_y_hidden_state = self.output.backward(True, batch) + + if self.debug_output: + print_ln("backward layer attention output.nabla_Y %s", self.output.nabla_Y.reveal_nested()[0][0][:8]) + print_ln("backward layer attention output.nabla_X %s", self.output.nabla_X.reveal_nested()[0][0][:8]) + + # Backprop context + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) + def _(i, j): + res = sfix.Matrix(self.seq_len, self.attention_head_size) + value_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + + @for_range_opt([self.seq_len]) + def _(k): + # dout_bth + res[k].assign_vector(self.nabla_context[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)) # nabla_Y + # value_t2 + value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + + nabla_value_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + + # dvalue_t2 = dout_bth * att_bth + nabla_value_sub.assign_vector(self.dropout.Y[i][j].direct_trans_mul(res)) + # nabla_value_sub.assign_vector(self.context[i][j].direct_trans_mul(res)) + + # datt_bth = dout_bth * value_t2 + self.dropout.nabla_Y[i][j].assign_vector(res.direct_mul_trans(value_sub)) + + @for_range_opt([self.seq_len]) + def _(k): + # value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + self.wv.nabla_Y[i][k].assign_part_vector( + nabla_value_sub[k], + j * self.attention_head_size) + + print("RES MULTI BACK", self.dropout.Y, res, self.num_attention_heads, self.attention_head_size) + + self.dropout.nabla_X.alloc() + self.dropout.backward(True, batch) + + if self.debug_output: + # Dropout nabla y is correct + # wv nabla_Y also correct + print_ln("backward layer attention dropout.nabla_Y %s", self.dropout.nabla_Y.reveal_nested()[:8]) + print_ln("backward layer attention wv.nabla_Y %s", self.wv.nabla_Y.reveal_nested()[:8]) + + # attention to pre + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads, self.seq_len]) + def _(i, j, k): + @for_range_opt([self.seq_len, self.seq_len]) + def _(t1, t2): + indicator = cfix(t1 == t2) + # local_deriv = self.attention_scores[i][j][k][t1] * (indicator - self.attention_scores[i][j][k][t2]) + local_deriv = self.dropout.Y[i][j][k][t1] * (indicator - self.dropout.Y[i][j][k][t2]) + + # print_ln("indiciator %s %s %s %s %s %s", t1, t2, indicator, local_deriv.reveal(), self.dropout.Y[i][j][k][t1].reveal(), self.attention_scores[i][j][k][t2].reveal()) + self.nabla_preattention_scores[i][j][k][t2] += local_deriv * self.dropout.nabla_X[i][j][k][t1] # x or y? + + print_ln("attention_scores %s", self.attention_scores.reveal()) + print_ln("nabla preattention scores %s", self.nabla_preattention_scores.reveal()) + + scale = 1 / math.sqrt(self.attention_head_size) + # backward pass 1 + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) + def _(i, j): + # for j in range(self.num_attention_heads): + query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + key_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + # print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:]) + @for_range_opt(self.seq_len) + def _(k): # This mempcopy is ugly + query_sub[k] = self.wq.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + key_sub[k] = self.wk.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + + # nabla_query_sub = key_sub.direct_trans_mul(self.nabla_preattention_scores[i][j]) + # nabla_key_sub = self.nabla_preattention_scores[i][j].direct_mul_trans(key_sub) + + print_ln("preatt %s", self.nabla_preattention_scores[i][j].reveal()) + + nabla_query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + nabla_key_sub_trans = sfix.Matrix(self.attention_head_size, self.seq_len) + nabla_query_sub.assign_vector(self.nabla_preattention_scores[i][j].direct_trans_mul(key_sub) * scale) + nabla_key_sub_trans.assign_vector(query_sub.direct_trans_mul(self.nabla_preattention_scores[i][j]) * scale) + nabla_key_sub = nabla_key_sub_trans.transpose() + + print_ln("nabla query sub %s", nabla_query_sub.reveal()) + print_ln("nabla key sub %s", nabla_key_sub.reveal()) + + # nabla_key_sub is seq_len_seqlen, copy back into wk which is seq_len, all_head_size + @for_range_opt(self.seq_len) + def _(k): # This mempcopy is ugly? + self.wq.nabla_Y[i][k].assign_part_vector(nabla_query_sub[k], j * self.attention_head_size) + self.wk.nabla_Y[i][k].assign_part_vector(nabla_key_sub[k], j * self.attention_head_size) + + if self.debug_output: + print_ln("backward layer attention wq.nabla_Y %s", self.wq.nabla_Y.reveal_nested()[:8]) + + # wk slightly off + print_ln("backward layer attention wk.nabla_Y %s", self.wk.nabla_Y.reveal_nested()[:8]) + + self.wq.backward(compute_nabla_X, batch) + self.wk.backward(compute_nabla_X, batch) + self.wv.backward(compute_nabla_X, batch) + + @multithread(self.n_threads, len(batch)) + def _(base, size): + sum_layers = sum([layer.nabla_X.get_part_vector(base, size) for layer in dense_layers]) + self.nabla_X.assign_part_vector( + sum_layers, base) + + if self.debug_output: + # TODO: Wq seems off still + print_ln("backward layer attention wq.nabla_X %s", self.wq.nabla_X.reveal_nested()[:8]) + + return nabla_y_hidden_state + class Optimizer: """ Base class for graphs of layers. """ n_threads = Layer.n_threads @@ -2437,7 +3677,8 @@ class Optimizer: layer.backward(compute_nabla_X=False, batch=self.batch_for(layer, batch)) else: - layer.nabla_X.alloc() + if len(layer.inputs) == 1: + layer.nabla_X.alloc() layer.backward(batch=self.batch_for(layer, batch)) if len(layer.inputs) == 1: layer.inputs[0].nabla_Y.address = \ @@ -3330,8 +4571,23 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, return reduce(operator.mul, x) import torch + import torch.fx + + # Custom tracer to prevent inlining BERT layers + class BertTracer(torch.fx.Tracer): + def is_leaf_module(self, m, module_qualified_name): + # Treat BertLayer, BertPooler, and BertEmbeddings as leaf modules (don't trace into them) + type_name = type(m).__name__ + if any(x in type_name for x in ['BertLayer', 'BertPooler', 'BertEmbeddings']): + return True + return super().is_leaf_module(m, module_qualified_name) def process(item, inputs, input_shape, args, kwargs={}): + # Skip assertion and validation functions from torch.fx trace + if callable(item) and hasattr(item, '__name__') and ( + item.__name__.startswith('_assert') or + item.__name__ in ('eq', 'getitem', 'size')): + return if item == torch.cat: if len(inputs) > 1: layers.append( @@ -3387,18 +4643,22 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, elif name == 'Conv2d': layers.append(easyConv2d(input_shape, batch_size, item.out_channels, item.kernel_size, item.stride, - item.padding, **layer_args.get(item, {}))) + item.padding, item.bias is not None, **layer_args.get(item, {}))) input_shape = layers[-1].Y.shape if input_via is not None: - shapes = [x.shape for x in + if item.bias is not None: + shapes = [x.shape for x in (layers[-1].weights, layers[-1].bias)] + else: + shapes = [layers[-1].weights.shape] + print("shapes", shapes, layers[-1].weights.shape) import numpy swapped = numpy.moveaxis( numpy.array(item.weight.detach()), 1, -1) layers[-1].weights = \ layers[-1].weights.value_type.input_tensor_via( input_via, swapped) - assert layers[-1].weights.shape == shapes[0] + assert layers[-1].weights.shape == shapes[0], f"{layers[-1].weights.shape} != {shapes[0]}" if isinstance(item.bias, torch.Tensor): layers[-1].bias = sfix.input_tensor_via( input_via, item.bias.detach()) @@ -3429,7 +4689,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, layers.append(Relu(input_shape)) elif name == 'Flatten': return - elif name == 'BatchNorm2d': + elif name == 'BatchNorm2d' or name == 'BatchNorm1d': layers.append(BatchNorm(layers[-1].Y.sizes)) if input_via is not None: layers[-1].epsilon = item.eps @@ -3437,18 +4697,124 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, item.weight.detach()) layers[-1].bias = sfix.input_tensor_via(input_via, item.bias.detach()) + layers[-1].mu_hat = sfix.input_tensor_via( + input_via, item.running_mean.detach()) + layers[-1].var_hat = sfix.input_tensor_via( + input_via, item.running_var.detach()) elif name == 'Dropout': + alpha = item.p + if alpha == 0.1: + print('WARNING: dropout rate 0.1 not supported, using 0.125') + alpha = 0.125 layers.append(Dropout(input_shape[0], mul(layers[-1].Y.sizes[1:]), - alpha=item.p)) + alpha=alpha)) input_shape = layers[-1].Y.sizes + elif name == 'BertForSequenceClassification': + process(item.bert) + process(item.dropout) + process(item.classifier) + elif name == 'BertModel': + bert_config = item.config + process(item.embeddings) + process(item.encoder) + process(item.pooler) + elif name == 'BertEmbeddings': + print('Embedding layer not implemented.', item) + pass # no-op + elif name == 'BertEncoder': + for x in item.layer: + process(x) + elif name == 'BertLayer': + # Get config from the model or item + if 'bert_config' in locals(): + config = bert_config + elif hasattr(model, 'config'): + config = model.config + elif hasattr(item, 'config'): + config = item.config + else: + raise CompilerError('BertLayer requires config but none found in model or item') + hidden_state = config.hidden_size + intermediate_size = config.intermediate_size + num_attention_heads = config.num_attention_heads + layernorm_eps = config.layer_norm_eps + seq_len = input_shape[1] + rsqrt_approx = False + layer = BertLayer(input_shape[0], seq_len, hidden_state, intermediate_size, num_attention_heads, + layernorm_eps, 0.125, rsqrt_approx, batch_size=batch_size) + if input_via is not None: + layer.load_state_dict(item.state_dict(), input_via) + layers.append(layer) + input_shape = [batch_size, seq_len, hidden_state] + elif name == 'BertPooler': + # Get config from the model or item + if 'bert_config' in locals(): + config = bert_config + elif hasattr(model, 'config'): + config = model.config + elif hasattr(item, 'config'): + config = item.config + else: + raise CompilerError('BertPooler requires config but none found in model or item') + layer = BertPooler(input_shape[0], input_shape[1], config.hidden_size) + if input_via is not None: + layer.load_state_dict(item.state_dict(), input_via) + layers.append(layer) + elif name == "Identity": + return else: raise CompilerError('unknown PyTorch module: %s' % item) layers[-1].inputs = inputs input_shape = data_input_shape + [1] * (4 - len(data_input_shape)) - torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes) + # torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes) + # Use custom tracer to keep BERT layers as modules + tracer = BertTracer() + + # Determine concrete_args based on model type + # Check if this is BertModel or BertEncoder + import torch as torch_module + model_type = type(model).__name__ + if model_type == 'BertModel': + # BertModel requires both input_ids and token_type_ids in concrete_args + # None means they must be provided as positional args during trace + concrete_args = { + "attention_mask": None, + "token_type_ids": None, + "position_ids": None, + "head_mask": None, + "inputs_embeds": None, + "encoder_hidden_states": None, + "encoder_attention_mask": None, + "past_key_values": None, + "use_cache": None, + "output_attentions": False, + "output_hidden_states": False, + "return_dict": False, + } + else: + # BertEncoder and other modules + concrete_args = { + "attention_mask": None, + "head_mask": None, + "encoder_hidden_states": None, + "encoder_attention_mask": None, + "past_key_values": None, + "use_cache": None, + "output_attentions": False, + "output_hidden_states": False, + "return_dict": False, + } + + graph = tracer.trace(model, concrete_args=concrete_args) + torch_layers = list(graph.nodes) + print(torch_layers) for i, layer in enumerate(torch_layers[1:-1]): + # Skip non-module and non-function operations (like getitem, assertions, etc.) + if layer.op not in ('call_module', 'call_function'): + continue + if layer.op == 'call_module': target = model for attr in layer.target.split('.'): @@ -3456,7 +4822,8 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, else: target = layer.target if not layers: - assert layer.args == (torch_layers[i],) + print(f"First layer check: layer.args={layer.args}, torch_layers[i]={torch_layers[i]}") + # First real layer - no need for assertion inputs = None else: if len(layer.args) < 2 or (layer.args[1] != 1 and diff --git a/Programs/Source/bert_inference.mpc b/Programs/Source/bert_inference.mpc new file mode 100644 index 00000000..0e891322 --- /dev/null +++ b/Programs/Source/bert_inference.mpc @@ -0,0 +1,389 @@ +""" +BERT Inference in MP-SPDZ + +This program demonstrates secure multi-party computation (MPC) inference using a +pre-trained BERT model for sequence classification. It compares PyTorch and MP-SPDZ +implementations layer-by-layer and computes accuracy on the QNLI task from GLUE benchmark. + +The program: +1. Loads a pre-trained BERT-tiny model fine-tuned on QNLI +2. Converts it to MP-SPDZ representation using ml.layers_from_torch() +3. Runs inference on N samples from the validation set +4. Compares MP-SPDZ outputs with PyTorch outputs layer-by-layer +5. Computes and reports accuracy + +Usage: + ./Scripts/compile-run.py -E replicated-ring bert_inference + +Configuration: + - MODEL_NAME: HuggingFace model identifier + - MAX_LENGTH: Maximum sequence length for tokenization + - N_SAMPLES: Number of validation samples to run + - BATCH_SIZE: Batch size for MP-SPDZ inference +""" + +import ml +import torch +import torch.nn as nn +from transformers import BertForSequenceClassification, BertTokenizer +from datasets import load_dataset + +# ============================================================================ +# Configuration +# ============================================================================ + +MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden) +MAX_LENGTH = 64 # Maximum sequence length +N_SAMPLES = 10 # Number of samples to evaluate +BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance) + +# GLUE task configuration +TASK_NAME = 'qnli' +TASK_KEYS = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +# ============================================================================ +# Model Loading and Data Preparation +# ============================================================================ + +print(f"Loading model: {MODEL_NAME}") +tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) +model = BertForSequenceClassification.from_pretrained(MODEL_NAME) +model.eval() + +print(f"Loading {TASK_NAME} dataset from GLUE benchmark") +dataset = load_dataset('glue', TASK_NAME) +validation = dataset['validation'].take(N_SAMPLES) + +print(f"Configuration:") +print(f" Model: {MODEL_NAME}") +print(f" Task: {TASK_NAME}") +print(f" Samples: {N_SAMPLES}") +print(f" Max length: {MAX_LENGTH}") +print(f" Batch size: {BATCH_SIZE}") +print(f" Model architecture: {model.config.num_hidden_layers} layers, " + f"{model.config.hidden_size} hidden size") + + +def tokenize_dataset(example): + """Tokenize dataset examples based on task configuration.""" + sentence1_key, sentence2_key = TASK_KEYS[TASK_NAME] + args = ( + (example[sentence1_key],) if sentence2_key is None + else (example[sentence1_key], example[sentence2_key]) + ) + return tokenizer(*args, truncation=True, padding='max_length', max_length=MAX_LENGTH) + + +def embed_inputs(example): + """Convert tokenized inputs to BERT embeddings.""" + input_ids = torch.tensor(example["input_ids"]) + token_type_ids = torch.tensor(example["token_type_ids"]) + embedding = model.bert.embeddings(input_ids, token_type_ids=token_type_ids).detach() + return {'embedding': embedding} + + +# Tokenize and embed the validation data +print("Tokenizing and embedding validation data...") +tokenized_data = validation.map(tokenize_dataset, batched=True) +embedded_data = tokenized_data.map(embed_inputs, batched=True) + +# ============================================================================ +# PyTorch Inference (Ground Truth) +# ============================================================================ + +print("\nRunning PyTorch inference for ground truth...") + + +def run_pytorch_inference(model, dataset, n_samples): + """Run inference using PyTorch and collect predictions.""" + model.eval() + predictions = [] + probabilities = [] + labels = [] + + with torch.no_grad(): + for i in range(n_samples): + example = dataset[i] + inputs = { + key: torch.tensor([val]) + for key, val in example.items() + if key in ['input_ids', 'attention_mask', 'token_type_ids'] + } + print("PT Inputs", inputs) + + outputs = model(**inputs) + logits = outputs.logits + probs = torch.softmax(logits, dim=-1) + predicted = torch.argmax(logits, dim=-1).item() + + predictions.append(predicted) + probabilities.append(probs.detach()) + labels.append(example['label']) + + return predictions, probabilities, labels + + +pt_predictions, pt_probabilities, true_labels = run_pytorch_inference(model, tokenized_data, N_SAMPLES) +pt_accuracy = sum(p == l for p, l in zip(pt_predictions, true_labels)) / len(true_labels) +print(f"PyTorch accuracy: {pt_accuracy:.4f} ({sum(p == l for p, l in zip(pt_predictions, true_labels))}/{len(true_labels)})") + +# ============================================================================ +# MP-SPDZ Model Conversion +# ============================================================================ + +print("\nConverting BERT model to MP-SPDZ...") + + +class BertEncoderWithHead(nn.Module): + """Wrapper combining BERT encoder, pooler, dropout, and classification head.""" + + def __init__(self, encoder, pooler, dropout, classifier, config): + super().__init__() + self.encoder = encoder + self.pooler = pooler + self.dropout = dropout + self.classifier = classifier + self.config = config + + def forward(self, hidden_states, attention_mask=None, head_mask=None, + encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, use_cache=None, output_attentions=False, + output_hidden_states=False, return_dict=False): + """Forward pass through encoder, pooler, dropout, and classifier.""" + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + + +# Build MPC-compatible input tensors +def build_mpc_tensors(dataset): + """Convert dataset to MPC-compatible sfix tensors.""" + with dataset.formatted_as("torch", ["embedding", "label"]): + embeddings = torch.concat(list(map(lambda x: x['embedding'], dataset.iter(batch_size=1)))) + labels = torch.tensor([x['label'] for x in dataset.iter(batch_size=1)]) + + # One-hot encode labels (2 classes for QNLI) + labels_onehot = torch.nn.functional.one_hot(labels, num_classes=2) + + embeddings_sfix = sfix.input_tensor_via(0, embeddings.numpy()) + labels_sfix = sfix.input_tensor_via(0, labels_onehot.numpy()) + + return embeddings_sfix, labels_sfix, labels.numpy() + + +test_embeddings, test_labels_onehot, test_labels = build_mpc_tensors(embedded_data) +model_shape = test_embeddings.shape +print(f"Input shape: {model_shape}") + +# Wrap model for conversion +bert_wrapped = BertEncoderWithHead( + model.bert.encoder, + model.bert.pooler, + model.dropout, + model.classifier, + model.config +) + +# Convert to MP-SPDZ layers +print("Tracing model and converting to MP-SPDZ layers...") +mpc_layers = ml.layers_from_torch(bert_wrapped, model_shape, input_via=0, batch_size=BATCH_SIZE) +print(f"MP-SPDZ model: {len(mpc_layers)} top-level layers") +for i, layer in enumerate(mpc_layers): + print(f" Layer {i}: {layer}") + +# ============================================================================ +# MP-SPDZ Inference +# ============================================================================ + +print("\nRunning MP-SPDZ inference...") + +# Configure fixed-point arithmetic +sfix.round_nearest = False +program.use_trunc_pr = False + +# Create optimizer (used for forward pass) +optimizer = ml.SGD(mpc_layers) + +# Run inference using optimizer.eval() to get predictions +print_ln("\n=== Starting MPC Inference ===") +print_ln("Samples: %s", N_SAMPLES) +print_ln("Batch size: %s", BATCH_SIZE) + +# Convert Python lists to compile-time constants +pt_preds_list = [int(p) for p in pt_predictions] +true_labels_list = [int(l) for l in true_labels] + +# Use optimizer.eval() to get MPC predictions (argmax) +print_ln("Running MPC inference...") +mpc_predictions = optimizer.eval(test_embeddings, batch_size=BATCH_SIZE, top=True) + +print_ln("\n=== Per-Sample Comparison ===") +print_ln("Sample | True Label | PyTorch Pred | MPC Pred | PT Correct | MPC Correct | Match") +print_ln("-" * 80) + +# Track statistics +n_correct = MemValue(regint(0)) +n_mpc_matches_pytorch = MemValue(regint(0)) + +# Use regular Python for loop to access compile-time constants +for i in range(N_SAMPLES): + # Get predictions + mpc_pred = mpc_predictions[i].reveal() + true_label = true_labels_list[i] + pt_pred = pt_preds_list[i] + + # Check correctness + mpc_correct = cint(mpc_pred == true_label) + pt_correct = cint(pt_pred == true_label) + predictions_match = cint(mpc_pred == pt_pred) + + # Update statistics + n_correct.iadd(mpc_correct) + n_mpc_matches_pytorch.iadd(predictions_match) + + # Print per-sample results + print_ln("%s | %s | %s | %s | %s | %s | %s", + i, true_label, pt_pred, mpc_pred, + pt_correct, mpc_correct, predictions_match) + +# Compute final statistics +mpc_accuracy = cfix(n_correct.read(), k=63, f=31) / N_SAMPLES +match_rate = cfix(n_mpc_matches_pytorch.read(), k=63, f=31) / N_SAMPLES + +print_ln("\n=== Results Summary ===") +print_ln("PyTorch Accuracy: %s", pt_accuracy) +print_ln("MP-SPDZ Correct: %s/%s", n_correct.read(), N_SAMPLES) +print_ln("MP-SPDZ Accuracy: %s", mpc_accuracy.reveal()) +print_ln("MPC-PyTorch Agreement: %s/%s = %s", + n_mpc_matches_pytorch.read(), N_SAMPLES, match_rate.reveal()) + +# ============================================================================ +# Layer-by-Layer Comparison using Forward Hooks +# ============================================================================ + +print_ln("\n=== Layer-by-Layer Comparison ===") + +# Map to store PyTorch activations +activation_map = {} + +def get_activation(name): + """Create a forward hook to capture layer outputs.""" + def hook(model, input, output): + if isinstance(output, tuple): + actual_output = output[0] + else: + actual_output = output + activation_map[name] = actual_output.detach() + return hook + +# Build layer comparison list +def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt): + """Map MPC BertLayer components to PyTorch components.""" + return [ + (bert_layer_mpc.multi_head_attention, bert_layer_pt.attention), + (bert_layer_mpc.intermediate, bert_layer_pt.intermediate), + (bert_layer_mpc.output, bert_layer_pt.output), + (bert_layer_mpc, bert_layer_pt), + ] + +# Build complete layer comparison list +layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in + zip(mpc_layers[:-4], model.bert.encoder.layer)] +layers_to_compare = [x for xs in layers_to_compare for x in xs] +layers_to_compare.append((mpc_layers[-4], model.bert.pooler)) +layers_to_compare.append((mpc_layers[-3], model.dropout)) +layers_to_compare.append((mpc_layers[-2], model.classifier)) + +# Register forward hooks +for layer_id, (_, pt_layer) in enumerate(layers_to_compare): + pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}')) + +# Run PyTorch forward pass to populate activation_map +print("Capturing PyTorch layer outputs...") +with torch.no_grad(): + for i in range(N_SAMPLES): + activation_map.clear() # Clear for each sample + + # Get sample embedding + with embedded_data.formatted_as("torch", ["embedding"]): + sample_embedding = embedded_data[i]['embedding'].unsqueeze(0) + + # Run forward through wrapped model + _ = bert_wrapped(sample_embedding) + + # Store activations for this sample + if i == 0: # Only compare first sample to save time + break + +print(f"Captured {len(activation_map)} layer outputs from PyTorch") + +# Run MPC forward pass using reveal_correctness +import numpy +pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities])) +pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor) + +test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:])) +test_embeddings_one.assign(test_embeddings.get_part_vector(0)) + +pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:])) +pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0)) + +print_ln("Running MPC forward pass for layer comparison...") +_ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE) + +# Compare layers +print_ln("\nLayer-by-layer comparison (Sample 0 only):") +print_ln("Layer | Total Absolute Difference | First 8 Values") +print_ln("-" * 80) + +for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare): + layer_id = f"{idx}.{type(pt_layer).__name__}" + + if layer_id not in activation_map: + continue + + # Skip dropout layers since they use different random masks + if 'Dropout' in type(pt_layer).__name__: + print_ln("%s | Skipped (dropout)", layer_id) + continue + + # Get PyTorch values + pt_values = activation_map[layer_id] + pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal() + + # Get MPC values + mpc_output = mpc_layer.Y[0].get_vector().reveal() + + # Compute sum of absolute differences + diff = sum(abs(pt_at_runtime - mpc_output)) + + # Print layer comparison with first 8 values + print_ln("%s | Diff: %s", layer_id, diff) + print_ln(" PyTorch: %s", pt_at_runtime[:8]) + print_ln(" MP-SPDZ: %s", mpc_output[:8]) + +print_ln("\n=== Inference Complete ===") From d0f955c2d3a98f9e12589a73b4c112a014fab488 Mon Sep 17 00:00:00 2001 From: Hidde L Date: Tue, 14 Oct 2025 21:43:18 +0200 Subject: [PATCH 2/8] avg diff, dropout --- Compiler/ml.py | 19 ++++++++----------- Programs/Source/bert_inference.mpc | 20 +++----------------- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 808721ee..58756439 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1285,14 +1285,14 @@ class FlexDropout(NoVariableLayer): n_bits = -math.log(self.alpha, 2) assert n_bits == int(n_bits) n_bits = int(n_bits) - self.B.assign_all(1) - self.alpha = 0.0 # TODO: temp disable for reproducibility - # @for_range_opt_multithread(self.n_threads, len(batch)) - # def _(i): - # size = reduce(operator.mul, self.shape[1:]) - # self.B[i].assign_vector(util.tree_reduce( - # util.or_op, (sint.get_random_bit(size=size) - # for i in range(n_bits)))) + # self.B.assign_all(1) + # self.alpha = 0.0 # TODO: temp disable for reproducibility + @for_range_opt_multithread(self.n_threads, len(batch)) + def _(i): + size = reduce(operator.mul, self.shape[1:]) + self.B[i].assign_vector(util.tree_reduce( + util.or_op, (sint.get_random_bit(size=size) + for i in range(n_bits)))) @for_range_opt_multithread(self.n_threads, len(batch)) def _(i): self.Y[i].assign_vector(1 / (1 - self.alpha) * @@ -2889,7 +2889,6 @@ class BertPooler(BertBase): # batch contains [n_batch, n_heads, n_dim] @for_range(len(batch)) def _(j): - print_ln("Pooling %s %s", j, batch[j]) self.dense.X[j][:] = self.X[batch[j]][0][:] # if self.debug_output: @@ -3266,8 +3265,6 @@ class MultiHeadAttention(BertBase): inc_batch = regint.Array(N) inc_batch.assign(regint.inc(N)) - print_ln("post forward") - if self.debug_output: # print_ln('forward layer wq full %s', self.wq.X.reveal()) print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal())) diff --git a/Programs/Source/bert_inference.mpc b/Programs/Source/bert_inference.mpc index 0e891322..170272d3 100644 --- a/Programs/Source/bert_inference.mpc +++ b/Programs/Source/bert_inference.mpc @@ -1,10 +1,6 @@ """ BERT Inference in MP-SPDZ -This program demonstrates secure multi-party computation (MPC) inference using a -pre-trained BERT model for sequence classification. It compares PyTorch and MP-SPDZ -implementations layer-by-layer and computes accuracy on the QNLI task from GLUE benchmark. - The program: 1. Loads a pre-trained BERT-tiny model fine-tuned on QNLI 2. Converts it to MP-SPDZ representation using ml.layers_from_torch() @@ -12,14 +8,6 @@ The program: 4. Compares MP-SPDZ outputs with PyTorch outputs layer-by-layer 5. Computes and reports accuracy -Usage: - ./Scripts/compile-run.py -E replicated-ring bert_inference - -Configuration: - - MODEL_NAME: HuggingFace model identifier - - MAX_LENGTH: Maximum sequence length for tokenization - - N_SAMPLES: Number of validation samples to run - - BATCH_SIZE: Batch size for MP-SPDZ inference """ import ml @@ -34,7 +22,7 @@ from datasets import load_dataset MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden) MAX_LENGTH = 64 # Maximum sequence length -N_SAMPLES = 10 # Number of samples to evaluate +N_SAMPLES = 1 # Number of samples to evaluate BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance) # GLUE task configuration @@ -141,8 +129,6 @@ print(f"PyTorch accuracy: {pt_accuracy:.4f} ({sum(p == l for p, l in zip(pt_pred # MP-SPDZ Model Conversion # ============================================================================ -print("\nConverting BERT model to MP-SPDZ...") - class BertEncoderWithHead(nn.Module): """Wrapper combining BERT encoder, pooler, dropout, and classification head.""" @@ -278,7 +264,7 @@ print_ln("\n=== Results Summary ===") print_ln("PyTorch Accuracy: %s", pt_accuracy) print_ln("MP-SPDZ Correct: %s/%s", n_correct.read(), N_SAMPLES) print_ln("MP-SPDZ Accuracy: %s", mpc_accuracy.reveal()) -print_ln("MPC-PyTorch Agreement: %s/%s = %s", +print_ln("MPC-PyTorch Match: %s/%s = %s", n_mpc_matches_pytorch.read(), N_SAMPLES, match_rate.reveal()) # ============================================================================ @@ -382,7 +368,7 @@ for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare): diff = sum(abs(pt_at_runtime - mpc_output)) # Print layer comparison with first 8 values - print_ln("%s | Diff: %s", layer_id, diff) + print_ln("%s | Avg. Diff: %s", layer_id, diff / sum(pt_values.shape)) print_ln(" PyTorch: %s", pt_at_runtime[:8]) print_ln(" MP-SPDZ: %s", mpc_output[:8]) From da0e4420d1c2c6da50a1575663f2d678922ca329 Mon Sep 17 00:00:00 2001 From: Hidde L Date: Tue, 14 Oct 2025 22:38:20 +0200 Subject: [PATCH 3/8] Update n_samples and output in test script --- Programs/Source/bert_inference.mpc | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/Programs/Source/bert_inference.mpc b/Programs/Source/bert_inference.mpc index 170272d3..39eec8f1 100644 --- a/Programs/Source/bert_inference.mpc +++ b/Programs/Source/bert_inference.mpc @@ -22,7 +22,7 @@ from datasets import load_dataset MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden) MAX_LENGTH = 64 # Maximum sequence length -N_SAMPLES = 1 # Number of samples to evaluate +N_SAMPLES = 25 # Number of samples to evaluate BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance) # GLUE task configuration @@ -343,8 +343,7 @@ _ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, # Compare layers print_ln("\nLayer-by-layer comparison (Sample 0 only):") -print_ln("Layer | Total Absolute Difference | First 8 Values") -print_ln("-" * 80) +print_ln("=" * 100) for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare): layer_id = f"{idx}.{type(pt_layer).__name__}" @@ -364,12 +363,16 @@ for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare): # Get MPC values mpc_output = mpc_layer.Y[0].get_vector().reveal() - # Compute sum of absolute differences - diff = sum(abs(pt_at_runtime - mpc_output)) + # Compute detailed statistics + total_abs_diff = sum(abs(pt_at_runtime - mpc_output)) + pt_magnitude = sum(abs(pt_at_runtime)) - # Print layer comparison with first 8 values - print_ln("%s | Avg. Diff: %s", layer_id, diff / sum(pt_values.shape)) - print_ln(" PyTorch: %s", pt_at_runtime[:8]) - print_ln(" MP-SPDZ: %s", mpc_output[:8]) + # Print layer comparison + print_ln("\n%s", layer_id) + print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime)) + print_ln(" Total Abs Diff: %s", total_abs_diff) + print_ln(" PT Total Magnitude: %s", pt_magnitude) + print_ln(" First 8 PT: %s", pt_at_runtime[:8]) + print_ln(" First 8 MPC: %s", mpc_output[:8]) print_ln("\n=== Inference Complete ===") From 2162b79b7315a141fe91e13c7578540e78465eed Mon Sep 17 00:00:00 2001 From: Hidde L Date: Mon, 20 Oct 2025 11:14:38 +0200 Subject: [PATCH 4/8] Revert unrelated changes --- Compiler/ml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 58756439..26f95f82 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -290,7 +290,7 @@ class Layer: return type(self).__name__ + str(self._Y.shape) def __repr__(self): - return '%s(%s)' % (type(self).__name__, self.Y.shape) + return '%s(%s)' % (type(self).__name__, self._Y.shape) class NoVariableLayer(Layer): input_from = lambda *args, **kwargs: None @@ -467,7 +467,7 @@ class LinearOutput(OutputBase): class MultiOutputBase(NoVariableLayer): def __init__(self, N, d_out, approx=False, debug=False): self.X = sfix.Matrix(N, d_out) - self.Y = sfix.Matrix(N, d_out) + self.Y = sint.Matrix(N, d_out) self.nabla_X = sfix.Matrix(N, d_out) self.l = MemValue(sfix(-1)) self.losses = sfix.Array(N) From 9ac22e3650f006336645f298eeae8b2114d1d75a Mon Sep 17 00:00:00 2001 From: Hidde L Date: Tue, 21 Oct 2025 15:45:07 +0200 Subject: [PATCH 5/8] fix flexdense --- Compiler/ml.py | 61 +++++++++++++++++++++++++++++--------------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 26f95f82..9101a57c 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -978,31 +978,28 @@ class FlexDense(Dense): # flattened_array version result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address) - max_size = get_program().budget // self.d_out + max_size = get_program().budget - # for now we assume that batch size is total size - # assert N == self.N - # batch contains the indices of the batches in self.N, we want to expand to have self.d too - # batch_with_d = + # X is stored at full dataset indices, batch specifies which samples to use + X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) + + # Precompute batch_d_indices for all N*d elements + # For each sample in batch, expand to d consecutive indices + batch_d_indices = regint.Array(N * self.d) + @for_range(N) + def _(i): + actual_sample = batch[i] + batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d) + # @for_range(self.d) + # def _(d_idx): + # batch_d_indices[i * self.d + d_idx] = actual_sample * self.d + d_idx - # we are going to assume the batch is continuous - batch_0 = MemValue(batch[0]) @multithread(self.n_threads, N * self.d, max_size) def _(base, size): - batch_offset = batch_0 * self.d - X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) - offset = regint.inc(size, base=base + batch_offset) - # array_offset = regint.Array(size) - # array_offset.assign_all(batch_offset) - # print_ln("array offset %s", array_offset.reveal()) - # offset[:] += array_offset[:] - # print_ln("total offset %s %s", batch_offset, offset) - result_matrix.assign_part_vector( X_sub.direct_mul(self.W, indices=( - offset, regint.inc(self.d_in), + batch_d_indices.get_vector(base, size), regint.inc(self.d_in), regint.inc(self.d_in), regint.inc(self.d_out))), base) - # print_ln("result matrix done") if self.input_bias: if self.d_out == 1: @@ -1044,14 +1041,12 @@ class FlexDense(Dense): nabla_X.alloc() # flattened matrix version - max_size = get_program().budget // self.d_in result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address) - batch_0 = MemValue(batch[0]) - @multithread(self.n_threads, N * self.d, max_size) + # Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices + @multithread(self.n_threads, N * self.d) def _(base, size): - batch_offset = batch_0 * self.d - X_sub = sfix.Matrix(self.N * self.d, self.d_out, address=f_schur_Y.address) - offset = regint.inc(size, base=base + batch_offset) + X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) + offset = regint.inc(size, base=base) result_matrix.assign_part_vector( X_sub.direct_mul_trans(self.W, indices=( @@ -1074,14 +1069,28 @@ class FlexDense(Dense): tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) # tmp.assign_all(0) + # A (f_schur_Y/nabla_Y) is stored at sequential batch indices [0, 1, ..., N-1] A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) + # B (X) is stored at the full dataset size, not just the batch B = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) @multithread(self.n_threads, self.d_in) def _(base, size): + # For A: use sequential indices [0, 1, ..., N*d-1] + # For B: use actual batch indices expanded for d dimension + batch_d_indices = regint.Array(N * self.d) + @for_range(N) + def _(i): + # batch[i] gives the actual sample index in the full dataset + # For FlexDense with d>1, we need to map to flattened indices + actual_sample_idx = batch[i] + @for_range(self.d) + def _(d_idx): + batch_d_indices[i * self.d + d_idx] = actual_sample_idx * self.d + d_idx + mp = B.direct_trans_mul(A, reduce=False, indices=(regint.inc(size, base), - regint.inc(N * self.d), # Not sure + batch_d_indices.get_vector(), regint.inc(N * self.d), regint.inc(self.d_out))) @@ -4612,7 +4621,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, if name == 'Linear': assert mul(input_shape[1:]) == item.in_features assert item.bias is not None - layers.append(Dense(input_shape[0], item.in_features, + layers.append(FlexDense(input_shape[0], item.in_features, item.out_features)) if input_via is not None: shapes = [x.shape for x in (layers[-1].W, layers[-1].b)] From ba2b4407122c69ff5d7480313e700793150d00dd Mon Sep 17 00:00:00 2001 From: Hidde L Date: Tue, 21 Oct 2025 15:55:20 +0200 Subject: [PATCH 6/8] Merge FlexDense with Dense --- Compiler/ml.py | 534 +++++++++++++++++-------------------------------- 1 file changed, 183 insertions(+), 351 deletions(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 9101a57c..5ac85e03 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -725,350 +725,6 @@ class DenseBase(Layer): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) - A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address) - B = sfix.Matrix(self.N, self.d_in, address=self.X.address) - - @multithread(self.n_threads, self.d_in) - def _(base, size): - mp = B.direct_trans_mul(A, reduce=False, - indices=(regint.inc(size, base), - batch.get_vector(), - regint.inc(N), - regint.inc(self.d_out))) - tmp.assign_part_vector(mp, base) - - progress('nabla W (matmul)') - - @multithread(self.n_threads, self.d_in * self.d_out, - max_size=get_program().budget) - def _(base, size): - self.nabla_W.assign_vector( - tmp.get_vector(base, size).reduce_after_mul(), base=base) - - if self.print_random_update: - print_ln('backward %s', self) - i = regint.get_random(64) % self.d_in - j = regint.get_random(64) % self.d_out - print_ln('%s at (%s, %s): before=%s after=%s A=%s B=%s', - str(self.nabla_W), i, j, tmp[i][j].v.reveal(), - self.nabla_W[i][j].reveal(), - A.get_column(j).reveal(), - B.get_column_by_row_indices( - batch.get_vector(), i).reveal()) - print_ln('batch=%s B=%s', batch, - [self.X[bi][0][i].reveal() for bi in batch]) - - progress('nabla W') - - self.nabla_b.assign_vector(sum(sum(f_schur_Y[k][j].get_vector() - for k in range(N)) - for j in range(self.d))) - - progress('nabla b') - - if self.debug_output: - print_ln('dense nabla Y %s', self.nabla_Y.reveal_nested()) - print_ln('dense W %s', self.W.reveal_nested()) - print_ln('dense nabla X %s', self.nabla_X.reveal_nested()) - if self.debug: - limit = N * self.debug - @for_range_opt(self.d_in) - def _(i): - @for_range_opt(self.d_out) - def _(j): - to_check = self.nabla_W[i][j].reveal() - check = sum(to_check > limit) + sum(to_check < -limit) - @if_(check) - def _(): - print_ln('nabla W %s %s %s: %s', i, j, self.W.sizes, to_check) - print_ln('Y %s', [f_schur_Y[k][0][j].reveal() - for k in range(N)]) - print_ln('X %s', [self.X[k][0][i].reveal() - for k in range(N)]) - @for_range_opt(self.d_out) - def _(j): - to_check = self.nabla_b[j].reveal() - check = sum(to_check > limit) + sum(to_check < -limit) - @if_(check) - def _(): - print_ln('nabla b %s %s: %s', j, len(self.b), to_check) - print_ln('Y %s', [f_schur_Y[k][0][j].reveal() - for k in range(N)]) - @for_range_opt(len(batch)) - def _(i): - to_check = self.nabla_X[i].get_vector().reveal() - check = sum(to_check > limit) + sum(to_check < -limit) - @if_(check) - def _(): - print_ln('X %s %s', i, self.X[i].reveal_nested()) - print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested()) - -class Dense(DenseBase): - """ Fixed-point dense (matrix multiplication) layer. - - :param N: number of examples - :param d_in: input dimension - :param d_out: output dimension - """ - def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): - if activation == 'id': - self.activation_layer = None - elif activation == 'relu': - self.activation_layer = Relu([N, d, d_out]) - elif activation == 'square': - self.activation_layer = Square([N, d, d_out]) - else: - raise CompilerError('activation not supported: %s', activation) - - self.N = N - self.d_in = d_in - self.d_out = d_out - self.d = d - self.activation = activation - - self.X = Tensor([N, d, d_in], sfix) - self.Y = Tensor([N, d, d_out], sfix) - self.W = Tensor([d_in, d_out], sfix) - self.b = sfix.Array(d_out) - - back_N = min(N, self.back_batch_size) - self.nabla_Y = Tensor([back_N, d, d_out], sfix) - self.nabla_X = Tensor([back_N, d, d_in], sfix) - self.nabla_W = Tensor([d_in, d_out], sfix) - self.nabla_b = sfix.Array(d_out) - - self.debug = debug - - l = self.activation_layer - if l: - self.f_input = l.X - l.Y = self.Y - l.nabla_Y = self.nabla_Y - else: - self.f_input = self.Y - - def __repr__(self): - return '%s(%s, %s, %s, activation=%s)' % \ - (type(self).__name__, self.N, self.d_in, - self.d_out, repr(self.activation)) - - def reset(self): - d_in = self.d_in - d_out = self.d_out - r = math.sqrt(6.0 / (d_in + d_out)) - print('Initializing dense weights in [%f,%f]' % (-r, r)) - self.W.randomize(-r, r, n_threads=self.n_threads) - self.b.assign_all(0) - - def input_from(self, player, **kwargs): - self.W.input_from(player, **kwargs) - if self.input_bias: - self.b.input_from(player, **kwargs) - - def compute_f_input(self, batch): - N = len(batch) - assert self.d == 1 - if self.input_bias: - prod = MultiArray([N, self.d, self.d_out], sfix) - else: - prod = self.f_input - max_size = get_program().budget - @multithread(self.n_threads, N, max_size) - def _(base, size): - X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address) - prod.assign_part_vector( - X_sub.direct_mul(self.W, indices=( - batch.get_vector(base, size), regint.inc(self.d_in), - regint.inc(self.d_in), regint.inc(self.d_out))), base) - - if self.input_bias: - if self.d_out == 1: - @multithread(self.n_threads, N) - def _(base, size): - v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) - self.f_input.assign_vector(v, base) - else: - @for_range_multithread(self.n_threads, 100, N) - def _(i): - v = prod[i].get_vector() + self.b.get_vector() - self.f_input[i].assign_vector(v) - progress('f input') - - def _forward(self, batch=None): - if batch is None: - batch = regint.Array(self.N) - batch.assign(regint.inc(self.N)) - self.compute_f_input(batch=batch) - if self.activation_layer: - self.activation_layer.forward(batch, training=True) - if self.debug_output: - print_ln('dense X %s', self.X.reveal_nested()) - print_ln('dense W %s', self.W.reveal_nested()) - print_ln('dense b %s', self.b.reveal_nested()) - print_ln('dense Y %s', self.Y.reveal_nested()) - if self.debug: - limit = self.debug - @for_range_opt(len(batch)) - def _(i): - @for_range_opt(self.d_out) - def _(j): - to_check = self.Y[i][0][j].reveal() - check = to_check > limit - @if_(check) - def _(): - print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check) - print_ln('X %s', self.X[i].reveal_nested()) - print_ln('W %s', - [self.W[k][j].reveal() for k in range(self.d_in)]) - - def backward(self, compute_nabla_X=True, batch=None): - N = len(batch) - d = self.d - d_out = self.d_out - X = self.X - Y = self.Y - W = self.W - b = self.b - nabla_X = self.nabla_X - nabla_Y = self.nabla_Y - nabla_W = self.nabla_W - nabla_b = self.nabla_b - - if self.activation_layer: - self.activation_layer.backward(batch) - f_schur_Y = self.activation_layer.nabla_X - else: - f_schur_Y = nabla_Y - - if compute_nabla_X: - nabla_X.alloc() - @multithread(self.n_threads, N) - def _(base, size): - B = sfix.Matrix(N, d_out, address=f_schur_Y.address) - nabla_X.assign_part_vector( - B.direct_mul_trans(W, indices=(regint.inc(size, base), - regint.inc(self.d_out), - regint.inc(self.d_out), - regint.inc(self.d_in))), - base) - - if self.print_random_update: - print_ln('backward %s', self) - index = regint.get_random(64) % self.nabla_X.total_size() - print_ln('%s nabla_X at %s: %s', str(self.nabla_X), - index, self.nabla_X.to_array()[index].reveal()) - - progress('nabla X') - - self.backward_params(f_schur_Y, batch=batch) - - -class FlexDense(Dense): - """ Fixed-point dense (matrix multiplication) layer with flexible number of dimensions. - Behaves like torch's nn.Linear which loops over the additional dimensions. - - :param N: number of examples - :param d_in: input dimension - :param d_out: output dimension - """ - - def compute_f_input(self, batch): - N = len(batch) - prod = MultiArray([N, self.d, self.d_out], sfix) - - # flattened_array version - result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address) - max_size = get_program().budget - - # X is stored at full dataset indices, batch specifies which samples to use - X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) - - # Precompute batch_d_indices for all N*d elements - # For each sample in batch, expand to d consecutive indices - batch_d_indices = regint.Array(N * self.d) - @for_range(N) - def _(i): - actual_sample = batch[i] - batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d) - # @for_range(self.d) - # def _(d_idx): - # batch_d_indices[i * self.d + d_idx] = actual_sample * self.d + d_idx - - @multithread(self.n_threads, N * self.d, max_size) - def _(base, size): - result_matrix.assign_part_vector( - X_sub.direct_mul(self.W, indices=( - batch_d_indices.get_vector(base, size), regint.inc(self.d_in), - regint.inc(self.d_in), regint.inc(self.d_out))), base) - - if self.input_bias: - if self.d_out == 1: - @multithread(self.n_threads, N) - def _(base, size): - v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) - self.f_input.assign_vector(v, base) - else: - @for_range_multithread(self.n_threads, 100, [N, self.d]) - def _(i, j): - v = prod[i][j].get_vector() + self.b.get_vector() - # print_ln("running bias %s %s %s", i, j, v.reveal()) - self.f_input[i][j].assign_vector(v) - - # print_ln("FlexDense f_inpu full %s", self.f_input.reveal()) - - progress('f input') - - def backward(self, compute_nabla_X=True, batch=None): - N = len(batch) - d = self.d - d_out = self.d_out - X = self.X - Y = self.Y - W = self.W - b = self.b - nabla_X = self.nabla_X - nabla_Y = self.nabla_Y - nabla_W = self.nabla_W - nabla_b = self.nabla_b - - if self.activation_layer: - self.activation_layer.backward(batch) - f_schur_Y = self.activation_layer.nabla_X - else: - f_schur_Y = nabla_Y - - if compute_nabla_X: - nabla_X.alloc() - - # flattened matrix version - result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address) - # Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices - @multithread(self.n_threads, N * self.d) - def _(base, size): - X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) - offset = regint.inc(size, base=base) - - result_matrix.assign_part_vector( - X_sub.direct_mul_trans(self.W, indices=( - offset, regint.inc(self.d_out), - regint.inc(self.d_out), regint.inc(self.d_in))), base) - - if self.print_random_update: - print_ln('backward %s', self) - index = regint.get_random(64) % self.nabla_X.total_size() - print_ln('%s nabla_X at %s: %s', str(self.nabla_X), - index, self.nabla_X.to_array()[index].reveal()) - - progress('nabla X') - - self.backward_params(f_schur_Y, batch=batch) - - def backward_params(self, f_schur_Y, batch): - print("backward params flexdense") - N = len(batch) - tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) - # tmp.assign_all(0) - # A (f_schur_Y/nabla_Y) is stored at sequential batch indices [0, 1, ..., N-1] A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) # B (X) is stored at the full dataset size, not just the batch @@ -1082,7 +738,7 @@ class FlexDense(Dense): @for_range(N) def _(i): # batch[i] gives the actual sample index in the full dataset - # For FlexDense with d>1, we need to map to flattened indices + # For Dense with d>1, we need to map to flattened indices actual_sample_idx = batch[i] @for_range(self.d) def _(d_idx): @@ -1162,6 +818,182 @@ class FlexDense(Dense): print_ln('X %s %s', i, self.X[i].reveal_nested()) print_ln('Y %s %s', i, f_schur_Y[i].reveal_nested()) +class Dense(DenseBase): + """ Fixed-point dense (matrix multiplication) layer. + Supports inputs of size [N, d, d_in] to map to [N, d, d_out]. If d > 1, the layer + behaves like torch's nn.Linear which loops over the additional d + + :param N: number of examples + :param d_in: input dimension + :param d_out: output dimension + :param d: (optional) extra dimension + """ + def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): + if activation == 'id': + self.activation_layer = None + elif activation == 'relu': + self.activation_layer = Relu([N, d, d_out]) + elif activation == 'square': + self.activation_layer = Square([N, d, d_out]) + else: + raise CompilerError('activation not supported: %s', activation) + + self.N = N + self.d_in = d_in + self.d_out = d_out + self.d = d + self.activation = activation + + self.X = Tensor([N, d, d_in], sfix) + self.Y = Tensor([N, d, d_out], sfix) + self.W = Tensor([d_in, d_out], sfix) + self.b = sfix.Array(d_out) + + back_N = min(N, self.back_batch_size) + self.nabla_Y = Tensor([back_N, d, d_out], sfix) + self.nabla_X = Tensor([back_N, d, d_in], sfix) + self.nabla_W = Tensor([d_in, d_out], sfix) + self.nabla_b = sfix.Array(d_out) + + self.debug = debug + + l = self.activation_layer + if l: + self.f_input = l.X + l.Y = self.Y + l.nabla_Y = self.nabla_Y + else: + self.f_input = self.Y + + def __repr__(self): + return '%s(%s, %s, %s, activation=%s)' % \ + (type(self).__name__, self.N, self.d_in, + self.d_out, repr(self.activation)) + + def reset(self): + d_in = self.d_in + d_out = self.d_out + r = math.sqrt(6.0 / (d_in + d_out)) + print('Initializing dense weights in [%f,%f]' % (-r, r)) + self.W.randomize(-r, r, n_threads=self.n_threads) + self.b.assign_all(0) + + def input_from(self, player, **kwargs): + self.W.input_from(player, **kwargs) + if self.input_bias: + self.b.input_from(player, **kwargs) + + def compute_f_input(self, batch): + N = len(batch) + prod = MultiArray([N, self.d, self.d_out], sfix) + + # flattened_array version + result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address) + max_size = get_program().budget + + # X is stored at full dataset indices, batch specifies which samples to use + X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) + + # Precompute batch_d_indices for all N*d elements + # For each sample in batch, expand to d consecutive indices + batch_d_indices = regint.Array(N * self.d) + @for_range(N) + def _(i): + actual_sample = batch[i] + batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d) + + @multithread(self.n_threads, N * self.d, max_size) + def _(base, size): + result_matrix.assign_part_vector( + X_sub.direct_mul(self.W, indices=( + batch_d_indices.get_vector(base, size), regint.inc(self.d_in), + regint.inc(self.d_in), regint.inc(self.d_out))), base) + + if self.input_bias: + if self.d_out == 1: + @multithread(self.n_threads, N) + def _(base, size): + v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) + self.f_input.assign_vector(v, base) + else: + @for_range_multithread(self.n_threads, 100, [N, self.d]) + def _(i, j): + v = prod[i][j].get_vector() + self.b.get_vector() + self.f_input[i][j].assign_vector(v) + + progress('f input') + + def _forward(self, batch=None): + if batch is None: + batch = regint.Array(self.N) + batch.assign(regint.inc(self.N)) + self.compute_f_input(batch=batch) + if self.activation_layer: + self.activation_layer.forward(batch, training=True) + if self.debug_output: + print_ln('dense X %s', self.X.reveal_nested()) + print_ln('dense W %s', self.W.reveal_nested()) + print_ln('dense b %s', self.b.reveal_nested()) + print_ln('dense Y %s', self.Y.reveal_nested()) + if self.debug: + limit = self.debug + @for_range_opt(len(batch)) + def _(i): + @for_range_opt(self.d_out) + def _(j): + to_check = self.Y[i][0][j].reveal() + check = to_check > limit + @if_(check) + def _(): + print_ln('dense Y %s %s %s %s', i, j, self.W.sizes, to_check) + print_ln('X %s', self.X[i].reveal_nested()) + print_ln('W %s', + [self.W[k][j].reveal() for k in range(self.d_in)]) + + def backward(self, compute_nabla_X=True, batch=None): + N = len(batch) + d = self.d + d_out = self.d_out + X = self.X + Y = self.Y + W = self.W + b = self.b + nabla_X = self.nabla_X + nabla_Y = self.nabla_Y + nabla_W = self.nabla_W + nabla_b = self.nabla_b + + if self.activation_layer: + self.activation_layer.backward(batch) + f_schur_Y = self.activation_layer.nabla_X + else: + f_schur_Y = nabla_Y + + if compute_nabla_X: + nabla_X.alloc() + + # flattened matrix version + result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address) + # Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices + @multithread(self.n_threads, N * self.d) + def _(base, size): + X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) + offset = regint.inc(size, base=base) + + result_matrix.assign_part_vector( + X_sub.direct_mul_trans(self.W, indices=( + offset, regint.inc(self.d_out), + regint.inc(self.d_out), regint.inc(self.d_in))), base) + + if self.print_random_update: + print_ln('backward %s', self) + index = regint.get_random(64) % self.nabla_X.total_size() + print_ln('%s nabla_X at %s: %s', str(self.nabla_X), + index, self.nabla_X.to_array()[index].reveal()) + + progress('nabla X') + + self.backward_params(f_schur_Y, batch=batch) class QuantizedDense(DenseBase): @@ -3108,7 +2940,7 @@ class BertIntermediate(BertBase): input_shape = [n_examples, seq_len, hidden_size] output_shape = [n_examples, seq_len, intermediate_size] super(BertIntermediate, self).__init__(input_shape, output_shape) - self.dense = FlexDense(n_examples, hidden_size, intermediate_size, seq_len) + self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len) self.activation = Gelu([n_examples, seq_len, intermediate_size]) @@ -3150,7 +2982,7 @@ class BertOutput(BertBase): self.input_shape = input_shape print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx) super(BertOutput, self).__init__(input_shape, output_shape) - self.dense = FlexDense(n_examples, intermediate_size, hidden_size, seq_len) + self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len) self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx) self.dropout = FlexDropout([n_examples, seq_len, hidden_size], alpha=dropout) @@ -3241,9 +3073,9 @@ class MultiHeadAttention(BertBase): self.seq_len = seq_len self.hidden_size = hidden_size - self.wq = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) - self.wk = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) - self.wv = FlexDense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wq = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wk = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wv = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) self.dropout = FlexDropout([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], alpha=dropout) # I think? # TODO: DROPOUT? self.output = BertOutput(internal_shape, hidden_size, hidden_size, seq_len, dropout, layernorm_eps, rsqrt_approx) @@ -4621,7 +4453,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, if name == 'Linear': assert mul(input_shape[1:]) == item.in_features assert item.bias is not None - layers.append(FlexDense(input_shape[0], item.in_features, + layers.append(Dense(input_shape[0], item.in_features, item.out_features)) if input_via is not None: shapes = [x.shape for x in (layers[-1].W, layers[-1].b)] From e17e175b16826a4b9da1e927fd940edbc546e1bf Mon Sep 17 00:00:00 2001 From: Hidde L Date: Tue, 21 Oct 2025 16:26:15 +0200 Subject: [PATCH 7/8] Merge FlexDropout into Dropout Update mnist_full examples using Dropout Clean up code style --- Compiler/ml.py | 81 +++++--------------------------- Programs/Source/mnist_full_A.mpc | 8 ++-- Programs/Source/mnist_full_C.mpc | 10 ++-- Programs/Source/mnist_full_D.mpc | 6 +-- 4 files changed, 23 insertions(+), 82 deletions(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 5ac85e03..32a9eb0a 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -978,12 +978,13 @@ class Dense(DenseBase): @multithread(self.n_threads, N * self.d) def _(base, size): X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) - offset = regint.inc(size, base=base) result_matrix.assign_part_vector( - X_sub.direct_mul_trans(self.W, indices=( - offset, regint.inc(self.d_out), - regint.inc(self.d_out), regint.inc(self.d_in))), base) + X_sub.direct_mul_trans(self.W, indices=(regint.inc(size, base=base), + regint.inc(self.d_out), + regint.inc(self.d_out), + regint.inc(self.d_in))), + base) if self.print_random_update: print_ln('backward %s', self) @@ -995,7 +996,6 @@ class Dense(DenseBase): self.backward_params(f_schur_Y, batch=batch) - class QuantizedDense(DenseBase): def __init__(self, N, d_in, d_out): self.N = N @@ -1048,64 +1048,7 @@ class Dropout(NoVariableLayer): """ Dropout layer. :param N: number of examples - :param d1: total dimension - :param alpha: probability (power of two) - """ - def __init__(self, N, d1, d2=1, alpha=0.5): - self.N = N - self.d1 = d1 - self.d2 = d2 - self.X = Tensor([N, d1, d2], sfix) - self.Y = Tensor([N, d1, d2], sfix) - self.nabla_Y = Tensor([N, d1, d2], sfix) - self.nabla_X = Tensor([N, d1, d2], sfix) - self.alpha = alpha - self.B = MultiArray([N, d1, d2], sint) - - def __repr__(self): - return '%s(%s, %s, alpha=%s)' % \ - (type(self).__name__, self.N, self.d1, self.alpha) - - def forward(self, batch, training=False): - if training: - n_bits = -math.log(self.alpha, 2) - assert n_bits == int(n_bits) - n_bits = int(n_bits) - # self.B.assign_all(1) # TODO: temp disable for reproducibility - # self.alpha = 0.0 # TODO: temp disable for reproducibility - @for_range_opt_multithread(self.n_threads, len(batch)) - def _(i): - size = self.d1 * self.d2 - self.B[i].assign_vector(util.tree_reduce( - util.or_op, (sint.get_random_bit(size=size) - for i in range(n_bits)))) - @for_range_opt_multithread(self.n_threads, len(batch)) - def _(i): - self.Y[i].assign_vector(1 / (1 - self.alpha) * - self.X[batch[i]].get_vector() * self.B[i].get_vector()) - else: - @for_range(len(batch)) - def _(i): - self.Y[i] = self.X[batch[i]] - if self.debug_output: - print_ln('dropout X %s', self.X.reveal_nested()) - print_ln('dropout Y %s', self.Y.reveal_nested()) - - def backward(self, compute_nabla_X=True, batch=None): - if compute_nabla_X: - @for_range_opt_multithread(self.n_threads, len(batch)) - def _(i): - self.nabla_X[batch[i]].assign_vector( - self.nabla_Y[i].get_vector() * self.B[i].get_vector()) - if self.debug_output: - print_ln('dropout nabla_Y %s', self.nabla_Y.reveal_nested()) - print_ln('dropout nabla_X %s', self.nabla_X.reveal_nested()) - -class FlexDropout(NoVariableLayer): - """ Dropout layer. - - :param N: number of examples - :param d1: total dimension + :param shape: list [N, ...] where N is the number of examples and an arbitrary amount of further dimensions :param alpha: probability (power of two) """ def __init__(self, shape, alpha=0.5): @@ -1126,8 +1069,6 @@ class FlexDropout(NoVariableLayer): n_bits = -math.log(self.alpha, 2) assert n_bits == int(n_bits) n_bits = int(n_bits) - # self.B.assign_all(1) - # self.alpha = 0.0 # TODO: temp disable for reproducibility @for_range_opt_multithread(self.n_threads, len(batch)) def _(i): size = reduce(operator.mul, self.shape[1:]) @@ -2984,7 +2925,7 @@ class BertOutput(BertBase): super(BertOutput, self).__init__(input_shape, output_shape) self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len) self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx) - self.dropout = FlexDropout([n_examples, seq_len, hidden_size], alpha=dropout) + self.dropout = Dropout([n_examples, seq_len, hidden_size], alpha=dropout) def forward(self, batch, input_tensor, training=False, input_tensor_batch=None): @@ -3076,7 +3017,7 @@ class MultiHeadAttention(BertBase): self.wq = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) self.wk = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) self.wv = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) - self.dropout = FlexDropout([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], alpha=dropout) # I think? # TODO: DROPOUT? + self.dropout = Dropout([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], alpha=dropout) # I think? # TODO: DROPOUT? self.output = BertOutput(internal_shape, hidden_size, hidden_size, seq_len, dropout, layernorm_eps, rsqrt_approx) self.context = sfix.Tensor([internal_shape, self.seq_len, hidden_size]) @@ -4306,8 +4247,8 @@ class keras: layers.append(FixAveragePool2d(input_shape, None, **layer[1])) input_shape = layers[-1].Y.sizes elif name == 'dropout': - layers.append(Dropout(batch_size, reduce( - operator.mul, layers[-1].Y.sizes[1:]), + layers.append(Dropout([batch_size] + [reduce( + operator.mul, layers[-1].Y.sizes[1:])], alpha=layer[1])) input_shape = layers[-1].Y.sizes elif name == 'flatten': @@ -4544,7 +4485,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, if alpha == 0.1: print('WARNING: dropout rate 0.1 not supported, using 0.125') alpha = 0.125 - layers.append(Dropout(input_shape[0], mul(layers[-1].Y.sizes[1:]), + layers.append(Dropout([input_shape[0]] + list(layers[-1].Y.sizes[1:]), alpha=alpha)) input_shape = layers[-1].Y.sizes elif name == 'BertForSequenceClassification': diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index 308465e3..dbb7bf73 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -77,16 +77,16 @@ if 'batchnorm' in program.args: if 'dropout' in program.args: for i in range(len(layers) - 1, 0, -1): - layers.insert(i, ml.Dropout(N, n_inner)) + layers.insert(i, ml.Dropout([N, n_inner])) if 'dropout-late' in program.args: - layers.insert(-1, ml.Dropout(N, n_inner)) + layers.insert(-1, ml.Dropout([N, n_inner])) if 'dropout-early' in program.args: - layers.insert(0, ml.Dropout(n_examples, n_features)) + layers.insert(0, ml.Dropout([n_examples, n_features])) if 'dropout-early.25' in program.args: - layers.insert(0, ml.Dropout(n_examples, n_features, alpha=.25)) + layers.insert(0, ml.Dropout([n_examples, n_features], alpha=.25)) layers += [ml.MultiOutput.from_args(program, n_examples, 10)] diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc index 2933d7fe..2aacc576 100644 --- a/Programs/Source/mnist_full_C.mpc +++ b/Programs/Source/mnist_full_C.mpc @@ -73,16 +73,16 @@ if 'batchnorm' in program.args: layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args)) if 'dropout' in program.args or 'dropout2' in program.args: - layers.insert(8, ml.Dropout(N, 500)) + layers.insert(8, ml.Dropout([N, 500])) elif 'dropout.25' in program.args: - layers.insert(8, ml.Dropout(N, 500, alpha=0.25)) + layers.insert(8, ml.Dropout([N, 500], alpha=0.25)) elif 'dropout.125' in program.args: - layers.insert(8, ml.Dropout(N, 500, alpha=0.125)) + layers.insert(8, ml.Dropout([N, 500], alpha=0.125)) if 'dropout2' in program.args: - layers.insert(6, ml.Dropout(N, 800, alpha=0.125)) + layers.insert(6, ml.Dropout([N, 800], alpha=0.125)) elif 'dropout1' in program.args: - layers.insert(6, ml.Dropout(N, 800, alpha=0.5)) + layers.insert(6, ml.Dropout([N, 800], alpha=0.5)) if 'no_relu' in program.args: for x in layers: diff --git a/Programs/Source/mnist_full_D.mpc b/Programs/Source/mnist_full_D.mpc index 13ad1398..6e1c7d0d 100644 --- a/Programs/Source/mnist_full_D.mpc +++ b/Programs/Source/mnist_full_D.mpc @@ -79,18 +79,18 @@ dropout = 'dropout' in program.args if '1dense' in program.args: if dropout: - layers += [ml.Dropout(N, n_inner)] + layers += [ml.Dropout([N, n_inner])] layers += [ml.Dense(N, n_inner, 10),] elif '2dense' in program.args: if dropout: - layers += [ml.Dropout(N, n_inner)] + layers += [ml.Dropout([N, n_inner])] layers += [ ml.Dense(N, n_inner, 100), ml.Relu([N, 100]), ml.Dense(N, 100, 10), ] if dropout or 'dropout1' in program.args: - layers.insert(-1, ml.Dropout(N, 100)) + layers.insert(-1, ml.Dropout([N, 100])) else: raise Exception('need to specify number of dense layers') From bfe3e7e038de32d0104cccf89c868ae7b389f9da Mon Sep 17 00:00:00 2001 From: Hidde L Date: Thu, 6 Nov 2025 16:16:21 -0500 Subject: [PATCH 8/8] Make interface of Dropout backwards compatible --- Compiler/ml.py | 9 ++++++++- Programs/Source/mnist_full_A.mpc | 8 ++++---- Programs/Source/mnist_full_C.mpc | 10 +++++----- Programs/Source/mnist_full_D.mpc | 6 +++--- 4 files changed, 20 insertions(+), 13 deletions(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 32a9eb0a..a1f1c4dc 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -1051,7 +1051,14 @@ class Dropout(NoVariableLayer): :param shape: list [N, ...] where N is the number of examples and an arbitrary amount of further dimensions :param alpha: probability (power of two) """ - def __init__(self, shape, alpha=0.5): + def __init__(self, N, d1=None, d2=1, alpha=0.5): + if isinstance(N, list) or isinstance(N, tuple): + shape = N + assert d1 is None, ("If shape is given as list/tuple, d1 must be None. " + "Alpha must be passed explicitly for backwards compatibility.") + else: + assert d1 is not None, "At least one non-batch dimension must be set" + shape = [N, d1] if d2 == 1 else [N, d1, d2] self.N = shape[0] self.X = Tensor(shape, sfix) self.Y = Tensor(shape, sfix) diff --git a/Programs/Source/mnist_full_A.mpc b/Programs/Source/mnist_full_A.mpc index dbb7bf73..308465e3 100644 --- a/Programs/Source/mnist_full_A.mpc +++ b/Programs/Source/mnist_full_A.mpc @@ -77,16 +77,16 @@ if 'batchnorm' in program.args: if 'dropout' in program.args: for i in range(len(layers) - 1, 0, -1): - layers.insert(i, ml.Dropout([N, n_inner])) + layers.insert(i, ml.Dropout(N, n_inner)) if 'dropout-late' in program.args: - layers.insert(-1, ml.Dropout([N, n_inner])) + layers.insert(-1, ml.Dropout(N, n_inner)) if 'dropout-early' in program.args: - layers.insert(0, ml.Dropout([n_examples, n_features])) + layers.insert(0, ml.Dropout(n_examples, n_features)) if 'dropout-early.25' in program.args: - layers.insert(0, ml.Dropout([n_examples, n_features], alpha=.25)) + layers.insert(0, ml.Dropout(n_examples, n_features, alpha=.25)) layers += [ml.MultiOutput.from_args(program, n_examples, 10)] diff --git a/Programs/Source/mnist_full_C.mpc b/Programs/Source/mnist_full_C.mpc index 2aacc576..2933d7fe 100644 --- a/Programs/Source/mnist_full_C.mpc +++ b/Programs/Source/mnist_full_C.mpc @@ -73,16 +73,16 @@ if 'batchnorm' in program.args: layers.insert(1, ml.BatchNorm([N, 24, 24, 20], args=program.args)) if 'dropout' in program.args or 'dropout2' in program.args: - layers.insert(8, ml.Dropout([N, 500])) + layers.insert(8, ml.Dropout(N, 500)) elif 'dropout.25' in program.args: - layers.insert(8, ml.Dropout([N, 500], alpha=0.25)) + layers.insert(8, ml.Dropout(N, 500, alpha=0.25)) elif 'dropout.125' in program.args: - layers.insert(8, ml.Dropout([N, 500], alpha=0.125)) + layers.insert(8, ml.Dropout(N, 500, alpha=0.125)) if 'dropout2' in program.args: - layers.insert(6, ml.Dropout([N, 800], alpha=0.125)) + layers.insert(6, ml.Dropout(N, 800, alpha=0.125)) elif 'dropout1' in program.args: - layers.insert(6, ml.Dropout([N, 800], alpha=0.5)) + layers.insert(6, ml.Dropout(N, 800, alpha=0.5)) if 'no_relu' in program.args: for x in layers: diff --git a/Programs/Source/mnist_full_D.mpc b/Programs/Source/mnist_full_D.mpc index 6e1c7d0d..13ad1398 100644 --- a/Programs/Source/mnist_full_D.mpc +++ b/Programs/Source/mnist_full_D.mpc @@ -79,18 +79,18 @@ dropout = 'dropout' in program.args if '1dense' in program.args: if dropout: - layers += [ml.Dropout([N, n_inner])] + layers += [ml.Dropout(N, n_inner)] layers += [ml.Dense(N, n_inner, 10),] elif '2dense' in program.args: if dropout: - layers += [ml.Dropout([N, n_inner])] + layers += [ml.Dropout(N, n_inner)] layers += [ ml.Dense(N, n_inner, 100), ml.Relu([N, 100]), ml.Dense(N, 100, 10), ] if dropout or 'dropout1' in program.args: - layers.insert(-1, ml.Dropout([N, 100])) + layers.insert(-1, ml.Dropout(N, 100)) else: raise Exception('need to specify number of dense layers')