diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index deaa4271f8..e49f3071ee 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1314,7 +1314,9 @@ def train_llama3(): # TODO: confirm weights are in bf16 # vocab_size from the mixtral tokenizer - model = Transformer(**(MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000}), max_context=SEQLEN, jit=False, disable_kv_cache=True) + params = MODEL_PARAMS[getenv("LLAMA3_SIZE", "8B")]["args"]|{"vocab_size": 32000} + if (llama_layers:=getenv("LLAMA_LAYERS")) != 0: params['n_layers'] = llama_layers + model = Transformer(**params, max_context=SEQLEN, jit=False, disable_kv_cache=True) optim = AdamW(get_parameters(model), lr=0.0, b1=opt_adamw_beta_1, b2=opt_adamw_beta_2, eps=opt_adamw_epsilon, weight_decay=opt_adamw_weight_decay) diff --git a/extra/models/llama.py b/extra/models/llama.py index 09907b964e..b942115816 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -99,7 +99,9 @@ class FeedForward: self.w3 = linear(dim, hidden_dim, bias=False) # the gate in Gated Linear Unit def __call__(self, x:Tensor) -> Tensor: - return self.w2(self.w1(x).silu() * self.w3(x)) # SwiGLU [arxiv/2002.05202, eq (5)] + w1 = self.w1(x).silu() + w3 = self.w3(x.contiguous_backward()) # this fixes a strange fusion that makes tensor cores miss + return self.w2(w1 * w3) class TransformerBlock: def __init__(self, dim:int, hidden_dim:int, n_heads:int, n_kv_heads:int, norm_eps:float, max_context:int, linear=nn.Linear, diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index fd343d0e70..3c233092fa 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -156,7 +156,7 @@ class ExecItem: lds_est = sym_infer(self.prg.estimates.lds, var_vals) mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" - print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(41-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 + print(f"{colored(f'*** {self.prg.device[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(44-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501 f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}")) self.prg.first_run = False