Merge pull request #1736 from pps-lab/bert

Add BERT to ML Library
This commit is contained in:
Marcel Keller
2025-11-07 15:06:37 +11:00
committed by GitHub
2 changed files with 1615 additions and 84 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,378 @@
"""
BERT Inference in MP-SPDZ
The program:
1. Loads a pre-trained BERT-tiny model fine-tuned on QNLI
2. Converts it to MP-SPDZ representation using ml.layers_from_torch()
3. Runs inference on N samples from the validation set
4. Compares MP-SPDZ outputs with PyTorch outputs layer-by-layer
5. Computes and reports accuracy
"""
import ml
import torch
import torch.nn as nn
from transformers import BertForSequenceClassification, BertTokenizer
from datasets import load_dataset
# ============================================================================
# Configuration
# ============================================================================
MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden)
MAX_LENGTH = 64 # Maximum sequence length
N_SAMPLES = 25 # Number of samples to evaluate
BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance)
# GLUE task configuration
TASK_NAME = 'qnli'
TASK_KEYS = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
# ============================================================================
# Model Loading and Data Preparation
# ============================================================================
print(f"Loading model: {MODEL_NAME}")
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
print(f"Loading {TASK_NAME} dataset from GLUE benchmark")
dataset = load_dataset('glue', TASK_NAME)
validation = dataset['validation'].take(N_SAMPLES)
print(f"Configuration:")
print(f" Model: {MODEL_NAME}")
print(f" Task: {TASK_NAME}")
print(f" Samples: {N_SAMPLES}")
print(f" Max length: {MAX_LENGTH}")
print(f" Batch size: {BATCH_SIZE}")
print(f" Model architecture: {model.config.num_hidden_layers} layers, "
f"{model.config.hidden_size} hidden size")
def tokenize_dataset(example):
"""Tokenize dataset examples based on task configuration."""
sentence1_key, sentence2_key = TASK_KEYS[TASK_NAME]
args = (
(example[sentence1_key],) if sentence2_key is None
else (example[sentence1_key], example[sentence2_key])
)
return tokenizer(*args, truncation=True, padding='max_length', max_length=MAX_LENGTH)
def embed_inputs(example):
"""Convert tokenized inputs to BERT embeddings."""
input_ids = torch.tensor(example["input_ids"])
token_type_ids = torch.tensor(example["token_type_ids"])
embedding = model.bert.embeddings(input_ids, token_type_ids=token_type_ids).detach()
return {'embedding': embedding}
# Tokenize and embed the validation data
print("Tokenizing and embedding validation data...")
tokenized_data = validation.map(tokenize_dataset, batched=True)
embedded_data = tokenized_data.map(embed_inputs, batched=True)
# ============================================================================
# PyTorch Inference (Ground Truth)
# ============================================================================
print("\nRunning PyTorch inference for ground truth...")
def run_pytorch_inference(model, dataset, n_samples):
"""Run inference using PyTorch and collect predictions."""
model.eval()
predictions = []
probabilities = []
labels = []
with torch.no_grad():
for i in range(n_samples):
example = dataset[i]
inputs = {
key: torch.tensor([val])
for key, val in example.items()
if key in ['input_ids', 'attention_mask', 'token_type_ids']
}
print("PT Inputs", inputs)
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
predicted = torch.argmax(logits, dim=-1).item()
predictions.append(predicted)
probabilities.append(probs.detach())
labels.append(example['label'])
return predictions, probabilities, labels
pt_predictions, pt_probabilities, true_labels = run_pytorch_inference(model, tokenized_data, N_SAMPLES)
pt_accuracy = sum(p == l for p, l in zip(pt_predictions, true_labels)) / len(true_labels)
print(f"PyTorch accuracy: {pt_accuracy:.4f} ({sum(p == l for p, l in zip(pt_predictions, true_labels))}/{len(true_labels)})")
# ============================================================================
# MP-SPDZ Model Conversion
# ============================================================================
class BertEncoderWithHead(nn.Module):
"""Wrapper combining BERT encoder, pooler, dropout, and classification head."""
def __init__(self, encoder, pooler, dropout, classifier, config):
super().__init__()
self.encoder = encoder
self.pooler = pooler
self.dropout = dropout
self.classifier = classifier
self.config = config
def forward(self, hidden_states, attention_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_attention_mask=None,
past_key_values=None, use_cache=None, output_attentions=False,
output_hidden_states=False, return_dict=False):
"""Forward pass through encoder, pooler, dropout, and classifier."""
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
# Build MPC-compatible input tensors
def build_mpc_tensors(dataset):
"""Convert dataset to MPC-compatible sfix tensors."""
with dataset.formatted_as("torch", ["embedding", "label"]):
embeddings = torch.concat(list(map(lambda x: x['embedding'], dataset.iter(batch_size=1))))
labels = torch.tensor([x['label'] for x in dataset.iter(batch_size=1)])
# One-hot encode labels (2 classes for QNLI)
labels_onehot = torch.nn.functional.one_hot(labels, num_classes=2)
embeddings_sfix = sfix.input_tensor_via(0, embeddings.numpy())
labels_sfix = sfix.input_tensor_via(0, labels_onehot.numpy())
return embeddings_sfix, labels_sfix, labels.numpy()
test_embeddings, test_labels_onehot, test_labels = build_mpc_tensors(embedded_data)
model_shape = test_embeddings.shape
print(f"Input shape: {model_shape}")
# Wrap model for conversion
bert_wrapped = BertEncoderWithHead(
model.bert.encoder,
model.bert.pooler,
model.dropout,
model.classifier,
model.config
)
# Convert to MP-SPDZ layers
print("Tracing model and converting to MP-SPDZ layers...")
mpc_layers = ml.layers_from_torch(bert_wrapped, model_shape, input_via=0, batch_size=BATCH_SIZE)
print(f"MP-SPDZ model: {len(mpc_layers)} top-level layers")
for i, layer in enumerate(mpc_layers):
print(f" Layer {i}: {layer}")
# ============================================================================
# MP-SPDZ Inference
# ============================================================================
print("\nRunning MP-SPDZ inference...")
# Configure fixed-point arithmetic
sfix.round_nearest = False
program.use_trunc_pr = False
# Create optimizer (used for forward pass)
optimizer = ml.SGD(mpc_layers)
# Run inference using optimizer.eval() to get predictions
print_ln("\n=== Starting MPC Inference ===")
print_ln("Samples: %s", N_SAMPLES)
print_ln("Batch size: %s", BATCH_SIZE)
# Convert Python lists to compile-time constants
pt_preds_list = [int(p) for p in pt_predictions]
true_labels_list = [int(l) for l in true_labels]
# Use optimizer.eval() to get MPC predictions (argmax)
print_ln("Running MPC inference...")
mpc_predictions = optimizer.eval(test_embeddings, batch_size=BATCH_SIZE, top=True)
print_ln("\n=== Per-Sample Comparison ===")
print_ln("Sample | True Label | PyTorch Pred | MPC Pred | PT Correct | MPC Correct | Match")
print_ln("-" * 80)
# Track statistics
n_correct = MemValue(regint(0))
n_mpc_matches_pytorch = MemValue(regint(0))
# Use regular Python for loop to access compile-time constants
for i in range(N_SAMPLES):
# Get predictions
mpc_pred = mpc_predictions[i].reveal()
true_label = true_labels_list[i]
pt_pred = pt_preds_list[i]
# Check correctness
mpc_correct = cint(mpc_pred == true_label)
pt_correct = cint(pt_pred == true_label)
predictions_match = cint(mpc_pred == pt_pred)
# Update statistics
n_correct.iadd(mpc_correct)
n_mpc_matches_pytorch.iadd(predictions_match)
# Print per-sample results
print_ln("%s | %s | %s | %s | %s | %s | %s",
i, true_label, pt_pred, mpc_pred,
pt_correct, mpc_correct, predictions_match)
# Compute final statistics
mpc_accuracy = cfix(n_correct.read(), k=63, f=31) / N_SAMPLES
match_rate = cfix(n_mpc_matches_pytorch.read(), k=63, f=31) / N_SAMPLES
print_ln("\n=== Results Summary ===")
print_ln("PyTorch Accuracy: %s", pt_accuracy)
print_ln("MP-SPDZ Correct: %s/%s", n_correct.read(), N_SAMPLES)
print_ln("MP-SPDZ Accuracy: %s", mpc_accuracy.reveal())
print_ln("MPC-PyTorch Match: %s/%s = %s",
n_mpc_matches_pytorch.read(), N_SAMPLES, match_rate.reveal())
# ============================================================================
# Layer-by-Layer Comparison using Forward Hooks
# ============================================================================
print_ln("\n=== Layer-by-Layer Comparison ===")
# Map to store PyTorch activations
activation_map = {}
def get_activation(name):
"""Create a forward hook to capture layer outputs."""
def hook(model, input, output):
if isinstance(output, tuple):
actual_output = output[0]
else:
actual_output = output
activation_map[name] = actual_output.detach()
return hook
# Build layer comparison list
def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt):
"""Map MPC BertLayer components to PyTorch components."""
return [
(bert_layer_mpc.multi_head_attention, bert_layer_pt.attention),
(bert_layer_mpc.intermediate, bert_layer_pt.intermediate),
(bert_layer_mpc.output, bert_layer_pt.output),
(bert_layer_mpc, bert_layer_pt),
]
# Build complete layer comparison list
layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in
zip(mpc_layers[:-4], model.bert.encoder.layer)]
layers_to_compare = [x for xs in layers_to_compare for x in xs]
layers_to_compare.append((mpc_layers[-4], model.bert.pooler))
layers_to_compare.append((mpc_layers[-3], model.dropout))
layers_to_compare.append((mpc_layers[-2], model.classifier))
# Register forward hooks
for layer_id, (_, pt_layer) in enumerate(layers_to_compare):
pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}'))
# Run PyTorch forward pass to populate activation_map
print("Capturing PyTorch layer outputs...")
with torch.no_grad():
for i in range(N_SAMPLES):
activation_map.clear() # Clear for each sample
# Get sample embedding
with embedded_data.formatted_as("torch", ["embedding"]):
sample_embedding = embedded_data[i]['embedding'].unsqueeze(0)
# Run forward through wrapped model
_ = bert_wrapped(sample_embedding)
# Store activations for this sample
if i == 0: # Only compare first sample to save time
break
print(f"Captured {len(activation_map)} layer outputs from PyTorch")
# Run MPC forward pass using reveal_correctness
import numpy
pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities]))
pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor)
test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:]))
test_embeddings_one.assign(test_embeddings.get_part_vector(0))
pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:]))
pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0))
print_ln("Running MPC forward pass for layer comparison...")
_ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE)
# Compare layers
print_ln("\nLayer-by-layer comparison (Sample 0 only):")
print_ln("=" * 100)
for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare):
layer_id = f"{idx}.{type(pt_layer).__name__}"
if layer_id not in activation_map:
continue
# Skip dropout layers since they use different random masks
if 'Dropout' in type(pt_layer).__name__:
print_ln("%s | Skipped (dropout)", layer_id)
continue
# Get PyTorch values
pt_values = activation_map[layer_id]
pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal()
# Get MPC values
mpc_output = mpc_layer.Y[0].get_vector().reveal()
# Compute detailed statistics
total_abs_diff = sum(abs(pt_at_runtime - mpc_output))
pt_magnitude = sum(abs(pt_at_runtime))
# Print layer comparison
print_ln("\n%s", layer_id)
print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime))
print_ln(" Total Abs Diff: %s", total_abs_diff)
print_ln(" PT Total Magnitude: %s", pt_magnitude)
print_ln(" First 8 PT: %s", pt_at_runtime[:8])
print_ln(" First 8 MPC: %s", mpc_output[:8])
print_ln("\n=== Inference Complete ===")