mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
data parallel train llama (#11466)
This commit is contained in:
@@ -1318,6 +1318,30 @@ def train_llama3():
|
||||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
||||
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
for v in get_parameters(model):
|
||||
v.shard_(device, axis=None)
|
||||
|
||||
# TODO: MP
|
||||
# if (GPUS := getenv("GPUS", 1)) > 1:
|
||||
# device = tuple(f"{Device.DEFAULT}:{i}" for i in range(GPUS))
|
||||
# for k,v in get_state_dict(model).items():
|
||||
# if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||
# # elif '.attention.wq' in k: v.shard_(device, axis=0)
|
||||
# # elif '.attention.wk' in k: v.shard_(device, axis=0)
|
||||
# # elif '.attention.wv' in k: v.shard_(device, axis=0)
|
||||
# # elif '.attention.wo' in k: v.shard_(device, axis=1)
|
||||
# # elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
||||
# # elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
|
||||
# # elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
||||
# # elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
||||
# elif 'output.weight' in k: v.shard_(device, axis=0) # 243.32
|
||||
# else:
|
||||
# # print(k)
|
||||
# # attention_norm, ffn_norm, norm
|
||||
# v.shard_(device, axis=None)
|
||||
|
||||
optim = AdamW(get_parameters(model), lr=0.0,
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
||||
scheduler = CosineAnnealingLRWithWarmup(optim, opt_base_learning_rate, opt_end_learning_rate, opt_learning_rate_warmup_steps, opt_learning_rate_decay_steps)
|
||||
@@ -1326,6 +1350,9 @@ def train_llama3():
|
||||
@Tensor.train()
|
||||
def train_step(model, tokens):
|
||||
optim.zero_grad()
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
tokens = tokens.shard(device, 0)
|
||||
logits:Tensor = model(tokens[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(tokens[:, 1:])
|
||||
loss.backward()
|
||||
|
||||
Reference in New Issue
Block a user