mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -593,7 +593,8 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
# TODO: no to("CPU") here because it blocks and messes the python time
|
||||
return loss.realize(), global_norm.realize(), optimizer.optimizers[0].lr.realize()
|
||||
Tensor.realize(loss, global_norm, optimizer.optimizers[0].lr)
|
||||
return loss, global_norm, optimizer.optimizers[0].lr
|
||||
|
||||
@TinyJit
|
||||
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
|
||||
@@ -604,8 +605,10 @@ def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:T
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
|
||||
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
return masked_lm_accuracy.to("CPU").realize(), seq_relationship_accuracy.realize().to("CPU"), \
|
||||
masked_lm_loss.to("CPU").realize(), next_sentence_loss.to("CPU").realize()
|
||||
for t in [masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss]:
|
||||
t.to_("CPU")
|
||||
Tensor.realize(masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss)
|
||||
return masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss
|
||||
|
||||
def train_bert():
|
||||
# NOTE: pip install tensorflow, wandb required
|
||||
|
||||
Reference in New Issue
Block a user