diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 98994d21a0..d962508c0a 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1318,6 +1318,10 @@ def train_llama3(): if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True) + if getenv("FAKEDATA"): + for v in get_parameters(model): + v = v.assign(Tensor.empty(v.shape)) + if (DP := getenv("DP", 1)) > 1: device = tuple(f"{Device.DEFAULT}:{i}" for i in range(DP)) for v in get_parameters(model): @@ -1339,6 +1343,8 @@ def train_llama3(): else: # attention_norm, ffn_norm, norm v.shard_(device, axis=None) + # prevents memory spike on device 0 + v.realize() optim = AdamW(get_parameters(model), lr=0.0, b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay)