faster mixtral + green for new kernels (#2701)

* green for new kernels

* track ram
This commit is contained in:
George Hotz
2023-12-10 19:04:58 -08:00
committed by GitHub
parent 2ee6f689c5
commit 59ab3675a3
2 changed files with 13 additions and 12 deletions

View File

@@ -1,6 +1,6 @@
import functools, argparse, os
from tqdm import tqdm
from tinygrad import Tensor, nn, Device, GlobalCounters
from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
from tinygrad.helpers import Timing
from tinygrad.nn.state import torch_load, get_state_dict
from extra.models.llama import FeedForward, Transformer
@@ -16,8 +16,8 @@ class MixtureFeedForward:
top = sorted(enumerate(choice), key=lambda x: -x[1])
norm = top[0][1] + top[1][1]
e1, e2 = self.experts[top[0][0]], self.experts[top[1][0]]
ret = e1(x.to(e1.w1.weight.device)).to(x.device) * (top[0][1]/norm) + \
e2(x.to(e2.w1.weight.device)).to(x.device) * (top[1][1]/norm)
ret = e1(x.to(e1.w1.weight.device)).to(x.device) * Tensor([top[0][1]/norm]) + \
e2(x.to(e2.w1.weight.device)).to(x.device) * Tensor([top[1][1]/norm])
return ret
if __name__ == "__main__":
@@ -33,12 +33,12 @@ if __name__ == "__main__":
model_state_dict = get_state_dict(model)
for k in (t := tqdm(state)):
t.set_description(f"loading {k}")
if 'feed_forward.experts.' in k:
expert_no = int(k.split('feed_forward.experts.')[1].split('.')[0])
device = Device.DEFAULT + ":" + str((expert_no//2)+1)
else:
device = Device.DEFAULT
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}")
# NOTE: we have to copy through CLANG to avoid the HIP hang bug when copying directly from the DISK
model_state_dict[k].assign(state[k].to("CLANG").contiguous().to(device).half()).realize()
@@ -50,7 +50,7 @@ if __name__ == "__main__":
for i in range(args.count):
GlobalCounters.reset()
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
tok = model(Tensor([toks[start_pos:]]), start_pos, args.temperature).multinomial().item()
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).multinomial().item()
toks.append(tok)
start_pos += 1
print(spp.decode(toks))

View File

@@ -48,7 +48,7 @@ class JITRunner:
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
raise NotImplementedError("override this")
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str=""):
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False):
if var_vals is None: var_vals = {}
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
GlobalCounters.kernel_count += num_kernels
@@ -56,7 +56,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
GlobalCounters.global_mem += mem_estimate
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} dev {device[:10]:10s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} dev {device[:10]:10s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
# **************** Buffer / Allocator ****************
@@ -69,9 +69,9 @@ class Buffer:
# TODO: image hack shouldn't be here. where should it be?
self._buf = opaque if opaque is not None else self.allocator.alloc(dtype if isinstance(dtype, ImageDType) else size * dtype.itemsize)
# TODO: mem_used for all devices
if self.device == Device.DEFAULT: GlobalCounters.mem_used += self.size * self.dtype.itemsize
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.size * self.dtype.itemsize
def __del__(self):
if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.size * self.dtype.itemsize
if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype)
else: self.allocator.free(self._buf, self.size * self.dtype.itemsize)
def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>"
@@ -236,8 +236,8 @@ class CompiledASTRunner(JITRunner):
if DEBUG >= 4: print(prg)
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \
to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {}
self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args, self.first_run = \
to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {}, True
self.vars: List[Variable] = []
if ast:
info = get_lazyop_info(ast)
@@ -266,7 +266,8 @@ class CompiledASTRunner(JITRunner):
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device)
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run)
self.first_run = False
return et
class Compiled: