two stage cumsum in tensor.py (#2331)

* two stage cumsum in tensor.py

* 2 more kernels for llama cumsum

* gpt-2 and llama use fast multinomial
This commit is contained in:
George Hotz
2023-11-16 12:09:53 -08:00
committed by GitHub
parent 163b2bc26a
commit 3baaf298d6
7 changed files with 40 additions and 11 deletions

View File

@@ -89,7 +89,7 @@ class Transformer:
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
# TODO: fix empty token
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
return (self.forward_jit if tokens.shape[0:2] == (1,1) and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
VOCAB_SIZE = 50257
@@ -134,8 +134,8 @@ class GPT2:
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
probs = self.model(Tensor([toks[start_pos:]]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))
# TODO: fix JIT rand so we can put this in the JIT
tok = probs.multinomial().item()
start_pos = len(toks)
toks.append(tok)
output = self.tokenizer.decode(toks)

View File

@@ -542,8 +542,8 @@ After you are done speaking, output [EOS]. You are not Chad.
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
probs_np = probs.numpy()
tok = int(np.random.choice(len(probs_np), p=probs_np))
# TODO: fix JIT rand so we can put this in the JIT
tok = probs.multinomial().item()
# use the kv cache
start_pos = len(toks)

View File

@@ -89,7 +89,7 @@ class TestInferenceMinKernels(unittest.TestCase):
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(98):
with CLCache(100):
model(Tensor([[1,2,3,4]]), 0).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")

View File

@@ -415,8 +415,15 @@ class TestOps(unittest.TestCase):
with self.assertRaises(AssertionError):
a = Tensor(3.14)
a.matmul(a)
def test_multinomial(self):
# NOTE: this is random, so it has a very large atol
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1), lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)
def test_small_cumsum(self):
helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
def test_simple_cumsum(self):
helper_test_op([(1024)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
helper_test_op([(1022)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
def test_cumsum(self):
helper_test_op([(20)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)

View File

@@ -1,6 +1,6 @@
import unittest
import numpy as np
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up
from tinygrad.shape.symbolic import Variable, NumNode
VARIABLE = ContextVar("VARIABLE", 0)
@@ -138,5 +138,14 @@ class TestProd(unittest.TestCase):
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
class TestRoundUp(unittest.TestCase):
def test_round_up(self):
self.assertEqual(round_up(-3,4), 0)
self.assertEqual(round_up(-4,4), -4)
self.assertEqual(round_up(6,4), 8)
self.assertEqual(round_up(8,4), 8)
self.assertEqual(round_up(232, 24984), 24984)
self.assertEqual(round_up(24984, 232), 25056)
if __name__ == '__main__':
unittest.main()

View File

@@ -25,6 +25,7 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
def flatten(l:Union[List, Iterator]): return [item for sublist in l for item in sublist]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
def round_up(num, amt): return num if num%amt == 0 else num+(amt-(num%amt))
def merge_dicts(ds:Iterable[Dict]) -> Dict:
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for d in ds for k,v in d.items()}

View File

@@ -7,7 +7,7 @@ from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import Device, LoadOps
from tinygrad.shape.symbolic import sint
@@ -124,6 +124,7 @@ class Tensor:
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
def item(self) -> Union[float, int]: return self.numpy().item()
def to(self, device:str) -> Tensor:
ret = Tensor(self.lazydata, device)
@@ -210,7 +211,7 @@ class Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
def multinomial(self: Tensor, num_samples: int, replacement: bool = False) -> Tensor:
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
assert self.ndim <= 2, "p must be 1 or 2 dim"
assert replacement or num_samples == 1, "supported only with replacement"
p = self.unsqueeze(0) if self.ndim == 1 else self
@@ -570,7 +571,18 @@ class Tensor:
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1)
def cumsum(self, axis:int=0) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-1,0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
def cumsum(self, axis:int=0) -> Tensor:
# TODO: someday the optimizer will find this on it's own
# for now this is a two stage cumsum
SPLIT = 256
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
ret = ret.reshape(*ret.shape[0:-1], ret.shape[-1]//SPLIT, SPLIT)._cumsum(-1)
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
def fix(x:Tensor): return x.reshape(*ret.shape[0:-2], ret.shape[-2] * ret.shape[-1])[..., -self.shape[axis]:].transpose(axis,-1)
return fix(ret) + fix(base_add)
# ***** mlops (unary) *****