From 70c797b1077e73ba00bc440a5bdfd3ff70cf4223 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 11 May 2025 08:42:08 -0400 Subject: [PATCH] train bert tests (#10248) added a working bert tiny test, and a failed bert FUSE_ARANGE test --- test/models/test_real_world.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 08280a28d4..979448103b 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -14,6 +14,7 @@ from examples.hlb_cifar10 import SpeedyResNet, hyp from examples.llama import Transformer as LLaMaTransformer from examples.stable_diffusion import UNetModel, unet_params from extra.models.unet import ResBlock +from extra.models.bert import BertForPretraining global_mem_used = 0 def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False): @@ -143,5 +144,33 @@ class TestRealWorld(unittest.TestCase): final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4) assert not np.isnan(lr_scheduler.min_lr), "lr too small or initial_div_facotr too big for half" + def test_bert(self): + with Tensor.train(): + args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2, + "max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2} + model = BertForPretraining(**args_tiny) + optimizer = optim.LAMB(get_parameters(model)) + + @TinyJit + def train(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, + masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): + lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) + loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + from examples.mlperf.helpers import get_fake_data_bert + data = get_fake_data_bert(BS=4) + for v in data.values(): v.to_(Device.DEFAULT) + + helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \ + data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 346) + + @unittest.expectedFailure # TODO: fix FUSE_ARANGE + def test_bert_fuse_arange(self): + with Context(FUSE_ARANGE=1): + self.test_bert() + if __name__ == '__main__': unittest.main()