mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-05-13 03:00:24 -04:00
- Add layers like Gelu and LayerNorm explicitly to layers_from_torch so they are more generally usable (beyond BertLayer)
- Fix Dense.__repr__ to include the missing d (sequence length) parameter - Default LayerNorm and BertLayer to use approximate rsqrt (approx=True) for better MPC performance - Remove debug prints - Fix SubMultiArray.__add__ to handle addition with plain arrays (check for sizes/size attributes) - Fix SubMultiArray.__str__ to use self.address instead of self.array._address to avoid attribute errors - Add LAYER_COMPARISON flag to bert_inference.mpc to skip the expensive layer-by-layer comparison section (~95% of compile time) by default
This commit is contained in:
118
Compiler/ml.py
118
Compiler/ml.py
@@ -866,8 +866,8 @@ class Dense(DenseBase):
|
||||
self.f_input = self.Y
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%s, %s, %s, activation=%s)' % \
|
||||
(type(self).__name__, self.N, self.d_in,
|
||||
return '%s(%s, %s, %s, %s, activation=%s)' % \
|
||||
(type(self).__name__, self.N, self.d_in, self.d,
|
||||
self.d_out, repr(self.activation))
|
||||
|
||||
def reset(self):
|
||||
@@ -1750,7 +1750,7 @@ class LayerNorm(Layer): # Changed class name
|
||||
thetas = lambda self: (self.weights, self.bias)
|
||||
nablas = lambda self: (self.nabla_weights, self.nabla_bias)
|
||||
|
||||
def __init__(self, shape, approx=False, layernorm_eps=None, args=None):
|
||||
def __init__(self, shape, approx=True, layernorm_eps=None, args=None):
|
||||
if len(shape) == 2:
|
||||
shape = [shape[0], 1, shape[1]] # Not sure why this extra dimension is added
|
||||
tensors = (Tensor(shape, sfix) for i in range(4))
|
||||
@@ -1805,6 +1805,7 @@ class LayerNorm(Layer): # Changed class name
|
||||
tmp = self.weights[:] * (sel_X[:] - mu_sel) * fac_sel # Removed self.mu reference
|
||||
sel_Y[:] = self.bias[:] + tmp
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch, training=False):
|
||||
d = self.X.sizes[1]
|
||||
d_in = self.X.sizes[2]
|
||||
@@ -2660,7 +2661,7 @@ class BertBase(BaseLayer, FixBase):
|
||||
class BertPooler(BertBase):
|
||||
|
||||
thetas = lambda self: self.dense.thetas()
|
||||
nablas = lambda self: self.dense.nablas() # refer to downstream layers?
|
||||
nablas = lambda self: self.dense.nablas()
|
||||
|
||||
def __init__(self, n_examples, seq_len, hidden_state):
|
||||
input_shape = [n_examples, seq_len, hidden_state]
|
||||
@@ -2669,28 +2670,19 @@ class BertPooler(BertBase):
|
||||
self.dense = Dense(n_examples, hidden_state, hidden_state)
|
||||
self.activation = Tanh(output_shape)
|
||||
|
||||
self.d_out = hidden_state
|
||||
|
||||
|
||||
def _forward(self, batch):
|
||||
# self.dense.X.address = self.X.address
|
||||
self.activation.X.address = self.dense.Y.address
|
||||
self.activation.Y.address = self.Y.address
|
||||
|
||||
# grab the first repr?
|
||||
self.d_out = hidden_state
|
||||
|
||||
def _forward(self, batch):
|
||||
# batch contains [n_batch, n_heads, n_dim]
|
||||
@for_range(len(batch))
|
||||
def _(j):
|
||||
self.dense.X[j][:] = self.X[batch[j]][0][:]
|
||||
|
||||
# if self.debug_output:
|
||||
# print_ln("forward layer pooler.dense X %s", self.dense.X.reveal_nested())
|
||||
|
||||
self.dense.forward(batch)
|
||||
# print_ln("LINEAR Layer weights after bertpooler.dense: %s", self.opt.layers[-2].W.reveal_nested())
|
||||
|
||||
self.activation._forward(batch)
|
||||
# print_ln("LINEAR Layer weights after bertpooler.activation: %s", self.opt.layers[-2].W.reveal_nested())
|
||||
self.activation.forward(batch)
|
||||
|
||||
def reset(self):
|
||||
self.dense.reset()
|
||||
@@ -2757,36 +2749,21 @@ class BertLayer(BertBase):
|
||||
self.intermediate = BertIntermediate(internal_shape, hidden_state, intermediate_size, seq_len)
|
||||
self.output = BertOutput(internal_shape, intermediate_size, hidden_state, seq_len, dropout, layernorm_eps, rsqrt_approx)
|
||||
|
||||
self.hidden_state = sfix.Tensor(input_shape) # TODO: Could also make this smaller
|
||||
# self.nabla_hidden_state = sfix.Tensor(input_shape)
|
||||
# self.nabla_hidden_state.alloc()
|
||||
|
||||
# self.X.address = self.multi_head_attention.X.address
|
||||
# self.Y.address = self.output.Y.address
|
||||
|
||||
self.d_out = hidden_state
|
||||
|
||||
print("Init BertLayer", input_shape, output_shape)
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch, training=False):
|
||||
if batch is None:
|
||||
batch = Array.create_from(regint(0))
|
||||
|
||||
self.multi_head_attention._X.address = self.X.address
|
||||
self.output.Y.address = self.Y.address
|
||||
self.hidden_state.address = self.X.address
|
||||
# self.multi_head_attention.Y.address = self.Y.address
|
||||
|
||||
self.multi_head_attention.forward(batch, self.hidden_state, training)
|
||||
# if self.debug_output:
|
||||
# print_ln("our layer X %s %s", self.X[0][0][0].reveal(), self.output.X[0][0][0].reveal())
|
||||
|
||||
self.multi_head_attention.forward(batch, self.X, training)
|
||||
if self.debug_output:
|
||||
print_ln("forward layer multi_head_attention %s %s", self.multi_head_attention.Y[0][1][0].reveal(), sum(sum(self.multi_head_attention.Y[0].reveal())))
|
||||
# print_ln("forward layer multi_head_attention full %s", self.multi_head_attention.Y.reveal())
|
||||
|
||||
print("Forward Attention")
|
||||
|
||||
batch_inc = regint.Array(len(batch))
|
||||
batch_inc.assign(regint.inc(len(batch)))
|
||||
self.intermediate.X.address = self.multi_head_attention.Y.address
|
||||
@@ -2794,7 +2771,6 @@ class BertLayer(BertBase):
|
||||
|
||||
if self.debug_output:
|
||||
print_ln("forward layer intermediate %s %s %s", self.intermediate.Y.shape, self.intermediate.Y[0][1][0:20].reveal(), sum(sum(self.intermediate.Y[0].reveal())))
|
||||
|
||||
print_ln(" ")
|
||||
|
||||
self.output.X.address = self.intermediate.Y.address
|
||||
@@ -2803,14 +2779,7 @@ class BertLayer(BertBase):
|
||||
|
||||
if self.debug_output:
|
||||
print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
|
||||
# print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
|
||||
# print_ln("our output %s %s %s %s", self.Y.address, len(self.Y[0].reveal()), self.Y[0][0][0:20].reveal(), sum(sum(self.Y[0].reveal())))
|
||||
|
||||
print_ln("our layer output %s %s %s %s", self.output.Y.address, len(self.Y[0].reveal()), self.output.Y[0][0][0:20].reveal(), sum(sum(self.output.Y[0].reveal())))
|
||||
# print_ln("shapes %s %s", self.Y.sizes, self.output.Y.sizes)
|
||||
# print_ln("types %s %s %s %s %s %s", self.Y.value_type, self.output.Y.value_type, type(self.Y), type(self.output.Y), self, self.output)
|
||||
|
||||
print("Forward BertLayer")
|
||||
|
||||
def reset(self):
|
||||
self.multi_head_attention.reset()
|
||||
@@ -2854,10 +2823,8 @@ class BertLayer(BertBase):
|
||||
self.output.nabla_Y.address = self.nabla_Y.address
|
||||
self.intermediate.nabla_Y.address = self.output.nabla_X.address
|
||||
self.multi_head_attention.nabla_Y.address = self.intermediate.nabla_X.address
|
||||
# self.multi_head_attention.nabla_X.address = self.nabla_X.address
|
||||
|
||||
nabla_y_multi_head_attention_from_layernorm = self.output.backward(True, batch)
|
||||
# print_ln("Backward BertLayer.output.nabla_X %s", self.output.nabla_X.reveal_nested()[:8])
|
||||
self.intermediate.backward(True, batch)
|
||||
|
||||
# residual, add it to Y because it gave the output of multihadattention to output
|
||||
@@ -2894,7 +2861,11 @@ class BertIntermediate(BertBase):
|
||||
self.dense = Dense(n_examples, hidden_size, intermediate_size, seq_len)
|
||||
self.activation = Gelu([n_examples, seq_len, intermediate_size])
|
||||
|
||||
self.dense.X.address = self.X.address
|
||||
self.activation.X.address = self.dense.Y.address
|
||||
self.activation.Y.address = self.Y.address
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch=None, training=None):
|
||||
self.dense.X.address = self.X.address
|
||||
self.activation.X.address = self.dense.Y.address
|
||||
@@ -2904,7 +2875,7 @@ class BertIntermediate(BertBase):
|
||||
if self.debug_output:
|
||||
print_ln("forward layer intermediate.dense %s", self.dense.Y[0][0][0:20].reveal())
|
||||
|
||||
self.activation._forward(batch)
|
||||
self.activation.forward(batch)
|
||||
|
||||
def reset(self):
|
||||
self.dense.reset()
|
||||
@@ -2912,8 +2883,6 @@ class BertIntermediate(BertBase):
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
self.activation.nabla_X.alloc()
|
||||
|
||||
# print_ln("Backward BertIntermediate.nabla_X %s", self.nabla_X.reveal_nested()[:8])
|
||||
|
||||
self.activation.nabla_Y.address = self.nabla_Y.address
|
||||
self.dense.nabla_Y.address = self.activation.nabla_X.address
|
||||
self.dense.nabla_X.address = self.nabla_X.address
|
||||
@@ -2931,13 +2900,13 @@ class BertOutput(BertBase):
|
||||
input_shape = [n_examples, seq_len, intermediate_size]
|
||||
output_shape = [n_examples, seq_len, hidden_size]
|
||||
self.input_shape = input_shape
|
||||
print("INSTANTIATING BERTOUTPUT with ", input_shape, output_shape, intermediate_size, hidden_size, rsqrt_approx)
|
||||
super(BertOutput, self).__init__(input_shape, output_shape)
|
||||
self.dense = Dense(n_examples, intermediate_size, hidden_size, seq_len)
|
||||
self.layer_norm = LayerNorm(output_shape, layernorm_eps=layernorm_eps, approx=rsqrt_approx)
|
||||
self.dropout = Dropout([n_examples, seq_len, hidden_size], alpha=dropout)
|
||||
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch, input_tensor, training=False, input_tensor_batch=None):
|
||||
# Because input_tensor might be the full training data shape
|
||||
self.dense.X.address = self.X.address
|
||||
@@ -2965,18 +2934,12 @@ class BertOutput(BertBase):
|
||||
self.layer_norm.X.assign_part_vector(
|
||||
self.layer_norm.X.get_part_vector(base, size) +
|
||||
input_tensor.get_part_vector(base, size), base)
|
||||
# if self.debug_output:
|
||||
# print_ln("input tensor %s", input_tensor.reveal())
|
||||
|
||||
# self.layer_norm.X[:] += input_tensor[:] # TODO: is it maybe this addition since we take the last value? would be strange
|
||||
|
||||
if self.debug_output:
|
||||
print_ln("forward layer layer_norm_add %s", self.layer_norm.X[0][0][0:20].reveal())
|
||||
print_ln("")
|
||||
self.layer_norm.forward(batch)
|
||||
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.dense.reset()
|
||||
|
||||
@@ -3039,6 +3002,7 @@ class MultiHeadAttention(BertBase):
|
||||
self.nabla_attention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix)
|
||||
self.nabla_preattention_scores = MultiArray([internal_shape, self.num_attention_heads, self.seq_len, self.seq_len], sfix)
|
||||
|
||||
@_layer_method_call_tape
|
||||
def forward(self, batch=None, hidden_state=None, training=None):
|
||||
N = len(batch)
|
||||
|
||||
@@ -3058,32 +3022,24 @@ class MultiHeadAttention(BertBase):
|
||||
inc_batch.assign(regint.inc(N))
|
||||
|
||||
if self.debug_output:
|
||||
# print_ln('forward layer wq full %s', self.wq.X.reveal())
|
||||
print_ln('forward layer wv %s %s', self.wv.Y[0][0][0:10].reveal(), sum(self.wv.Y[0][0].reveal()))
|
||||
print_ln('forward layer hidden_state %s', hidden_state[0][1][0:10].reveal())
|
||||
# print_ln('forward layer wv full %s', self.wv.Y.reveal())
|
||||
|
||||
# max_size = program.budget // self.attention_head_size
|
||||
@for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads])
|
||||
def _(i, j):
|
||||
# for j in range(self.num_attention_heads):
|
||||
query_sub = sfix.Matrix(self.seq_len, self.attention_head_size) # this is mem inefficient?
|
||||
key_sub = sfix.Matrix(self.seq_len, self.attention_head_size)
|
||||
# print(self.wq.Y.shape, "wk Y shape", i, self.attention_head_size, j, self.wq.Y[i], self.wq.Y[i][:])
|
||||
|
||||
@for_range_opt(self.seq_len)
|
||||
def _(k):
|
||||
# for k in range(self.seq_len):
|
||||
query_sub[k] = self.wq.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
|
||||
key_sub[k] = self.wk.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
|
||||
|
||||
# print_ln("query_sub %s %s", i, j)
|
||||
res = query_sub.direct_mul_trans(key_sub)
|
||||
self.attention_scores[i].assign_part_vector(res, j)
|
||||
|
||||
if self.debug_output:
|
||||
print_ln('forward layer attention_scores %s', self.attention_scores[0][0].reveal())
|
||||
# print_ln('forward layer attention_scores full %s', self.attention_scores.reveal())
|
||||
|
||||
@for_range_opt_multithread(self.n_threads, [N, self.num_attention_heads, self.seq_len])
|
||||
def _(i, j, k):
|
||||
@@ -3103,7 +3059,6 @@ class MultiHeadAttention(BertBase):
|
||||
@for_range_opt([self.seq_len])
|
||||
def _(k):
|
||||
value_sub[k] = self.wv.Y[i][k].get_part_vector(j * self.attention_head_size, self.attention_head_size)
|
||||
# value_sub[k] = self.wv.Y[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size]
|
||||
|
||||
res = sfix.Matrix(self.seq_len, self.attention_head_size)
|
||||
res.assign_vector(self.dropout.Y[i][j].direct_mul(value_sub))
|
||||
@@ -3114,13 +3069,7 @@ class MultiHeadAttention(BertBase):
|
||||
self.context[i][k].assign_part_vector(res[k],
|
||||
j * self.attention_head_size
|
||||
)
|
||||
# for k in range(self.seq_len):
|
||||
# self.context[i][k][j * self.attention_head_size:(j + 1) * self.attention_head_size] = res[k * self.attention_head_size:(k + 1) * self.attention_head_size]
|
||||
|
||||
# How to transfer to forward?
|
||||
|
||||
# missing half of the values ?
|
||||
# print_ln('forward layer old_context %s', self.old_context[0].get_vector().reveal())
|
||||
if self.debug_output:
|
||||
print_ln('forward layer multiheadattention before internal output %s', self.context[0][0][0:20].get_vector().reveal())
|
||||
|
||||
@@ -3132,8 +3081,6 @@ class MultiHeadAttention(BertBase):
|
||||
print_ln('forward multiheadattention output %s', self.output.Y[0][0][0:20].reveal())
|
||||
print_ln("")
|
||||
|
||||
# return context
|
||||
|
||||
def reset(self):
|
||||
self.wq.reset()
|
||||
self.wk.reset()
|
||||
@@ -3189,8 +3136,6 @@ class MultiHeadAttention(BertBase):
|
||||
nabla_value_sub[k],
|
||||
j * self.attention_head_size)
|
||||
|
||||
print("RES MULTI BACK", self.dropout.Y, res, self.num_attention_heads, self.attention_head_size)
|
||||
|
||||
self.dropout.nabla_X.alloc()
|
||||
self.dropout.backward(True, batch)
|
||||
|
||||
@@ -4402,10 +4347,20 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
|
||||
raise CompilerError('multi-input layer %s not supported' % item)
|
||||
name = type(item).__name__
|
||||
if name == 'Linear':
|
||||
assert mul(input_shape[1:]) == item.in_features
|
||||
# Precondition: the item
|
||||
assert item.bias is not None
|
||||
layers.append(Dense(input_shape[0], item.in_features,
|
||||
item.out_features))
|
||||
if mul(input_shape[1:]) == item.in_features:
|
||||
layers.append(Dense(input_shape[0], item.in_features,
|
||||
item.out_features))
|
||||
elif input_shape[-1] == item.in_features:
|
||||
# we loop over all but last dimension
|
||||
assert len(input_shape) == 3, "Dense only supports one extra dimension to loop over"
|
||||
d = input_shape[1]
|
||||
layers.append(Dense(input_shape[0], item.in_features,
|
||||
item.out_features, d))
|
||||
else:
|
||||
assert False, f"input shape {input_shape} incompatible with in_features {item.in_features}"
|
||||
|
||||
if input_via is not None:
|
||||
shapes = [x.shape for x in (layers[-1].W, layers[-1].b)]
|
||||
import numpy
|
||||
@@ -4476,6 +4431,15 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
|
||||
input_shape = layers[-1].shape
|
||||
elif name == 'ReLU' or item == torch.nn.functional.relu:
|
||||
layers.append(Relu(input_shape))
|
||||
elif name == 'GeLU' or item == torch.nn.functional.gelu:
|
||||
layers.append(Gelu(input_shape))
|
||||
elif name == 'LayerNorm':
|
||||
layers.append(LayerNorm(input_shape, True, item.eps))
|
||||
if input_via is not None:
|
||||
layers[-1].weights = sfix.input_tensor_via(
|
||||
input_via, item.weight.detach())
|
||||
layers[-1].beta = sfix.input_tensor_via(
|
||||
input_via, item.bias.detach())
|
||||
elif name == 'Flatten':
|
||||
return
|
||||
elif name == 'BatchNorm2d' or name == 'BatchNorm1d':
|
||||
@@ -4528,7 +4492,7 @@ def layers_from_torch(model, data_input_shape, batch_size, input_via=None,
|
||||
num_attention_heads = config.num_attention_heads
|
||||
layernorm_eps = config.layer_norm_eps
|
||||
seq_len = input_shape[1]
|
||||
rsqrt_approx = False
|
||||
rsqrt_approx = True
|
||||
layer = BertLayer(input_shape[0], seq_len, hidden_state, intermediate_size, num_attention_heads,
|
||||
layernorm_eps, 0.125, rsqrt_approx, batch_size=batch_size)
|
||||
if input_via is not None:
|
||||
|
||||
@@ -6984,7 +6984,10 @@ class SubMultiArray(_vectorizable):
|
||||
:return: container of same shape and type as :py:obj:`self` """
|
||||
if is_zero(other):
|
||||
return self
|
||||
assert self.sizes == other.sizes
|
||||
if hasattr(other, 'sizes'):
|
||||
assert self.sizes == other.sizes
|
||||
if hasattr(other, 'size'):
|
||||
assert self.total_size() == other.size
|
||||
return self.from_vector(
|
||||
self.sizes, self.get_vector() + other.get_vector())
|
||||
|
||||
@@ -7563,7 +7566,7 @@ class SubMultiArray(_vectorizable):
|
||||
def __str__(self):
|
||||
return '%s multi-array of lengths %s at %s' % (
|
||||
self.value_type, self.sizes,
|
||||
'<unallocated>' if self.array._address is None else self.address)
|
||||
'<unallocated>' if self.address is None else self.address)
|
||||
__repr__ = __str__
|
||||
|
||||
class MultiArray(SubMultiArray):
|
||||
|
||||
@@ -20,10 +20,12 @@ from datasets import load_dataset
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-tiny (2 layers, 128 hidden)
|
||||
# Model sizes: M-FAC/bert-tiny-finetuned-qnli, M-FAC/bert-mini-finetuned-qnli, yoshitomo-matsubara/bert-large-uncased-qnli
|
||||
MODEL_NAME = 'M-FAC/bert-tiny-finetuned-qnli' # BERT-mini (4 layers, 256 hidden)
|
||||
MAX_LENGTH = 64 # Maximum sequence length
|
||||
N_SAMPLES = 25 # Number of samples to evaluate
|
||||
BATCH_SIZE = 1 # Batch size for MPC inference (increase for better performance)
|
||||
LAYER_COMPARISON = True # Set to False to skip layer-by-layer comparison (saves ~95% compile time)
|
||||
|
||||
# GLUE task configuration
|
||||
TASK_NAME = 'qnli'
|
||||
@@ -269,110 +271,112 @@ print_ln("MPC-PyTorch Match: %s/%s = %s",
|
||||
|
||||
# ============================================================================
|
||||
# Layer-by-Layer Comparison using Forward Hooks
|
||||
# (Generates ~95% of compile-time instructions. Disable for faster builds.)
|
||||
# ============================================================================
|
||||
|
||||
print_ln("\n=== Layer-by-Layer Comparison ===")
|
||||
if LAYER_COMPARISON:
|
||||
print_ln("\n=== Layer-by-Layer Comparison ===")
|
||||
|
||||
# Map to store PyTorch activations
|
||||
activation_map = {}
|
||||
# Map to store PyTorch activations
|
||||
activation_map = {}
|
||||
|
||||
def get_activation(name):
|
||||
"""Create a forward hook to capture layer outputs."""
|
||||
def hook(model, input, output):
|
||||
if isinstance(output, tuple):
|
||||
actual_output = output[0]
|
||||
else:
|
||||
actual_output = output
|
||||
activation_map[name] = actual_output.detach()
|
||||
return hook
|
||||
def get_activation(name):
|
||||
"""Create a forward hook to capture layer outputs."""
|
||||
def hook(model, input, output):
|
||||
if isinstance(output, tuple):
|
||||
actual_output = output[0]
|
||||
else:
|
||||
actual_output = output
|
||||
activation_map[name] = actual_output.detach()
|
||||
return hook
|
||||
|
||||
# Build layer comparison list
|
||||
def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt):
|
||||
"""Map MPC BertLayer components to PyTorch components."""
|
||||
return [
|
||||
(bert_layer_mpc.multi_head_attention, bert_layer_pt.attention),
|
||||
(bert_layer_mpc.intermediate, bert_layer_pt.intermediate),
|
||||
(bert_layer_mpc.output, bert_layer_pt.output),
|
||||
(bert_layer_mpc, bert_layer_pt),
|
||||
]
|
||||
# Build layer comparison list
|
||||
def layers_for_bertlayer(bert_layer_mpc, bert_layer_pt):
|
||||
"""Map MPC BertLayer components to PyTorch components."""
|
||||
return [
|
||||
(bert_layer_mpc.multi_head_attention, bert_layer_pt.attention),
|
||||
(bert_layer_mpc.intermediate, bert_layer_pt.intermediate),
|
||||
(bert_layer_mpc.output, bert_layer_pt.output),
|
||||
(bert_layer_mpc, bert_layer_pt),
|
||||
]
|
||||
|
||||
# Build complete layer comparison list
|
||||
layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in
|
||||
zip(mpc_layers[:-4], model.bert.encoder.layer)]
|
||||
layers_to_compare = [x for xs in layers_to_compare for x in xs]
|
||||
layers_to_compare.append((mpc_layers[-4], model.bert.pooler))
|
||||
layers_to_compare.append((mpc_layers[-3], model.dropout))
|
||||
layers_to_compare.append((mpc_layers[-2], model.classifier))
|
||||
# Build complete layer comparison list
|
||||
layers_to_compare = [layers_for_bertlayer(l1, l2) for l1, l2 in
|
||||
zip(mpc_layers[:-4], model.bert.encoder.layer)]
|
||||
layers_to_compare = [x for xs in layers_to_compare for x in xs]
|
||||
layers_to_compare.append((mpc_layers[-4], model.bert.pooler))
|
||||
layers_to_compare.append((mpc_layers[-3], model.dropout))
|
||||
layers_to_compare.append((mpc_layers[-2], model.classifier))
|
||||
|
||||
# Register forward hooks
|
||||
for layer_id, (_, pt_layer) in enumerate(layers_to_compare):
|
||||
pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}'))
|
||||
# Register forward hooks
|
||||
for layer_id, (_, pt_layer) in enumerate(layers_to_compare):
|
||||
pt_layer.register_forward_hook(get_activation(f'{layer_id}.{type(pt_layer).__name__}'))
|
||||
|
||||
# Run PyTorch forward pass to populate activation_map
|
||||
print("Capturing PyTorch layer outputs...")
|
||||
with torch.no_grad():
|
||||
for i in range(N_SAMPLES):
|
||||
activation_map.clear() # Clear for each sample
|
||||
# Run PyTorch forward pass to populate activation_map
|
||||
print("Capturing PyTorch layer outputs...")
|
||||
with torch.no_grad():
|
||||
for i in range(N_SAMPLES):
|
||||
activation_map.clear() # Clear for each sample
|
||||
|
||||
# Get sample embedding
|
||||
with embedded_data.formatted_as("torch", ["embedding"]):
|
||||
sample_embedding = embedded_data[i]['embedding'].unsqueeze(0)
|
||||
# Get sample embedding
|
||||
with embedded_data.formatted_as("torch", ["embedding"]):
|
||||
sample_embedding = embedded_data[i]['embedding'].unsqueeze(0)
|
||||
|
||||
# Run forward through wrapped model
|
||||
_ = bert_wrapped(sample_embedding)
|
||||
# Run forward through wrapped model
|
||||
_ = bert_wrapped(sample_embedding)
|
||||
|
||||
# Store activations for this sample
|
||||
if i == 0: # Only compare first sample to save time
|
||||
break
|
||||
# Store activations for this sample
|
||||
if i == 0: # Only compare first sample to save time
|
||||
break
|
||||
|
||||
print(f"Captured {len(activation_map)} layer outputs from PyTorch")
|
||||
print(f"Captured {len(activation_map)} layer outputs from PyTorch")
|
||||
|
||||
# Run MPC forward pass using reveal_correctness
|
||||
import numpy
|
||||
pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities]))
|
||||
pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor)
|
||||
# Run MPC forward pass using reveal_correctness
|
||||
import numpy
|
||||
pt_probs_tensor = numpy.array(numpy.concatenate([p.numpy() for p in pt_probabilities]))
|
||||
pt_probabilities_sfix = sfix.input_tensor_via(0, pt_probs_tensor)
|
||||
|
||||
test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:]))
|
||||
test_embeddings_one.assign(test_embeddings.get_part_vector(0))
|
||||
test_embeddings_one = sfix.Tensor([1] + list(test_embeddings.sizes[1:]))
|
||||
test_embeddings_one.assign(test_embeddings.get_part_vector(0))
|
||||
|
||||
pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:]))
|
||||
pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0))
|
||||
pt_probabilities_sfix_one = sfix.Tensor([1] + list(pt_probabilities_sfix.sizes[1:]))
|
||||
pt_probabilities_sfix_one.assign(pt_probabilities_sfix.get_part_vector(0))
|
||||
|
||||
print_ln("Running MPC forward pass for layer comparison...")
|
||||
_ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE)
|
||||
print_ln("Running MPC forward pass for layer comparison...")
|
||||
_ = optimizer.reveal_correctness(test_embeddings_one, pt_probabilities_sfix_one, batch_size=BATCH_SIZE)
|
||||
|
||||
# Compare layers
|
||||
print_ln("\nLayer-by-layer comparison (Sample 0 only):")
|
||||
print_ln("=" * 100)
|
||||
# Compare layers
|
||||
print_ln("\nLayer-by-layer comparison (Sample 0 only):")
|
||||
print_ln("=" * 100)
|
||||
|
||||
for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare):
|
||||
layer_id = f"{idx}.{type(pt_layer).__name__}"
|
||||
for idx, (mpc_layer, pt_layer) in enumerate(layers_to_compare):
|
||||
layer_id = f"{idx}.{type(pt_layer).__name__}"
|
||||
|
||||
if layer_id not in activation_map:
|
||||
continue
|
||||
if layer_id not in activation_map:
|
||||
continue
|
||||
|
||||
# Skip dropout layers since they use different random masks
|
||||
if 'Dropout' in type(pt_layer).__name__:
|
||||
print_ln("%s | Skipped (dropout)", layer_id)
|
||||
continue
|
||||
# Skip dropout layers since they use different random masks
|
||||
if 'Dropout' in type(pt_layer).__name__:
|
||||
print_ln("%s | Skipped (dropout)", layer_id)
|
||||
continue
|
||||
|
||||
# Get PyTorch values
|
||||
pt_values = activation_map[layer_id]
|
||||
pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal()
|
||||
# Get PyTorch values
|
||||
pt_values = activation_map[layer_id]
|
||||
pt_at_runtime = sfix.input_tensor_via(0, pt_values.numpy()).get_vector().reveal()
|
||||
|
||||
# Get MPC values
|
||||
mpc_output = mpc_layer.Y[0].get_vector().reveal()
|
||||
# Get MPC values
|
||||
mpc_output = mpc_layer.Y[0].get_vector().reveal()
|
||||
|
||||
# Compute detailed statistics
|
||||
total_abs_diff = sum(abs(pt_at_runtime - mpc_output))
|
||||
pt_magnitude = sum(abs(pt_at_runtime))
|
||||
# Compute detailed statistics
|
||||
total_abs_diff = sum(abs(pt_at_runtime - mpc_output))
|
||||
pt_magnitude = sum(abs(pt_at_runtime))
|
||||
|
||||
# Print layer comparison
|
||||
print_ln("\n%s", layer_id)
|
||||
print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime))
|
||||
print_ln(" Total Abs Diff: %s", total_abs_diff)
|
||||
print_ln(" PT Total Magnitude: %s", pt_magnitude)
|
||||
print_ln(" First 8 PT: %s", pt_at_runtime[:8])
|
||||
print_ln(" First 8 MPC: %s", mpc_output[:8])
|
||||
# Print layer comparison
|
||||
print_ln("\n%s", layer_id)
|
||||
print_ln(" Shape: %s, Elements: %s", pt_values.shape, len(pt_at_runtime))
|
||||
print_ln(" Total Abs Diff: %s", total_abs_diff)
|
||||
print_ln(" PT Total Magnitude: %s", pt_magnitude)
|
||||
print_ln(" First 8 PT: %s", pt_at_runtime[:8])
|
||||
print_ln(" First 8 MPC: %s", mpc_output[:8])
|
||||
|
||||
print_ln("\n=== Inference Complete ===")
|
||||
|
||||
Reference in New Issue
Block a user