data parallel train llama (#11466)

This commit is contained in:
chenyu
2025-08-01 09:13:51 -07:00
committed by GitHub
parent 9f2182f92f
commit 7ad7329257

View File

@@ -1318,6 +1318,30 @@ def train_llama3():
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
for v in get_parameters(model):
v.shard_(device, axis=None)
# TODO: MP
# if (GPUS := getenv("GPUS", 1)) > 1:
# device = tuple(f"{Device.DEFAULT}:{i}" for i in range(GPUS))
# for k,v in get_state_dict(model).items():
# if 'scale' in k: v.shard_(device, axis=None) # from quantized
# # elif '.attention.wq' in k: v.shard_(device, axis=0)
# # elif '.attention.wk' in k: v.shard_(device, axis=0)
# # elif '.attention.wv' in k: v.shard_(device, axis=0)
# # elif '.attention.wo' in k: v.shard_(device, axis=1)
# # elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
# # elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
# # elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
# # elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
# elif 'output.weight' in k: v.shard_(device, axis=0) # 243.32
# else:
# # print(k)
# # attention_norm, ffn_norm, norm
# v.shard_(device, axis=None)
optim = AdamW(get_parameters(model), lr=0.0,
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
@@ -1326,6 +1350,9 @@ def train_llama3():
@Tensor.train()
def train_step(model, tokens):
optim.zero_grad()
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
tokens = tokens.shard(device, 0)
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
loss.backward()