mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
examples of new GPT2 and JIT change (#2261)
* var_vals are global * working with global ish * better * fix export model * fix tests * better kv cache * does it run? * use where for kvmask * fix excessive var_vals * fix import * how does multigpu use this? * llama kinda work * faster and simpler * cleanup * fix conversation mode * test cleanups * fix one more test * test cleanup --------- Co-authored-by: George Hotz <geohot@gmail.com>
This commit is contained in:
9
test/external/external_test_opt.py
vendored
9
test/external/external_test_opt.py
vendored
@@ -18,14 +18,14 @@ from tinygrad.lazy import PUSH_PERMUTES
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
class CLCache:
|
||||
def __init__(self, allowed=None, strict=False, preclear=True): self.allowed, self.strict, self.preclear = allowed, strict, preclear
|
||||
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None): self.allowed, self.strict, self.preclear, self.var_vals = allowed, strict, preclear, var_vals if var_vals is not None else {}
|
||||
def __enter__(self):
|
||||
if self.preclear:
|
||||
gc.collect()
|
||||
for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]:
|
||||
x.realize()
|
||||
GlobalCounters.reset()
|
||||
CacheCollector.start()
|
||||
CacheCollector.start(self.var_vals)
|
||||
print("cache: entering")
|
||||
def __exit__(self, type, value, traceback):
|
||||
cache = CacheCollector.finish()
|
||||
@@ -85,11 +85,12 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
|
||||
def test_llama(self):
|
||||
from examples.llama import Transformer
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
args_tiny = {"dim": 512, "multiple_of": 256, "n_heads": 8, "n_layers": 4, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**args_tiny)
|
||||
for p in get_parameters(model): p.assign(np.zeros(p.shape, dtype=p.dtype.np))
|
||||
with CLCache(85):
|
||||
model(Tensor([[1,2,3,4]]), 0).realize()
|
||||
with CLCache(98, var_vals={Variable("start_pos", 0, 1024): 0}):
|
||||
model(Tensor([[1,2,3,4]]), Variable("start_pos", 0, 1024).bind(0)).realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOptBinOp(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user