From 7b72a46fe289ca518c2f13d8f526c6c93c45f174 Mon Sep 17 00:00:00 2001 From: Hidde L Date: Wed, 17 Dec 2025 08:22:18 +0100 Subject: [PATCH] - Add layers like Gelu and LayerNorm explicitly to layers_from_torch so they are more generally usable (beyond BertLayer) - Fix Dense.__repr__ to include the missing d (sequence length) parameter - Default LayerNorm and BertLayer to use approximate rsqrt (approx=True) for better MPC performance - Remove debug prints - Fix SubMultiArray.__add__ to handle addition with plain arrays (check for sizes/size attributes) - Fix SubMultiArray.__str__ to use self.address instead of self.array._address to avoid attribute errors - Add LAYER_COMPARISON flag to bert_inference.mpc to skip the expensive layer-by-layer comparison section (~95% of compile time) by default --- Compiler/ml.py | 118 +++++++------------- Compiler/types.py | 7 +- Programs/Source/bert_inference.mpc | 168 +++++++++++++++-------------- 3 files changed, 132 insertions(+), 161 deletions(-) diff --git a/Compiler/ml.py b/Compiler/ml.py index 30059d5b..fe67a921 100644 --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -866,8 +866,8 @@ class Dense(DenseBase): self.f_input = self.Y def __repr__(self): - return '%s(%s, %s, %s, activation=%s)' % \ - (type(self).__name__, self.N, self.d_in, + return '%s(%s, %s, %s, %s, activation=%s)' % \ + (type(self).__name__, self.N, self.d_in, self.d, self.d_out, repr(self.activation)) def reset(self): @@ -1750,7 +1750,7 @@ class LayerNorm(Layer): # Changed class name thetas = lambda self: (self.weights, self.bias) nablas = lambda self: (self.nabla_weights, self.nabla_bias) - def __init__(self, shape, approx=False, layernorm_eps=None, args=None): + def __init__(self, shape, approx=True, layernorm_eps=None, args=None): if len(shape) == 2: shape = [shape[0], 1, shape[1]] # Not sure why this extra dimension is added tensors = (Tensor(shape, sfix) for i in range(4)) @@ -1805,6 +1805,7 @@ class LayerNorm(Layer): # Changed class name tmp = self.weights[:] * (sel_X[:] - mu_sel) * fac_sel # Removed self.mu reference sel_Y[:] = self.bias[:] + tmp + @_layer_method_call_tape def forward(self, batch, training=False): d = self.X.sizes[1] d_in = self.X.sizes[2] @@ -2660,7 +2661,7 @@ class BertBase(BaseLayer, FixBase): class BertPooler(BertBase): thetas = lambda self: self.dense.thetas() - nablas = lambda self: self.dense.nablas() # refer to downstream layers? + nablas = lambda self: self.dense.nablas() def __init__(self, n_examples, seq_len, hidden_state): input_shape = [n_examples, seq_len, hidden_state] @@ -2669,28 +2670,19 @@ class BertPooler(BertBase): self.dense = Dense(n_examples, hidden_state, hidden_state) self.activation = Tanh(output_shape) - self.d_out = hidden_state - - - def _forward(self, batch): - # self.dense.X.address = self.X.address self.activation.X.address = self.dense.Y.address self.activation.Y.address = self.Y.address - # grab the first repr? + self.d_out = hidden_state + + def _forward(self, batch): # batch contains [n_batch, n_heads, n_dim] @for_range(len(batch)) def _(j): self.dense.X[j][:] = self.X[batch[j]][0][:] - # if self.debug_output: - # print_ln("forward layer pooler.dense X %s", self.dense.X.reveal_nested()) - self.dense.forward(batch) - # print_ln("LINEAR Layer weights after bertpooler.dense: %s", self.opt.layers[-2].W.reveal_nested()) - - self.activation._forward(batch) - # print_ln("LINEAR Layer weights after bertpooler.activation: %s", self.opt.layers[-2].W.reveal_nested()) + self.activation.forward(batch) def reset(self): self.dense.reset() @@ -2757,36 +2749,21 @@ class BertLayer(BertBase): self.intermediate = BertIntermediate(internal_shape, hidden_state, intermediate_size, seq_len) self.output = BertOutput(internal_shape, intermediate_size, hidden_state, seq_len, dropout, layernorm_eps, rsqrt_approx) - self.hidden_state = sfix.Tensor(input_shape) # TODO: Could also make this smaller - # self.nabla_hidden_state = sfix.Tensor(input_shape) - # self.nabla_hidden_state.alloc() - - # self.X.address = self.multi_head_attention.X.address - # self.Y.address = self.output.Y.address - self.d_out = hidden_state - print("Init BertLayer", input_shape, output_shape) - + @_layer_method_call_tape def forward(self, batch, training=False): if batch is None: batch = Array.create_from(regint(0)) self.multi_head_attention._X.address = self.X.address self.output.Y.address = self.Y.address - self.hidden_state.address = self.X.address - # self.multi_head_attention.Y.address = self.Y.address - - self.multi_head_attention.forward(batch, self.hidden_state, training) - # if self.debug_output: - # print_ln("our layer X %s %s", self.X[0][0][0].reveal(), self.output.X[0][0][0].reveal()) + self.multi_head_attention.forward(batch, self.X, training) if self.debug_output: print_ln("forward layer multi_head_attention %s %s", self.multi_head_attention.Y[0][1][0].reveal(), sum(sum(self.multi_head_attention.Y[0].reveal()))) # print_ln("forward layer multi_head_attention full %s", self.multi_head_attention.Y.reveal()) - print("Forward Attention") - batch_inc = regint.Array(len(batch)) batch_inc.assign(regint.inc(len(batch))) self.intermediate.X.address = self.multi_head_attention.Y.address @@ -2794,7 +2771,6 @@ class BertLayer(BertBase): if self.debug_output: print_ln("forward layer intermediate %s %s %s", self.intermediate.Y.shape, self.intermediate.Y[0][1][0:20].reveal(), sum(sum(self.intermediate.Y[0].reveal()))) - print_ln(" ") self.output.X.address = self.intermediate.Y.address @@ -2803,14 +2779,7 @@ class BertLayer(BertBase): if self.debug_output: print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) - # print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) - # print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal()))) - print_ln("our layer output %s %s %s %s", self.output.Y.address, len(self.Y[0].reveal()), self.output.Y[0][0][0:20].reveal(), sum(sum(self.output.Y[0].reveal()))) - # print_ln("shapes %s %s", self.Y.sizes, self.output.Y.sizes) - # print_ln("types %s %s %s %s %s %s", self.Y.value_type, self.output.Y.value_type, type(self.Y), type(self.output.Y), self, self.output) - - print("Forward BertLayer") def reset(self): self.multi_head_attention.reset() @@ -2854,10 +2823,8 @@ class BertLayer(BertBase): self.output.nabla_Y.address = self.nabla_Y.address self.intermediate.nabla_Y.address = self.output.nabla_X.address self.multi_head_attention.nabla_Y.address = self.intermediate.nabla_X.address - # self.multi_head_attention.nabla_X.address = self.nabla_X.address nabla_y_multi_head_attention_from_layernorm = self.output.backward(True, batch) - # print_ln("Backward BertLayer.output.nabla_X %s", self.output.nabla_X.reveal_nested()[:8]) self.intermediate.backward(True, batch) # residual, add it to Y because it gave the output of multihadattention to output @@ -2894,7 +2861,11 @@ class BertIntermediate(BertBase): self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len) self.activation = Gelu([n_examples, seq_len, intermediate_size]) + self.dense.X.address = self.X.address + self.activation.X.address = self.dense.Y.address + self.activation.Y.address = self.Y.address + @_layer_method_call_tape def forward(self, batch=None, training=None): self.dense.X.address = self.X.address self.activation.X.address = self.dense.Y.address @@ -2904,7 +2875,7 @@ class BertIntermediate(BertBase): if self.debug_output: print_ln("forward layer intermediate.dense %s", self.dense.Y[0][0][0:20].reveal()) - self.activation._forward(batch) + self.activation.forward(batch) def reset(self): self.dense.reset() @@ -2912,8 +2883,6 @@ class BertIntermediate(BertBase): def backward(self, compute_nabla_X=True, batch=None): self.activation.nabla_X.alloc() - # print_ln("Backward BertIntermediate.nabla_X %s", self.nabla_X.reveal_nested()[:8]) - self.activation.nabla_Y.address = self.nabla_Y.address self.dense.nabla_Y.address = self.activation.nabla_X.address self.dense.nabla_X.address = self.nabla_X.address @@ -2931,13 +2900,13 @@ class BertOutput(BertBase): input_shape = [n_examples, seq_len, intermediate_size] output_shape = [n_examples, seq_len, hidden_size] self.input_shape = input_shape - print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx) super(BertOutput, self).__init__(input_shape, output_shape) self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len) self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx) self.dropout = Dropout([n_examples, seq_len, hidden_size], alpha=dropout) + @_layer_method_call_tape def forward(self, batch, input_tensor, training=False, input_tensor_batch=None): # Because input_tensor might be the full training data shape self.dense.X.address = self.X.address @@ -2965,18 +2934,12 @@ class BertOutput(BertBase): self.layer_norm.X.assign_part_vector( self.layer_norm.X.get_part_vector(base, size) + input_tensor.get_part_vector(base, size), base) - # if self.debug_output: - # print_ln("input tensor %s", input_tensor.reveal()) - - # self.layer_norm.X[:] += input_tensor[:] # TODO: is it maybe this addition since we take the last value? would be strange if self.debug_output: print_ln("forward layer layer_norm_add %s", self.layer_norm.X[0][0][0:20].reveal()) print_ln("") self.layer_norm.forward(batch) - - def reset(self): self.dense.reset() @@ -3039,6 +3002,7 @@ class MultiHeadAttention(BertBase): self.nabla_attention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix) self.nabla_preattention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix) + @_layer_method_call_tape def forward(self, batch=None, hidden_state=None, training=None): N = len(batch) @@ -3058,32 +3022,24 @@ class MultiHeadAttention(BertBase): inc_batch.assign(regint.inc(N)) if self.debug_output: - # print_ln('forward layer wq full %s', self.wq.X.reveal()) print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal())) print_ln('forward layer hidden_state %s', hidden_state[0][1][0:10].reveal()) - # print_ln('forward layer wv full %s', self.wv.Y.reveal()) - # max_size = program.budget // self.attention_head_size @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads]) def _(i, j): - # for j in range(self.num_attention_heads): query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) # this is mem inefficient? key_sub = sfix.Matrix(self.seq_len, self.attention_head_size) - # print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:]) @for_range_opt(self.seq_len) def _(k): - # for k in range(self.seq_len): query_sub[k] = self.wq.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) key_sub[k] = self.wk.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) - # print_ln("query_sub %s %s", i, j) res = query_sub.direct_mul_trans(key_sub) self.attention_scores[i].assign_part_vector(res, j) if self.debug_output: print_ln('forward layer attention_scores %s', self.attention_scores[0][0].reveal()) - # print_ln('forward layer attention_scores full %s', self.attention_scores.reveal()) @for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads, self.seq_len]) def _(i, j, k): @@ -3103,7 +3059,6 @@ class MultiHeadAttention(BertBase): @for_range_opt([self.seq_len]) def _(k): value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size) - # value_sub[k] = self.wv.Y[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] res = sfix.Matrix(self.seq_len, self.attention_head_size) res.assign_vector(self.dropout.Y[i][j].direct_mul(value_sub)) @@ -3114,13 +3069,7 @@ class MultiHeadAttention(BertBase): self.context[i][k].assign_part_vector(res[k], j * self.attention_head_size ) - # for k in range(self.seq_len): - # self.context[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] = res[k * self.attention_head_size:(k + 1) * self.attention_head_size] - # How to transfer to forward? - - # missing half of the values ? - # print_ln('forward layer old_context %s', self.old_context[0].get_vector().reveal()) if self.debug_output: print_ln('forward layer multiheadattention before internal output %s', self.context[0][0][0:20].get_vector().reveal()) @@ -3132,8 +3081,6 @@ class MultiHeadAttention(BertBase): print_ln('forward multiheadattention output %s', self.output.Y[0][0][0:20].reveal()) print_ln("") - # return context - def reset(self): self.wq.reset() self.wk.reset() @@ -3189,8 +3136,6 @@ class MultiHeadAttention(BertBase): nabla_value_sub[k], j * self.attention_head_size) - print("RES MULTI BACK", self.dropout.Y, res, self.num_attention_heads, self.attention_head_size) - self.dropout.nabla_X.alloc() self.dropout.backward(True, batch) @@ -4402,10 +4347,20 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, raise CompilerError('multi-input layer %s not supported' % item) name = type(item).__name__ if name == 'Linear': - assert mul(input_shape[1:]) == item.in_features + # Precondition: the item assert item.bias is not None - layers.append(Dense(input_shape[0], item.in_features, - item.out_features)) + if mul(input_shape[1:]) == item.in_features: + layers.append(Dense(input_shape[0], item.in_features, + item.out_features)) + elif input_shape[-1] == item.in_features: + # we loop over all but last dimension + assert len(input_shape) == 3, "Dense only supports one extra dimension to loop over" + d = input_shape[1] + layers.append(Dense(input_shape[0], item.in_features, + item.out_features, d)) + else: + assert False, f"input shape {input_shape} incompatible with in_features {item.in_features}" + if input_via is not None: shapes = [x.shape for x in (layers[-1].W, layers[-1].b)] import numpy @@ -4476,6 +4431,15 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, input_shape = layers[-1].shape elif name == 'ReLU' or item == torch.nn.functional.relu: layers.append(Relu(input_shape)) + elif name == 'GeLU' or item == torch.nn.functional.gelu: + layers.append(Gelu(input_shape)) + elif name == 'LayerNorm': + layers.append(LayerNorm(input_shape, True, item.eps)) + if input_via is not None: + layers[-1].weights = sfix.input_tensor_via( + input_via, item.weight.detach()) + layers[-1].beta = sfix.input_tensor_via( + input_via, item.bias.detach()) elif name == 'Flatten': return elif name == 'BatchNorm2d' or name == 'BatchNorm1d': @@ -4528,7 +4492,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None, num_attention_heads = config.num_attention_heads layernorm_eps = config.layer_norm_eps seq_len = input_shape[1] - rsqrt_approx = False + rsqrt_approx = True layer = BertLayer(input_shape[0], seq_len, hidden_state, intermediate_size, num_attention_heads, layernorm_eps, 0.125, rsqrt_approx, batch_size=batch_size) if input_via is not None: diff --git a/Compiler/types.py b/Compiler/types.py index b1f15517..a5ba14ea 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -6984,7 +6984,10 @@ class SubMultiArray(_vectorizable): :return: container of same shape and type as :py:obj:`self` """ if is_zero(other): return self - assert self.sizes == other.sizes + if hasattr(other, 'sizes'): + assert self.sizes == other.sizes + if hasattr(other, 'size'): + assert self.total_size() == other.size return self.from_vector( self.sizes, self.get_vector() + other.get_vector()) @@ -7563,7 +7566,7 @@ class SubMultiArray(_vectorizable): def __str__(self): return '%s multi-array of lengths %s at %s' % ( self.value_type, self.sizes, - '' if self.array._address is None else self.address) + '' if self.address is None else self.address) __repr__ = __str__ class MultiArray(SubMultiArray): diff --git a/Programs/Source/bert_inference.mpc b/Programs/Source/bert_inference.mpc index 39eec8f1..e24524ac 100644 --- a/Programs/Source/bert_inference.mpc +++ b/Programs/Source/bert_inference.mpc @@ -20,10 +20,12 @@ from datasets import load_dataset # Configuration # ============================================================================ -MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden) +# Model sizes: M-FAC/bert-tiny-finetuned-qnli, M-FAC/bert-mini-finetuned-qnli, yoshitomo-matsubara/bert-large-uncased-qnli +MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-mini (4 layers, 256 hidden) MAX_LENGTH = 64 # Maximum sequence length N_SAMPLES = 25 # Number of samples to evaluate BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance) +LAYER_COMPARISON = True # Set to False to skip layer-by-layer comparison (saves ~95% compile time) # GLUE task configuration TASK_NAME = 'qnli' @@ -269,110 +271,112 @@ print_ln("MPC-PyTorch Match: %s/%s = %s", # ============================================================================ # Layer-by-Layer Comparison using Forward Hooks +# (Generates ~95% of compile-time instructions. Disable for faster builds.) # ============================================================================ -print_ln("\n=== Layer-by-Layer Comparison ===") +if LAYER_COMPARISON: + print_ln("\n=== Layer-by-Layer Comparison ===") -# Map to store PyTorch activations -activation_map = {} + # Map to store PyTorch activations + activation_map = {} -def get_activation(name): - """Create a forward hook to capture layer outputs.""" - def hook(model, input, output): - if isinstance(output, tuple): - actual_output = output[0] - else: - actual_output = output - activation_map[name] = actual_output.detach() - return hook + def get_activation(name): + """Create a forward hook to capture layer outputs.""" + def hook(model, input, output): + if isinstance(output, tuple): + actual_output = output[0] + else: + actual_output = output + activation_map[name] = actual_output.detach() + return hook -# Build layer comparison list -def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt): - """Map MPC BertLayer components to PyTorch components.""" - return [ - (bert_layer_mpc.multi_head_attention, bert_layer_pt.attention), - (bert_layer_mpc.intermediate, bert_layer_pt.intermediate), - (bert_layer_mpc.output, bert_layer_pt.output), - (bert_layer_mpc, bert_layer_pt), - ] + # Build layer comparison list + def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt): + """Map MPC BertLayer components to PyTorch components.""" + return [ + (bert_layer_mpc.multi_head_attention, bert_layer_pt.attention), + (bert_layer_mpc.intermediate, bert_layer_pt.intermediate), + (bert_layer_mpc.output, bert_layer_pt.output), + (bert_layer_mpc, bert_layer_pt), + ] -# Build complete layer comparison list -layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in - zip(mpc_layers[:-4], model.bert.encoder.layer)] -layers_to_compare = [x for xs in layers_to_compare for x in xs] -layers_to_compare.append((mpc_layers[-4], model.bert.pooler)) -layers_to_compare.append((mpc_layers[-3], model.dropout)) -layers_to_compare.append((mpc_layers[-2], model.classifier)) + # Build complete layer comparison list + layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in + zip(mpc_layers[:-4], model.bert.encoder.layer)] + layers_to_compare = [x for xs in layers_to_compare for x in xs] + layers_to_compare.append((mpc_layers[-4], model.bert.pooler)) + layers_to_compare.append((mpc_layers[-3], model.dropout)) + layers_to_compare.append((mpc_layers[-2], model.classifier)) -# Register forward hooks -for layer_id, (_, pt_layer) in enumerate(layers_to_compare): - pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}')) + # Register forward hooks + for layer_id, (_, pt_layer) in enumerate(layers_to_compare): + pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}')) -# Run PyTorch forward pass to populate activation_map -print("Capturing PyTorch layer outputs...") -with torch.no_grad(): - for i in range(N_SAMPLES): - activation_map.clear() # Clear for each sample + # Run PyTorch forward pass to populate activation_map + print("Capturing PyTorch layer outputs...") + with torch.no_grad(): + for i in range(N_SAMPLES): + activation_map.clear() # Clear for each sample - # Get sample embedding - with embedded_data.formatted_as("torch", ["embedding"]): - sample_embedding = embedded_data[i]['embedding'].unsqueeze(0) + # Get sample embedding + with embedded_data.formatted_as("torch", ["embedding"]): + sample_embedding = embedded_data[i]['embedding'].unsqueeze(0) - # Run forward through wrapped model - _ = bert_wrapped(sample_embedding) + # Run forward through wrapped model + _ = bert_wrapped(sample_embedding) - # Store activations for this sample - if i == 0: # Only compare first sample to save time - break + # Store activations for this sample + if i == 0: # Only compare first sample to save time + break -print(f"Captured {len(activation_map)} layer outputs from PyTorch") + print(f"Captured {len(activation_map)} layer outputs from PyTorch") -# Run MPC forward pass using reveal_correctness -import numpy -pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities])) -pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor) + # Run MPC forward pass using reveal_correctness + import numpy + pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities])) + pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor) -test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:])) -test_embeddings_one.assign(test_embeddings.get_part_vector(0)) + test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:])) + test_embeddings_one.assign(test_embeddings.get_part_vector(0)) -pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:])) -pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0)) + pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:])) + pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0)) -print_ln("Running MPC forward pass for layer comparison...") -_ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE) + print_ln("Running MPC forward pass for layer comparison...") + _ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE) -# Compare layers -print_ln("\nLayer-by-layer comparison (Sample 0 only):") -print_ln("=" * 100) + # Compare layers + print_ln("\nLayer-by-layer comparison (Sample 0 only):") + print_ln("=" * 100) -for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare): - layer_id = f"{idx}.{type(pt_layer).__name__}" + for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare): + layer_id = f"{idx}.{type(pt_layer).__name__}" - if layer_id not in activation_map: - continue + if layer_id not in activation_map: + continue - # Skip dropout layers since they use different random masks - if 'Dropout' in type(pt_layer).__name__: - print_ln("%s | Skipped (dropout)", layer_id) - continue + # Skip dropout layers since they use different random masks + if 'Dropout' in type(pt_layer).__name__: + print_ln("%s | Skipped (dropout)", layer_id) + continue - # Get PyTorch values - pt_values = activation_map[layer_id] - pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal() + # Get PyTorch values + pt_values = activation_map[layer_id] + pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal() - # Get MPC values - mpc_output = mpc_layer.Y[0].get_vector().reveal() + # Get MPC values + mpc_output = mpc_layer.Y[0].get_vector().reveal() - # Compute detailed statistics - total_abs_diff = sum(abs(pt_at_runtime - mpc_output)) - pt_magnitude = sum(abs(pt_at_runtime)) + # Compute detailed statistics + total_abs_diff = sum(abs(pt_at_runtime - mpc_output)) + pt_magnitude = sum(abs(pt_at_runtime)) - # Print layer comparison - print_ln("\n%s", layer_id) - print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime)) - print_ln(" Total Abs Diff: %s", total_abs_diff) - print_ln(" PT Total Magnitude: %s", pt_magnitude) - print_ln(" First 8 PT: %s", pt_at_runtime[:8]) - print_ln(" First 8 MPC: %s", mpc_output[:8]) + # Print layer comparison + print_ln("\n%s", layer_id) + print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime)) + print_ln(" Total Abs Diff: %s", total_abs_diff) + print_ln(" PT Total Magnitude: %s", pt_magnitude) + print_ln(" First 8 PT: %s", pt_at_runtime[:8]) + print_ln(" First 8 MPC: %s", mpc_output[:8]) print_ln("\n=== Inference Complete ===")