train bert tests (#10248)

added a working bert tiny test, and a failed bert FUSE_ARANGE test
This commit is contained in:
chenyu
2025-05-11 08:42:08 -04:00
committed by GitHub
parent b2df4cb696
commit 70c797b107

View File

@@ -14,6 +14,7 @@ from examples.hlb_cifar10 import SpeedyResNet, hyp
from examples.llama import Transformer as LLaMaTransformer
from examples.stable_diffusion import UNetModel, unet_params
from extra.models.unet import ResBlock
from extra.models.bert import BertForPretraining
global_mem_used = 0
def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False):
@@ -143,5 +144,33 @@ class TestRealWorld(unittest.TestCase):
final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4)
assert not np.isnan(lr_scheduler.min_lr), "lr too small or initial_div_facotr too big for half"
def test_bert(self):
with Tensor.train():
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
model = BertForPretraining(**args_tiny)
optimizer = optim.LAMB(get_parameters(model))
@TinyJit
def train(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
from examples.mlperf.helpers import get_fake_data_bert
data = get_fake_data_bert(BS=4)
for v in data.values(): v.to_(Device.DEFAULT)
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 346)
@unittest.expectedFailure # TODO: fix FUSE_ARANGE
def test_bert_fuse_arange(self):
with Context(FUSE_ARANGE=1):
self.test_bert()
if __name__ == '__main__':
unittest.main()