avg diff, dropout

This commit is contained in:
Hidde L
2025-10-14 21:43:18 +02:00
parent f8139db805
commit d0f955c2d3
2 changed files with 11 additions and 28 deletions

View File

@@ -1285,14 +1285,14 @@ class FlexDropout(NoVariableLayer):
n_bits = -math.log(self.alpha, 2)
assert n_bits == int(n_bits)
n_bits = int(n_bits)
self.B.assign_all(1)
self.alpha = 0.0 # TODO: temp disable for reproducibility
# @for_range_opt_multithread(self.n_threads, len(batch))
# def _(i):
# size = reduce(operator.mul, self.shape[1:])
# self.B[i].assign_vector(util.tree_reduce(
# util.or_op, (sint.get_random_bit(size=size)
# for i in range(n_bits))))
# self.B.assign_all(1)
# self.alpha = 0.0 # TODO: temp disable for reproducibility
@for_range_opt_multithread(self.n_threads, len(batch))
def _(i):
size = reduce(operator.mul, self.shape[1:])
self.B[i].assign_vector(util.tree_reduce(
util.or_op, (sint.get_random_bit(size=size)
for i in range(n_bits))))
@for_range_opt_multithread(self.n_threads, len(batch))
def _(i):
self.Y[i].assign_vector(1 / (1 - self.alpha) *
@@ -2889,7 +2889,6 @@ class BertPooler(BertBase):
# batch contains [n_batch, n_heads, n_dim]
@for_range(len(batch))
def _(j):
print_ln("Pooling %s %s", j, batch[j])
self.dense.X[j][:] = self.X[batch[j]][0][:]
# if self.debug_output:
@@ -3266,8 +3265,6 @@ class MultiHeadAttention(BertBase):
inc_batch = regint.Array(N)
inc_batch.assign(regint.inc(N))
print_ln("post forward")
if self.debug_output:
# print_ln('forward layer wq full %s', self.wq.X.reveal())
print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal()))

View File

@@ -1,10 +1,6 @@
"""
BERT Inference in MP-SPDZ
This program demonstrates secure multi-party computation (MPC) inference using a
pre-trained BERT model for sequence classification. It compares PyTorch and MP-SPDZ
implementations layer-by-layer and computes accuracy on the QNLI task from GLUE benchmark.
The program:
1. Loads a pre-trained BERT-tiny model fine-tuned on QNLI
2. Converts it to MP-SPDZ representation using ml.layers_from_torch()
@@ -12,14 +8,6 @@ The program:
4. Compares MP-SPDZ outputs with PyTorch outputs layer-by-layer
5. Computes and reports accuracy
Usage:
./Scripts/compile-run.py -E replicated-ring bert_inference
Configuration:
- MODEL_NAME: HuggingFace model identifier
- MAX_LENGTH: Maximum sequence length for tokenization
- N_SAMPLES: Number of validation samples to run
- BATCH_SIZE: Batch size for MP-SPDZ inference
"""
import ml
@@ -34,7 +22,7 @@ from datasets import load_dataset
MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden)
MAX_LENGTH = 64 # Maximum sequence length
N_SAMPLES = 10 # Number of samples to evaluate
N_SAMPLES = 1 # Number of samples to evaluate
BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance)
# GLUE task configuration
@@ -141,8 +129,6 @@ print(f"PyTorch accuracy: {pt_accuracy:.4f} ({sum(p == l for p, l in zip(pt_pred
# MP-SPDZ Model Conversion
# ============================================================================
print("\nConverting BERT model to MP-SPDZ...")
class BertEncoderWithHead(nn.Module):
"""Wrapper combining BERT encoder, pooler, dropout, and classification head."""
@@ -278,7 +264,7 @@ print_ln("\n=== Results Summary ===")
print_ln("PyTorch Accuracy: %s", pt_accuracy)
print_ln("MP-SPDZ Correct: %s/%s", n_correct.read(), N_SAMPLES)
print_ln("MP-SPDZ Accuracy: %s", mpc_accuracy.reveal())
print_ln("MPC-PyTorch Agreement: %s/%s = %s",
print_ln("MPC-PyTorch Match: %s/%s = %s",
n_mpc_matches_pytorch.read(), N_SAMPLES, match_rate.reveal())
# ============================================================================
@@ -382,7 +368,7 @@ for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare):
diff = sum(abs(pt_at_runtime - mpc_output))
# Print layer comparison with first 8 values
print_ln("%s | Diff: %s", layer_id, diff)
print_ln("%s | Avg. Diff: %s", layer_id, diff / sum(pt_values.shape))
print_ln(" PyTorch: %s", pt_at_runtime[:8])
print_ln(" MP-SPDZ: %s", mpc_output[:8])