grad acc train llama (#11467)

* grad acc train llama

* log step time
This commit is contained in:
chenyu
2025-08-01 12:54:50 -07:00
committed by GitHub
parent 7ad7329257
commit 9e8e6b45ab

View File

@@ -1348,15 +1348,17 @@ def train_llama3():
@TinyJit
@Tensor.train()
def train_step(model, tokens):
def train_step(model, tokens:Tensor, grad_acc:int):
optim.zero_grad()
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
tokens = tokens.shard(device, 0)
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
loss.backward()
# grad acc
for batch in tokens.split(tokens.shape[0]//grad_acc):
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
batch = batch.shard(device, 0)
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
loss.backward()
Tensor.realize(*[p.grad for p in optim.params])
# L2 norm grad clip
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
@@ -1385,11 +1387,12 @@ def train_llama3():
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
i = 0
for tokens in tqdm(iter, total=SAMPLES//BS):
for tokens in tqdm(iter, total=SAMPLES//GBS):
t = time.perf_counter()
GlobalCounters.reset()
loss, lr = train_step(model, tokens)
loss, lr = train_step(model, tokens, grad_acc)
# above as tqdm.write f-string
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used")
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
if (fname:=getenv("LOSS_FILE", "")):
with open(fname, "a") as f:
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")