verify dtype of llama model params (#13719)

This commit is contained in:
chenyu
2025-12-16 12:32:02 -05:00
committed by GitHub
parent e5a66ace80
commit e428fbfab6

View File

@@ -1314,12 +1314,14 @@ def train_llama3():
opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark
opt_end_learning_rate = getenv("END_LR", 8e-7)
# TODO: confirm weights are in bf16
model_params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
# vocab_size from the mixtral tokenizer
params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
params = params | {"vocab_size": 32000} if not SMALL else params
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
if not SMALL: model_params |= {"vocab_size": 32000}
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: model_params['n_layers'] = llama_layers
model = Transformer(**model_params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
params = get_parameters(model)
# weights are all bfloat16 for now
assert params and all(p.dtype == dtypes.bfloat16 for p in params)
if getenv("FAKEDATA"):
for v in get_parameters(model):
@@ -1409,7 +1411,7 @@ def train_llama3():
# ** data iters **
def fake_data(bs, samples):
for _ in range(samples // bs):
yield Tensor.randint(bs, SEQLEN + 1, low=0, high=params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT)
yield Tensor.randint(bs, SEQLEN + 1, low=0, high=model_params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT)
def get_train_iter():
if getenv("FAKEDATA", 0):