mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
faster mixtral + green for new kernels (#2701)
* green for new kernels * track ram
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import functools, argparse, os
|
||||
from tqdm import tqdm
|
||||
from tinygrad import Tensor, nn, Device, GlobalCounters
|
||||
from tinygrad import Tensor, nn, Device, GlobalCounters, Variable
|
||||
from tinygrad.helpers import Timing
|
||||
from tinygrad.nn.state import torch_load, get_state_dict
|
||||
from extra.models.llama import FeedForward, Transformer
|
||||
@@ -16,8 +16,8 @@ class MixtureFeedForward:
|
||||
top = sorted(enumerate(choice), key=lambda x: -x[1])
|
||||
norm = top[0][1] + top[1][1]
|
||||
e1, e2 = self.experts[top[0][0]], self.experts[top[1][0]]
|
||||
ret = e1(x.to(e1.w1.weight.device)).to(x.device) * (top[0][1]/norm) + \
|
||||
e2(x.to(e2.w1.weight.device)).to(x.device) * (top[1][1]/norm)
|
||||
ret = e1(x.to(e1.w1.weight.device)).to(x.device) * Tensor([top[0][1]/norm]) + \
|
||||
e2(x.to(e2.w1.weight.device)).to(x.device) * Tensor([top[1][1]/norm])
|
||||
return ret
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -33,12 +33,12 @@ if __name__ == "__main__":
|
||||
model_state_dict = get_state_dict(model)
|
||||
|
||||
for k in (t := tqdm(state)):
|
||||
t.set_description(f"loading {k}")
|
||||
if 'feed_forward.experts.' in k:
|
||||
expert_no = int(k.split('feed_forward.experts.')[1].split('.')[0])
|
||||
device = Device.DEFAULT + ":" + str((expert_no//2)+1)
|
||||
else:
|
||||
device = Device.DEFAULT
|
||||
t.set_description(f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, loading {k} to {device}")
|
||||
# NOTE: we have to copy through CLANG to avoid the HIP hang bug when copying directly from the DISK
|
||||
model_state_dict[k].assign(state[k].to("CLANG").contiguous().to(device).half()).realize()
|
||||
|
||||
@@ -50,7 +50,7 @@ if __name__ == "__main__":
|
||||
for i in range(args.count):
|
||||
GlobalCounters.reset()
|
||||
with Timing("total ", enabled=args.timing, on_exit=lambda x: f", {1e9/x:.2f} tok/sec"):
|
||||
tok = model(Tensor([toks[start_pos:]]), start_pos, args.temperature).multinomial().item()
|
||||
tok = model(Tensor([toks[start_pos:]]), 0 if start_pos == 0 else Variable("start_pos", 1, 1024).bind(start_pos), args.temperature).multinomial().item()
|
||||
toks.append(tok)
|
||||
start_pos += 1
|
||||
print(spp.decode(toks))
|
||||
|
||||
@@ -48,7 +48,7 @@ class JITRunner:
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
raise NotImplementedError("override this")
|
||||
|
||||
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str=""):
|
||||
def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count:int, jit=False, num_kernels=1, lra: Optional[Dict]=None, device:str="", first_run=False):
|
||||
if var_vals is None: var_vals = {}
|
||||
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals)
|
||||
GlobalCounters.kernel_count += num_kernels
|
||||
@@ -56,7 +56,7 @@ def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Option
|
||||
GlobalCounters.global_mem += mem_estimate
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} dev {device[:10]:10s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else ('green' if first_run else None))} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} dev {device[:10]:10s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
|
||||
# **************** Buffer / Allocator ****************
|
||||
@@ -69,9 +69,9 @@ class Buffer:
|
||||
# TODO: image hack shouldn't be here. where should it be?
|
||||
self._buf = opaque if opaque is not None else self.allocator.alloc(dtype if isinstance(dtype, ImageDType) else size * dtype.itemsize)
|
||||
# TODO: mem_used for all devices
|
||||
if self.device == Device.DEFAULT: GlobalCounters.mem_used += self.size * self.dtype.itemsize
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used += self.size * self.dtype.itemsize
|
||||
def __del__(self):
|
||||
if self.device == Device.DEFAULT: GlobalCounters.mem_used -= self.size * self.dtype.itemsize
|
||||
if not self.device.startswith("DISK"): GlobalCounters.mem_used -= self.size * self.dtype.itemsize
|
||||
if isinstance(self.dtype, ImageDType): self.allocator.free(self._buf, self.dtype)
|
||||
else: self.allocator.free(self._buf, self.size * self.dtype.itemsize)
|
||||
def __repr__(self): return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>"
|
||||
@@ -236,8 +236,8 @@ class CompiledASTRunner(JITRunner):
|
||||
if DEBUG >= 4: print(prg)
|
||||
if global_size is not None: global_size = global_size + [1]*(3-len(global_size))
|
||||
if local_size is not None: local_size = local_size + [1]*(3-len(local_size))
|
||||
self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \
|
||||
to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {}
|
||||
self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args, self.first_run = \
|
||||
to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {}, True
|
||||
self.vars: List[Variable] = []
|
||||
if ast:
|
||||
info = get_lazyop_info(ast)
|
||||
@@ -266,7 +266,8 @@ class CompiledASTRunner(JITRunner):
|
||||
if global_size: lra['global_size'] = global_size
|
||||
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
|
||||
et = self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.vars), wait=wait or DEBUG>=2)
|
||||
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device)
|
||||
update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra, device=rawbufs[0].device, first_run=self.first_run)
|
||||
self.first_run = False
|
||||
return et
|
||||
|
||||
class Compiled:
|
||||
|
||||
Reference in New Issue
Block a user