mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
revert bert grad accumulation (#13596)
prep for the new split jit style
This commit is contained in:
@@ -919,24 +919,16 @@ def train_rnnt():
|
||||
pass
|
||||
|
||||
@TinyJit
|
||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, GPUS, grad_acc:int, **kwargs):
|
||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
||||
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
||||
else: t.to_(GPUS[0])
|
||||
optimizer.zero_grad()
|
||||
|
||||
for i in range(grad_acc):
|
||||
input_ids, segment_ids = kwargs[f"input_ids{i}"], kwargs[f"segment_ids{i}"]
|
||||
# NOTE: these two have different names
|
||||
attention_mask, masked_positions = kwargs[f"input_mask{i}"], kwargs[f"masked_lm_positions{i}"]
|
||||
masked_lm_ids, masked_lm_weights, next_sentence_labels = kwargs[f"masked_lm_ids{i}"], kwargs[f"masked_lm_weights{i}"], kwargs[f"next_sentence_labels{i}"]
|
||||
|
||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
||||
else: t.to_(GPUS[0])
|
||||
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
(loss * loss_scaler).backward()
|
||||
# TODO: OOM without this realize with large grad_acc
|
||||
Tensor.realize(*[p.grad for p in optimizer.params])
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
(loss * loss_scaler).backward()
|
||||
|
||||
global_norm = Tensor(0.0, dtype=dtypes.float32, device=optimizer[0].device)
|
||||
for p in optimizer.params:
|
||||
@@ -1014,7 +1006,8 @@ def train_bert():
|
||||
# ** hyperparameters **
|
||||
BS = config["BS"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||
# TODO: mlperf logging
|
||||
# TODO: implement grad accumulation + mlperf logging
|
||||
assert grad_acc == 1
|
||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(GBS/96))
|
||||
@@ -1131,7 +1124,7 @@ def train_bert():
|
||||
# ** train loop **
|
||||
wc_start = time.perf_counter()
|
||||
|
||||
i, train_data = start_step, [next(train_it) for _ in range(grad_acc)]
|
||||
i, train_data = start_step, next(train_it)
|
||||
|
||||
if RUNMLPERF:
|
||||
if MLLOGGER:
|
||||
@@ -1144,13 +1137,14 @@ def train_bert():
|
||||
st = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
with WallTimeEvent(BenchEvent.STEP):
|
||||
data = {f"{k}{i}":v for i,d in enumerate(train_data) for k,v in d.items()}
|
||||
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, GPUS, grad_acc, **data)
|
||||
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
||||
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
||||
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
|
||||
|
||||
pt = time.perf_counter()
|
||||
|
||||
try:
|
||||
next_data = [next(train_it) for _ in range(grad_acc)]
|
||||
next_data = next(train_it)
|
||||
except StopIteration:
|
||||
next_data = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user