fix some rangeify tests (#12370)

* fix bad range merges

* fix rng

* fix uop gc

* fix some rangeify tests

* now that needs rangeify 2 also
This commit is contained in:
George Hotz
2025-09-30 20:12:08 +08:00
committed by GitHub
parent 2c397eb2a2
commit 44558a37f7
4 changed files with 13 additions and 3 deletions

View File

@@ -123,6 +123,7 @@ class TestLinearizer(unittest.TestCase):
assert num_loads <= 4, "more load uops than needed"
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
@unittest.skip("this is handled at higher level now")
def test_upcast_cse(self):
# when upcasting, within a subtree, there may be common expressions.

View File

@@ -44,6 +44,7 @@ class TestFuse(unittest.TestCase):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a / a.mean(axis=1), a)
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
def test_fuse_argmax(self):
a = Tensor.rand(50,50).realize()
self._test_fuse(lambda a: a.argmax(axis=-1), a)

View File

@@ -1,6 +1,6 @@
import unittest
from tinygrad import Tensor
from tinygrad.helpers import getenv, GlobalCounters, EMULATE
from tinygrad.helpers import getenv, GlobalCounters, EMULATE, RANGEIFY
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec, get_program
from tinygrad.renderer import Estimates
from tinygrad.codegen import full_rewrite
@@ -51,7 +51,11 @@ class TestMemoryCount(unittest.TestCase):
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
b = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
_, mem = get_stats(a+b)
self.assertEqual(mem, 1024*1024 + 2*1024) # 2 lil reads + 1 write
if RANGEIFY:
# rangeify is smart!
self.assertEqual(mem, 1024 + 2*1024) # 2 lil reads + 1 lil write
else:
self.assertEqual(mem, 1024*1024 + 2*1024) # 2 lil reads + 1 write
def test_self_add(self):
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)

View File

@@ -7,6 +7,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
from tinygrad.schedule.kernelize import Kernel
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
from tinygrad.codegen.simplify import pm_flatten_range
# *****************
# 0. do some cleanup rewrites, mostly copied from the old stuff
@@ -587,6 +588,9 @@ to_define_global = PatternMatcher([
# this is only needed if you are using symbolic
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
# remove RANGE with 0 size
(UPat(Ops.RANGE, name="r"), lambda r: UOp.const(dtypes.index, 0) if r.vmax == 0 else None),
# renumber the ranges starting with 0 so that kernel deduping works
(UPat(Ops.RANGE, name="r"), renumber_range),
])
@@ -629,7 +633,7 @@ def split_store(ctx:list[UOp], x:UOp):
# local kernel rewrite
lctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
# gather the metadata
metadatas = [ctx[y].metadata for y in lctx.parent_tags]