mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: small llama3 training (#11829)
This commit is contained in:
@@ -758,6 +758,27 @@ def batch_load_llama3(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0
|
|||||||
batch.append(tokens)
|
batch.append(tokens)
|
||||||
yield Tensor.stack(batch, dim=0)
|
yield Tensor.stack(batch, dim=0)
|
||||||
|
|
||||||
|
def batch_load_llama3_small(bs:int, samples:int, seqlen:int, base_dir:Path, seed:int=0, val:bool=True):
|
||||||
|
if val:
|
||||||
|
dataset = BlendedGPTDataset([
|
||||||
|
base_dir / "c4-validation-91205-samples.en_text_document",
|
||||||
|
], [
|
||||||
|
1.0
|
||||||
|
], samples, seqlen, seed, False)
|
||||||
|
else:
|
||||||
|
dataset = BlendedGPTDataset([
|
||||||
|
base_dir / "c4-train.en_6_text_document",
|
||||||
|
], [
|
||||||
|
1.0
|
||||||
|
], samples, seqlen, seed, True)
|
||||||
|
|
||||||
|
for b in range(math.ceil(samples / bs)):
|
||||||
|
batch = []
|
||||||
|
for i in range(bs):
|
||||||
|
tokens = dataset.get(b * bs + i)
|
||||||
|
batch.append(tokens)
|
||||||
|
yield Tensor.stack(batch, dim=0)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
def load_unet3d(val):
|
def load_unet3d(val):
|
||||||
assert not val, "validation set is not supported due to different sizes on inputs"
|
assert not val, "validation set is not supported due to different sizes on inputs"
|
||||||
|
|||||||
@@ -243,31 +243,49 @@ def eval_mrcnn():
|
|||||||
|
|
||||||
def eval_llama3():
|
def eval_llama3():
|
||||||
from extra.models.llama import Transformer
|
from extra.models.llama import Transformer
|
||||||
from examples.llama3 import MODEL_PARAMS
|
from examples.llama3 import MODEL_PARAMS, load, convert_from_huggingface
|
||||||
from tinygrad.helpers import tqdm
|
from tinygrad.helpers import tqdm
|
||||||
|
|
||||||
bs = 4
|
BASEDIR = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
|
||||||
sequence_length = 512
|
BS = getenv("BS", 4)
|
||||||
|
SMALL = getenv("SMALL", 0)
|
||||||
|
SEQLEN = getenv("SEQLEN", 8192)
|
||||||
|
MODEL_PATH = Path(getenv("MODEL_PATH", "/raid/weights/llama31_8b/"))
|
||||||
|
|
||||||
model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=sequence_length, jit=False, disable_kv_cache=True)
|
params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
|
||||||
|
params = params | {"vocab_size": 32000} if not SMALL else params
|
||||||
|
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
||||||
|
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||||
|
|
||||||
|
# load weights
|
||||||
|
weights = load(str(MODEL_PATH / "model.safetensors.index.json"))
|
||||||
|
if "model.embed_tokens.weight" in weights:
|
||||||
|
print("converting from huggingface format")
|
||||||
|
weights = convert_from_huggingface(weights, params["n_layers"], params["n_heads"], params["n_kv_heads"])
|
||||||
|
|
||||||
|
load_state_dict(model, weights, strict=False, consume=True)
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def eval_step(model, tokens):
|
def eval_step(model, tokens):
|
||||||
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||||
return loss.flatten()
|
return loss.flatten().float()
|
||||||
|
|
||||||
from examples.mlperf.dataloader import batch_load_llama3
|
if SMALL:
|
||||||
iter = batch_load_llama3(bs, 5760, sequence_length, Path(getenv("BASEDIR", "/raid/datasets/c4/")), True)
|
from examples.mlperf.dataloader import batch_load_llama3_small
|
||||||
|
iter = batch_load_llama3_small(BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||||
|
else:
|
||||||
|
from examples.mlperf.dataloader import batch_load_llama3
|
||||||
|
iter = batch_load_llama3(BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||||
|
|
||||||
losses = []
|
losses = []
|
||||||
for tokens in tqdm(iter, total=5760//bs):
|
for tokens in tqdm(iter, total=5760//BS):
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
losses += eval_step(model, tokens).tolist()
|
losses += eval_step(model, tokens).tolist()
|
||||||
tqdm.write(f"loss: {np.mean(losses)}")
|
tqdm.write(f"loss: {np.mean(losses)}")
|
||||||
|
|
||||||
log_perplexity = Tensor(losses).mean()
|
log_perplexity = np.mean(losses)
|
||||||
print(f"Log Perplexity: {log_perplexity.item()}")
|
print(f"Log Perplexity: {log_perplexity}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# inference only
|
# inference only
|
||||||
|
|||||||
@@ -1290,12 +1290,14 @@ def train_llama3():
|
|||||||
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
from examples.mlperf.lr_schedulers import CosineAnnealingLRWithWarmup
|
||||||
|
|
||||||
config = {}
|
config = {}
|
||||||
|
BASEDIR = config["BASEDIR"] = Path(getenv("BASEDIR", "/raid/datasets/c4/"))
|
||||||
BS = config["BS"] = getenv("BS", 16)
|
BS = config["BS"] = getenv("BS", 16)
|
||||||
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
grad_acc = config["GRADIENT_ACC_STEPS"] = getenv("GRADIENT_ACC_STEPS", 1)
|
||||||
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
GBS = config["GLOBAL_BATCH_SIZE"] = BS * grad_acc
|
||||||
SEED = config["SEED"] = getenv("SEED", 5760)
|
SEED = config["SEED"] = getenv("SEED", 5760)
|
||||||
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
SEQLEN = config["SEQLEN"] = getenv("SEQLEN", 8192)
|
||||||
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
|
TRAIN_ON_VAL = config["TRAIN_ON_VAL"] = getenv("TRAIN_ON_VAL", 0)
|
||||||
|
SMALL = config["SMALL"] = getenv("SMALL", 0)
|
||||||
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
|
SAMPLES = config["SAMPLES"] = getenv("SAMPLES", 5_760 if TRAIN_ON_VAL else 1_200_000 * 1152)
|
||||||
EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080)
|
EVAL_FREQ = config["EVAL_FREQ"] = getenv("EVAL_FREQ", 46080)
|
||||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
|
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 16)
|
||||||
@@ -1317,7 +1319,8 @@ def train_llama3():
|
|||||||
|
|
||||||
# TODO: confirm weights are in bf16
|
# TODO: confirm weights are in bf16
|
||||||
# vocab_size from the mixtral tokenizer
|
# vocab_size from the mixtral tokenizer
|
||||||
params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}
|
params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
|
||||||
|
params = params | {"vocab_size": 32000} if not SMALL else params
|
||||||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
||||||
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||||
|
|
||||||
@@ -1403,21 +1406,29 @@ def train_llama3():
|
|||||||
# ** data iters **
|
# ** data iters **
|
||||||
def fake_data(bs, samples):
|
def fake_data(bs, samples):
|
||||||
for _ in range(samples // bs):
|
for _ in range(samples // bs):
|
||||||
yield Tensor.randint(bs, SEQLEN + 1, low=0, high=32000, dtype=dtypes.int32, device=Device.DEFAULT)
|
yield Tensor.randint(bs, SEQLEN + 1, low=0, high=params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT)
|
||||||
|
|
||||||
def get_train_iter():
|
def get_train_iter():
|
||||||
if getenv("FAKEDATA", 0):
|
if getenv("FAKEDATA", 0):
|
||||||
return fake_data(GBS, SAMPLES)
|
return fake_data(GBS, SAMPLES)
|
||||||
else:
|
else:
|
||||||
from examples.mlperf.dataloader import batch_load_llama3
|
if SMALL:
|
||||||
return batch_load_llama3(GBS, SAMPLES, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=bool(TRAIN_ON_VAL))
|
from examples.mlperf.dataloader import batch_load_llama3_small
|
||||||
|
return batch_load_llama3_small(GBS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||||
|
else:
|
||||||
|
from examples.mlperf.dataloader import batch_load_llama3
|
||||||
|
return batch_load_llama3(GBS, SAMPLES, SEQLEN, BASEDIR, seed=SEED, val=bool(TRAIN_ON_VAL))
|
||||||
|
|
||||||
def get_eval_iter():
|
def get_eval_iter():
|
||||||
if getenv("FAKEDATA", 0):
|
if getenv("FAKEDATA", 0):
|
||||||
return fake_data(EVAL_BS, 5760)
|
return fake_data(EVAL_BS, 5760)
|
||||||
else:
|
else:
|
||||||
from examples.mlperf.dataloader import batch_load_llama3
|
if SMALL:
|
||||||
return batch_load_llama3(EVAL_BS, 5760, SEQLEN, Path(getenv("BASEDIR", "/raid/datasets/c4/")), seed=SEED, val=True)
|
from examples.mlperf.dataloader import batch_load_llama3_small
|
||||||
|
return batch_load_llama3_small(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||||
|
else:
|
||||||
|
from examples.mlperf.dataloader import batch_load_llama3
|
||||||
|
return batch_load_llama3(EVAL_BS, 5760, SEQLEN, BASEDIR, val=True)
|
||||||
|
|
||||||
iter = get_train_iter()
|
iter = get_train_iter()
|
||||||
i, sequences_seen = 0, 0
|
i, sequences_seen = 0, 0
|
||||||
@@ -1426,7 +1437,7 @@ def train_llama3():
|
|||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
loss, lr = train_step(model, tokens, grad_acc)
|
loss, lr = train_step(model, tokens, grad_acc)
|
||||||
loss = loss.float().item()
|
loss = loss.float().item()
|
||||||
# above as tqdm.write f-string
|
|
||||||
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
tqdm.write(f"{loss:.4f} loss, {lr.item():.12f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {time.perf_counter()-t:.2f} s")
|
||||||
if (fname:=getenv("LOSS_FILE", "")):
|
if (fname:=getenv("LOSS_FILE", "")):
|
||||||
with open(fname, "a") as f:
|
with open(fname, "a") as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user