mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -1348,15 +1348,17 @@ def train_llama3():
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(model, tokens):
|
||||
def train_step(model, tokens:Tensor, grad_acc:int):
|
||||
optim.zero_grad()
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
tokens = tokens.shard(device, 0)
|
||||
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
loss.backward()
|
||||
|
||||
# grad acc
|
||||
for batch in tokens.split(tokens.shape[0]//grad_acc):
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
batch = batch.shard(device, 0)
|
||||
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
|
||||
loss.backward()
|
||||
Tensor.realize(*[p.grad for p in optim.params])
|
||||
# L2 norm grad clip
|
||||
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
|
||||
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
|
||||
@@ -1385,11 +1387,12 @@ def train_llama3():
|
||||
iter = batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||
|
||||
i = 0
|
||||
for tokens in tqdm(iter, total=SAMPLES//BS):
|
||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||
t = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, tokens)
|
||||
loss, lr = train_step(model, tokens, grad_acc)
|
||||
# above as tqdm.write f-string
|
||||
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used")
|
||||
tqdm.write(f"{loss.item():.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
||||
if (fname:=getenv("LOSS_FILE", "")):
|
||||
with open(fname, "a") as f:
|
||||
f.write(f"{i} {loss.item():.4f} {lr.item():.12f} {GlobalCounters.mem_used / 1e9:.2f}\n")
|
||||
|
||||
Reference in New Issue
Block a user