remove old llama grad_acc (#13611)

* remove old llama grad_acc

* GRADIENT_ACC_STEPS=1
This commit is contained in:
chenyu
2025-12-07 13:03:47 -05:00
committed by GitHub
parent 94d7646bdc
commit b981b6f89e
2 changed files with 13 additions and 15 deletions

View File

@@ -465,7 +465,7 @@ jobs:
- name: Test Bert training
run: NULL=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=24 GPUS=4 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
- name: Test llama 3 training
run: NULL=1 SAMPLES=300 BS=8 SEQLEN=512 GRADIENT_ACC_STEPS=8 FAKEDATA=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B MODEL=llama3 python3 examples/mlperf/model_train.py
run: NULL=1 SAMPLES=300 BS=8 SEQLEN=512 GRADIENT_ACC_STEPS=1 FAKEDATA=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B MODEL=llama3 python3 examples/mlperf/model_train.py
- name: Run process replay tests
uses: ./.github/actions/process-replay

View File

@@ -1294,6 +1294,7 @@ def train_llama3():
BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
BS = config["BS"] = getenv("BS", 16)
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
assert grad_acc == 1, f"{grad_acc=} is not supported"
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
SEED = config["SEED"] = getenv("SEED", 5760)
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
@@ -1368,20 +1369,17 @@ def train_llama3():
@TinyJit
@Tensor.train()
def train_step(model, tokens:Tensor, grad_acc:int):
def train_step(model, tokens:Tensor):
optim.zero_grad()
# grad acc
for batch in tokens.split(tokens.shape[0]//grad_acc):
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
batch = batch.shard(device, 0)
if (MP := getenv("MP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
batch = batch.shard(device)
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
loss.backward()
Tensor.realize(*[p.grad for p in optim.params])
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
tokens = tokens.shard(device, 0)
if (MP := getenv("MP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
tokens = tokens.shard(device)
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
loss.backward()
# L2 norm grad clip
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
@@ -1445,7 +1443,7 @@ def train_llama3():
for tokens in tqdm(iter, total=SAMPLES//GBS):
t = time.perf_counter()
GlobalCounters.reset()
loss, lr = train_step(model, tokens, grad_acc)
loss, lr = train_step(model, tokens)
loss = loss.float().item()
i += 1