diff --git a/Compiler/ml.py b/Compiler/ml.py index 2c87a4c3..a1f1c4dc 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -725,16 +725,31 @@ class DenseBase(Layer): N = len(batch) tmp = Matrix(self.d_in, self.d_out, unreduced_sfix) - A = sfix.Matrix(N, self.d_out, address=f_schur_Y.address) - B = sfix.Matrix(self.N, self.d_in, address=self.X.address) + # A (f_schur_Y/nabla_Y) is stored at sequential batch indices [0, 1, ..., N-1] + A = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) + # B (X) is stored at the full dataset size, not just the batch + B = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) @multithread(self.n_threads, self.d_in) def _(base, size): + # For A: use sequential indices [0, 1, ..., N*d-1] + # For B: use actual batch indices expanded for d dimension + batch_d_indices = regint.Array(N * self.d) + @for_range(N) + def _(i): + # batch[i] gives the actual sample index in the full dataset + # For Dense with d>1, we need to map to flattened indices + actual_sample_idx = batch[i] + @for_range(self.d) + def _(d_idx): + batch_d_indices[i * self.d + d_idx] = actual_sample_idx * self.d + d_idx + mp = B.direct_trans_mul(A, reduce=False, indices=(regint.inc(size, base), - batch.get_vector(), - regint.inc(N), + batch_d_indices.get_vector(), + regint.inc(N * self.d), regint.inc(self.d_out))) + tmp.assign_part_vector(mp, base) progress('nabla W (matmul)') @@ -805,10 +820,13 @@ class DenseBase(Layer): class Dense(DenseBase): """ Fixed-point dense (matrix multiplication) layer. + Supports inputs of size [N, d, d_in] to map to [N, d, d_out]. If d > 1, the layer + behaves like torch's nn.Linear which loops over the additional d :param N: number of examples :param d_in: input dimension :param d_out: output dimension + :param d: (optional) extra dimension """ def __init__(self, N, d_in, d_out, d=1, activation='id', debug=False): if activation == 'id': @@ -867,18 +885,28 @@ class Dense(DenseBase): def compute_f_input(self, batch): N = len(batch) - assert self.d == 1 - if self.input_bias: - prod = MultiArray([N, self.d, self.d_out], sfix) - else: - prod = self.f_input + prod = MultiArray([N, self.d, self.d_out], sfix) + + # flattened_array version + result_matrix = sfix.Matrix(N * self.d, self.d_out, address=prod.address) max_size = get_program().budget - @multithread(self.n_threads, N, max_size) + + # X is stored at full dataset indices, batch specifies which samples to use + X_sub = sfix.Matrix(self.N * self.d, self.d_in, address=self.X.address) + + # Precompute batch_d_indices for all N*d elements + # For each sample in batch, expand to d consecutive indices + batch_d_indices = regint.Array(N * self.d) + @for_range(N) + def _(i): + actual_sample = batch[i] + batch_d_indices.assign(regint.inc(self.d, actual_sample * self.d), i * self.d) + + @multithread(self.n_threads, N * self.d, max_size) def _(base, size): - X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address) - prod.assign_part_vector( + result_matrix.assign_part_vector( X_sub.direct_mul(self.W, indices=( - batch.get_vector(base, size), regint.inc(self.d_in), + batch_d_indices.get_vector(base, size), regint.inc(self.d_in), regint.inc(self.d_in), regint.inc(self.d_out))), base) if self.input_bias: @@ -888,10 +916,11 @@ class Dense(DenseBase): v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size) self.f_input.assign_vector(v, base) else: - @for_range_multithread(self.n_threads, 100, N) - def _(i): - v = prod[i].get_vector() + self.b.get_vector() - self.f_input[i].assign_vector(v) + @for_range_multithread(self.n_threads, 100, [N, self.d]) + def _(i, j): + v = prod[i][j].get_vector() + self.b.get_vector() + self.f_input[i][j].assign_vector(v) + progress('f input') def _forward(self, batch=None): @@ -942,15 +971,20 @@ class Dense(DenseBase): if compute_nabla_X: nabla_X.alloc() - @multithread(self.n_threads, N) + + # flattened matrix version + result_matrix = sfix.Matrix(N * self.d, self.d_in, address=nabla_X.address) + # Note: f_schur_Y is stored at indices [0, 1, ..., N-1] not at actual batch indices + @multithread(self.n_threads, N * self.d) def _(base, size): - B = sfix.Matrix(N, d_out, address=f_schur_Y.address) - nabla_X.assign_part_vector( - B.direct_mul_trans(W, indices=(regint.inc(size, base), - regint.inc(self.d_out), - regint.inc(self.d_out), - regint.inc(self.d_in))), - base) + X_sub = sfix.Matrix(N * self.d, self.d_out, address=f_schur_Y.address) + + result_matrix.assign_part_vector( + X_sub.direct_mul_trans(self.W, indices=(regint.inc(size, base=base), + regint.inc(self.d_out), + regint.inc(self.d_out), + regint.inc(self.d_in))), + base) if self.print_random_update: print_ln('backward %s', self) @@ -1014,23 +1048,28 @@ class Dropout(NoVariableLayer): """ Dropout layer. :param N: number of examples - :param d1: total dimension + :param shape: list [N, ...] where N is the number of examples and an arbitrary amount of further dimensions :param alpha: probability (power of two) """ - def __init__(self, N, d1, d2=1, alpha=0.5): - self.N = N - self.d1 = d1 - self.d2 = d2 - self.X = Tensor([N, d1, d2], sfix) - self.Y = Tensor([N, d1, d2], sfix) - self.nabla_Y = Tensor([N, d1, d2], sfix) - self.nabla_X = Tensor([N, d1, d2], sfix) + def __init__(self, N, d1=None, d2=1, alpha=0.5): + if isinstance(N, list) or isinstance(N, tuple): + shape = N + assert d1 is None, ("If shape is given as list/tuple, d1 must be None. " + "Alpha must be passed explicitly for backwards compatibility.") + else: + assert d1 is not None, "At least one non-batch dimension must be set" + shape = [N, d1] if d2 == 1 else [N, d1, d2] + self.N = shape[0] + self.X = Tensor(shape, sfix) + self.Y = Tensor(shape, sfix) + self.nabla_Y = Tensor(shape, sfix) + self.nabla_X = Tensor(shape, sfix) self.alpha = alpha - self.B = MultiArray([N, d1, d2], sint) + self.B = MultiArray(shape, sint) def __repr__(self): - return '%s(%s, %s, alpha=%s)' % \ - (type(self).__name__, self.N, self.d1, self.alpha) + return '%s(%s, alpha=%s)' % \ + (type(self).__name__, self.shape, self.alpha) def forward(self, batch, training=False): if training: @@ -1039,7 +1078,7 @@ class Dropout(NoVariableLayer): n_bits = int(n_bits) @for_range_opt_multithread(self.n_threads, len(batch)) def _(i): - size = self.d1 * self.d2 + size = reduce(operator.mul, self.shape[1:]) self.B[i].assign_vector(util.tree_reduce( util.or_op, (sint.get_random_bit(size=size) for i in range(n_bits)))) @@ -1148,6 +1187,149 @@ class Relu(ElementWiseLayer): def f_prime_part(self, base, size): return self.comparisons.get_vector(base, size) + +class Gelu(ElementWiseLayer): + """ Fixed-point GeLU layer. + Based on Dong et al., "PUMA: SECURE INFERENCE OF LLAMA-7B IN FIVE MINUTES" + + :param shape: input/output shape (tuple/list of int) + """ + prime_type = sfix + + def __init__(self, shape, inputs=None, approx=True): + super(Gelu, self).__init__(shape) + self.approx = approx + self.z0s = MultiArray(shape, sint) + self.z1s = MultiArray(shape, sint) + self.z2s = MultiArray(shape, sint) + + self.x2s = MultiArray(shape, sfix) + self.x3s = MultiArray(shape, sfix) + self.x4s = MultiArray(shape, sfix) + + self.poly_f_0_0 = -0.5054031199708174 + self.poly_f_0_1 = -0.42226581151983866 + self.poly_f_0_2 = -0.11807612951181953 + self.poly_f_0_3 = -0.011034134030615728 + + self.poly_f_1_0 = 0.008526321541038084 + self.poly_f_1_1 = 0.5 + self.poly_f_1_2 = 0.3603292692789629 + self.poly_f_1_4 = -0.037688200365904236 + self.poly_f_1_6 = 0.0018067462606141187 + + def f_part(self, base, size): + x = self.X.get_vector(base, size) + + if self.approx: + return self.compute_gelu_approx(x, base, size) + else: + return self.compute_gelu_sigmoid(x) + + def compute_gelu_approx(self, x, base, size): + # print_ln("GELU inputs %s", x.reveal()) + b0 = x < -4 + b1 = x < -1.95 + b2 = 3 < x + + z0 = b0 ^ b1 + z1 = b1 ^ b2 ^ 1 + z2 = b2 + + x1 = x + x2 = x ** 2 + x3 = x2 * x1 + x4 = x2 ** 2 + x6 = x3 ** 2 + + f_0 = self.poly_f_0_0 + self.poly_f_0_1 * x1 + self.poly_f_0_2 * x2 + self.poly_f_0_3 * x3 + f_1 = self.poly_f_1_0 + self.poly_f_1_1 * x1 + self.poly_f_1_2 * x2 + self.poly_f_1_4 * x4 + self.poly_f_1_6 * x6 + + self.z0s.assign_vector(z0, base) + self.z1s.assign_vector(z1, base) + self.z2s.assign_vector(z2, base) + + self.x2s.assign_vector(x2, base) + self.x3s.assign_vector(x3, base) + self.x4s.assign_vector(x6, base) + + return (z0 * f_0) + (z1 * f_1) + (z2 * x) + + def compute_gelu_sigmoid(self, x, base=0): + return x * sigmoid(0.071355 * (x ** 3) + 1.595769 * x) + + def tanh(self, x, base=0): + exp_2x = exp(2 * x) + + real_tanh = (exp_2x - 1) / (exp_2x + 1) + return real_tanh + + def f_prime_part(self, base, size): + if self.approx: + return self.f_prime_part_approx(base, size) + else: + return self.f_prime_part_sigmoid(base, size) + + def f_prime_part_approx(self, base, size): + # what we compute here is all derivatives + # we need to compute the derivative of the function at the point + poly_prime_f_0_0 = self.poly_f_0_1 + poly_prime_f_0_1 = self.poly_f_0_2 * 2 + poly_prime_f_0_2 = self.poly_f_0_3 * 3 + + poly_prime_f_1_0 = self.poly_f_1_1 + poly_prime_f_1_1 = self.poly_f_1_2 * 2 + poly_prime_f_1_3 = self.poly_f_1_4 * 4 + poly_prime_f_1_5 = self.poly_f_1_6 * 6 + + x1 = self.X.get_vector(base, size) + x2 = self.x2s.get_vector(base, size) + x3 = self.x3s.get_vector(base, size) + x4 = self.x4s.get_vector(base, size) + x5 = x4 * x1 + + f_0_prime = poly_prime_f_0_0 + poly_prime_f_0_1 * x1 + poly_prime_f_0_2 * x2 + f_1_prime = poly_prime_f_1_0 + poly_prime_f_1_1 * x1 + poly_prime_f_1_3 * x3 + poly_prime_f_1_5 * x5 + + z0 = self.z0s.get_vector(base, size) + z1 = self.z1s.get_vector(base, size) + z2 = self.z2s.get_vector(base, size) + + print_ln("gelu backward x %s res %s", x1.reveal()[:8], ((z0 * f_0_prime) + (z1 * f_1_prime) + z2).reveal()[:8]) + + return (z0 * f_0_prime) + (z1 * f_1_prime) + z2 + + def f_prime_part_sigmoid(self, base, size): + x = self.X.get_vector(base, size) + # return 0.5 * (1 + self.tanh(0.0356774 * x ** 3 + 0.797884 * x)) + print_ln("f sigmoid") + x3 = x ** 3 + exp_term = exp(1.59577 * x + 0.071355 * x3) + return (exp(3.19154* x + 0.14271 * x) + exp_term * (1 + 1.59577 * x + 0.214065 * x3)) / (1 + exp_term) ** 2 + + + +class Tanh(ElementWiseLayer): + + prime_type = sfix + + def __init__(self, shape, inputs=None): + super(Tanh, self).__init__(shape) + self.tanh_computations = MultiArray(shape, sfix) + + def f_part(self, base, size): + x = self.X.get_vector(base, size) + res = self.tanh(x) + self.tanh_computations.assign_vector(res, base) + return res + + def tanh(self, x): + return 2 * sigmoid(2 * x) - 1 + + def f_prime_part(self, base, size): + return 1 - self.tanh_computations.get_vector(base, size) ** 2 + + class Square(ElementWiseLayer): """ Fixed-point square layer. @@ -1353,6 +1535,8 @@ class Add(NoVariableLayer): self.Y = Tensor(shape, sfix) self.inputs = inputs + self.nabla_Y = Tensor(shape, sfix) + def _forward(self, batch=[0]): @multithread(self.n_threads, self.Y[0].total_size()) def _(base, size): @@ -1361,6 +1545,14 @@ class Add(NoVariableLayer): for inp in self.inputs) self.Y[bb].assign_vector(tmp, base) + def backward(self, compute_nabla_X=True, batch=None): + if compute_nabla_X: + for inp in self.inputs: + @multithread(self.n_threads, self.Y.total_size()) + def _(base, size): + inp.nabla_X.assign_vector(self.nabla_Y.get_vector(base, size), base) + + class FusedBatchNorm(Layer): """ Fixed-point fused batch normalization layer (inference only). @@ -1544,6 +1736,200 @@ class BatchNorm(Layer): for param in self.thetas() + (self.mu_hat, self.var_hat): param.reveal().binary_output() +class LayerNorm(Layer): # Changed class name + """ Fixed-point layer normalization layer. + For now assumes we only want to normalize over the last dimension + + :param shape: input/output shape (tuple/list of any number of int) + :param approx: use approximate square root + + """ + thetas = lambda self: (self.weights, self.bias) + nablas = lambda self: (self.nabla_weights, self.nabla_bias) + + def __init__(self, shape, approx=False, layernorm_eps=None, args=None): + if len(shape) == 2: + shape = [shape[0], 1, shape[1]] # Not sure why this extra dimension is added + tensors = (Tensor(shape, sfix) for i in range(4)) + self.X, self.Y, self.nabla_X, self.nabla_Y = tensors + self.epsilon = 2 ** (-sfix.f * 2 // 3 + 1) #if layernorm_eps is not None else sfix(layernorm_eps) + self.approx = approx + if approx: + print('Approximate square root inverse in layer normalization') + self.InvertSqrt = mpc_math.InvertSqrt + else: + print('Precise square root inverse in layer normalization') + self.InvertSqrt = lambda x: 1 / mpc_math.sqrt(x) + self.weights = sfix.Array(shape[-1]) + self.bias = sfix.Array(shape[-1]) + + batch_shape = [shape[0]] + list(self.X.sizes[1:-1]) + self.mu = sfix.Tensor(batch_shape) + self.var = sfix.Tensor(batch_shape) + self.nabla_weights = sfix.Array(shape[-1]) + self.nabla_bias = sfix.Array(shape[-1]) + + def __repr__(self): + return '%s(%s, approx=%s)' % \ + (type(self).__name__, self.X.sizes, self.approx) + + def reset(self): # Simplified reset method + self.bias.assign_all(0) + self.weights.assign_all(1) + + def _output(self, batch, mu, var): + batch_shape = [len(batch)] + list(self.X.sizes[1:-1]) + factor = sfix.Tensor(batch_shape) + + @multithread(self.n_threads, len(batch)) + def _(base, size): + factor.assign_part_vector( + self.InvertSqrt(var.get_part_vector(base, size) + self.epsilon), + base) + + @for_range_opt_multithread(self.n_threads, + batch_shape) + def _(*arg): + sel_X = self.X + sel_Y = self.Y + mu_sel = mu + fac_sel = factor + for ar in arg: + sel_X = sel_X[ar] + sel_Y = sel_Y[ar] + mu_sel = mu_sel[ar] + fac_sel = fac_sel[ar] + tmp = self.weights[:] * (sel_X[:] - mu_sel) * fac_sel # Removed self.mu reference + sel_Y[:] = self.bias[:] + tmp + + def forward(self, batch, training=False): + d = self.X.sizes[1] + d_in = self.X.sizes[2] + X_batch = MultiArray([len(batch), self.X.sizes[1], self.X.sizes[2]], sfix) + X_batch.assign_vector(self.X.get_slice_vector(batch)) + batch_shape = [len(batch)] + list(self.X.sizes[1:-1]) + + @for_range_opt_multithread(self.n_threads, batch_shape) + def _(*arg): + sel = X_batch + for ar in arg: + sel = sel[ar] + res = sum(sel[:]) + if len(arg) == 2: + self.mu[arg[0]][arg[1]] = res + else: + raise NotImplementedError("Only 3D tensors supported") + + @multithread(self.n_threads, self.mu.total_size()) + def _(base, size): + total_dim = self.X.sizes[-1] + self.mu.assign_vector( + self.mu.get_vector(base, size) / total_dim, base) + + @for_range_opt_multithread(self.n_threads, batch_shape) + def _(*arg): + sel = X_batch + mu_sel = self.mu + for ar in arg: + sel = sel[ar] + mu_sel = mu_sel[ar] + res = sum((sel[:] - mu_sel) ** 2) # Removed self.mu reference + if len(arg) == 2: + self.var[arg[0]][arg[1]] = res + else: + raise NotImplementedError("Only 3D tensors supported") + + @multithread(self.n_threads, self.var.total_size()) + def _(base, size): + total_dim = self.X.sizes[-1] + self.var.assign_vector( + self.var.get_vector(base, size) / (total_dim), + base) + self._output(batch, self.mu, self.var) # Simplified to always use current batch statistics + if self.print_random_update: + i = regint.get_random(64) % len(batch) + j = regint.get_random(64) % d + k = regint.get_random(64) % d_in + for x in self.mu, self.var: + print_ln('%s at %s: %s', str(x), k, x[k].reveal()) + print_ln('%s at (%s, %s, %s): in=%s out=%s', + str(self.Y), i, j, k, self.X[i][j][k].reveal(), + self.Y[i][j][k].reveal()) + + def backward(self, batch, compute_nabla_X=True): + batch_shape = [len(batch)] + list(self.X.sizes[1:-1]) + factor = sfix.Tensor(batch_shape) + factor[:] = self.InvertSqrt(self.var[:] + self.epsilon) # TODO: Can we cache this? + + # Gradients wrt outputs stored in self.nabla_Y + norm = self.X.same_shape() # Gradient wrt scaled input (gamma * nabla_Y) + nYdf = self.X.same_shape() # Used for weights gradient computation + dNorm = self.X.same_shape() # only needed for nabla_X comp + # dNormNorm = self.X.same_shape() + + # print("layernorm back ", batch, nYdf.sizes, factor.sizes) + # print("batch shape", batch_shape, self.nabla_Y.sizes) + # print(self.nabla_bias.length, self.X.sizes) + + @for_range_opt_multithread(self.n_threads, batch_shape) + def _(*arg): + assert len(arg) == 2, "Only 3D tensors supported" + norm[arg[0]][arg[1]][:] = (self.X[batch[arg[0]]][arg[1]][:] - self.mu[arg[0]][arg[1]]) * factor[arg[0]][arg[1]] # norm_bti + nYdf[arg[0]][arg[1]][:] = self.nabla_Y[batch[arg[0]]][arg[1]][:] * norm[arg[0]][arg[1]][:] + + dNorm[arg[0]][arg[1]][:] = self.nabla_Y[batch[arg[0]]][arg[1]][:] * self.weights[:] # dnorm_i + # dNormNorm[arg[0]][arg[1]][:] = dNorm[arg[0]][arg[1]][:] * norm[arg[0]][arg[1]][:] + + # Sum over the appropriate axes for nabla_weights and nabla_bias + @map_sum_simple(self.n_threads, batch_shape, sfix, self.X.sizes[-1]) + def _(*arg): + sel = self.nabla_Y + for ar in arg: + sel = sel[ar] + return sel[:] + self.nabla_bias.assign(_()) + + @map_sum_simple(self.n_threads, batch_shape, sfix, self.X.sizes[-1]) + def _(*arg): + sel = nYdf + for ar in arg: + sel = sel[ar] + return sel[:] + self.nabla_weights.assign(_()) + + if compute_nabla_X: + # sum_dNorm = sfix.Array(self.X.sizes[-1]) + # sum_dNormNorm = sfix.Array(self.X.sizes[-1]) + # + # @map_sum_simple(self.n_threads, batch_shape, sfix, self.X.sizes[-1]) + # def _(*arg): + # sel = dNormNorm + # for ar in arg: + # sel = sel[ar] + # return sel[:] + # sum_dNormNorm.assign(_()) + + # print_ln("sum norm %s", sum_dNorm[:].reveal()) + # print_ln("sum norm norm %s", sum_dNormNorm[:].reveal()) + # print_ln("norm %s", norm[:].reveal()[:8]) # corr + # print_ln("dnorm %s", dNorm[:].reveal()[:8]) + + # Compute final gradient wrt input X + @for_range_opt_multithread(self.n_threads, batch_shape) + def _(*arg): + + if len(arg) == 2: + mean_dnorm = sum(dNorm[arg[0]][arg[1]]) / self.X.sizes[-1] + + self.nabla_X[arg[0]][arg[1]][:] = (dNorm[arg[0]][arg[1]][:] - mean_dnorm) * factor[arg[0]][arg[1]] + mean_dnormnorm = sum(self.nabla_X[arg[0]][arg[1]][:] * norm[arg[0]][arg[1]]) / self.X.sizes[-1] + self.nabla_X[arg[0]][arg[1]][:] -= norm[arg[0]][arg[1]] * mean_dnormnorm + + else: + raise NotImplementedError("Only 3D tensors supported") + # sel_nX[:] = sel_dnorm[:] - sum_dNorm - sel_norm[:] * sum_dNormNorm + # print_ln("layernorm result compute_x %s", sel_nX[:].reveal()) + class QuantBase(object): bias_before_reduction = True @@ -1615,8 +2001,16 @@ class ConvBase(BaseLayer): use_conv2ds = True temp_weights = None temp_inputs = None - thetas = lambda self: (self.weights, self.bias) - nablas = lambda self: (self.nabla_weights, self.nabla_bias) + def thetas(self): + if self.use_bias: + return self.weights, self.bias + else: + return tuple([self.weights]) + def nablas(self): + if self.use_bias: + return self.nabla_weights, self.nabla_bias + else: + return tuple([self.nabla_weights]) @classmethod def init_temp(cls, layers): @@ -1628,13 +2022,14 @@ class ConvBase(BaseLayer): def __init__(self, input_shape, weight_shape, bias_shape, output_shape, stride, padding='SAME', tf_weight_format=False, inputs=None, - weight_type=None): + weight_type=None, bias=True): super(ConvBase, self).__init__(input_shape, output_shape, inputs=inputs) self.weight_shape = weight_shape self.bias_shape = bias_shape self.stride = stride self.tf_weight_format = tf_weight_format + self.use_bias = bias if padding == 'SAME': # https://web.archive.org/web/20171223022012/https://www.tensorflow.org/api_guides/python/nn self.padding = [] @@ -1665,10 +2060,14 @@ class ConvBase(BaseLayer): self.bias_squant = self.new_squant() self.weights = Tensor(weight_shape, self.weight_squant) - self.bias = Array(output_shape[-1], self.bias_squant) + if self.use_bias: + self.bias = Array(output_shape[-1], self.bias_squant) self.nabla_weights = Tensor(weight_shape, self.weight_squant) - self.nabla_bias = Array(output_shape[-1], self.bias_squant) + if self.use_bias: + self.nabla_bias = Array(output_shape[-1], self.bias_squant) + + self.unreduced = Tensor(self.output_shape, sint, address=self.Y.address) if tf_weight_format: weight_in = weight_shape[2] @@ -1687,10 +2086,6 @@ class ConvBase(BaseLayer): self.bias_shape, self.Y.sizes, self.stride, repr(self.padding), self.tf_weight_format) - @property - def unreduced(self): - return Tensor(self.output_shape, sint, address=self.Y.address) - def input_from(self, player, **kwargs): self.input_params_from(player) self.weights.input_from(player, budget=100000, **kwargs) @@ -1699,7 +2094,8 @@ class ConvBase(BaseLayer): def output_weights(self): self.weights.print_reveal_nested() - print_ln('%s', self.bias.reveal_nested()) + if self.use_bias: + print_ln('%s', self.bias.reveal_nested()) def reveal_parameters_to_binary(self): assert not self.tf_weight_format @@ -1711,15 +2107,21 @@ class ConvBase(BaseLayer): def _(j): part = self.weights.get_vector_by_indices(i, None, None, j) part.reveal().binary_output() - self.bias.reveal_to_binary_output() + if self.use_bias: + self.bias.reveal_to_binary_output() def dot_product(self, iv, wv, out_y, out_x, out_c): - bias = self.bias[out_c] - acc = self.output_squant.unreduced_dot_product(iv, wv) - acc.v += bias.v - acc.res_params = self.output_squant.params - #self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul() - self.unreduced[0][out_y][out_x][out_c] = acc.v + if self.use_bias: + bias = self.bias[out_c] + acc = self.output_squant.unreduced_dot_product(iv, wv) + acc.v += bias.v + acc.res_params = self.output_squant.params + #self.Y[0][out_y][out_x][out_c] = acc.reduce_after_mul() + self.unreduced[0][out_y][out_x][out_c] = acc.v + else: + acc = self.output_squant.unreduced_dot_product(iv, wv) + acc.res_params = self.output_squant.params + self.unreduced[0][out_y][out_x][out_c] = acc.v def reduction(self, batch_length=1): unreduced = self.unreduced @@ -1785,11 +2187,12 @@ class Conv2d(ConvBase): inputs_h, inputs_w, weights_h, weights_w, stride_h, stride_w, n_channels_in, padding_h, padding_w, part_size) - if self.bias_before_reduction: - res += self.bias.expand_to_vector(j, res.size).v - else: - res += self.bias.expand_to_vector(j, res.size).v << \ - self.weight_squant.f + if self.use_bias: + if self.bias_before_reduction: + res += self.bias.expand_to_vector(j, res.size).v + else: + res += self.bias.expand_to_vector(j, res.size).v << \ + self.weight_squant.f addresses = regint.inc(res.size, self.unreduced[i * part_size].address + j, n_channels_out) @@ -1797,7 +2200,8 @@ class Conv2d(ConvBase): self.reduction(len(batch)) if self.debug_output: print_ln('%s weights %s', self, self.weights.reveal_nested()) - print_ln('%s bias %s', self, self.bias.reveal_nested()) + if self.use_bias: + print_ln('%s bias %s', self, self.bias.reveal_nested()) @for_range(len(batch)) def _(i): print_ln('%s X %s %s', self, i, self.X[batch[i]].reveal_nested()) @@ -1873,7 +2277,8 @@ class FixConv2d(Conv2d, FixBase): r = math.sqrt(6.0 / (n_in + self.weight_shape[0])) print('Initializing convolution weights in [%f,%f]' % (-r, r)) self.weights.randomize(-r, r, n_threads=self.n_threads) - self.bias.assign_all(0) + if self.use_bias: + self.bias.assign_all(0) def backward(self, compute_nabla_X=True, batch=None): assert self.use_conv2ds @@ -1888,14 +2293,15 @@ class FixConv2d(Conv2d, FixBase): N = len(batch) - self.nabla_bias.assign_all(0) + if self.use_bias: + self.nabla_bias.assign_all(0) - @for_range(N) - def _(i): - self.nabla_bias.assign_vector( - self.nabla_bias.get_vector() + sum(sum( - self.nabla_Y[i][j][k].get_vector() for k in range(output_w)) - for j in range(output_h))) + @for_range(N) + def _(i): + self.nabla_bias.assign_vector( + self.nabla_bias.get_vector() + sum(sum( + self.nabla_Y[i][j][k].get_vector() for k in range(output_w)) + for j in range(output_h))) input_size = inputs_h * inputs_w * N batch_repeat = regint.Matrix(N, inputs_h * inputs_w) @@ -1975,8 +2381,9 @@ class FixConv2d(Conv2d, FixBase): print_ln('%s nabla weights %s', self, (self.nabla_weights.reveal_nested())) print_ln('%s weights %s', self, (self.weights.reveal_nested())) - print_ln('%s nabla b %s', self, (self.nabla_bias.reveal_nested())) - print_ln('%s bias %s', self, (self.bias.reveal_nested())) + if self.use_bias: + print_ln('%s nabla b %s', self, (self.nabla_bias.reveal_nested())) + print_ln('%s bias %s', self, (self.bias.reveal_nested())) class QuantDepthwiseConv2d(QuantConvBase, Conv2d): def n_summands(self): @@ -2109,7 +2516,7 @@ class AveragePool2d(BaseLayer): self.Y[0][out_y][out_x][c] = self.output_squant._new(acc) def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, - padding=0, **kwargs): + padding=0, bias=True, **kwargs): """ More convenient interface to :py:class:`FixConv2d`. :param input_shape: input shape (tuple/list of four int) @@ -2117,6 +2524,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, :param kernel_size: kernel size (int or tuple/list of two int) :param stride: stride (int or tuple/list of two int) :param padding: :py:obj:`'SAME'`, :py:obj:`'VALID'`, int, or tuple/list of two int + :param bias: whether layer has bias (bool) """ if isinstance(kernel_size, int): @@ -2130,7 +2538,7 @@ def easyConv2d(input_shape, batch_size, out_channels, kernel_size, stride=1, padding = padding.upper() if isinstance(padding, str) \ else padding return FixConv2d(input_shape, weight_shape, (out_channels,), output_shape, - stride, padding, **kwargs) + stride, padding, bias=bias, **kwargs) def easyMaxPool(input_shape, kernel_size, stride=None, padding=0): """ More convenient interface to :py:class:`MaxPool`. @@ -2241,6 +2649,624 @@ class QuantSoftmax(QuantBase, BaseLayer): return [c.if_else(x, y) for x, y in zip(left, right)] print_ln('guess: %s', util.tree_reduce(comp, list(enumerate(self.X[0])))[0].reveal()) +class BertBase(BaseLayer, FixBase): + pass + +# class BertEmbedding(BertBase): # we dont do embedding + +class BertPooler(BertBase): + + thetas = lambda self: self.dense.thetas() + nablas = lambda self: self.dense.nablas() # refer to downstream layers? + + def __init__(self, n_examples, seq_len, hidden_state): + input_shape = [n_examples, seq_len, hidden_state] + output_shape = [n_examples, hidden_state] + super(BertPooler, self).__init__(input_shape, output_shape) + self.dense = Dense(n_examples, hidden_state, hidden_state) + self.activation = Tanh(output_shape) + + self.d_out = hidden_state + + + def _forward(self, batch): + # self.dense.X.address = self.X.address + self.activation.X.address = self.dense.Y.address + self.activation.Y.address = self.Y.address + + # grab the first repr? + # batch contains [n_batch, n_heads, n_dim] + @for_range(len(batch)) + def _(j): + self.dense.X[j][:] = self.X[batch[j]][0][:] + + # if self.debug_output: + # print_ln("forward layer pooler.dense X %s", self.dense.X.reveal_nested()) + + self.dense.forward(batch) + # print_ln("LINEAR Layer weights after bertpooler.dense: %s", self.opt.layers[-2].W.reveal_nested()) + + self.activation._forward(batch) + # print_ln("LINEAR Layer weights after bertpooler.activation: %s", self.opt.layers[-2].W.reveal_nested()) + + def reset(self): + self.dense.reset() + self.activation.reset() + + def load_state_dict(self, state_dict, input_via): + import numpy + self.dense.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['dense.weight'], 0, 1)) + self.dense.b = sfix.input_tensor_via(input_via, state_dict['dense.bias']) + + def backward(self, compute_nabla_X=True, batch=None): + if batch is None: + batch = regint.Array(self.N) + batch.assign(regint.inc(self.N)) + + self.activation.nabla_X.alloc() + + self.activation.nabla_Y.address = self.nabla_Y.address + self.dense.nabla_Y.address = self.activation.nabla_X.address + self.dense.nabla_X.address = self.nabla_X.address # TODO: size mismatch here, but should be okay? cause rest 0s? + + self.activation.backward(batch) + self.dense.backward(compute_nabla_X, batch) + +class BertEncoder(BertBase): + + # I think this is unused? + + def __init__(self, n_examples, n_layers, d_model, n_heads, d_k, d_v, d_ff, dropout=0.1): + input_shape = [n_examples, d_model] + output_shape = [n_examples, d_model] + super(BertEncoder, self).__init__(input_shape, output_shape) + self.layers = [] + for _ in range(n_layers): + self.layers.append(BertLayer(n_examples, d_model, n_heads, d_k, d_v, d_ff, dropout)) + + for i in enumerate(1, len(self.layers)): + self.layers[i].X.address = self.layers[i - 1].Y.address + + self.layers[0].X.address = self.X.address + self.layers[-1].Y.address = self.Y.address + + def _forward(self, batch): + for layer in self.layers: + layer.forward(batch) + + def reset(self): + for layer in self.layers: + layer.reset() + + +class BertLayer(BertBase): + + thetas = lambda self: self.multi_head_attention.thetas() + self.intermediate.thetas() + self.output.thetas() #+ tuple(self.nabla_hidden_state) + nablas = lambda self: self.multi_head_attention.nablas() + self.intermediate.nablas() + self.output.nablas() #+ tuple(self.nabla_hidden_state) + + def __init__(self, n_examples, seq_len, hidden_state, intermediate_size, num_attention_heads, layernorm_eps, dropout=0.1, rsqrt_approx=True, batch_size=None): + input_shape = [n_examples, seq_len, hidden_state] + output_shape = [n_examples, seq_len, hidden_state] # TODO: we could make this batch_size + super(BertLayer, self).__init__(input_shape, output_shape) + + internal_shape = batch_size if batch_size is not None else n_examples + self.multi_head_attention = MultiHeadAttention(internal_shape, seq_len, hidden_state, num_attention_heads, dropout, layernorm_eps, rsqrt_approx) + self.intermediate = BertIntermediate(internal_shape, hidden_state, intermediate_size, seq_len) + self.output = BertOutput(internal_shape, intermediate_size, hidden_state, seq_len, dropout, layernorm_eps, rsqrt_approx) + + self.hidden_state = sfix.Tensor(input_shape) # TODO: Could also make this smaller + # self.nabla_hidden_state = sfix.Tensor(input_shape) + # self.nabla_hidden_state.alloc() + + # self.X.address = self.multi_head_attention.X.address + # self.Y.address = self.output.Y.address + + self.d_out = hidden_state + + print("Init BertLayer", input_shape, output_shape) + + def forward(self, batch, training=False): + if batch is None: + batch = Array.create_from(regint(0)) + + self.multi_head_attention._X.address = self.X.address + self.output.Y.address = self.Y.address + self.hidden_state.address = self.X.address + # self.multi_head_attention.Y.address = self.Y.address + + self.multi_head_attention.forward(batch, self.hidden_state, training) + # if self.debug_output: + # print_ln("our layer X %s %s", self.X[0][0][0].reveal(), self.output.X[0][0][0].reveal()) + + if self.debug_output: + print_ln("forward layer multi_head_attention %s %s", self.multi_head_attention.Y[0][1][0].reveal(), sum(sum(self.multi_head_attention.Y[0].reveal()))) + # print_ln("forward layer multi_head_attention full %s", self.multi_head_attention.Y.reveal()) + + print("Forward Attention") + + batch_inc = regint.Array(len(batch)) + batch_inc.assign(regint.inc(len(batch))) + self.intermediate.X.address = self.multi_head_attention.Y.address + self.intermediate.forward(batch_inc) + + if self.debug_output: + print_ln("forward layer intermediate %s %s %s", self.intermediate.Y.shape, self.intermediate.Y[0][1][0:20].reveal(), sum(sum(self.intermediate.Y[0].reveal()))) + + print_ln(" ") + + self.output.X.address = self.intermediate.Y.address + self.output.forward(batch_inc, self.multi_head_attention.Y, training) + # self.output.Y.address = self.output.X.address + + if self.debug_output: + print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) + # print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) + # print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) + + print_ln("our layer output %s %s %s %s", self.output.Y.address, len(self.Y[0].reveal()), self.output.Y[0][0][0:20].reveal(), sum(sum(self.output.Y[0].reveal()))) + # print_ln("shapes %s %s", self.Y.sizes, self.output.Y.sizes) + # print_ln("types %s %s %s %s %s %s", self.Y.value_type, self.output.Y.value_type, type(self.Y), type(self.output.Y), self, self.output) + + print("Forward BertLayer") + + def reset(self): + self.multi_head_attention.reset() + self.intermediate.reset() + self.output.reset() + + def load_state_dict(self, state_dict, input_via): + import numpy + # format of state_dict + # ['attention.self.query.weight', 'attention.self.query.bias', 'attention.self.key.weight', 'attention.self.key.bias', 'attention.self.value.weight', 'attention.self.value.bias', 'attention.output.dense.weight', 'attention.output.dense.bias', 'attention.output.LayerNorm.weight', 'attention.output.LayerNorm.bias', 'intermediate.dense.weight', 'intermediate.dense.bias', 'output.dense.weight', 'output.dense.bias', 'output.LayerNorm.weight', 'output.LayerNorm.bias'] + # set the values of the layers + self.multi_head_attention.wq.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['attention.self.query.weight'], 0, 1)) + self.multi_head_attention.wq.b = sfix.input_tensor_via(input_via, state_dict['attention.self.query.bias']) + self.multi_head_attention.wk.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['attention.self.key.weight'], 0, 1)) + self.multi_head_attention.wk.b = sfix.input_tensor_via(input_via, state_dict['attention.self.key.bias']) + self.multi_head_attention.wv.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['attention.self.value.weight'], 0, 1)) + self.multi_head_attention.wv.b = sfix.input_tensor_via(input_via, state_dict['attention.self.value.bias']) + + self.multi_head_attention.output.dense.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['attention.output.dense.weight'], 0, 1)) + self.multi_head_attention.output.dense.b = sfix.input_tensor_via(input_via, state_dict['attention.output.dense.bias']) + self.multi_head_attention.output.layer_norm.weights = sfix.input_tensor_via(input_via, state_dict['attention.output.LayerNorm.weight']) + self.multi_head_attention.output.layer_norm.bias = sfix.input_tensor_via(input_via, state_dict['attention.output.LayerNorm.bias']) + + self.intermediate.dense.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['intermediate.dense.weight'], 0, 1)) + self.intermediate.dense.b = sfix.input_tensor_via(input_via, state_dict['intermediate.dense.bias']) + + self.output.dense.W = sfix.input_tensor_via(input_via, numpy.swapaxes(state_dict['output.dense.weight'], 0, 1)) + # print_ln("output.dense.W state_dict %s", self.output.dense.W[0][0].reveal()) + self.output.dense.b = sfix.input_tensor_via(input_via, state_dict['output.dense.bias']) + self.output.layer_norm.weights = sfix.input_tensor_via(input_via, state_dict['output.LayerNorm.weight']) + self.output.layer_norm.bias = sfix.input_tensor_via(input_via, state_dict['output.LayerNorm.bias']) + + def backward(self, compute_nabla_X=True, batch=None): + # layer.inputs[0].nabla_Y.address = \ + # layer.nabla_X.address + # assign nabla_X and Y + self.multi_head_attention.nabla_X.alloc() + self.intermediate.nabla_X.alloc() + self.output.nabla_X.alloc() + + self.output.nabla_Y.address = self.nabla_Y.address + self.intermediate.nabla_Y.address = self.output.nabla_X.address + self.multi_head_attention.nabla_Y.address = self.intermediate.nabla_X.address + # self.multi_head_attention.nabla_X.address = self.nabla_X.address + + nabla_y_multi_head_attention_from_layernorm = self.output.backward(True, batch) + # print_ln("Backward BertLayer.output.nabla_X %s", self.output.nabla_X.reveal_nested()[:8]) + self.intermediate.backward(True, batch) + + # residual, add it to Y because it gave the output of multihadattention to output + @multithread(self.n_threads, len(batch)) + def _(base, size): + self.multi_head_attention.nabla_Y.assign_part_vector( + self.multi_head_attention.nabla_Y.get_part_vector(base, size) + + nabla_y_multi_head_attention_from_layernorm.get_part_vector(base, size), base) + + if compute_nabla_X: + self.multi_head_attention.nabla_X.address = self.nabla_X.address + + nabla_y_hidden_state = self.multi_head_attention.backward(compute_nabla_X, batch) + + if compute_nabla_X: + print_ln("Bertlayer nabla_x %s %s", nabla_y_hidden_state.get_vector().reveal()[-8:], self.nabla_X.get_vector().reveal()[-8:]) + # and add hidden_state back to nabla_X, add to x because we gave x to multi_head_attention + @multithread(self.n_threads, len(batch)) + def _(base, size): + self.nabla_X.assign_part_vector( + self.nabla_X.get_part_vector(base, size) + + nabla_y_hidden_state.get_part_vector(base, size), base) + + +class BertIntermediate(BertBase): + + thetas = lambda self: self.dense.thetas() + nablas = lambda self: self.dense.nablas() + + def __init__(self, n_examples, hidden_size, intermediate_size, seq_len): + input_shape = [n_examples, seq_len, hidden_size] + output_shape = [n_examples, seq_len, intermediate_size] + super(BertIntermediate, self).__init__(input_shape, output_shape) + self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len) + self.activation = Gelu([n_examples, seq_len, intermediate_size]) + + + def forward(self, batch=None, training=None): + self.dense.X.address = self.X.address + self.activation.X.address = self.dense.Y.address + self.activation.Y.address = self.Y.address + + self.dense.forward(batch) + if self.debug_output: + print_ln("forward layer intermediate.dense %s", self.dense.Y[0][0][0:20].reveal()) + + self.activation._forward(batch) + + def reset(self): + self.dense.reset() + + def backward(self, compute_nabla_X=True, batch=None): + self.activation.nabla_X.alloc() + + # print_ln("Backward BertIntermediate.nabla_X %s", self.nabla_X.reveal_nested()[:8]) + + self.activation.nabla_Y.address = self.nabla_Y.address + self.dense.nabla_Y.address = self.activation.nabla_X.address + self.dense.nabla_X.address = self.nabla_X.address + + self.activation.backward(batch) + self.dense.backward(compute_nabla_X, batch) + + +class BertOutput(BertBase): + + thetas = lambda self: self.dense.thetas() + self.layer_norm.thetas() + nablas = lambda self: self.dense.nablas() + self.layer_norm.nablas() + + def __init__(self, n_examples, intermediate_size, hidden_size, seq_len, dropout=0.1, layernorm_eps=1e-12, rsqrt_approx=True): + input_shape = [n_examples, seq_len, intermediate_size] + output_shape = [n_examples, seq_len, hidden_size] + self.input_shape = input_shape + print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx) + super(BertOutput, self).__init__(input_shape, output_shape) + self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len) + self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx) + self.dropout = Dropout([n_examples, seq_len, hidden_size], alpha=dropout) + + + def forward(self, batch, input_tensor, training=False, input_tensor_batch=None): + # Because input_tensor might be the full training data shape + self.dense.X.address = self.X.address + self.dropout.X.address = self.dense.Y.address + self.layer_norm.X.address = self.dropout.Y.address + self.layer_norm.Y.address = self.Y.address + + self.dense.forward(batch) + if self.debug_output: + print_ln("forward layer output.dense %s", self.dense.Y[0][0][0:20].reveal()) + + self.dropout.forward(batch, training) + + if input_tensor_batch is not None: + input_tensor_batch_arr = MultiArray([len(batch), input_tensor.sizes[1], input_tensor.sizes[2]], sfix) + input_tensor_batch_arr.assign_vector(input_tensor.get_slice_vector(input_tensor_batch)) + @multithread(self.n_threads, len(batch)) + def _(base, size): + self.layer_norm.X.assign_part_vector( + self.layer_norm.X.get_part_vector(base, size) + + input_tensor_batch_arr.get_part_vector(base, size), base) + else: + @multithread(self.n_threads, len(batch)) + def _(base, size): + self.layer_norm.X.assign_part_vector( + self.layer_norm.X.get_part_vector(base, size) + + input_tensor.get_part_vector(base, size), base) + # if self.debug_output: + # print_ln("input tensor %s", input_tensor.reveal()) + + # self.layer_norm.X[:] += input_tensor[:] # TODO: is it maybe this addition since we take the last value? would be strange + + if self.debug_output: + print_ln("forward layer layer_norm_add %s", self.layer_norm.X[0][0][0:20].reveal()) + print_ln("") + self.layer_norm.forward(batch) + + + + def reset(self): + self.dense.reset() + + def backward(self, compute_nabla_X=True, batch=None): + self.layer_norm.nabla_X.alloc() + self.dropout.nabla_X.alloc() + + self.layer_norm.nabla_Y.address = self.nabla_Y.address + self.dropout.nabla_Y.address = self.layer_norm.nabla_X.address + self.dense.nabla_Y.address= self.dropout.nabla_X.address + self.dense.nabla_X.address = self.nabla_X.address + + # layer norm flows back to dropout but also to hidden_tensor... nabla hidden state? + + self.layer_norm.backward(batch, compute_nabla_X) + self.dropout.backward(compute_nabla_X, batch) + + if self.debug_output: + print_ln("backward layer dense x %s", self.dropout.nabla_X[0][0][0:20].reveal()) + + self.dense.backward(compute_nabla_X, batch) + + return self.layer_norm.nabla_X + +class MultiHeadAttention(BertBase): + + thetas = lambda self: self.wq.thetas() + self.wk.thetas() + self.wv.thetas() + self.output.thetas() + nablas = lambda self: self.wq.nablas() + self.wk.nablas() + self.wv.nablas() + self.output.nablas() + + def __init__(self, n_examples, seq_len, hidden_size, num_attention_heads, dropout=0.1, layernorm_eps=1e-12, rsqrt_approx=True, batch_size=None): + + # In the first layer the internal_shape is different from n_examples, afterwards it is the same + internal_shape = batch_size if batch_size is not None else n_examples + self.n_examples = internal_shape + + input_shape = [n_examples, seq_len, hidden_size] + output_shape = [internal_shape, seq_len, hidden_size] + super().__init__(input_shape, output_shape) + + print("Multheadattention", rsqrt_approx, input_shape, output_shape) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.hidden_size = hidden_size + self.seq_len = seq_len + + self.hidden_size = hidden_size + self.wq = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wk = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.wv = Dense(n_examples, hidden_size, self.all_head_size, self.seq_len) + self.dropout = Dropout([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], alpha=dropout) # I think? # TODO: DROPOUT? + + self.output = BertOutput(internal_shape, hidden_size, hidden_size, seq_len, dropout, layernorm_eps, rsqrt_approx) + self.context = sfix.Tensor([internal_shape, self.seq_len, hidden_size]) + self.nabla_context = sfix.Tensor([internal_shape, self.seq_len, hidden_size]) + + # self.context_nabla + + self.attention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix) + self.nabla_attention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix) + self.nabla_preattention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix) + + def forward(self, batch=None, hidden_state=None, training=None): + N = len(batch) + + # set up layers + dense_layers = [self.wq, self.wk, self.wv] + for layer in dense_layers: + layer.X.address = self.X.address + + self.output.X.address = self.context.address + self.output.Y.address = self.Y.address + + self.wq.forward(batch) + self.wk.forward(batch) + self.wv.forward(batch) + + inc_batch = regint.Array(N) + inc_batch.assign(regint.inc(N)) + + if self.debug_output: + # print_ln('forward layer wq full %s', self.wq.X.reveal()) + print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal())) + print_ln('forward layer hidden_state %s', hidden_state[0][1][0:10].reveal()) + # print_ln('forward layer wv full %s', self.wv.Y.reveal()) + + # max_size = program.budget // self.attention_head_size + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) + def _(i, j): + # for j in range(self.num_attention_heads): + query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) # this is mem inefficient? + key_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + # print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:]) + + @for_range_opt(self.seq_len) + def _(k): + # for k in range(self.seq_len): + query_sub[k] = self.wq.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + key_sub[k] = self.wk.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + + # print_ln("query_sub %s %s", i, j) + res = query_sub.direct_mul_trans(key_sub) + self.attention_scores[i].assign_part_vector(res, j) + + if self.debug_output: + print_ln('forward layer attention_scores %s', self.attention_scores[0][0].reveal()) + # print_ln('forward layer attention_scores full %s', self.attention_scores.reveal()) + + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads, self.seq_len]) + def _(i, j, k): + self.attention_scores[i][j][k][:] = self.attention_scores[i][j][k][:] / math.sqrt(self.attention_head_size) + self.attention_scores[i][j][k][:] = softmax(self.attention_scores[i][j][k][:]) + + self.dropout.X.address = self.attention_scores.address + self.dropout.forward(batch=inc_batch, training=training) + + if self.debug_output: + print_ln('forward layer dropout full %s', self.dropout.Y.reveal()) + + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) + def _(i, j): + value_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + + @for_range_opt([self.seq_len]) + def _(k): + value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + # value_sub[k] = self.wv.Y[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] + + res = sfix.Matrix(self.seq_len, self.attention_head_size) + res.assign_vector(self.dropout.Y[i][j].direct_mul(value_sub)) + # res = self.dropout.Y[i][j].direct_mul(value_sub) + + @for_range_opt([self.seq_len]) + def _(k): + self.context[i][k].assign_part_vector(res[k], + j * self.attention_head_size + ) + # for k in range(self.seq_len): + # self.context[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] = res[k * self.attention_head_size:(k + 1) * self.attention_head_size] + + # How to transfer to forward? + + # missing half of the values ? + # print_ln('forward layer old_context %s', self.old_context[0].get_vector().reveal()) + if self.debug_output: + print_ln('forward layer multiheadattention before internal output %s', self.context[0][0][0:20].get_vector().reveal()) + + if self.debug_output: + print_ln('forward layer hidden_state %s', hidden_state[0][1][0:20].reveal()) + + self.output.forward(inc_batch, hidden_state, training, batch) + if self.debug_output: + print_ln('forward multiheadattention output %s', self.output.Y[0][0][0:20].reveal()) + print_ln("") + + # return context + + def reset(self): + self.wq.reset() + self.wk.reset() + self.wv.reset() + self.output.reset() + + def backward(self, compute_nabla_X=True, batch=None): + N = len(batch) + dense_layers = [self.wq, self.wk, self.wv] + for layer in dense_layers: + layer.nabla_Y.alloc() # we will fill them manually below + layer.nabla_X.alloc() # we have to add them up layer + + self.output.nabla_Y.address = self.nabla_Y.address + + self.output.nabla_X.alloc() + self.nabla_context.address = self.output.nabla_X.address + + self.nabla_attention_scores.address = self.dropout.nabla_X + + nabla_y_hidden_state = self.output.backward(True, batch) + + if self.debug_output: + print_ln("backward layer attention output.nabla_Y %s", self.output.nabla_Y.reveal_nested()[0][0][:8]) + print_ln("backward layer attention output.nabla_X %s", self.output.nabla_X.reveal_nested()[0][0][:8]) + + # Backprop context + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) + def _(i, j): + res = sfix.Matrix(self.seq_len, self.attention_head_size) + value_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + + @for_range_opt([self.seq_len]) + def _(k): + # dout_bth + res[k].assign_vector(self.nabla_context[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)) # nabla_Y + # value_t2 + value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + + nabla_value_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + + # dvalue_t2 = dout_bth * att_bth + nabla_value_sub.assign_vector(self.dropout.Y[i][j].direct_trans_mul(res)) + # nabla_value_sub.assign_vector(self.context[i][j].direct_trans_mul(res)) + + # datt_bth = dout_bth * value_t2 + self.dropout.nabla_Y[i][j].assign_vector(res.direct_mul_trans(value_sub)) + + @for_range_opt([self.seq_len]) + def _(k): + # value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + self.wv.nabla_Y[i][k].assign_part_vector( + nabla_value_sub[k], + j * self.attention_head_size) + + print("RES MULTI BACK", self.dropout.Y, res, self.num_attention_heads, self.attention_head_size) + + self.dropout.nabla_X.alloc() + self.dropout.backward(True, batch) + + if self.debug_output: + # Dropout nabla y is correct + # wv nabla_Y also correct + print_ln("backward layer attention dropout.nabla_Y %s", self.dropout.nabla_Y.reveal_nested()[:8]) + print_ln("backward layer attention wv.nabla_Y %s", self.wv.nabla_Y.reveal_nested()[:8]) + + # attention to pre + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads, self.seq_len]) + def _(i, j, k): + @for_range_opt([self.seq_len, self.seq_len]) + def _(t1, t2): + indicator = cfix(t1 == t2) + # local_deriv = self.attention_scores[i][j][k][t1] * (indicator - self.attention_scores[i][j][k][t2]) + local_deriv = self.dropout.Y[i][j][k][t1] * (indicator - self.dropout.Y[i][j][k][t2]) + + # print_ln("indiciator %s %s %s %s %s %s", t1, t2, indicator, local_deriv.reveal(), self.dropout.Y[i][j][k][t1].reveal(), self.attention_scores[i][j][k][t2].reveal()) + self.nabla_preattention_scores[i][j][k][t2] += local_deriv * self.dropout.nabla_X[i][j][k][t1] # x or y? + + print_ln("attention_scores %s", self.attention_scores.reveal()) + print_ln("nabla preattention scores %s", self.nabla_preattention_scores.reveal()) + + scale = 1 / math.sqrt(self.attention_head_size) + # backward pass 1 + @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) + def _(i, j): + # for j in range(self.num_attention_heads): + query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + key_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + # print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:]) + @for_range_opt(self.seq_len) + def _(k): # This mempcopy is ugly + query_sub[k] = self.wq.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + key_sub[k] = self.wk.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) + + # nabla_query_sub = key_sub.direct_trans_mul(self.nabla_preattention_scores[i][j]) + # nabla_key_sub = self.nabla_preattention_scores[i][j].direct_mul_trans(key_sub) + + print_ln("preatt %s", self.nabla_preattention_scores[i][j].reveal()) + + nabla_query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) + nabla_key_sub_trans = sfix.Matrix(self.attention_head_size, self.seq_len) + nabla_query_sub.assign_vector(self.nabla_preattention_scores[i][j].direct_trans_mul(key_sub) * scale) + nabla_key_sub_trans.assign_vector(query_sub.direct_trans_mul(self.nabla_preattention_scores[i][j]) * scale) + nabla_key_sub = nabla_key_sub_trans.transpose() + + print_ln("nabla query sub %s", nabla_query_sub.reveal()) + print_ln("nabla key sub %s", nabla_key_sub.reveal()) + + # nabla_key_sub is seq_len_seqlen, copy back into wk which is seq_len, all_head_size + @for_range_opt(self.seq_len) + def _(k): # This mempcopy is ugly? + self.wq.nabla_Y[i][k].assign_part_vector(nabla_query_sub[k], j * self.attention_head_size) + self.wk.nabla_Y[i][k].assign_part_vector(nabla_key_sub[k], j * self.attention_head_size) + + if self.debug_output: + print_ln("backward layer attention wq.nabla_Y %s", self.wq.nabla_Y.reveal_nested()[:8]) + + # wk slightly off + print_ln("backward layer attention wk.nabla_Y %s", self.wk.nabla_Y.reveal_nested()[:8]) + + self.wq.backward(compute_nabla_X, batch) + self.wk.backward(compute_nabla_X, batch) + self.wv.backward(compute_nabla_X, batch) + + @multithread(self.n_threads, len(batch)) + def _(base, size): + sum_layers = sum([layer.nabla_X.get_part_vector(base, size) for layer in dense_layers]) + self.nabla_X.assign_part_vector( + sum_layers, base) + + if self.debug_output: + # TODO: Wq seems off still + print_ln("backward layer attention wq.nabla_X %s", self.wq.nabla_X.reveal_nested()[:8]) + + return nabla_y_hidden_state + class Optimizer: """ Base class for graphs of layers. """ n_threads = Layer.n_threads @@ -2437,7 +3463,8 @@ class Optimizer: layer.backward(compute_nabla_X=False, batch=self.batch_for(layer, batch)) else: - layer.nabla_X.alloc() + if len(layer.inputs) == 1: + layer.nabla_X.alloc() layer.backward(batch=self.batch_for(layer, batch)) if len(layer.inputs) == 1: layer.inputs[0].nabla_Y.address = \ @@ -3227,8 +4254,8 @@ class keras: layers.append(FixAveragePool2d(input_shape, None, **layer[1])) input_shape = layers[-1].Y.sizes elif name == 'dropout': - layers.append(Dropout(batch_size, reduce( - operator.mul, layers[-1].Y.sizes[1:]), + layers.append(Dropout([batch_size] + [reduce( + operator.mul, layers[-1].Y.sizes[1:])], alpha=layer[1])) input_shape = layers[-1].Y.sizes elif name == 'flatten': @@ -3330,8 +4357,23 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, return reduce(operator.mul, x) import torch + import torch.fx + + # Custom tracer to prevent inlining BERT layers + class BertTracer(torch.fx.Tracer): + def is_leaf_module(self, m, module_qualified_name): + # Treat BertLayer, BertPooler, and BertEmbeddings as leaf modules (don't trace into them) + type_name = type(m).__name__ + if any(x in type_name for x in ['BertLayer', 'BertPooler', 'BertEmbeddings']): + return True + return super().is_leaf_module(m, module_qualified_name) def process(item, inputs, input_shape, args, kwargs={}): + # Skip assertion and validation functions from torch.fx trace + if callable(item) and hasattr(item, '__name__') and ( + item.__name__.startswith('_assert') or + item.__name__ in ('eq', 'getitem', 'size')): + return if item == torch.cat: if len(inputs) > 1: layers.append( @@ -3387,18 +4429,22 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, elif name == 'Conv2d': layers.append(easyConv2d(input_shape, batch_size, item.out_channels, item.kernel_size, item.stride, - item.padding, **layer_args.get(item, {}))) + item.padding, item.bias is not None, **layer_args.get(item, {}))) input_shape = layers[-1].Y.shape if input_via is not None: - shapes = [x.shape for x in + if item.bias is not None: + shapes = [x.shape for x in (layers[-1].weights, layers[-1].bias)] + else: + shapes = [layers[-1].weights.shape] + print("shapes", shapes, layers[-1].weights.shape) import numpy swapped = numpy.moveaxis( numpy.array(item.weight.detach()), 1, -1) layers[-1].weights = \ layers[-1].weights.value_type.input_tensor_via( input_via, swapped) - assert layers[-1].weights.shape == shapes[0] + assert layers[-1].weights.shape == shapes[0], f"{layers[-1].weights.shape} != {shapes[0]}" if isinstance(item.bias, torch.Tensor): layers[-1].bias = sfix.input_tensor_via( input_via, item.bias.detach()) @@ -3429,7 +4475,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, layers.append(Relu(input_shape)) elif name == 'Flatten': return - elif name == 'BatchNorm2d': + elif name == 'BatchNorm2d' or name == 'BatchNorm1d': layers.append(BatchNorm(layers[-1].Y.sizes)) if input_via is not None: layers[-1].epsilon = item.eps @@ -3437,18 +4483,124 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, item.weight.detach()) layers[-1].bias = sfix.input_tensor_via(input_via, item.bias.detach()) + layers[-1].mu_hat = sfix.input_tensor_via( + input_via, item.running_mean.detach()) + layers[-1].var_hat = sfix.input_tensor_via( + input_via, item.running_var.detach()) elif name == 'Dropout': - layers.append(Dropout(input_shape[0], mul(layers[-1].Y.sizes[1:]), - alpha=item.p)) + alpha = item.p + if alpha == 0.1: + print('WARNING: dropout rate 0.1 not supported, using 0.125') + alpha = 0.125 + layers.append(Dropout([input_shape[0]] + list(layers[-1].Y.sizes[1:]), + alpha=alpha)) input_shape = layers[-1].Y.sizes + elif name == 'BertForSequenceClassification': + process(item.bert) + process(item.dropout) + process(item.classifier) + elif name == 'BertModel': + bert_config = item.config + process(item.embeddings) + process(item.encoder) + process(item.pooler) + elif name == 'BertEmbeddings': + print('Embedding layer not implemented.', item) + pass # no-op + elif name == 'BertEncoder': + for x in item.layer: + process(x) + elif name == 'BertLayer': + # Get config from the model or item + if 'bert_config' in locals(): + config = bert_config + elif hasattr(model, 'config'): + config = model.config + elif hasattr(item, 'config'): + config = item.config + else: + raise CompilerError('BertLayer requires config but none found in model or item') + hidden_state = config.hidden_size + intermediate_size = config.intermediate_size + num_attention_heads = config.num_attention_heads + layernorm_eps = config.layer_norm_eps + seq_len = input_shape[1] + rsqrt_approx = False + layer = BertLayer(input_shape[0], seq_len, hidden_state, intermediate_size, num_attention_heads, + layernorm_eps, 0.125, rsqrt_approx, batch_size=batch_size) + if input_via is not None: + layer.load_state_dict(item.state_dict(), input_via) + layers.append(layer) + input_shape = [batch_size, seq_len, hidden_state] + elif name == 'BertPooler': + # Get config from the model or item + if 'bert_config' in locals(): + config = bert_config + elif hasattr(model, 'config'): + config = model.config + elif hasattr(item, 'config'): + config = item.config + else: + raise CompilerError('BertPooler requires config but none found in model or item') + layer = BertPooler(input_shape[0], input_shape[1], config.hidden_size) + if input_via is not None: + layer.load_state_dict(item.state_dict(), input_via) + layers.append(layer) + elif name == "Identity": + return else: raise CompilerError('unknown PyTorch module: %s' % item) layers[-1].inputs = inputs input_shape = data_input_shape + [1] * (4 - len(data_input_shape)) - torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes) + # torch_layers = list(torch.fx.symbolic_trace(model).graph.nodes) + # Use custom tracer to keep BERT layers as modules + tracer = BertTracer() + + # Determine concrete_args based on model type + # Check if this is BertModel or BertEncoder + import torch as torch_module + model_type = type(model).__name__ + if model_type == 'BertModel': + # BertModel requires both input_ids and token_type_ids in concrete_args + # None means they must be provided as positional args during trace + concrete_args = { + "attention_mask": None, + "token_type_ids": None, + "position_ids": None, + "head_mask": None, + "inputs_embeds": None, + "encoder_hidden_states": None, + "encoder_attention_mask": None, + "past_key_values": None, + "use_cache": None, + "output_attentions": False, + "output_hidden_states": False, + "return_dict": False, + } + else: + # BertEncoder and other modules + concrete_args = { + "attention_mask": None, + "head_mask": None, + "encoder_hidden_states": None, + "encoder_attention_mask": None, + "past_key_values": None, + "use_cache": None, + "output_attentions": False, + "output_hidden_states": False, + "return_dict": False, + } + + graph = tracer.trace(model, concrete_args=concrete_args) + torch_layers = list(graph.nodes) + print(torch_layers) for i, layer in enumerate(torch_layers[1:-1]): + # Skip non-module and non-function operations (like getitem, assertions, etc.) + if layer.op not in ('call_module', 'call_function'): + continue + if layer.op == 'call_module': target = model for attr in layer.target.split('.'): @@ -3456,7 +4608,8 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, else: target = layer.target if not layers: - assert layer.args == (torch_layers[i],) + print(f"First layer check: layer.args={layer.args}, torch_layers[i]={torch_layers[i]}") + # First real layer - no need for assertion inputs = None else: if len(layer.args) < 2 or (layer.args[1] != 1 and diff --git a/Programs/Source/bert_inference.mpc b/Programs/Source/bert_inference.mpc new file mode 100644 index 00000000..39eec8f1 --- /dev/null +++ b/Programs/Source/bert_inference.mpc @@ -0,0 +1,378 @@ +""" +BERT Inference in MP-SPDZ + +The program: +1. Loads a pre-trained BERT-tiny model fine-tuned on QNLI +2. Converts it to MP-SPDZ representation using ml.layers_from_torch() +3. Runs inference on N samples from the validation set +4. Compares MP-SPDZ outputs with PyTorch outputs layer-by-layer +5. Computes and reports accuracy + +""" + +import ml +import torch +import torch.nn as nn +from transformers import BertForSequenceClassification, BertTokenizer +from datasets import load_dataset + +# ============================================================================ +# Configuration +# ============================================================================ + +MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden) +MAX_LENGTH = 64 # Maximum sequence length +N_SAMPLES = 25 # Number of samples to evaluate +BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance) + +# GLUE task configuration +TASK_NAME = 'qnli' +TASK_KEYS = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +# ============================================================================ +# Model Loading and Data Preparation +# ============================================================================ + +print(f"Loading model: {MODEL_NAME}") +tokenizer = BertTokenizer.from_pretrained(MODEL_NAME) +model = BertForSequenceClassification.from_pretrained(MODEL_NAME) +model.eval() + +print(f"Loading {TASK_NAME} dataset from GLUE benchmark") +dataset = load_dataset('glue', TASK_NAME) +validation = dataset['validation'].take(N_SAMPLES) + +print(f"Configuration:") +print(f" Model: {MODEL_NAME}") +print(f" Task: {TASK_NAME}") +print(f" Samples: {N_SAMPLES}") +print(f" Max length: {MAX_LENGTH}") +print(f" Batch size: {BATCH_SIZE}") +print(f" Model architecture: {model.config.num_hidden_layers} layers, " + f"{model.config.hidden_size} hidden size") + + +def tokenize_dataset(example): + """Tokenize dataset examples based on task configuration.""" + sentence1_key, sentence2_key = TASK_KEYS[TASK_NAME] + args = ( + (example[sentence1_key],) if sentence2_key is None + else (example[sentence1_key], example[sentence2_key]) + ) + return tokenizer(*args, truncation=True, padding='max_length', max_length=MAX_LENGTH) + + +def embed_inputs(example): + """Convert tokenized inputs to BERT embeddings.""" + input_ids = torch.tensor(example["input_ids"]) + token_type_ids = torch.tensor(example["token_type_ids"]) + embedding = model.bert.embeddings(input_ids, token_type_ids=token_type_ids).detach() + return {'embedding': embedding} + + +# Tokenize and embed the validation data +print("Tokenizing and embedding validation data...") +tokenized_data = validation.map(tokenize_dataset, batched=True) +embedded_data = tokenized_data.map(embed_inputs, batched=True) + +# ============================================================================ +# PyTorch Inference (Ground Truth) +# ============================================================================ + +print("\nRunning PyTorch inference for ground truth...") + + +def run_pytorch_inference(model, dataset, n_samples): + """Run inference using PyTorch and collect predictions.""" + model.eval() + predictions = [] + probabilities = [] + labels = [] + + with torch.no_grad(): + for i in range(n_samples): + example = dataset[i] + inputs = { + key: torch.tensor([val]) + for key, val in example.items() + if key in ['input_ids', 'attention_mask', 'token_type_ids'] + } + print("PT Inputs", inputs) + + outputs = model(**inputs) + logits = outputs.logits + probs = torch.softmax(logits, dim=-1) + predicted = torch.argmax(logits, dim=-1).item() + + predictions.append(predicted) + probabilities.append(probs.detach()) + labels.append(example['label']) + + return predictions, probabilities, labels + + +pt_predictions, pt_probabilities, true_labels = run_pytorch_inference(model, tokenized_data, N_SAMPLES) +pt_accuracy = sum(p == l for p, l in zip(pt_predictions, true_labels)) / len(true_labels) +print(f"PyTorch accuracy: {pt_accuracy:.4f} ({sum(p == l for p, l in zip(pt_predictions, true_labels))}/{len(true_labels)})") + +# ============================================================================ +# MP-SPDZ Model Conversion +# ============================================================================ + + +class BertEncoderWithHead(nn.Module): + """Wrapper combining BERT encoder, pooler, dropout, and classification head.""" + + def __init__(self, encoder, pooler, dropout, classifier, config): + super().__init__() + self.encoder = encoder + self.pooler = pooler + self.dropout = dropout + self.classifier = classifier + self.config = config + + def forward(self, hidden_states, attention_mask=None, head_mask=None, + encoder_hidden_states=None, encoder_attention_mask=None, + past_key_values=None, use_cache=None, output_attentions=False, + output_hidden_states=False, return_dict=False): + """Forward pass through encoder, pooler, dropout, and classifier.""" + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + return logits + + +# Build MPC-compatible input tensors +def build_mpc_tensors(dataset): + """Convert dataset to MPC-compatible sfix tensors.""" + with dataset.formatted_as("torch", ["embedding", "label"]): + embeddings = torch.concat(list(map(lambda x: x['embedding'], dataset.iter(batch_size=1)))) + labels = torch.tensor([x['label'] for x in dataset.iter(batch_size=1)]) + + # One-hot encode labels (2 classes for QNLI) + labels_onehot = torch.nn.functional.one_hot(labels, num_classes=2) + + embeddings_sfix = sfix.input_tensor_via(0, embeddings.numpy()) + labels_sfix = sfix.input_tensor_via(0, labels_onehot.numpy()) + + return embeddings_sfix, labels_sfix, labels.numpy() + + +test_embeddings, test_labels_onehot, test_labels = build_mpc_tensors(embedded_data) +model_shape = test_embeddings.shape +print(f"Input shape: {model_shape}") + +# Wrap model for conversion +bert_wrapped = BertEncoderWithHead( + model.bert.encoder, + model.bert.pooler, + model.dropout, + model.classifier, + model.config +) + +# Convert to MP-SPDZ layers +print("Tracing model and converting to MP-SPDZ layers...") +mpc_layers = ml.layers_from_torch(bert_wrapped, model_shape, input_via=0, batch_size=BATCH_SIZE) +print(f"MP-SPDZ model: {len(mpc_layers)} top-level layers") +for i, layer in enumerate(mpc_layers): + print(f" Layer {i}: {layer}") + +# ============================================================================ +# MP-SPDZ Inference +# ============================================================================ + +print("\nRunning MP-SPDZ inference...") + +# Configure fixed-point arithmetic +sfix.round_nearest = False +program.use_trunc_pr = False + +# Create optimizer (used for forward pass) +optimizer = ml.SGD(mpc_layers) + +# Run inference using optimizer.eval() to get predictions +print_ln("\n=== Starting MPC Inference ===") +print_ln("Samples: %s", N_SAMPLES) +print_ln("Batch size: %s", BATCH_SIZE) + +# Convert Python lists to compile-time constants +pt_preds_list = [int(p) for p in pt_predictions] +true_labels_list = [int(l) for l in true_labels] + +# Use optimizer.eval() to get MPC predictions (argmax) +print_ln("Running MPC inference...") +mpc_predictions = optimizer.eval(test_embeddings, batch_size=BATCH_SIZE, top=True) + +print_ln("\n=== Per-Sample Comparison ===") +print_ln("Sample | True Label | PyTorch Pred | MPC Pred | PT Correct | MPC Correct | Match") +print_ln("-" * 80) + +# Track statistics +n_correct = MemValue(regint(0)) +n_mpc_matches_pytorch = MemValue(regint(0)) + +# Use regular Python for loop to access compile-time constants +for i in range(N_SAMPLES): + # Get predictions + mpc_pred = mpc_predictions[i].reveal() + true_label = true_labels_list[i] + pt_pred = pt_preds_list[i] + + # Check correctness + mpc_correct = cint(mpc_pred == true_label) + pt_correct = cint(pt_pred == true_label) + predictions_match = cint(mpc_pred == pt_pred) + + # Update statistics + n_correct.iadd(mpc_correct) + n_mpc_matches_pytorch.iadd(predictions_match) + + # Print per-sample results + print_ln("%s | %s | %s | %s | %s | %s | %s", + i, true_label, pt_pred, mpc_pred, + pt_correct, mpc_correct, predictions_match) + +# Compute final statistics +mpc_accuracy = cfix(n_correct.read(), k=63, f=31) / N_SAMPLES +match_rate = cfix(n_mpc_matches_pytorch.read(), k=63, f=31) / N_SAMPLES + +print_ln("\n=== Results Summary ===") +print_ln("PyTorch Accuracy: %s", pt_accuracy) +print_ln("MP-SPDZ Correct: %s/%s", n_correct.read(), N_SAMPLES) +print_ln("MP-SPDZ Accuracy: %s", mpc_accuracy.reveal()) +print_ln("MPC-PyTorch Match: %s/%s = %s", + n_mpc_matches_pytorch.read(), N_SAMPLES, match_rate.reveal()) + +# ============================================================================ +# Layer-by-Layer Comparison using Forward Hooks +# ============================================================================ + +print_ln("\n=== Layer-by-Layer Comparison ===") + +# Map to store PyTorch activations +activation_map = {} + +def get_activation(name): + """Create a forward hook to capture layer outputs.""" + def hook(model, input, output): + if isinstance(output, tuple): + actual_output = output[0] + else: + actual_output = output + activation_map[name] = actual_output.detach() + return hook + +# Build layer comparison list +def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt): + """Map MPC BertLayer components to PyTorch components.""" + return [ + (bert_layer_mpc.multi_head_attention, bert_layer_pt.attention), + (bert_layer_mpc.intermediate, bert_layer_pt.intermediate), + (bert_layer_mpc.output, bert_layer_pt.output), + (bert_layer_mpc, bert_layer_pt), + ] + +# Build complete layer comparison list +layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in + zip(mpc_layers[:-4], model.bert.encoder.layer)] +layers_to_compare = [x for xs in layers_to_compare for x in xs] +layers_to_compare.append((mpc_layers[-4], model.bert.pooler)) +layers_to_compare.append((mpc_layers[-3], model.dropout)) +layers_to_compare.append((mpc_layers[-2], model.classifier)) + +# Register forward hooks +for layer_id, (_, pt_layer) in enumerate(layers_to_compare): + pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}')) + +# Run PyTorch forward pass to populate activation_map +print("Capturing PyTorch layer outputs...") +with torch.no_grad(): + for i in range(N_SAMPLES): + activation_map.clear() # Clear for each sample + + # Get sample embedding + with embedded_data.formatted_as("torch", ["embedding"]): + sample_embedding = embedded_data[i]['embedding'].unsqueeze(0) + + # Run forward through wrapped model + _ = bert_wrapped(sample_embedding) + + # Store activations for this sample + if i == 0: # Only compare first sample to save time + break + +print(f"Captured {len(activation_map)} layer outputs from PyTorch") + +# Run MPC forward pass using reveal_correctness +import numpy +pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities])) +pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor) + +test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:])) +test_embeddings_one.assign(test_embeddings.get_part_vector(0)) + +pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:])) +pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0)) + +print_ln("Running MPC forward pass for layer comparison...") +_ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE) + +# Compare layers +print_ln("\nLayer-by-layer comparison (Sample 0 only):") +print_ln("=" * 100) + +for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare): + layer_id = f"{idx}.{type(pt_layer).__name__}" + + if layer_id not in activation_map: + continue + + # Skip dropout layers since they use different random masks + if 'Dropout' in type(pt_layer).__name__: + print_ln("%s | Skipped (dropout)", layer_id) + continue + + # Get PyTorch values + pt_values = activation_map[layer_id] + pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal() + + # Get MPC values + mpc_output = mpc_layer.Y[0].get_vector().reveal() + + # Compute detailed statistics + total_abs_diff = sum(abs(pt_at_runtime - mpc_output)) + pt_magnitude = sum(abs(pt_at_runtime)) + + # Print layer comparison + print_ln("\n%s", layer_id) + print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime)) + print_ln(" Total Abs Diff: %s", total_abs_diff) + print_ln(" PT Total Magnitude: %s", pt_magnitude) + print_ln(" First 8 PT: %s", pt_at_runtime[:8]) + print_ln(" First 8 MPC: %s", mpc_output[:8]) + +print_ln("\n=== Inference Complete ===")