mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
train bert tests (#10248)
added a working bert tiny test, and a failed bert FUSE_ARANGE test
This commit is contained in:
@@ -14,6 +14,7 @@ from examples.hlb_cifar10 import SpeedyResNet, hyp
|
||||
from examples.llama import Transformer as LLaMaTransformer
|
||||
from examples.stable_diffusion import UNetModel, unet_params
|
||||
from extra.models.unet import ResBlock
|
||||
from extra.models.bert import BertForPretraining
|
||||
|
||||
global_mem_used = 0
|
||||
def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
||||
@@ -143,5 +144,33 @@ class TestRealWorld(unittest.TestCase):
|
||||
final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=4)
|
||||
assert not np.isnan(lr_scheduler.min_lr), "lr too small or initial_div_facotr too big for half"
|
||||
|
||||
def test_bert(self):
|
||||
with Tensor.train():
|
||||
args_tiny = {"attention_probs_dropout_prob": 0.0, "hidden_dropout_prob": 0.0, "vocab_size": 30522, "type_vocab_size": 2,
|
||||
"max_position_embeddings": 512, "hidden_size": 128, "intermediate_size": 512, "num_attention_heads": 2, "num_hidden_layers": 2}
|
||||
model = BertForPretraining(**args_tiny)
|
||||
optimizer = optim.LAMB(get_parameters(model))
|
||||
|
||||
@TinyJit
|
||||
def train(input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
||||
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
from examples.mlperf.helpers import get_fake_data_bert
|
||||
data = get_fake_data_bert(BS=4)
|
||||
for v in data.values(): v.to_(Device.DEFAULT)
|
||||
|
||||
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 346)
|
||||
|
||||
@unittest.expectedFailure # TODO: fix FUSE_ARANGE
|
||||
def test_bert_fuse_arange(self):
|
||||
with Context(FUSE_ARANGE=1):
|
||||
self.test_bert()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user