diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 274703eb61..dc709344e7 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -340,6 +340,10 @@ if __name__ == "__main__": # do benchmark if args.benchmark: param_bytes = sum(x.nbytes() for x in nn.state.get_parameters(model)) + for b in model.blk: + if hasattr(b, 'ffn_gate_exps'): + expert_bytes = b.ffn_gate_exps.weight.nbytes() + b.ffn_up_exps.weight.nbytes() + b.ffn_down_exps.weight.nbytes() + param_bytes -= int(expert_bytes * (1 - b.num_experts_per_tok / b.ffn_gate_exps.weight.shape[0])) gen = model.generate([0], 0) for _ in range(args.benchmark): GlobalCounters.reset()