- Add layers like Gelu and LayerNorm explicitly to layers_from_torch so they are more generally usable (beyond BertLayer)

- Fix Dense.__repr__ to include the missing d (sequence length) parameter
- Default LayerNorm and BertLayer to use approximate rsqrt (approx=True) for better MPC performance
- Remove debug prints
- Fix SubMultiArray.__add__ to handle addition with plain arrays (check for sizes/size attributes)
- Fix SubMultiArray.__str__ to use self.address instead of self.array._address to avoid attribute errors
- Add LAYER_COMPARISON flag to bert_inference.mpc to skip the expensive layer-by-layer comparison section (~95% of compile time) by default
This commit is contained in:
Hidde L
2025-12-17 08:22:18 +01:00
parent b47c9bb6f8
commit 7b72a46fe2
3 changed files with 132 additions and 161 deletions

View File

@@ -866,8 +866,8 @@ class Dense(DenseBase):
self.f_input = self.Y
def __repr__(self):
return '%s(%s, %s, %s, activation=%s)' % \
(type(self).__name__, self.N, self.d_in,
return '%s(%s, %s, %s, %s, activation=%s)' % \
(type(self).__name__, self.N, self.d_in, self.d,
self.d_out, repr(self.activation))
def reset(self):
@@ -1750,7 +1750,7 @@ class LayerNorm(Layer): # Changed class name
thetas = lambda self: (self.weights, self.bias)
nablas = lambda self: (self.nabla_weights, self.nabla_bias)
def __init__(self, shape, approx=False, layernorm_eps=None, args=None):
def __init__(self, shape, approx=True, layernorm_eps=None, args=None):
if len(shape) == 2:
shape = [shape[0], 1, shape[1]] # Not sure why this extra dimension is added
tensors = (Tensor(shape, sfix) for i in range(4))
@@ -1805,6 +1805,7 @@ class LayerNorm(Layer): # Changed class name
tmp = self.weights[:] * (sel_X[:] - mu_sel) * fac_sel # Removed self.mu reference
sel_Y[:] = self.bias[:] + tmp
@_layer_method_call_tape
def forward(self, batch, training=False):
d = self.X.sizes[1]
d_in = self.X.sizes[2]
@@ -2660,7 +2661,7 @@ class BertBase(BaseLayer, FixBase):
class BertPooler(BertBase):
thetas = lambda self: self.dense.thetas()
nablas = lambda self: self.dense.nablas() # refer to downstream layers?
nablas = lambda self: self.dense.nablas()
def __init__(self, n_examples, seq_len, hidden_state):
input_shape = [n_examples, seq_len, hidden_state]
@@ -2669,28 +2670,19 @@ class BertPooler(BertBase):
self.dense = Dense(n_examples, hidden_state, hidden_state)
self.activation = Tanh(output_shape)
self.d_out = hidden_state
def _forward(self, batch):
# self.dense.X.address = self.X.address
self.activation.X.address = self.dense.Y.address
self.activation.Y.address = self.Y.address
# grab the first repr?
self.d_out = hidden_state
def _forward(self, batch):
# batch contains [n_batch, n_heads, n_dim]
@for_range(len(batch))
def _(j):
self.dense.X[j][:] = self.X[batch[j]][0][:]
# if self.debug_output:
# print_ln("forward layer pooler.dense X %s", self.dense.X.reveal_nested())
self.dense.forward(batch)
# print_ln("LINEAR Layer weights after bertpooler.dense: %s", self.opt.layers[-2].W.reveal_nested())
self.activation._forward(batch)
# print_ln("LINEAR Layer weights after bertpooler.activation: %s", self.opt.layers[-2].W.reveal_nested())
self.activation.forward(batch)
def reset(self):
self.dense.reset()
@@ -2757,36 +2749,21 @@ class BertLayer(BertBase):
self.intermediate = BertIntermediate(internal_shape, hidden_state, intermediate_size, seq_len)
self.output = BertOutput(internal_shape, intermediate_size, hidden_state, seq_len, dropout, layernorm_eps, rsqrt_approx)
self.hidden_state = sfix.Tensor(input_shape) # TODO: Could also make this smaller
# self.nabla_hidden_state = sfix.Tensor(input_shape)
# self.nabla_hidden_state.alloc()
# self.X.address = self.multi_head_attention.X.address
# self.Y.address = self.output.Y.address
self.d_out = hidden_state
print("Init BertLayer", input_shape, output_shape)
@_layer_method_call_tape
def forward(self, batch, training=False):
if batch is None:
batch = Array.create_from(regint(0))
self.multi_head_attention._X.address = self.X.address
self.output.Y.address = self.Y.address
self.hidden_state.address = self.X.address
# self.multi_head_attention.Y.address = self.Y.address
self.multi_head_attention.forward(batch, self.hidden_state, training)
# if self.debug_output:
# print_ln("our layer X %s %s", self.X[0][0][0].reveal(), self.output.X[0][0][0].reveal())
self.multi_head_attention.forward(batch, self.X, training)
if self.debug_output:
print_ln("forward layer multi_head_attention %s %s", self.multi_head_attention.Y[0][1][0].reveal(), sum(sum(self.multi_head_attention.Y[0].reveal())))
# print_ln("forward layer multi_head_attention full %s", self.multi_head_attention.Y.reveal())
print("Forward Attention")
batch_inc = regint.Array(len(batch))
batch_inc.assign(regint.inc(len(batch)))
self.intermediate.X.address = self.multi_head_attention.Y.address
@@ -2794,7 +2771,6 @@ class BertLayer(BertBase):
if self.debug_output:
print_ln("forward layer intermediate %s %s %s", self.intermediate.Y.shape, self.intermediate.Y[0][1][0:20].reveal(), sum(sum(self.intermediate.Y[0].reveal())))
print_ln(" ")
self.output.X.address = self.intermediate.Y.address
@@ -2803,14 +2779,7 @@ class BertLayer(BertBase):
if self.debug_output:
print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
# print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
# print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
print_ln("our layer output %s %s %s %s", self.output.Y.address, len(self.Y[0].reveal()), self.output.Y[0][0][0:20].reveal(), sum(sum(self.output.Y[0].reveal())))
# print_ln("shapes %s %s", self.Y.sizes, self.output.Y.sizes)
# print_ln("types %s %s %s %s %s %s", self.Y.value_type, self.output.Y.value_type, type(self.Y), type(self.output.Y), self, self.output)
print("Forward BertLayer")
def reset(self):
self.multi_head_attention.reset()
@@ -2854,10 +2823,8 @@ class BertLayer(BertBase):
self.output.nabla_Y.address = self.nabla_Y.address
self.intermediate.nabla_Y.address = self.output.nabla_X.address
self.multi_head_attention.nabla_Y.address = self.intermediate.nabla_X.address
# self.multi_head_attention.nabla_X.address = self.nabla_X.address
nabla_y_multi_head_attention_from_layernorm = self.output.backward(True, batch)
# print_ln("Backward BertLayer.output.nabla_X %s", self.output.nabla_X.reveal_nested()[:8])
self.intermediate.backward(True, batch)
# residual, add it to Y because it gave the output of multihadattention to output
@@ -2894,7 +2861,11 @@ class BertIntermediate(BertBase):
self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len)
self.activation = Gelu([n_examples, seq_len, intermediate_size])
self.dense.X.address = self.X.address
self.activation.X.address = self.dense.Y.address
self.activation.Y.address = self.Y.address
@_layer_method_call_tape
def forward(self, batch=None, training=None):
self.dense.X.address = self.X.address
self.activation.X.address = self.dense.Y.address
@@ -2904,7 +2875,7 @@ class BertIntermediate(BertBase):
if self.debug_output:
print_ln("forward layer intermediate.dense %s", self.dense.Y[0][0][0:20].reveal())
self.activation._forward(batch)
self.activation.forward(batch)
def reset(self):
self.dense.reset()
@@ -2912,8 +2883,6 @@ class BertIntermediate(BertBase):
def backward(self, compute_nabla_X=True, batch=None):
self.activation.nabla_X.alloc()
# print_ln("Backward BertIntermediate.nabla_X %s", self.nabla_X.reveal_nested()[:8])
self.activation.nabla_Y.address = self.nabla_Y.address
self.dense.nabla_Y.address = self.activation.nabla_X.address
self.dense.nabla_X.address = self.nabla_X.address
@@ -2931,13 +2900,13 @@ class BertOutput(BertBase):
input_shape = [n_examples, seq_len, intermediate_size]
output_shape = [n_examples, seq_len, hidden_size]
self.input_shape = input_shape
print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx)
super(BertOutput, self).__init__(input_shape, output_shape)
self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len)
self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx)
self.dropout = Dropout([n_examples, seq_len, hidden_size], alpha=dropout)
@_layer_method_call_tape
def forward(self, batch, input_tensor, training=False, input_tensor_batch=None):
# Because input_tensor might be the full training data shape
self.dense.X.address = self.X.address
@@ -2965,18 +2934,12 @@ class BertOutput(BertBase):
self.layer_norm.X.assign_part_vector(
self.layer_norm.X.get_part_vector(base, size) +
input_tensor.get_part_vector(base, size), base)
# if self.debug_output:
# print_ln("input tensor %s", input_tensor.reveal())
# self.layer_norm.X[:] += input_tensor[:] # TODO: is it maybe this addition since we take the last value? would be strange
if self.debug_output:
print_ln("forward layer layer_norm_add %s", self.layer_norm.X[0][0][0:20].reveal())
print_ln("")
self.layer_norm.forward(batch)
def reset(self):
self.dense.reset()
@@ -3039,6 +3002,7 @@ class MultiHeadAttention(BertBase):
self.nabla_attention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix)
self.nabla_preattention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix)
@_layer_method_call_tape
def forward(self, batch=None, hidden_state=None, training=None):
N = len(batch)
@@ -3058,32 +3022,24 @@ class MultiHeadAttention(BertBase):
inc_batch.assign(regint.inc(N))
if self.debug_output:
# print_ln('forward layer wq full %s', self.wq.X.reveal())
print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal()))
print_ln('forward layer hidden_state %s', hidden_state[0][1][0:10].reveal())
# print_ln('forward layer wv full %s', self.wv.Y.reveal())
# max_size = program.budget // self.attention_head_size
@for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads])
def _(i, j):
# for j in range(self.num_attention_heads):
query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) # this is mem inefficient?
key_sub = sfix.Matrix(self.seq_len, self.attention_head_size)
# print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:])
@for_range_opt(self.seq_len)
def _(k):
# for k in range(self.seq_len):
query_sub[k] = self.wq.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
key_sub[k] = self.wk.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
# print_ln("query_sub %s %s", i, j)
res = query_sub.direct_mul_trans(key_sub)
self.attention_scores[i].assign_part_vector(res, j)
if self.debug_output:
print_ln('forward layer attention_scores %s', self.attention_scores[0][0].reveal())
# print_ln('forward layer attention_scores full %s', self.attention_scores.reveal())
@for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads, self.seq_len])
def _(i, j, k):
@@ -3103,7 +3059,6 @@ class MultiHeadAttention(BertBase):
@for_range_opt([self.seq_len])
def _(k):
value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
# value_sub[k] = self.wv.Y[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size]
res = sfix.Matrix(self.seq_len, self.attention_head_size)
res.assign_vector(self.dropout.Y[i][j].direct_mul(value_sub))
@@ -3114,13 +3069,7 @@ class MultiHeadAttention(BertBase):
self.context[i][k].assign_part_vector(res[k],
j * self.attention_head_size
)
# for k in range(self.seq_len):
# self.context[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] = res[k * self.attention_head_size:(k + 1) * self.attention_head_size]
# How to transfer to forward?
# missing half of the values ?
# print_ln('forward layer old_context %s', self.old_context[0].get_vector().reveal())
if self.debug_output:
print_ln('forward layer multiheadattention before internal output %s', self.context[0][0][0:20].get_vector().reveal())
@@ -3132,8 +3081,6 @@ class MultiHeadAttention(BertBase):
print_ln('forward multiheadattention output %s', self.output.Y[0][0][0:20].reveal())
print_ln("")
# return context
def reset(self):
self.wq.reset()
self.wk.reset()
@@ -3189,8 +3136,6 @@ class MultiHeadAttention(BertBase):
nabla_value_sub[k],
j * self.attention_head_size)
print("RES MULTI BACK", self.dropout.Y, res, self.num_attention_heads, self.attention_head_size)
self.dropout.nabla_X.alloc()
self.dropout.backward(True, batch)
@@ -4402,10 +4347,20 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
raise CompilerError('multi-input layer %s not supported' % item)
name = type(item).__name__
if name == 'Linear':
assert mul(input_shape[1:]) == item.in_features
# Precondition: the item
assert item.bias is not None
layers.append(Dense(input_shape[0], item.in_features,
item.out_features))
if mul(input_shape[1:]) == item.in_features:
layers.append(Dense(input_shape[0], item.in_features,
item.out_features))
elif input_shape[-1] == item.in_features:
# we loop over all but last dimension
assert len(input_shape) == 3, "Dense only supports one extra dimension to loop over"
d = input_shape[1]
layers.append(Dense(input_shape[0], item.in_features,
item.out_features, d))
else:
assert False, f"input shape {input_shape} incompatible with in_features {item.in_features}"
if input_via is not None:
shapes = [x.shape for x in (layers[-1].W, layers[-1].b)]
import numpy
@@ -4476,6 +4431,15 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
input_shape = layers[-1].shape
elif name == 'ReLU' or item == torch.nn.functional.relu:
layers.append(Relu(input_shape))
elif name == 'GeLU' or item == torch.nn.functional.gelu:
layers.append(Gelu(input_shape))
elif name == 'LayerNorm':
layers.append(LayerNorm(input_shape, True, item.eps))
if input_via is not None:
layers[-1].weights = sfix.input_tensor_via(
input_via, item.weight.detach())
layers[-1].beta = sfix.input_tensor_via(
input_via, item.bias.detach())
elif name == 'Flatten':
return
elif name == 'BatchNorm2d' or name == 'BatchNorm1d':
@@ -4528,7 +4492,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
num_attention_heads = config.num_attention_heads
layernorm_eps = config.layer_norm_eps
seq_len = input_shape[1]
rsqrt_approx = False
rsqrt_approx = True
layer = BertLayer(input_shape[0], seq_len, hidden_state, intermediate_size, num_attention_heads,
layernorm_eps, 0.125, rsqrt_approx, batch_size=batch_size)
if input_via is not None:

View File

@@ -6984,7 +6984,10 @@ class SubMultiArray(_vectorizable):
:return: container of same shape and type as :py:obj:`self` """
if is_zero(other):
return self
assert self.sizes == other.sizes
if hasattr(other, 'sizes'):
assert self.sizes == other.sizes
if hasattr(other, 'size'):
assert self.total_size() == other.size
return self.from_vector(
self.sizes, self.get_vector() + other.get_vector())
@@ -7563,7 +7566,7 @@ class SubMultiArray(_vectorizable):
def __str__(self):
return '%s multi-array of lengths %s at %s' % (
self.value_type, self.sizes,
'<unallocated>' if self.array._address is None else self.address)
'<unallocated>' if self.address is None else self.address)
__repr__ = __str__
class MultiArray(SubMultiArray):

View File

@@ -20,10 +20,12 @@ from datasets import load_dataset
# Configuration
# ============================================================================
MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden)
# Model sizes: M-FAC/bert-tiny-finetuned-qnli, M-FAC/bert-mini-finetuned-qnli, yoshitomo-matsubara/bert-large-uncased-qnli
MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-mini (4 layers, 256 hidden)
MAX_LENGTH = 64 # Maximum sequence length
N_SAMPLES = 25 # Number of samples to evaluate
BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance)
LAYER_COMPARISON = True # Set to False to skip layer-by-layer comparison (saves ~95% compile time)
# GLUE task configuration
TASK_NAME = 'qnli'
@@ -269,110 +271,112 @@ print_ln("MPC-PyTorch Match: %s/%s = %s",
# ============================================================================
# Layer-by-Layer Comparison using Forward Hooks
# (Generates ~95% of compile-time instructions. Disable for faster builds.)
# ============================================================================
print_ln("\n=== Layer-by-Layer Comparison ===")
if LAYER_COMPARISON:
print_ln("\n=== Layer-by-Layer Comparison ===")
# Map to store PyTorch activations
activation_map = {}
# Map to store PyTorch activations
activation_map = {}
def get_activation(name):
"""Create a forward hook to capture layer outputs."""
def hook(model, input, output):
if isinstance(output, tuple):
actual_output = output[0]
else:
actual_output = output
activation_map[name] = actual_output.detach()
return hook
def get_activation(name):
"""Create a forward hook to capture layer outputs."""
def hook(model, input, output):
if isinstance(output, tuple):
actual_output = output[0]
else:
actual_output = output
activation_map[name] = actual_output.detach()
return hook
# Build layer comparison list
def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt):
"""Map MPC BertLayer components to PyTorch components."""
return [
(bert_layer_mpc.multi_head_attention, bert_layer_pt.attention),
(bert_layer_mpc.intermediate, bert_layer_pt.intermediate),
(bert_layer_mpc.output, bert_layer_pt.output),
(bert_layer_mpc, bert_layer_pt),
]
# Build layer comparison list
def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt):
"""Map MPC BertLayer components to PyTorch components."""
return [
(bert_layer_mpc.multi_head_attention, bert_layer_pt.attention),
(bert_layer_mpc.intermediate, bert_layer_pt.intermediate),
(bert_layer_mpc.output, bert_layer_pt.output),
(bert_layer_mpc, bert_layer_pt),
]
# Build complete layer comparison list
layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in
zip(mpc_layers[:-4], model.bert.encoder.layer)]
layers_to_compare = [x for xs in layers_to_compare for x in xs]
layers_to_compare.append((mpc_layers[-4], model.bert.pooler))
layers_to_compare.append((mpc_layers[-3], model.dropout))
layers_to_compare.append((mpc_layers[-2], model.classifier))
# Build complete layer comparison list
layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in
zip(mpc_layers[:-4], model.bert.encoder.layer)]
layers_to_compare = [x for xs in layers_to_compare for x in xs]
layers_to_compare.append((mpc_layers[-4], model.bert.pooler))
layers_to_compare.append((mpc_layers[-3], model.dropout))
layers_to_compare.append((mpc_layers[-2], model.classifier))
# Register forward hooks
for layer_id, (_, pt_layer) in enumerate(layers_to_compare):
pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}'))
# Register forward hooks
for layer_id, (_, pt_layer) in enumerate(layers_to_compare):
pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}'))
# Run PyTorch forward pass to populate activation_map
print("Capturing PyTorch layer outputs...")
with torch.no_grad():
for i in range(N_SAMPLES):
activation_map.clear() # Clear for each sample
# Run PyTorch forward pass to populate activation_map
print("Capturing PyTorch layer outputs...")
with torch.no_grad():
for i in range(N_SAMPLES):
activation_map.clear() # Clear for each sample
# Get sample embedding
with embedded_data.formatted_as("torch", ["embedding"]):
sample_embedding = embedded_data[i]['embedding'].unsqueeze(0)
# Get sample embedding
with embedded_data.formatted_as("torch", ["embedding"]):
sample_embedding = embedded_data[i]['embedding'].unsqueeze(0)
# Run forward through wrapped model
_ = bert_wrapped(sample_embedding)
# Run forward through wrapped model
_ = bert_wrapped(sample_embedding)
# Store activations for this sample
if i == 0: # Only compare first sample to save time
break
# Store activations for this sample
if i == 0: # Only compare first sample to save time
break
print(f"Captured {len(activation_map)} layer outputs from PyTorch")
print(f"Captured {len(activation_map)} layer outputs from PyTorch")
# Run MPC forward pass using reveal_correctness
import numpy
pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities]))
pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor)
# Run MPC forward pass using reveal_correctness
import numpy
pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities]))
pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor)
test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:]))
test_embeddings_one.assign(test_embeddings.get_part_vector(0))
test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:]))
test_embeddings_one.assign(test_embeddings.get_part_vector(0))
pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:]))
pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0))
pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:]))
pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0))
print_ln("Running MPC forward pass for layer comparison...")
_ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE)
print_ln("Running MPC forward pass for layer comparison...")
_ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE)
# Compare layers
print_ln("\nLayer-by-layer comparison (Sample 0 only):")
print_ln("=" * 100)
# Compare layers
print_ln("\nLayer-by-layer comparison (Sample 0 only):")
print_ln("=" * 100)
for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare):
layer_id = f"{idx}.{type(pt_layer).__name__}"
for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare):
layer_id = f"{idx}.{type(pt_layer).__name__}"
if layer_id not in activation_map:
continue
if layer_id not in activation_map:
continue
# Skip dropout layers since they use different random masks
if 'Dropout' in type(pt_layer).__name__:
print_ln("%s | Skipped (dropout)", layer_id)
continue
# Skip dropout layers since they use different random masks
if 'Dropout' in type(pt_layer).__name__:
print_ln("%s | Skipped (dropout)", layer_id)
continue
# Get PyTorch values
pt_values = activation_map[layer_id]
pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal()
# Get PyTorch values
pt_values = activation_map[layer_id]
pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal()
# Get MPC values
mpc_output = mpc_layer.Y[0].get_vector().reveal()
# Get MPC values
mpc_output = mpc_layer.Y[0].get_vector().reveal()
# Compute detailed statistics
total_abs_diff = sum(abs(pt_at_runtime - mpc_output))
pt_magnitude = sum(abs(pt_at_runtime))
# Compute detailed statistics
total_abs_diff = sum(abs(pt_at_runtime - mpc_output))
pt_magnitude = sum(abs(pt_at_runtime))
# Print layer comparison
print_ln("\n%s", layer_id)
print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime))
print_ln(" Total Abs Diff: %s", total_abs_diff)
print_ln(" PT Total Magnitude: %s", pt_magnitude)
print_ln(" First 8 PT: %s", pt_at_runtime[:8])
print_ln(" First 8 MPC: %s", mpc_output[:8])
# Print layer comparison
print_ln("\n%s", layer_id)
print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime))
print_ln(" Total Abs Diff: %s", total_abs_diff)
print_ln(" PT Total Magnitude: %s", pt_magnitude)
print_ln(" First 8 PT: %s", pt_at_runtime[:8])
print_ln(" First 8 MPC: %s", mpc_output[:8])
print_ln("\n=== Inference Complete ===")