remove tests that pre date the uop spec (#12168)

* remove tests that pre date the uop spec

* const src

* for RANGEIFY=1

* update with bind

* remove import
This commit is contained in:
qazal
2025-09-14 18:47:42 +03:00
committed by GitHub
parent 1591e4f66b
commit 02054b53fe
3 changed files with 7 additions and 116 deletions

View File

@@ -12,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
from tinygrad.uop.symbolic import symbolic_simple
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, RANGEIFY
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
@@ -2148,84 +2148,6 @@ class TestSimplifier(unittest.TestCase):
assert UPat(Ops.CONST, arg=False).match(sink, {}), f"expected {sink} to collapse to a const False"
assert sink.shape == a.shape
tensor_const_pm = PatternMatcher([
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True),
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)))), UPat(Ops.CONST))), lambda: True),
])
class TestConst(unittest.TestCase):
# ** part 1: basic functionality of a tensor directly created from CONST
def test_tensor_const(self):
a = Tensor(1)
print(a.uop)
self.assertTrue(tensor_const_pm.rewrite(a.uop))
def test_tensor_variable(self):
vv = UOp.variable("a", 0, 10).bind(1)
a = Tensor(vv)
print(a.uop)
self.assertTrue(tensor_const_pm.rewrite(a.uop))
def test_const_schedule(self):
a = Tensor.ones((4, 4))
sched = a.schedule()
self.assertEqual(len(sched), 0)
def test_const_contiguous_schedule(self):
# this ends up in the big graph
a = Tensor.ones((4,)).contiguous()
sched = a.schedule()
self.assertEqual(len(sched), 1)
# ** part 2: scheduler behavior when const folding happens later
def test_const_folding_no_realize(self):
a = Tensor([1, 2, 3, 4])*0
sched = a.schedule()
self.assertEqual(len(sched), 0)
def test_src_const_folding(self):
with Context(TRACK_MATCH_STATS=0):
a = Tensor.full((4,), 1).contiguous().realize()
b = Tensor.full((4,), 2).contiguous().realize()
mul0 = a*0
add = b+mul0
sched = add.schedule()
self.assertEqual(len(sched), 0)
# b+0 and b share the same underlying device memory
self.assertIs(add.uop.buffer, b.uop.buffer)
self.assertListEqual(add.tolist(), [2, 2, 2, 2])
def test_src_masked_const_folding(self):
with Context(TRACK_MATCH_STATS=0):
a = Tensor.full((4,), 1).contiguous().realize()
b = Tensor.full((6,), 2).contiguous().realize()
mul0 = a*0
add = b+mul0.pad((1, 1), value=2)
sched = add.schedule()
self.assertEqual(len(sched), 1)
run_schedule(sched)
# add gets assigned to a new buffer
self.assertIsNot(add.uop.base.realized, b.uop.base.realized)
self.assertListEqual(add.tolist(), [4, 2, 2, 2, 2, 4])
# ** part 3: Tensor variable bindings
#@unittest.expectedFailure # TODO: should schedule assert if you try to realize a Variable?
def test_var_schedule(self):
vv = UOp.variable("a", 0, 10).bind(1)
a = Tensor(vv)
sched = a.schedule()
self.assertEqual(len(sched), 0)
def test_add_tvar(self):
vv = UOp.variable("a", 0, 10).bind(1)
a = Tensor(vv)+2
sched, var_vals = a.schedule_with_vars()
self.assertEqual(len(sched), 1)
run_schedule(sched, var_vals)
self.assertEqual(a.tolist(), 3)
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu")
class TestCopyFolding(unittest.TestCase):
def test_const_copy_is_free(self):

View File

@@ -34,7 +34,7 @@ class TestTensorMutates(unittest.TestCase):
is_pattern_uop(c.uop.base, realized_pattern)
# NOTE: we keep movement ops on top of the buffer view
is_pattern_uop(c.uop, UPat(Ops.BUFFER))
is_pattern_uop(d.uop, UPat(Ops.VIEW, src=(realized_pattern,)))
assert d.uop is not d.uop.base
def test_reshape_is_same_child(self):
a = Tensor([1,2,3])
@@ -58,40 +58,6 @@ class TestTensorUopRepresentation(unittest.TestCase):
print(c.uop)
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
def test_const_pattern(self):
a = Tensor(1)
print(a.uop)
is_pattern(a, const_pattern) # const in tensor has a DEVICE and VIEW src
is_pattern(a, UPat.cvar("x")) # even cvar works!
def test_consts_do_not_realize(self):
a = Tensor(1)
print(a.uop)
pre_realize = a.uop
a.realize()
assert a.uop is pre_realize
def test_viewed_consts_do_not_realize(self):
a = Tensor.ones(10, 10)
print(a.uop)
a.realize()
is_pattern(a, const_pattern)
self.assertEqual(a.uop.shape, (10, 10))
# CONST is EXPAND -> RESHAPE -> CONST -> DEVICE
def test_consts_dont_have_buffers(self):
a = Tensor.ones(10, 10)
buffers_in_parents = [x.op for x in a.uop.toposort() if x.op is Ops.BUFFER]
self.assertEqual(len(buffers_in_parents), 0)
is_pattern(a, UPat(Ops.EXPAND, src=(UPat(Ops.RESHAPE, src=(const_pattern,)),)))
# COPY has a copyin source and a device.
def test_copyin(self):
a = Tensor([1.,2,3]).realize()
c = a.to("TEST") # NOTE: this isn't checked
print(c.uop)
is_pattern(c, UPat(Ops.COPY, src=(realized_pattern, UPat(Ops.DEVICE)), arg=None))
def test_empty_buf(self):
a = Tensor.empty(3, 3)
is_pattern(a, UPat(Ops.RESHAPE, src=(UPat(Ops.BUFFER),)))

View File

@@ -99,8 +99,11 @@ tensor_uop_spec = buffer_spec+assign_spec+PatternMatcher([
# Tensor const has a device and an unmasked ShapeTracker of stride 0
# NOTE: variables in shape can cause multiple views in this ShapeTracker and other issues, see TestSymbolicJit.test_ones_sum
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, name="st", src=(UPat(Ops.DEVICE),)),)),
# TODO: remove after rangeify is default
(UPat(Ops.CONST, src=(UPat.any(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="st"),
UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND)), name="st")),)),
lambda st: len(st.st.views) == 1 and all(v.mask is None for v in st.st.views)),
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes
@@ -165,7 +168,7 @@ spec = PatternMatcher([
lambda x,src: isinstance(x.arg, ShapeTracker) and src.op is not Ops.STORE and x.dtype.base == src.dtype.base),
(UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True),
(UPat(Ops.CONST, name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
(UPat(Ops.CONST, src=(), name="x"), lambda x: type(x.arg) is type(dtypes.as_const(x.arg, x.dtype))),
# early LOAD has a <bufview, store?>
(UPat(Ops.LOAD, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.Defines),)),)), lambda: True),