llama MP realize weight after shard (#11672)

* llama MP realize weight after shard

prevents memory spike on device 0

* empty weight for FAKEDATA
This commit is contained in:
chenyu
2025-08-14 13:17:46 -07:00
committed by GitHub
parent 4176b24264
commit e9d0027591

View File

@@ -1318,6 +1318,10 @@ def train_llama3():
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
if getenv("FAKEDATA"):
for v in get_parameters(model):
v = v.assign(Tensor.empty(v.shape))
if (DP := getenv("DP", 1)) > 1:
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
for v in get_parameters(model):
@@ -1339,6 +1343,8 @@ def train_llama3():
else:
# attention_norm, ffn_norm, norm
v.shard_(device, axis=None)
# prevents memory spike on device 0
v.realize()
optim = AdamW(get_parameters(model), lr=0.0,
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)