Revert "use movement ops [pr] (#8222)" (#8224)

This reverts commit 0d26c970ba.
This commit is contained in:
George Hotz
2024-12-13 14:10:47 -08:00
committed by GitHub
parent 0d26c970ba
commit da19c37f0a
5 changed files with 13 additions and 43 deletions

View File

@@ -1350,7 +1350,7 @@ class TestNumpy(unittest.TestCase):
# Empty tuple index creates a view
a = Tensor([1, 2, 3])
numpy_testing_assert_equal_helper(a[()], a)
#self.assertEqual(data_ptr(a[()]), data_ptr(a))
self.assertEqual(data_ptr(a[()]), data_ptr(a))
# TODO jax supports empty tensor indexing
@unittest.skip("empty tensor indexing not supported")
@@ -1372,7 +1372,7 @@ class TestNumpy(unittest.TestCase):
self.assertIsNot(a[...], a)
numpy_testing_assert_equal_helper(a[...], a)
# `a[...]` was `a` in numpy <1.9.
#numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
# Slicing with ellipsis can skip an
# arbitrary number of dimensions

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize, remove_movement_ops
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from extra.models.llama import precompute_freqs_cis
@@ -1874,7 +1874,6 @@ class TestSwizzle(unittest.TestCase):
base = ShapeTracker.from_shape((32, 32))
a = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
swizzle = a.reshape((64, 16))
swizzle = graph_rewrite(swizzle, remove_movement_ops)
self.assertEqual(swizzle_cnt(swizzle), 1)
ret = swizzle_rewrite(swizzle)
self.assertEqual(ret.st_arg, base.reshape((64, 16))) # late rewrite
@@ -1890,7 +1889,6 @@ class TestSwizzle(unittest.TestCase):
add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1))
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
to_store = add.permute((1, 0, 2)).contiguous()
to_store = graph_rewrite(to_store, remove_movement_ops)
self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1)))
self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2)))
self.assertIs(to_store.src[0].op, Ops.VIEW)
@@ -1926,8 +1924,6 @@ class TestView(unittest.TestCase):
a = UOp(Ops.VIEW, dtypes.float, (UOp.new_buffer(Device.DEFAULT, 121, dtypes.float), UOp(Ops.EMPTY, dtypes.float)), st)
b = a.pad(pad_arg:=((0, 0), (0, 0), (18, 0)))
self.assertEqual(b.st, st.pad(pad_arg))
# TODO: why does this help?
b = graph_rewrite(b, remove_movement_ops)
self.assertIs(b.base.src[1], UOp.const(dtypes.float, 0))
def test_partial_mask(self):

View File

@@ -493,17 +493,14 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
buf_uop.buffer.ref(1)
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
remove_movement_ops = PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),])
@track_rewrites(named=True)
def create_schedule_with_vars(outs:List[UOp]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
# create the big graph
ctx = ScheduleContext()
cache: Dict[UOp, UOp] = {}
# to_uop is removing (many) of the movement ops
for u in (big_graph:=UOp.sink(*(to_uop(x, ctx, cache) for x in outs))).src: ctx.realizes[u.buf_uop] = u
big_graph = graph_rewrite(big_graph, remove_movement_ops+ops_folding+do_realize, ctx)
big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx)
big_graph = graph_rewrite(big_graph, merge_bufs, ctx)
# create the scheduler context
graph_rewrite(big_graph, create_ctx, ctx)

View File

@@ -272,7 +272,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
if self.op is Ops.VIEW: return self.arg
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
# buffer ops can have a non contiguous shapetracker
if self.op in GroupOp.Buffer and len(src_sts:=[unwrap(x.st) for x in self.src if x.op is Ops.VIEW]) != 0: return src_sts[0]
if len(src_sts:=[x.st for x in self.src if x.st is not None]) == 0: return None
@@ -343,16 +342,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if bitcast: return self.bitcast(dtype, allow_buffer_view)
if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK")
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
# NOTE: we have to apply the movementops here, we can't use VIEW (yet)
# TODO: move this to the scheduler
ret = self.base.cast(dtype, bitcast)
op_arg = []
mop = self
while mop.op in GroupOp.Movement:
op_arg.append((mop.op, mop.arg))
mop = mop.src[0]
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
return ret
return self.base.cast(dtype, bitcast).view(self.st)
return UOp(Ops.CAST, dtype, (self,))
def bitcast(self, dtype:DType, allow_buffer_view=True):
if self.can_view() and allow_buffer_view:
@@ -472,9 +462,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop movement ops ***
@property
def base(self) -> UOp:
if self.op in GroupOp.Movement: return self.src[0].base
return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st)
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
@@ -482,18 +470,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return ret
def _mop(self, op:Ops, arg):
ret = UOp(op, self.dtype, (self,), arg)
ret.st # pylint: disable=pointless-statement
return ret
def reshape(self, arg:Tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
def expand(self, arg:Tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
def permute(self, arg:Tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
def stride(self, arg:Tuple[sint, ...]): return self._mop(Ops.STRIDE, arg)
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))
def expand(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).expand(arg))
def permute(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).permute(arg))
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).shrink(arg))
def stride(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).stride(arg))
# *** uop Buffer stuff ***

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from dataclasses import dataclass
import functools
from typing import Tuple, List, Optional, Dict, Set, Callable
from typing import Tuple, List, Optional, Dict, Set
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
@@ -115,8 +115,3 @@ class ShapeTracker:
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
return ShapeTracker(self.views + (View.create(new_shape), ))
def mop(self, op, arg): return mops[op](self, arg)
mops: Dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
Ops.SHRINK: ShapeTracker.shrink, Ops.STRIDE: ShapeTracker.stride, Ops.PAD: ShapeTracker.pad}