From e428fbfab6778aad9884f3967b3fae002623f77f Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 16 Dec 2025 12:32:02 -0500 Subject: [PATCH] verify dtype of llama model params (#13719) --- examples/mlperf/model_train.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 2cf190998c..884bf338ca 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1314,12 +1314,14 @@ def train_llama3(): opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark opt_end_learning_rate = getenv("END_LR", 8e-7) - # TODO: confirm weights are in bf16 + model_params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"] # vocab_size from the mixtral tokenizer - params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"] - params = params | {"vocab_size": 32000} if not SMALL else params - if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers - model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True) + if not SMALL: model_params |= {"vocab_size": 32000} + if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: model_params['n_layers'] = llama_layers + model = Transformer(**model_params, max_context=SEQLEN, jit=False, disable_kv_cache=True) + params = get_parameters(model) + # weights are all bfloat16 for now + assert params and all(p.dtype == dtypes.bfloat16 for p in params) if getenv("FAKEDATA"): for v in get_parameters(model): @@ -1409,7 +1411,7 @@ def train_llama3(): # ** data iters ** def fake_data(bs, samples): for _ in range(samples // bs): - yield Tensor.randint(bs, SEQLEN + 1, low=0, high=params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT) + yield Tensor.randint(bs, SEQLEN + 1, low=0, high=model_params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT) def get_train_iter(): if getenv("FAKEDATA", 0):