mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
This reverts commit 0d26c970ba.
This commit is contained in:
@@ -1350,7 +1350,7 @@ class TestNumpy(unittest.TestCase):
|
||||
# Empty tuple index creates a view
|
||||
a = Tensor([1, 2, 3])
|
||||
numpy_testing_assert_equal_helper(a[()], a)
|
||||
#self.assertEqual(data_ptr(a[()]), data_ptr(a))
|
||||
self.assertEqual(data_ptr(a[()]), data_ptr(a))
|
||||
|
||||
# TODO jax supports empty tensor indexing
|
||||
@unittest.skip("empty tensor indexing not supported")
|
||||
@@ -1372,7 +1372,7 @@ class TestNumpy(unittest.TestCase):
|
||||
self.assertIsNot(a[...], a)
|
||||
numpy_testing_assert_equal_helper(a[...], a)
|
||||
# `a[...]` was `a` in numpy <1.9.
|
||||
#numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
|
||||
numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
|
||||
|
||||
# Slicing with ellipsis can skip an
|
||||
# arbitrary number of dimensions
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize, remove_movement_ops
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, do_realize
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -1874,7 +1874,6 @@ class TestSwizzle(unittest.TestCase):
|
||||
base = ShapeTracker.from_shape((32, 32))
|
||||
a = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
|
||||
swizzle = a.reshape((64, 16))
|
||||
swizzle = graph_rewrite(swizzle, remove_movement_ops)
|
||||
self.assertEqual(swizzle_cnt(swizzle), 1)
|
||||
ret = swizzle_rewrite(swizzle)
|
||||
self.assertEqual(ret.st_arg, base.reshape((64, 16))) # late rewrite
|
||||
@@ -1890,7 +1889,6 @@ class TestSwizzle(unittest.TestCase):
|
||||
add = r.reshape((16, 32, 1)) + UOp.const_with_shape(r.dtype, 0, (16, 32, 1))
|
||||
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
|
||||
to_store = add.permute((1, 0, 2)).contiguous()
|
||||
to_store = graph_rewrite(to_store, remove_movement_ops)
|
||||
self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1)))
|
||||
self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2)))
|
||||
self.assertIs(to_store.src[0].op, Ops.VIEW)
|
||||
@@ -1926,8 +1924,6 @@ class TestView(unittest.TestCase):
|
||||
a = UOp(Ops.VIEW, dtypes.float, (UOp.new_buffer(Device.DEFAULT, 121, dtypes.float), UOp(Ops.EMPTY, dtypes.float)), st)
|
||||
b = a.pad(pad_arg:=((0, 0), (0, 0), (18, 0)))
|
||||
self.assertEqual(b.st, st.pad(pad_arg))
|
||||
# TODO: why does this help?
|
||||
b = graph_rewrite(b, remove_movement_ops)
|
||||
self.assertIs(b.base.src[1], UOp.const(dtypes.float, 0))
|
||||
|
||||
def test_partial_mask(self):
|
||||
|
||||
@@ -493,17 +493,14 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
buf_uop.buffer.ref(1)
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
remove_movement_ops = PatternMatcher([(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(outs:List[UOp]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
if len(outs:=dedup(x.base for x in outs if x.base.realized is None and x.base.op is not Ops.CONST)) == 0: return [], {}
|
||||
# create the big graph
|
||||
ctx = ScheduleContext()
|
||||
cache: Dict[UOp, UOp] = {}
|
||||
# to_uop is removing (many) of the movement ops
|
||||
for u in (big_graph:=UOp.sink(*(to_uop(x, ctx, cache) for x in outs))).src: ctx.realizes[u.buf_uop] = u
|
||||
big_graph = graph_rewrite(big_graph, remove_movement_ops+ops_folding+do_realize, ctx)
|
||||
big_graph = graph_rewrite(big_graph, ops_folding+do_realize, ctx)
|
||||
big_graph = graph_rewrite(big_graph, merge_bufs, ctx)
|
||||
# create the scheduler context
|
||||
graph_rewrite(big_graph, create_ctx, ctx)
|
||||
|
||||
@@ -272,7 +272,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@functools.cached_property
|
||||
def st(self) -> Optional[ShapeTracker]:
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
# buffer ops can have a non contiguous shapetracker
|
||||
if self.op in GroupOp.Buffer and len(src_sts:=[unwrap(x.st) for x in self.src if x.op is Ops.VIEW]) != 0: return src_sts[0]
|
||||
if len(src_sts:=[x.st for x in self.src if x.st is not None]) == 0: return None
|
||||
@@ -343,16 +342,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if bitcast: return self.bitcast(dtype, allow_buffer_view)
|
||||
if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK")
|
||||
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
||||
# NOTE: we have to apply the movementops here, we can't use VIEW (yet)
|
||||
# TODO: move this to the scheduler
|
||||
ret = self.base.cast(dtype, bitcast)
|
||||
op_arg = []
|
||||
mop = self
|
||||
while mop.op in GroupOp.Movement:
|
||||
op_arg.append((mop.op, mop.arg))
|
||||
mop = mop.src[0]
|
||||
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
|
||||
return ret
|
||||
return self.base.cast(dtype, bitcast).view(self.st)
|
||||
return UOp(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType, allow_buffer_view=True):
|
||||
if self.can_view() and allow_buffer_view:
|
||||
@@ -472,9 +462,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# *** uop movement ops ***
|
||||
|
||||
@property
|
||||
def base(self) -> UOp:
|
||||
if self.op in GroupOp.Movement: return self.src[0].base
|
||||
return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st)
|
||||
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
@@ -482,18 +470,12 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)
|
||||
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
||||
return ret
|
||||
|
||||
def _mop(self, op:Ops, arg):
|
||||
ret = UOp(op, self.dtype, (self,), arg)
|
||||
ret.st # pylint: disable=pointless-statement
|
||||
return ret
|
||||
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg)
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg)
|
||||
def expand(self, arg:Tuple[sint, ...]): return self._mop(Ops.EXPAND, arg)
|
||||
def permute(self, arg:Tuple[sint, ...]): return self._mop(Ops.PERMUTE, arg)
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg)
|
||||
def stride(self, arg:Tuple[sint, ...]): return self._mop(Ops.STRIDE, arg)
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))
|
||||
def expand(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).expand(arg))
|
||||
def permute(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).permute(arg))
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).shrink(arg))
|
||||
def stride(self, arg:Tuple[int, ...]): return self.view(unwrap(self.st).stride(arg))
|
||||
|
||||
# *** uop Buffer stuff ***
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
from typing import Tuple, List, Optional, Dict, Set, Callable
|
||||
from typing import Tuple, List, Optional, Dict, Set
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -115,8 +115,3 @@ class ShapeTracker:
|
||||
def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker:
|
||||
if getenv("MERGE_VIEW", 1) and (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,))
|
||||
return ShapeTracker(self.views + (View.create(new_shape), ))
|
||||
|
||||
def mop(self, op, arg): return mops[op](self, arg)
|
||||
|
||||
mops: Dict[Ops, Callable] = {Ops.RESHAPE: ShapeTracker.reshape, Ops.PERMUTE: ShapeTracker.permute, Ops.EXPAND: ShapeTracker.expand,
|
||||
Ops.SHRINK: ShapeTracker.shrink, Ops.STRIDE: ShapeTracker.stride, Ops.PAD: ShapeTracker.pad}
|
||||
|
||||
Reference in New Issue
Block a user