mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix some rangeify tests (#12370)
* fix bad range merges * fix rng * fix uop gc * fix some rangeify tests * now that needs rangeify 2 also
This commit is contained in:
@@ -123,6 +123,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert num_loads <= 4, "more load uops than needed"
|
||||
assert num_loads >= 4, "unexpected number of uops, maybe this test needs updating?"
|
||||
|
||||
@unittest.skip("this is handled at higher level now")
|
||||
def test_upcast_cse(self):
|
||||
# when upcasting, within a subtree, there may be common expressions.
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ class TestFuse(unittest.TestCase):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a / a.mean(axis=1), a)
|
||||
|
||||
@unittest.skipIf(0<RANGEIFY<2, "needs RANGEIFY>1")
|
||||
def test_fuse_argmax(self):
|
||||
a = Tensor.rand(50,50).realize()
|
||||
self._test_fuse(lambda a: a.argmax(axis=-1), a)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import getenv, GlobalCounters, EMULATE
|
||||
from tinygrad.helpers import getenv, GlobalCounters, EMULATE, RANGEIFY
|
||||
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec, get_program
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.codegen import full_rewrite
|
||||
@@ -51,7 +51,11 @@ class TestMemoryCount(unittest.TestCase):
|
||||
a = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
|
||||
b = Tensor.empty(1024, 1, dtype=dtypes.uint8).expand(1024, 1024)
|
||||
_, mem = get_stats(a+b)
|
||||
self.assertEqual(mem, 1024*1024 + 2*1024) # 2 lil reads + 1 write
|
||||
if RANGEIFY:
|
||||
# rangeify is smart!
|
||||
self.assertEqual(mem, 1024 + 2*1024) # 2 lil reads + 1 lil write
|
||||
else:
|
||||
self.assertEqual(mem, 1024*1024 + 2*1024) # 2 lil reads + 1 write
|
||||
|
||||
def test_self_add(self):
|
||||
a = Tensor.empty(1024, 1024, dtype=dtypes.uint8)
|
||||
|
||||
@@ -7,6 +7,7 @@ from tinygrad.uop.symbolic import sym, symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
|
||||
# *****************
|
||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||
@@ -587,6 +588,9 @@ to_define_global = PatternMatcher([
|
||||
# this is only needed if you are using symbolic
|
||||
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"), lambda c: c.replace(src=()) if len(c.src) else None),
|
||||
|
||||
# remove RANGE with 0 size
|
||||
(UPat(Ops.RANGE, name="r"), lambda r: UOp.const(dtypes.index, 0) if r.vmax == 0 else None),
|
||||
|
||||
# renumber the ranges starting with 0 so that kernel deduping works
|
||||
(UPat(Ops.RANGE, name="r"), renumber_range),
|
||||
])
|
||||
@@ -629,7 +633,7 @@ def split_store(ctx:list[UOp], x:UOp):
|
||||
|
||||
# local kernel rewrite
|
||||
lctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
|
||||
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
|
||||
|
||||
# gather the metadata
|
||||
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
|
||||
|
||||
Reference in New Issue
Block a user