model parallel llama (#11588)

MP=8 GRADIENT_ACC_STEPS=3 BS=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=70B SEQLEN=512 PYTHONPATH=. MODEL=llama3 python3 examples/mlperf/model_train.py
This commit is contained in:
chenyu
2025-08-09 13:54:27 -07:00
committed by GitHub
parent 09bc377da3
commit 45baec1aab

View File

@@ -1323,24 +1323,22 @@ def train_llama3():
for v in get_parameters(model):
v.shard_(device, axis=None)
# TODO: MP
# if (GPUS := getenv("GPUS", 1)) > 1:
# device = tuple(f"{Device.DEFAULT}:{i}" for i in range(GPUS))
# for k,v in get_state_dict(model).items():
# if 'scale' in k: v.shard_(device, axis=None) # from quantized
# # elif '.attention.wq' in k: v.shard_(device, axis=0)
# # elif '.attention.wk' in k: v.shard_(device, axis=0)
# # elif '.attention.wv' in k: v.shard_(device, axis=0)
# # elif '.attention.wo' in k: v.shard_(device, axis=1)
# # elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
# # elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
# # elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
# # elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
# elif 'output.weight' in k: v.shard_(device, axis=0) # 243.32
# else:
# # print(k)
# # attention_norm, ffn_norm, norm
# v.shard_(device, axis=None)
if (MP := getenv("MP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
for k,v in get_state_dict(model).items():
if 'scale' in k: v.shard_(device, axis=None) # from quantized
elif '.attention.wq' in k: v.shard_(device, axis=0)
elif '.attention.wk' in k: v.shard_(device, axis=0)
elif '.attention.wv' in k: v.shard_(device, axis=0)
elif '.attention.wo' in k: v.shard_(device, axis=1)
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
elif 'output.weight' in k: v.shard_(device, axis=0)
else:
# attention_norm, ffn_norm, norm
v.shard_(device, axis=None)
optim = AdamW(get_parameters(model), lr=0.0,
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
@@ -1355,6 +1353,9 @@ def train_llama3():
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
batch = batch.shard(device, 0)
if (MP := getenv("MP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
batch = batch.shard(device)
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
loss.backward()