mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
bert grad accumulation (#10863)
* bert grad accumulation * realize grad
This commit is contained in:
@@ -914,16 +914,24 @@ def train_rnnt():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, GPUS, grad_acc:int, **kwargs):
|
||||||
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor, GPUS):
|
|
||||||
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
|
||||||
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
|
||||||
else: t.to_(GPUS[0])
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
for i in range(grad_acc):
|
||||||
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
input_ids, segment_ids = kwargs[f"input_ids{i}"], kwargs[f"segment_ids{i}"]
|
||||||
(loss * loss_scaler).backward()
|
# NOTE: these two have different names
|
||||||
|
attention_mask, masked_positions = kwargs[f"input_mask{i}"], kwargs[f"masked_lm_positions{i}"]
|
||||||
|
masked_lm_ids, masked_lm_weights, next_sentence_labels = kwargs[f"masked_lm_ids{i}"], kwargs[f"masked_lm_weights{i}"], kwargs[f"next_sentence_labels{i}"]
|
||||||
|
|
||||||
|
for t in [input_ids, segment_ids, attention_mask, masked_positions, masked_lm_ids, masked_lm_weights, next_sentence_labels]:
|
||||||
|
if len(GPUS) > 1: t.shard_(GPUS, axis=0)
|
||||||
|
else: t.to_(GPUS[0])
|
||||||
|
|
||||||
|
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||||
|
loss = model.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||||
|
(loss * loss_scaler).backward()
|
||||||
|
# TODO: OOM without this realize with large grad_acc
|
||||||
|
Tensor.realize(*[p.grad for p in optimizer.params])
|
||||||
|
|
||||||
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device)
|
global_norm = Tensor([0.0], dtype=dtypes.float32, device=optimizer[0].device)
|
||||||
for p in optimizer.params:
|
for p in optimizer.params:
|
||||||
@@ -1001,8 +1009,7 @@ def train_bert():
|
|||||||
# ** hyperparameters **
|
# ** hyperparameters **
|
||||||
BS = config["BS"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
BS = config["BS"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||||
# TODO: implement grad accumulation + mlperf logging
|
# TODO: mlperf logging
|
||||||
assert grad_acc == 1
|
|
||||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(GBS/96))
|
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(GBS/96))
|
||||||
@@ -1119,7 +1126,7 @@ def train_bert():
|
|||||||
# ** train loop **
|
# ** train loop **
|
||||||
wc_start = time.perf_counter()
|
wc_start = time.perf_counter()
|
||||||
|
|
||||||
i, train_data = start_step, next(train_it)
|
i, train_data = start_step, [next(train_it) for _ in range(grad_acc)]
|
||||||
|
|
||||||
if RUNMLPERF:
|
if RUNMLPERF:
|
||||||
if MLLOGGER:
|
if MLLOGGER:
|
||||||
@@ -1132,14 +1139,13 @@ def train_bert():
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
with WallTimeEvent(BenchEvent.STEP):
|
with WallTimeEvent(BenchEvent.STEP):
|
||||||
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
data = {f"{k}{i}":v for i,d in enumerate(train_data) for k,v in d.items()}
|
||||||
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, GPUS, grad_acc, **data)
|
||||||
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
|
|
||||||
|
|
||||||
pt = time.perf_counter()
|
pt = time.perf_counter()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
next_data = next(train_it)
|
next_data = [next(train_it) for _ in range(grad_acc)]
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
next_data = None
|
next_data = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user