mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
remove old llama grad_acc (#13611)
* remove old llama grad_acc * GRADIENT_ACC_STEPS=1
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -465,7 +465,7 @@ jobs:
|
|||||||
- name: Test Bert training
|
- name: Test Bert training
|
||||||
run: NULL=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=24 GPUS=4 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
|
run: NULL=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=24 GPUS=4 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
|
||||||
- name: Test llama 3 training
|
- name: Test llama 3 training
|
||||||
run: NULL=1 SAMPLES=300 BS=8 SEQLEN=512 GRADIENT_ACC_STEPS=8 FAKEDATA=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B MODEL=llama3 python3 examples/mlperf/model_train.py
|
run: NULL=1 SAMPLES=300 BS=8 SEQLEN=512 GRADIENT_ACC_STEPS=1 FAKEDATA=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||||
- name: Run process replay tests
|
- name: Run process replay tests
|
||||||
uses: ./.github/actions/process-replay
|
uses: ./.github/actions/process-replay
|
||||||
|
|
||||||
|
|||||||
@@ -1294,6 +1294,7 @@ def train_llama3():
|
|||||||
BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
|
BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
|
||||||
BS = config["BS"] = getenv("BS", 16)
|
BS = config["BS"] = getenv("BS", 16)
|
||||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||||
|
assert grad_acc == 1, f"{grad_acc=} is not supported"
|
||||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||||
SEED = config["SEED"] = getenv("SEED", 5760)
|
SEED = config["SEED"] = getenv("SEED", 5760)
|
||||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||||
@@ -1368,20 +1369,17 @@ def train_llama3():
|
|||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Tensor.train()
|
@Tensor.train()
|
||||||
def train_step(model, tokens:Tensor, grad_acc:int):
|
def train_step(model, tokens:Tensor):
|
||||||
optim.zero_grad()
|
optim.zero_grad()
|
||||||
# grad acc
|
if (DP := getenv("DP", 1)) > 1:
|
||||||
for batch in tokens.split(tokens.shape[0]//grad_acc):
|
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||||
if (DP := getenv("DP", 1)) > 1:
|
tokens = tokens.shard(device, 0)
|
||||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
if (MP := getenv("MP", 1)) > 1:
|
||||||
batch = batch.shard(device, 0)
|
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||||
if (MP := getenv("MP", 1)) > 1:
|
tokens = tokens.shard(device)
|
||||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||||
batch = batch.shard(device)
|
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||||
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
|
loss.backward()
|
||||||
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
|
|
||||||
loss.backward()
|
|
||||||
Tensor.realize(*[p.grad for p in optim.params])
|
|
||||||
# L2 norm grad clip
|
# L2 norm grad clip
|
||||||
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
|
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
|
||||||
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
|
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
|
||||||
@@ -1445,7 +1443,7 @@ def train_llama3():
|
|||||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
loss, lr = train_step(model, tokens, grad_acc)
|
loss, lr = train_step(model, tokens)
|
||||||
loss = loss.float().item()
|
loss = loss.float().item()
|
||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user