mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
two stage cumsum in tensor.py (#2331)
* two stage cumsum in tensor.py * 2 more kernels for llama cumsum * gpt-2 and llama use fast multinomial
This commit is contained in:
@@ -89,7 +89,7 @@ class Transformer:
|
||||
return (logits[:, -1, :] / (temperature+1e-10)).softmax().flatten().realize()
|
||||
|
||||
# TODO: fix empty token
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0):
|
||||
def __call__(self, tokens:Tensor, start_pos:Variable, temperature:float=0.0) -> Tensor:
|
||||
return (self.forward_jit if tokens.shape[0:2] == (1,1) and getenv("JIT") else self.forward)(tokens, start_pos, temperature)
|
||||
|
||||
VOCAB_SIZE = 50257
|
||||
@@ -134,8 +134,8 @@ class GPT2:
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=timing):
|
||||
probs = self.model(Tensor([toks[start_pos:]]), Variable("start_pos", 1 if start_pos else 0, MAX_CONTEXT).bind(start_pos), temperature)
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().item()
|
||||
start_pos = len(toks)
|
||||
toks.append(tok)
|
||||
output = self.tokenizer.decode(toks)
|
||||
|
||||
@@ -542,8 +542,8 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
f", {GlobalCounters.global_ops*1e-9:.2f} GOPS, {GlobalCounters.global_mem*1e-9:.2f} GB"+
|
||||
(f", {GlobalCounters.global_mem*1e-9/(GlobalCounters.time_sum_s-st):.2f} GB/s, param {param_count*1e-9*2/(GlobalCounters.time_sum_s-st):.2f} GB/s" if DEBUG>=2 else "")) if DEBUG else None, enabled=args.timing):
|
||||
probs = llama.model(Tensor([toks[start_pos:]]), start_pos, args.temperature).realize()
|
||||
probs_np = probs.numpy()
|
||||
tok = int(np.random.choice(len(probs_np), p=probs_np))
|
||||
# TODO: fix JIT rand so we can put this in the JIT
|
||||
tok = probs.multinomial().item()
|
||||
|
||||
# use the kv cache
|
||||
start_pos = len(toks)
|
||||
|
||||
2
test/external/external_test_opt.py
vendored
2
test/external/external_test_opt.py
vendored
@@ -89,7 +89,7 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
with CLCache(98):
|
||||
with CLCache(100):
|
||||
model(Tensor([[1,2,3,4]]), 0).realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
|
||||
@@ -415,8 +415,15 @@ class TestOps(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
a = Tensor(3.14)
|
||||
a.matmul(a)
|
||||
|
||||
def test_multinomial(self):
|
||||
# NOTE: this is random, so it has a very large atol
|
||||
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1), lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)
|
||||
|
||||
def test_small_cumsum(self):
|
||||
helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
|
||||
def test_simple_cumsum(self):
|
||||
helper_test_op([(1024)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
|
||||
helper_test_op([(1022)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
|
||||
def test_cumsum(self):
|
||||
helper_test_op([(20)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
|
||||
helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod
|
||||
from tinygrad.helpers import Context, ContextVar, DType, dtypes, merge_dicts, strip_parens, prod, round_up
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
@@ -138,5 +138,14 @@ class TestProd(unittest.TestCase):
|
||||
def test_variable_order(self): self.assertEqual("(a*12)", prod((3, 4, Variable("a", 1, 5))).render())
|
||||
def test_num_nodes(self): self.assertEqual(NumNode(6), prod((NumNode(2), NumNode(3))))
|
||||
|
||||
class TestRoundUp(unittest.TestCase):
|
||||
def test_round_up(self):
|
||||
self.assertEqual(round_up(-3,4), 0)
|
||||
self.assertEqual(round_up(-4,4), -4)
|
||||
self.assertEqual(round_up(6,4), 8)
|
||||
self.assertEqual(round_up(8,4), 8)
|
||||
self.assertEqual(round_up(232, 24984), 24984)
|
||||
self.assertEqual(round_up(24984, 232), 25056)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -25,6 +25,7 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
|
||||
def flatten(l:Union[List, Iterator]): return [item for sublist in l for item in sublist]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def round_up(num, amt): return num if num%amt == 0 else num+(amt-(num%amt))
|
||||
def merge_dicts(ds:Iterable[Dict]) -> Dict:
|
||||
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
return {k:v for d in ds for k,v in d.items()}
|
||||
|
||||
@@ -7,7 +7,7 @@ from itertools import accumulate
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set
|
||||
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
|
||||
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import Device, LoadOps
|
||||
from tinygrad.shape.symbolic import sint
|
||||
@@ -124,6 +124,7 @@ class Tensor:
|
||||
assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}"
|
||||
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
|
||||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
|
||||
def item(self) -> Union[float, int]: return self.numpy().item()
|
||||
|
||||
def to(self, device:str) -> Tensor:
|
||||
ret = Tensor(self.lazydata, device)
|
||||
@@ -210,7 +211,7 @@ class Tensor:
|
||||
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
|
||||
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)
|
||||
|
||||
def multinomial(self: Tensor, num_samples: int, replacement: bool = False) -> Tensor:
|
||||
def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor:
|
||||
assert self.ndim <= 2, "p must be 1 or 2 dim"
|
||||
assert replacement or num_samples == 1, "supported only with replacement"
|
||||
p = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
@@ -570,7 +571,18 @@ class Tensor:
|
||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||
return (x*w).sum(-1)
|
||||
|
||||
def cumsum(self, axis:int=0) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-1,0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1)
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
# TODO: someday the optimizer will find this on it's own
|
||||
# for now this is a two stage cumsum
|
||||
SPLIT = 256
|
||||
if self.shape[axis] <= SPLIT*2: return self._cumsum(axis)
|
||||
ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0))
|
||||
ret = ret.reshape(*ret.shape[0:-1], ret.shape[-1]//SPLIT, SPLIT)._cumsum(-1)
|
||||
base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1]
|
||||
base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1])
|
||||
def fix(x:Tensor): return x.reshape(*ret.shape[0:-2], ret.shape[-2] * ret.shape[-1])[..., -self.shape[axis]:].transpose(axis,-1)
|
||||
return fix(ret) + fix(base_add)
|
||||
|
||||
# ***** mlops (unary) *****
|
||||
|
||||
|
||||
Reference in New Issue
Block a user