perf: lazyop as dataclass (#1603)

* perf: lazyop as dataclass

fix: linter

fix: restore eq

* use builtin methods, buffers to property to allow freezing

* fix: reduce diff

* fix: can't freeze due to KOPT tests, mypy

* fix: explicit hash

* can freeze if tests are fixed

* fix: typo

---------

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Roelof van Dijk
2023-10-25 23:54:30 +02:00
committed by GitHub
parent 0ca0e9ee5e
commit 36ab04ae35
5 changed files with 30 additions and 58 deletions

View File

@@ -2,11 +2,12 @@
import unittest, gc
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
from tinygrad.nn.state import get_state_dict
from tinygrad.ops import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.ops import Device
from test.helpers import derandomize_model
from examples.llama import Transformer
@@ -86,20 +87,6 @@ def check_gc():
from extra.introspection import print_objects
assert print_objects() == 0
# for speed
def derandomize(x):
if isinstance(x, LazyOp):
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
x.src = tuple([derandomize(s) for s in x.src])
else:
x.op = derandomize(x.op)
return x
def derandomize_model(model):
for p in get_parameters(model):
p.lazydata = derandomize(p.lazydata)
p.realize()
class TestAllocators(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
def test_lru_allocator_tiny_llama(self):

View File

@@ -2,28 +2,13 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.state import get_parameters
from tinygrad.ops import LazyOp, LoadOps
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.helpers import dtypes, CI
from tinygrad.lazy import Device
from tinygrad.ops import Device
from test.helpers import derandomize_model
from examples.llama import Transformer
# for speed
def derandomize(x):
if isinstance(x, LazyOp):
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
x.src = tuple([derandomize(s) for s in x.src])
else:
x.op = derandomize(x.op)
return x
def derandomize_model(model):
for p in get_parameters(model):
p.lazydata = derandomize(p.lazydata)
p.realize()
def helper_test_jitted_correctness(gen, train, train_jit):
nojit = train(*gen()).numpy()
for _ in range(5): jit = train_jit(*gen()).numpy()