From 3baaf298d6a1f41e487669cbefbe9f98c2898fdc Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 16 Nov 2023 12:09:53 -0800 Subject: [PATCH] two stage cumsum in tensor.py (#2331) * two stage cumsum in tensor.py * 2 more kernels for llama cumsum * gpt-2 and llama use fast multinomial --- examples/gpt2.py | 6 +++--- examples/llama.py | 4 ++-- test/external/external_test_opt.py | 2 +- test/test_ops.py | 9 ++++++++- test/unit/test_helpers.py | 11 ++++++++++- tinygrad/helpers.py | 1 + tinygrad/tensor.py | 18 +++++++++++++++--- 7 files changed, 40 insertions(+), 11 deletions(-) diff --git a/examples/gpt2.py b/examples/gpt2.py index fd449a23d1..a37e415d40 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -89,7 +89,7 @@ class Transformer: return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize() # TODO: fix empty token - def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0): + def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor: return (self.forward_jit if tokens.shape[0:2] == (1,1) and getenv("JIT") else self.forward)(tokens, start_pos, temperature) VOCAB_SIZE = 50257 @@ -134,8 +134,8 @@ class GPT2: f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing): probs = self.model(Tensor([toks[start_pos:]]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature) - probs_np = probs.numpy() - tok = int(np.random.choice(len(probs_np), p=probs_np)) + # TODO: fix JIT rand so we can put this in the JIT + tok = probs.multinomial().item() start_pos = len(toks) toks.append(tok) output = self.tokenizer.decode(toks) diff --git a/examples/llama.py b/examples/llama.py index 3890da5a2f..210f1ae558 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -542,8 +542,8 @@ After you are done speaking, output [EOS]. You are not Chad. f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+ (f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing): probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize() - probs_np = probs.numpy() - tok = int(np.random.choice(len(probs_np), p=probs_np)) + # TODO: fix JIT rand so we can put this in the JIT + tok = probs.multinomial().item() # use the kv cache start_pos = len(toks) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index e66b11cde9..96cd9987f9 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -89,7 +89,7 @@ class TestInferenceMinKernels(unittest.TestCase): args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000} model = Transformer(**args_tiny) for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np)) - with CLCache(98): + with CLCache(100): model(Tensor([[1,2,3,4]]), 0).realize() @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") diff --git a/test/test_ops.py b/test/test_ops.py index fda1dea177..b27e93eb0d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -415,8 +415,15 @@ class TestOps(unittest.TestCase): with self.assertRaises(AssertionError): a = Tensor(3.14) a.matmul(a) + + def test_multinomial(self): + # NOTE: this is random, so it has a very large atol + helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1), lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.) + + def test_small_cumsum(self): + helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) def test_simple_cumsum(self): - helper_test_op([(1024)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) + helper_test_op([(1022)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) def test_cumsum(self): helper_test_op([(20)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index 60d00c769f..4d4fd5155d 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod +from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up from tinygrad.shape.symbolic import Variable, NumNode VARIABLE = ContextVar("VARIABLE", 0) @@ -138,5 +138,14 @@ class TestProd(unittest.TestCase): def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render()) def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3)))) +class TestRoundUp(unittest.TestCase): + def test_round_up(self): + self.assertEqual(round_up(-3,4), 0) + self.assertEqual(round_up(-4,4), -4) + self.assertEqual(round_up(6,4), 8) + self.assertEqual(round_up(8,4), 8) + self.assertEqual(round_up(232, 24984), 24984) + self.assertEqual(round_up(24984, 232), 25056) + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 92d61fe083..23f409e2c8 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -25,6 +25,7 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return ( def flatten(l:Union[List, Iterator]): return [item for sublist in l for item in sublist] def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm) def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst +def round_up(num, amt): return num if num%amt == 0 else num+(amt-(num%amt)) def merge_dicts(ds:Iterable[Dict]) -> Dict: assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" return {k:v for d in ds for k,v in d.items()} diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 144db894dc..0dc1104496 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -7,7 +7,7 @@ from itertools import accumulate import numpy as np from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set -from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int +from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up from tinygrad.lazy import LazyBuffer from tinygrad.ops import Device, LoadOps from tinygrad.shape.symbolic import sint @@ -124,6 +124,7 @@ class Tensor: assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}" assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}" return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape) + def item(self) -> Union[float, int]: return self.numpy().item() def to(self, device:str) -> Tensor: ret = Tensor(self.lazydata, device) @@ -210,7 +211,7 @@ class Tensor: std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) - def multinomial(self: Tensor, num_samples: int, replacement: bool = False) -> Tensor: + def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor: assert self.ndim <= 2, "p must be 1 or 2 dim" assert replacement or num_samples == 1, "supported only with replacement" p = self.unsqueeze(0) if self.ndim == 1 else self @@ -570,7 +571,18 @@ class Tensor: w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) return (x*w).sum(-1) - def cumsum(self, axis:int=0) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-1,0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) + def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) + def cumsum(self, axis:int=0) -> Tensor: + # TODO: someday the optimizer will find this on it's own + # for now this is a two stage cumsum + SPLIT = 256 + if self.shape[axis] <= SPLIT*2: return self._cumsum(axis) + ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0)) + ret = ret.reshape(*ret.shape[0:-1], ret.shape[-1]//SPLIT, SPLIT)._cumsum(-1) + base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1] + base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1]) + def fix(x:Tensor): return x.reshape(*ret.shape[0:-2], ret.shape[-2] * ret.shape[-1])[..., -self.shape[axis]:].transpose(axis,-1) + return fix(ret) + fix(base_add) # ***** mlops (unary) *****