mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
remove tests that pre date the uop spec (#12168)
* remove tests that pre date the uop spec * const src * for RANGEIFY=1 * update with bind * remove import
This commit is contained in:
@@ -12,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, RANGEIFY
|
||||
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
|
||||
@@ -2148,84 +2148,6 @@ class TestSimplifier(unittest.TestCase):
|
||||
assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False"
|
||||
assert sink.shape == a.shape
|
||||
|
||||
tensor_const_pm = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True),
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)))), UPat(Ops.CONST))), lambda: True),
|
||||
])
|
||||
class TestConst(unittest.TestCase):
|
||||
# ** part 1: basic functionality of a tensor directly created from CONST
|
||||
|
||||
def test_tensor_const(self):
|
||||
a = Tensor(1)
|
||||
print(a.uop)
|
||||
self.assertTrue(tensor_const_pm.rewrite(a.uop))
|
||||
|
||||
def test_tensor_variable(self):
|
||||
vv = UOp.variable("a", 0, 10).bind(1)
|
||||
a = Tensor(vv)
|
||||
print(a.uop)
|
||||
self.assertTrue(tensor_const_pm.rewrite(a.uop))
|
||||
|
||||
def test_const_schedule(self):
|
||||
a = Tensor.ones((4, 4))
|
||||
sched = a.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
|
||||
def test_const_contiguous_schedule(self):
|
||||
# this ends up in the big graph
|
||||
a = Tensor.ones((4,)).contiguous()
|
||||
sched = a.schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
|
||||
# ** part 2: scheduler behavior when const folding happens later
|
||||
|
||||
def test_const_folding_no_realize(self):
|
||||
a = Tensor([1, 2, 3, 4])*0
|
||||
sched = a.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
|
||||
def test_src_const_folding(self):
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
a = Tensor.full((4,), 1).contiguous().realize()
|
||||
b = Tensor.full((4,), 2).contiguous().realize()
|
||||
mul0 = a*0
|
||||
add = b+mul0
|
||||
sched = add.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
# b+0 and b share the same underlying device memory
|
||||
self.assertIs(add.uop.buffer, b.uop.buffer)
|
||||
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
|
||||
|
||||
def test_src_masked_const_folding(self):
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
a = Tensor.full((4,), 1).contiguous().realize()
|
||||
b = Tensor.full((6,), 2).contiguous().realize()
|
||||
mul0 = a*0
|
||||
add = b+mul0.pad((1, 1), value=2)
|
||||
sched = add.schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
run_schedule(sched)
|
||||
# add gets assigned to a new buffer
|
||||
self.assertIsNot(add.uop.base.realized, b.uop.base.realized)
|
||||
self.assertListEqual(add.tolist(), [4, 2, 2, 2, 2, 4])
|
||||
|
||||
# ** part 3: Tensor variable bindings
|
||||
|
||||
#@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable?
|
||||
def test_var_schedule(self):
|
||||
vv = UOp.variable("a", 0, 10).bind(1)
|
||||
a = Tensor(vv)
|
||||
sched = a.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
|
||||
def test_add_tvar(self):
|
||||
vv = UOp.variable("a", 0, 10).bind(1)
|
||||
a = Tensor(vv)+2
|
||||
sched, var_vals = a.schedule_with_vars()
|
||||
self.assertEqual(len(sched), 1)
|
||||
run_schedule(sched, var_vals)
|
||||
self.assertEqual(a.tolist(), 3)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu")
|
||||
class TestCopyFolding(unittest.TestCase):
|
||||
def test_const_copy_is_free(self):
|
||||
|
||||
@@ -34,7 +34,7 @@ class TestTensorMutates(unittest.TestCase):
|
||||
is_pattern_uop(c.uop.base, realized_pattern)
|
||||
# NOTE: we keep movement ops on top of the buffer view
|
||||
is_pattern_uop(c.uop, UPat(Ops.BUFFER))
|
||||
is_pattern_uop(d.uop, UPat(Ops.VIEW, src=(realized_pattern,)))
|
||||
assert d.uop is not d.uop.base
|
||||
|
||||
def test_reshape_is_same_child(self):
|
||||
a = Tensor([1,2,3])
|
||||
@@ -58,40 +58,6 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
print(c.uop)
|
||||
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
|
||||
|
||||
def test_const_pattern(self):
|
||||
a = Tensor(1)
|
||||
print(a.uop)
|
||||
is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src
|
||||
is_pattern(a, UPat.cvar("x")) # even cvar works!
|
||||
|
||||
def test_consts_do_not_realize(self):
|
||||
a = Tensor(1)
|
||||
print(a.uop)
|
||||
pre_realize = a.uop
|
||||
a.realize()
|
||||
assert a.uop is pre_realize
|
||||
|
||||
def test_viewed_consts_do_not_realize(self):
|
||||
a = Tensor.ones(10, 10)
|
||||
print(a.uop)
|
||||
a.realize()
|
||||
is_pattern(a, const_pattern)
|
||||
self.assertEqual(a.uop.shape, (10, 10))
|
||||
|
||||
# CONST is EXPAND -> RESHAPE -> CONST -> DEVICE
|
||||
def test_consts_dont_have_buffers(self):
|
||||
a = Tensor.ones(10, 10)
|
||||
buffers_in_parents = [x.op for x in a.uop.toposort() if x.op is Ops.BUFFER]
|
||||
self.assertEqual(len(buffers_in_parents), 0)
|
||||
is_pattern(a, UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(const_pattern,)),)))
|
||||
|
||||
# COPY has a copyin source and a device.
|
||||
def test_copyin(self):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
c = a.to("TEST") # NOTE: this isn't checked
|
||||
print(c.uop)
|
||||
is_pattern(c, UPat(Ops.COPY, src=(realized_pattern, UPat(Ops.DEVICE)), arg=None))
|
||||
|
||||
def test_empty_buf(self):
|
||||
a = Tensor.empty(3, 3)
|
||||
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))
|
||||
|
||||
@@ -99,8 +99,11 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
|
||||
|
||||
# Tensor const has a device and an unmasked ShapeTracker of stride 0
|
||||
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
|
||||
# TODO: remove after rangeify is default
|
||||
(UPat(Ops.CONST, src=(UPat.any(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="st"),
|
||||
UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND)), name="st")),)),
|
||||
lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
@@ -165,7 +168,7 @@ spec = PatternMatcher([
|
||||
lambda x,src: isinstance(x.arg, ShapeTracker) and src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
|
||||
|
||||
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
|
||||
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
|
||||
|
||||
# early LOAD has a <bufview, store?>
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),
|
||||
|
||||
Reference in New Issue
Block a user