mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
llama MP realize weight after shard (#11672)
* llama MP realize weight after shard prevents memory spike on device 0 * empty weight for FAKEDATA
This commit is contained in:
@@ -1318,6 +1318,10 @@ def train_llama3():
|
||||
if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers
|
||||
model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True)
|
||||
|
||||
if getenv("FAKEDATA"):
|
||||
for v in get_parameters(model):
|
||||
v = v.assign(Tensor.empty(v.shape))
|
||||
|
||||
if (DP := getenv("DP", 1)) > 1:
|
||||
device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP))
|
||||
for v in get_parameters(model):
|
||||
@@ -1339,6 +1343,8 @@ def train_llama3():
|
||||
else:
|
||||
# attention_norm, ffn_norm, norm
|
||||
v.shard_(device, axis=None)
|
||||
# prevents memory spike on device 0
|
||||
v.realize()
|
||||
|
||||
optim = AdamW(get_parameters(model), lr=0.0,
|
||||
b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)
|
||||
|
||||
Reference in New Issue
Block a user