From 59ab3675a3cfedcd0d443d6290653ff8c16b8d77 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 10 Dec 2023 19:04:58 -0800 Subject: [PATCH] faster mixtral + green for new kernels (#2701) * green for new kernels * track ram --- examples/mixtral.py | 10 +++++----- tinygrad/device.py | 15 ++++++++------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/mixtral.py b/examples/mixtral.py index 8593cf4a4e..2dd226aa33 100644 --- a/examples/mixtral.py +++ b/examples/mixtral.py @@ -1,6 +1,6 @@ import functools, argparse, os from tqdm import tqdm -from tinygrad import Tensor, nn, Device, GlobalCounters +from tinygrad import Tensor, nn, Device, GlobalCounters, Variable from tinygrad.helpers import Timing from tinygrad.nn.state import torch_load, get_state_dict from extra.models.llama import FeedForward, Transformer @@ -16,8 +16,8 @@ class MixtureFeedForward: top = sorted(enumerate(choice), key=lambda x: -x[1]) norm = top[0][1] + top[1][1] e1, e2 = self.experts[top[0][0]], self.experts[top[1][0]] - ret = e1(x.to(e1.w1.weight.device)).to(x.device) * (top[0][1]/norm) + \ - e2(x.to(e2.w1.weight.device)).to(x.device) * (top[1][1]/norm) + ret = e1(x.to(e1.w1.weight.device)).to(x.device) * Tensor([top[0][1]/norm]) + \ + e2(x.to(e2.w1.weight.device)).to(x.device) * Tensor([top[1][1]/norm]) return ret if __name__ == "__main__": @@ -33,12 +33,12 @@ if __name__ == "__main__": model_state_dict = get_state_dict(model) for k in (t := tqdm(state)): - t.set_description(f"loading {k}") if 'feed_forward.experts.' in k: expert_no = int(k.split('feed_forward.experts.')[1].split('.')[0]) device = Device.DEFAULT + ":" + str((expert_no//2)+1) else: device = Device.DEFAULT + t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}") # NOTE: we have to copy through CLANG to avoid the HIP hang bug when copying directly from the DISK model_state_dict[k].assign(state[k].to("CLANG").contiguous().to(device).half()).realize() @@ -50,7 +50,7 @@ if __name__ == "__main__": for i in range(args.count): GlobalCounters.reset() with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"): - tok = model(Tensor([toks[start_pos:]]), start_pos, args.temperature).multinomial().item() + tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).multinomial().item() toks.append(tok) start_pos += 1 print(spp.decode(toks)) diff --git a/tinygrad/device.py b/tinygrad/device.py index 56a7de6848..c4078b740e 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -48,7 +48,7 @@ class JITRunner: def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: raise NotImplementedError("override this") -def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str=""): +def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False): if var_vals is None: var_vals = {} op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) GlobalCounters.kernel_count += num_kernels @@ -56,7 +56,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option GlobalCounters.global_mem += mem_estimate if et is not None: GlobalCounters.time_sum_s += et if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} dev {device[:10]:10s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} dev {device[:10]:10s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) # **************** Buffer / Allocator **************** @@ -69,9 +69,9 @@ class Buffer: # TODO: image hack shouldn't be here. where should it be? self._buf = opaque if opaque is not None else self.allocator.alloc(dtype if isinstance(dtype, ImageDType) else size * dtype.itemsize) # TODO: mem_used for all devices - if self.device == Device.DEFAULT: GlobalCounters.mem_used += self.size * self.dtype.itemsize + if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.size * self.dtype.itemsize def __del__(self): - if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize + if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.size * self.dtype.itemsize if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype) else: self.allocator.free(self._buf, self.size * self.dtype.itemsize) def __repr__(self): return f"" @@ -236,8 +236,8 @@ class CompiledASTRunner(JITRunner): if DEBUG >= 4: print(prg) if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) if local_size is not None: local_size = local_size + [1]*(3-len(local_size)) - self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \ - to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {} + self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args, self.first_run = \ + to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {}, True self.vars: List[Variable] = [] if ast: info = get_lazyop_info(ast) @@ -266,7 +266,8 @@ class CompiledASTRunner(JITRunner): if global_size: lra['global_size'] = global_size if local_size and 'local_size' not in lra: lra['local_size'] = local_size et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2) - update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device) + update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run) + self.first_run = False return et class Compiled: