mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
small changes from rangeify (#11682)
* small changes from rangeify * const like thing * ksym
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.uop.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
|
||||
from tinygrad.schedule.kernelize import sym, merge_views
|
||||
from tinygrad.schedule.kernelize import kernelize_sym, merge_views
|
||||
|
||||
class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
@unittest.skip("track_children no longer supported")
|
||||
@@ -57,7 +57,7 @@ class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
extra = PatternMatcher([(UPat(Ops.REDUCE_AXIS, name="r"), print_children)])
|
||||
a = Tensor.empty(3, 3)
|
||||
r = (a+0).sum()
|
||||
graph_rewrite(r.uop, merge_views+sym+extra, track_children=True)
|
||||
graph_rewrite(r.uop, merge_views+kernelize_sym+extra, track_children=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -227,7 +227,7 @@ block_merge = PatternMatcher([
|
||||
|
||||
def finalize(sink:UOp) -> UOp:
|
||||
if sink.op is not Ops.BLOCK or not all(x.op in DONT_PLACE_IN_BLOCK for x in sink.src):
|
||||
raise RuntimeError("linearize failure")
|
||||
raise RuntimeError(f"linearize failure {sink.op} {[x.op for x in sink.src if x.op not in DONT_PLACE_IN_BLOCK]}")
|
||||
|
||||
# place the early things
|
||||
lst = sorted(dedup(sink.src), key=lambda x: x.tuplize) + list(sink.arg.lst)
|
||||
|
||||
@@ -48,7 +48,7 @@ def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
|
||||
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
|
||||
return base.copy_to_device(copy.device).view(view.arg)
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
kernelize_sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
@@ -327,7 +327,7 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
||||
"""
|
||||
|
||||
# multi + merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
|
||||
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
|
||||
@@ -232,7 +232,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ret
|
||||
def sink(self, *srcs:UOp|None, **kwargs): return UOp(Ops.SINK, dtypes.void, (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def detach(self): return UOp(Ops.DETACH, self.dtype, (self,))
|
||||
def index(self, *srcs:UOp|None): return UOp(Ops.INDEX, self.dtype, (self,)+tuple([x for x in srcs if x is not None]))
|
||||
def index(self, *srcs:UOp|None, **kwargs):
|
||||
return UOp(Ops.INDEX, kwargs.pop("dtype", self.dtype), (self,)+tuple([x for x in srcs if x is not None]), **kwargs)
|
||||
def __getitem__(self, idx): return self.index(idx)
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
@@ -273,7 +274,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
ret = ret.replace(src=(UOp(Ops.VIEW, dtypes.void, (), ShapeTracker.from_shape(shape, (0,)*len(shape))),))
|
||||
if device is not None:
|
||||
ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
if shape is not None: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device).view(unwrap(ret.st)),))
|
||||
else: ret = ret.replace(src=(UOp(Ops.DEVICE, arg=device),))
|
||||
return ret
|
||||
@staticmethod
|
||||
def range(dtype:DType, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(end),), arg=idx)
|
||||
@@ -951,7 +953,7 @@ def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, na
|
||||
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
new_map: dict[UOp, UOp] = {}
|
||||
for k in sink.toposort():
|
||||
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
|
||||
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
|
||||
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
|
||||
if input_map is not None:
|
||||
|
||||
Reference in New Issue
Block a user