tests from remove uop mutability [pr] (#8442)

* tests from remove uop mutability [pr]

* more test fix

* simpler test fix

* remove that
This commit is contained in:
George Hotz
2024-12-29 12:14:10 -05:00
committed by GitHub
parent b7511cc390
commit b71c51191b
6 changed files with 48 additions and 15 deletions

View File

@@ -1,14 +1,13 @@
import unittest, math
from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import Ops
from tinygrad.engine.schedule import create_schedule
from tinygrad.helpers import CI
import numpy as np
from tinygrad.device import is_dtype_supported
def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
schedule = create_schedule(t.lazydata.lbs)
schedule = t.schedule()
asts = [s for s in schedule if s.ast.op is Ops.SINK]
assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"

View File

@@ -13,7 +13,7 @@ class TestConvShapetracker(unittest.TestCase):
conv = Conv2d(16, 32, (3, 3))
# first run to init the weights, they are scheduled.
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata])
conv(Tensor.empty(1, 16, 10, 10)).schedule()
# run it again to get the kernels
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"

View File

@@ -8,7 +8,6 @@ from tinygrad.helpers import CI, Context
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.device import is_dtype_supported
@@ -517,7 +516,7 @@ class TestNN(unittest.TestCase):
a = Tensor([[1, 5, 9, 11],
[12, 19, 8, 1]])
result = layer(a)
schedule = create_schedule([result.lazydata])
schedule = result.schedule()
self.assertEqual(3, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes arange, weight, and embedding")
run_schedule(schedule)
@@ -525,7 +524,7 @@ class TestNN(unittest.TestCase):
[4, 5, 6],
[7, 8, 9]])
result = layer(b)
schedule = create_schedule([result.lazydata])
schedule = result.schedule()
self.assertEqual(1, len([item for item in schedule if item.ast.op is Ops.SINK]), "second run realizes embedding only")
run_schedule(schedule)

View File

@@ -3,7 +3,6 @@ from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileRangeEvent, ProfileDeviceEvent, ProfileGraphEvent
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner
MOCKGPU = getenv("MOCKGPU")
@@ -34,7 +33,7 @@ class TestProfiler(unittest.TestCase):
TestProfiler.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
TestProfiler.b = self.a + 1
si = create_schedule([self.b.lazydata])[-1]
si = self.b.schedule()[-1]
TestProfiler.runner = get_runner(TestProfiler.d0.device, si.ast)
TestProfiler.b.lazydata.buffer.allocate()

View File

@@ -14,7 +14,7 @@ from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, remove_movement_ops, to_uop
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
@@ -22,12 +22,13 @@ from extra.models.llama import precompute_freqs_cis
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
if isinstance(t, Tensor): outs = t.lazydata.lbs
elif isinstance(t, List): outs = flatten([r.lazydata.lbs for r in t])
else: outs = [t]
if to_prerealize:
for pre in to_prerealize: pre.schedule()
sched = create_schedule(outs)
if isinstance(t, Tensor): sched = t.schedule()
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
else:
assert isinstance(t, UOp), f"can't schedule {t}"
sched = create_schedule([t])
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
if len(sched) != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")

View File

@@ -1,10 +1,45 @@
import unittest
from tinygrad import Tensor
from tinygrad.ops import UPat, Ops
from tinygrad.ops import UPat, Ops, UOp
realized_pattern = UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),))
const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),)))
def is_pattern(ten:Tensor, pat:UPat): assert pat.match(ten.lazydata, {})
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.lazydata, pat)
class TestTensorMutates(unittest.TestCase):
# this fails because uops are mutating
@unittest.expectedFailure
def test_mutate_add(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
ret = a+b
pa = a.lazydata
pb = b.lazydata
pr = ret.lazydata
ret.schedule()
self.assertIsNot(pa, a.lazydata)
self.assertIsNot(pb, b.lazydata)
self.assertIsNot(pr, ret.lazydata)
for t in [a,b,ret]: is_pattern(t, realized_pattern)
def test_reshape_is_same_parent(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = a+b
d = (a+b).reshape(3,1)
d.realize()
is_pattern_uop(d.lazydata.base, realized_pattern)
is_pattern_uop(c.lazydata.base, realized_pattern)
def test_reshape_is_same_child(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = a+b
d = (a+b).reshape(3,1)
c.realize()
is_pattern_uop(c.lazydata.base, realized_pattern)
is_pattern_uop(d.lazydata.base, realized_pattern)
class TestTensorUopRepresentation(unittest.TestCase):
def test_realized(self):