diff --git a/extra/models/bert.py b/extra/models/bert.py index d97aae7030..dfb2f34d89 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -59,8 +59,11 @@ class BertForMLPerf: output = self.model(input_ids, attention_mask, segment_ids) clsf_logits = self.clsf_output(self.clsf_pooling_activation(self.clsf_pooler(output[:, 0]))).cast(dtypes.float32) - masked_positions = masked_positions[:, :, None].expand(-1, -1, output.shape[-1]) - h_masked = Tensor.gather(output, masked_positions, 1) + # gather only the masked_positions we care about + counter = Tensor.arange(output.shape[1], requires_grad=False, device=output.device).reshape(1, 1, output.shape[1]).expand(*masked_positions.shape, output.shape[1]) + onehot = counter == masked_positions.unsqueeze(2).expand(*masked_positions.shape, output.shape[1]) + h_masked = onehot @ output + h_masked = self.lm_norm(self.lm_transform_activation(self.lm_transform(h_masked))) lm_logits = self.lm_output(h_masked) + self.lm_output_bias