mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
model parallel llama (#11588)
MP=8 GRADIENT_ACC_STEPS=3 BS=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=70B SEQLEN=512 PYTHONPATH=. MODEL=llama3 python3 examples/mlperf/model_train.py
This commit is contained in:
@@ -1323,24 +1323,22 @@ def train_llama3():
|
|||||||
for v in get_parameters(model):
|
for v in get_parameters(model):
|
||||||
v.shard_(device, axis=None)
|
v.shard_(device, axis=None)
|
||||||
|
|
||||||
# TODO: MP
|
if (MP := getenv("MP", 1)) > 1:
|
||||||
# if (GPUS := getenv("GPUS", 1)) > 1:
|
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||||
# device = tuple(f"{Device.DEFAULT}:{i}" for i in range(GPUS))
|
for k,v in get_state_dict(model).items():
|
||||||
# for k,v in get_state_dict(model).items():
|
if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||||
# if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
elif '.attention.wq' in k: v.shard_(device, axis=0)
|
||||||
# # elif '.attention.wq' in k: v.shard_(device, axis=0)
|
elif '.attention.wk' in k: v.shard_(device, axis=0)
|
||||||
# # elif '.attention.wk' in k: v.shard_(device, axis=0)
|
elif '.attention.wv' in k: v.shard_(device, axis=0)
|
||||||
# # elif '.attention.wv' in k: v.shard_(device, axis=0)
|
elif '.attention.wo' in k: v.shard_(device, axis=1)
|
||||||
# # elif '.attention.wo' in k: v.shard_(device, axis=1)
|
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
||||||
# # elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
|
||||||
# # elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
|
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
||||||
# # elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
||||||
# # elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
elif 'output.weight' in k: v.shard_(device, axis=0)
|
||||||
# elif 'output.weight' in k: v.shard_(device, axis=0) # 243.32
|
else:
|
||||||
# else:
|
# attention_norm, ffn_norm, norm
|
||||||
# # print(k)
|
v.shard_(device, axis=None)
|
||||||
# # attention_norm, ffn_norm, norm
|
|
||||||
# v.shard_(device, axis=None)
|
|
||||||
|
|
||||||
optim = AdamW(get_parameters(model), lr=0.0,
|
optim = AdamW(get_parameters(model), lr=0.0,
|
||||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
||||||
@@ -1355,6 +1353,9 @@ def train_llama3():
|
|||||||
if (DP := getenv("DP", 1)) > 1:
|
if (DP := getenv("DP", 1)) > 1:
|
||||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||||
batch = batch.shard(device, 0)
|
batch = batch.shard(device, 0)
|
||||||
|
if (MP := getenv("MP", 1)) > 1:
|
||||||
|
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||||
|
batch = batch.shard(device)
|
||||||
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
|
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
|
||||||
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
|
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
Reference in New Issue
Block a user