mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
avg diff, dropout
This commit is contained in:
@@ -1285,14 +1285,14 @@ class FlexDropout(NoVariableLayer):
|
||||
n_bits = -math.log(self.alpha, 2)
|
||||
assert n_bits == int(n_bits)
|
||||
n_bits = int(n_bits)
|
||||
self.B.assign_all(1)
|
||||
self.alpha = 0.0 # TODO: temp disable for reproducibility
|
||||
# @for_range_opt_multithread(self.n_threads, len(batch))
|
||||
# def _(i):
|
||||
# size = reduce(operator.mul, self.shape[1:])
|
||||
# self.B[i].assign_vector(util.tree_reduce(
|
||||
# util.or_op, (sint.get_random_bit(size=size)
|
||||
# for i in range(n_bits))))
|
||||
# self.B.assign_all(1)
|
||||
# self.alpha = 0.0 # TODO: temp disable for reproducibility
|
||||
@for_range_opt_multithread(self.n_threads, len(batch))
|
||||
def _(i):
|
||||
size = reduce(operator.mul, self.shape[1:])
|
||||
self.B[i].assign_vector(util.tree_reduce(
|
||||
util.or_op, (sint.get_random_bit(size=size)
|
||||
for i in range(n_bits))))
|
||||
@for_range_opt_multithread(self.n_threads, len(batch))
|
||||
def _(i):
|
||||
self.Y[i].assign_vector(1 / (1 - self.alpha) *
|
||||
@@ -2889,7 +2889,6 @@ class BertPooler(BertBase):
|
||||
# batch contains [n_batch, n_heads, n_dim]
|
||||
@for_range(len(batch))
|
||||
def _(j):
|
||||
print_ln("Pooling %s %s", j, batch[j])
|
||||
self.dense.X[j][:] = self.X[batch[j]][0][:]
|
||||
|
||||
# if self.debug_output:
|
||||
@@ -3266,8 +3265,6 @@ class MultiHeadAttention(BertBase):
|
||||
inc_batch = regint.Array(N)
|
||||
inc_batch.assign(regint.inc(N))
|
||||
|
||||
print_ln("post forward")
|
||||
|
||||
if self.debug_output:
|
||||
# print_ln('forward layer wq full %s', self.wq.X.reveal())
|
||||
print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal()))
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
"""
|
||||
BERT Inference in MP-SPDZ
|
||||
|
||||
This program demonstrates secure multi-party computation (MPC) inference using a
|
||||
pre-trained BERT model for sequence classification. It compares PyTorch and MP-SPDZ
|
||||
implementations layer-by-layer and computes accuracy on the QNLI task from GLUE benchmark.
|
||||
|
||||
The program:
|
||||
1. Loads a pre-trained BERT-tiny model fine-tuned on QNLI
|
||||
2. Converts it to MP-SPDZ representation using ml.layers_from_torch()
|
||||
@@ -12,14 +8,6 @@ The program:
|
||||
4. Compares MP-SPDZ outputs with PyTorch outputs layer-by-layer
|
||||
5. Computes and reports accuracy
|
||||
|
||||
Usage:
|
||||
./Scripts/compile-run.py -E replicated-ring bert_inference
|
||||
|
||||
Configuration:
|
||||
- MODEL_NAME: HuggingFace model identifier
|
||||
- MAX_LENGTH: Maximum sequence length for tokenization
|
||||
- N_SAMPLES: Number of validation samples to run
|
||||
- BATCH_SIZE: Batch size for MP-SPDZ inference
|
||||
"""
|
||||
|
||||
import ml
|
||||
@@ -34,7 +22,7 @@ from datasets import load_dataset
|
||||
|
||||
MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden)
|
||||
MAX_LENGTH = 64 # Maximum sequence length
|
||||
N_SAMPLES = 10 # Number of samples to evaluate
|
||||
N_SAMPLES = 1 # Number of samples to evaluate
|
||||
BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance)
|
||||
|
||||
# GLUE task configuration
|
||||
@@ -141,8 +129,6 @@ print(f"PyTorch accuracy: {pt_accuracy:.4f} ({sum(p == l for p, l in zip(pt_pred
|
||||
# MP-SPDZ Model Conversion
|
||||
# ============================================================================
|
||||
|
||||
print("\nConverting BERT model to MP-SPDZ...")
|
||||
|
||||
|
||||
class BertEncoderWithHead(nn.Module):
|
||||
"""Wrapper combining BERT encoder, pooler, dropout, and classification head."""
|
||||
@@ -278,7 +264,7 @@ print_ln("\n=== Results Summary ===")
|
||||
print_ln("PyTorch Accuracy: %s", pt_accuracy)
|
||||
print_ln("MP-SPDZ Correct: %s/%s", n_correct.read(), N_SAMPLES)
|
||||
print_ln("MP-SPDZ Accuracy: %s", mpc_accuracy.reveal())
|
||||
print_ln("MPC-PyTorch Agreement: %s/%s = %s",
|
||||
print_ln("MPC-PyTorch Match: %s/%s = %s",
|
||||
n_mpc_matches_pytorch.read(), N_SAMPLES, match_rate.reveal())
|
||||
|
||||
# ============================================================================
|
||||
@@ -382,7 +368,7 @@ for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare):
|
||||
diff = sum(abs(pt_at_runtime - mpc_output))
|
||||
|
||||
# Print layer comparison with first 8 values
|
||||
print_ln("%s | Diff: %s", layer_id, diff)
|
||||
print_ln("%s | Avg. Diff: %s", layer_id, diff / sum(pt_values.shape))
|
||||
print_ln(" PyTorch: %s", pt_at_runtime[:8])
|
||||
print_ln(" MP-SPDZ: %s", mpc_output[:8])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user