mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
remove old llama grad_acc (#13611)
* remove old llama grad_acc * GRADIENT_ACC_STEPS=1
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -465,7 +465,7 @@ jobs:
|
||||
- name: Test Bert training
|
||||
run: NULL=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=24 GPUS=4 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
|
||||
- name: Test llama 3 training
|
||||
run: NULL=1 SAMPLES=300 BS=8 SEQLEN=512 GRADIENT_ACC_STEPS=8 FAKEDATA=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
run: NULL=1 SAMPLES=300 BS=8 SEQLEN=512 GRADIENT_ACC_STEPS=1 FAKEDATA=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
||||
@@ -1294,6 +1294,7 @@ def train_llama3():
|
||||
BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
|
||||
BS = config["BS"] = getenv("BS", 16)
|
||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||
assert grad_acc == 1, f"{grad_acc=} is not supported"
|
||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||
SEED = config["SEED"] = getenv("SEED", 5760)
|
||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||
@@ -1368,20 +1369,17 @@ def train_llama3():
|
||||
|
||||
@TinyJit
|
||||
@Tensor.train()
|
||||
def train_step(model, tokens:Tensor, grad_acc:int):
|
||||
def train_step(model, tokens:Tensor):
|
||||
optim.zero_grad()
|
||||
# grad acc
|
||||
for batch in tokens.split(tokens.shape[0]//grad_acc):
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
batch = batch.shard(device, 0)
|
||||
if (MP := getenv("MP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||
batch = batch.shard(device)
|
||||
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
|
||||
loss.backward()
|
||||
Tensor.realize(*[p.grad for p in optim.params])
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
tokens = tokens.shard(device, 0)
|
||||
if (MP := getenv("MP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||
tokens = tokens.shard(device)
|
||||
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
loss.backward()
|
||||
# L2 norm grad clip
|
||||
# https://github.com/NVIDIA/NeMo/blob/3368c3fc0b4a186ab33a1d68a504315100c0b2a6/nemo/collections/nlp/modules/common/megatron/clip_grads.py#L57
|
||||
# https://docs.pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
|
||||
@@ -1445,7 +1443,7 @@ def train_llama3():
|
||||
for tokens in tqdm(iter, total=SAMPLES//GBS):
|
||||
t = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, lr = train_step(model, tokens, grad_acc)
|
||||
loss, lr = train_step(model, tokens)
|
||||
loss = loss.float().item()
|
||||
|
||||
i += 1
|
||||
|
||||
Reference in New Issue
Block a user