mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
verify dtype of llama model params (#13719)
This commit is contained in:
@@ -1314,12 +1314,14 @@ def train_llama3():
|
||||
opt_base_learning_rate = getenv("LR", 8e-5 * GBS / 1152) # NOTE: cannot change for benchmark
|
||||
opt_end_learning_rate = getenv("END_LR", 8e-7)
|
||||
|
||||
# TODO: confirm weights are in bf16
|
||||
model_params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
|
||||
# vocab_size from the mixtral tokenizer
|
||||
params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]
|
||||
params = params | {"vocab_size": 32000} if not SMALL else params
|
||||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
||||
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||
if not SMALL: model_params |= {"vocab_size": 32000}
|
||||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: model_params['n_layers'] = llama_layers
|
||||
model = Transformer(**model_params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||
params = get_parameters(model)
|
||||
# weights are all bfloat16 for now
|
||||
assert params and all(p.dtype == dtypes.bfloat16 for p in params)
|
||||
|
||||
if getenv("FAKEDATA"):
|
||||
for v in get_parameters(model):
|
||||
@@ -1409,7 +1411,7 @@ def train_llama3():
|
||||
# ** data iters **
|
||||
def fake_data(bs, samples):
|
||||
for _ in range(samples // bs):
|
||||
yield Tensor.randint(bs, SEQLEN + 1, low=0, high=params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT)
|
||||
yield Tensor.randint(bs, SEQLEN + 1, low=0, high=model_params["vocab_size"], dtype=dtypes.int32, device=Device.DEFAULT)
|
||||
|
||||
def get_train_iter():
|
||||
if getenv("FAKEDATA", 0):
|
||||
|
||||
Reference in New Issue
Block a user