small changes from rangeify (#11682)

* small changes from rangeify

* const like thing

* ksym
This commit is contained in:
George Hotz
2025-08-15 08:45:52 -07:00
committed by GitHub
parent d0d39885c3
commit 560984fd8d
4 changed files with 10 additions and 8 deletions

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad import Tensor
from tinygrad.uop.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
from tinygrad.schedule.kernelize import sym, merge_views
from tinygrad.schedule.kernelize import kernelize_sym, merge_views
class TestRewriteTrackedChildren(unittest.TestCase):
@unittest.skip("track_children no longer supported")
@@ -57,7 +57,7 @@ class TestRewriteTrackedChildren(unittest.TestCase):
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
a = Tensor.empty(3, 3)
r = (a+0).sum()
graph_rewrite(r.uop, merge_views+sym+extra, track_children=True)
graph_rewrite(r.uop, merge_views+kernelize_sym+extra, track_children=True)
if __name__ == '__main__':
unittest.main()

View File

@@ -227,7 +227,7 @@ block_merge = PatternMatcher([
def finalize(sink:UOp) -> UOp:
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
raise RuntimeError("linearize failure")
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
# place the early things
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)

View File

@@ -48,7 +48,7 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
return base.copy_to_device(copy.device).view(view.arg)
sym = symbolic_simple+PatternMatcher([
kernelize_sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
@@ -327,7 +327,7 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
"""
# multi + merge_views + simplify
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")

View File

@@ -232,7 +232,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
def index(self, *srcs:UOp|None): return UOp(Ops.INDEX, self.dtype, (self,)+tuple([x for x in srcs if x is not None]))
def index(self, *srcs:UOp|None, **kwargs):
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
def __getitem__(self, idx): return self.index(idx)
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
@@ -273,7 +274,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
from tinygrad.shape.shapetracker import ShapeTracker
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
if device is not None:
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
if shape is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
return ret
@staticmethod
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
@@ -951,7 +953,7 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
new_map: dict[UOp, UOp] = {}
for k in sink.toposort():
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
if input_map is not None: