mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
model parallel llama (#11588)
MP=8 GRADIENT_ACC_STEPS=3 BS=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=70B SEQLEN=512 PYTHONPATH=. MODEL=llama3 python3 examples/mlperf/model_train.py
This commit is contained in:
@@ -1323,24 +1323,22 @@ def train_llama3():
|
||||
for v in get_parameters(model):
|
||||
v.shard_(device, axis=None)
|
||||
|
||||
# TODO: MP
|
||||
# if (GPUS := getenv("GPUS", 1)) > 1:
|
||||
# device = tuple(f"{Device.DEFAULT}:{i}" for i in range(GPUS))
|
||||
# for k,v in get_state_dict(model).items():
|
||||
# if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||
# # elif '.attention.wq' in k: v.shard_(device, axis=0)
|
||||
# # elif '.attention.wk' in k: v.shard_(device, axis=0)
|
||||
# # elif '.attention.wv' in k: v.shard_(device, axis=0)
|
||||
# # elif '.attention.wo' in k: v.shard_(device, axis=1)
|
||||
# # elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
||||
# # elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
|
||||
# # elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
||||
# # elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
||||
# elif 'output.weight' in k: v.shard_(device, axis=0) # 243.32
|
||||
# else:
|
||||
# # print(k)
|
||||
# # attention_norm, ffn_norm, norm
|
||||
# v.shard_(device, axis=None)
|
||||
if (MP := getenv("MP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||
for k,v in get_state_dict(model).items():
|
||||
if 'scale' in k: v.shard_(device, axis=None) # from quantized
|
||||
elif '.attention.wq' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wk' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wv' in k: v.shard_(device, axis=0)
|
||||
elif '.attention.wo' in k: v.shard_(device, axis=1)
|
||||
elif '.feed_forward.w1.' in k: v.shard_(device, axis=0)
|
||||
elif '.feed_forward.w2.' in k: v.shard_(device, axis=1)
|
||||
elif '.feed_forward.w3.' in k: v.shard_(device, axis=0)
|
||||
elif 'tok_embeddings.weight' in k: v.shard_(device, axis=0)
|
||||
elif 'output.weight' in k: v.shard_(device, axis=0)
|
||||
else:
|
||||
# attention_norm, ffn_norm, norm
|
||||
v.shard_(device, axis=None)
|
||||
|
||||
optim = AdamW(get_parameters(model), lr=0.0,
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
||||
@@ -1355,6 +1353,9 @@ def train_llama3():
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
batch = batch.shard(device, 0)
|
||||
if (MP := getenv("MP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(MP))
|
||||
batch = batch.shard(device)
|
||||
logits:Tensor = model(batch[:, :-1], start_pos=0, temperature=math.nan)
|
||||
loss = logits.sparse_categorical_crossentropy(batch[:, 1:])
|
||||
loss.backward()
|
||||
|
||||
Reference in New Issue
Block a user