examples of new GPT2 and JIT change (#2261)

* var_vals are global

* working with global ish

* better

* fix export model

* fix tests

* better kv cache

* does it run?

* use where for kvmask

* fix excessive var_vals

* fix import

* how does multigpu use this?

* llama kinda work

* faster and simpler

* cleanup

* fix conversation mode

* test cleanups

* fix one more test

* test cleanup

---------

Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
chenyu
2023-11-10 15:07:02 -05:00
committed by GitHub
parent b6aaf12df7
commit a753c8e071
15 changed files with 189 additions and 249 deletions

View File

@@ -18,14 +18,14 @@ from tinygrad.lazy import PUSH_PERMUTES
from tinygrad.jit import CacheCollector
class CLCache:
def __init__(self, allowed=None, strict=False, preclear=True): self.allowed, self.strict, self.preclear = allowed, strict, preclear
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
def __enter__(self):
if self.preclear:
gc.collect()
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
x.realize()
GlobalCounters.reset()
CacheCollector.start()
CacheCollector.start(self.var_vals)
print("cache: entering")
def __exit__(self, type, value, traceback):
cache = CacheCollector.finish()
@@ -85,11 +85,12 @@ class TestInferenceMinKernels(unittest.TestCase):
def test_llama(self):
from examples.llama import Transformer
from tinygrad.shape.symbolic import Variable
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
model = Transformer(**args_tiny)
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
with CLCache(85):
model(Tensor([[1,2,3,4]]), 0).realize()
with CLCache(98, var_vals={Variable("start_pos", 0, 1024): 0}):
model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 1024).bind(0)).realize()
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptBinOp(unittest.TestCase):