mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
pad reduceops to the max of each dimension (#5889)
* early verify * pad reduceops to the max of each dim * remove the function
This commit is contained in:
@@ -8,7 +8,7 @@ from typing import List, Optional, Union, cast
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps, verify_lazyop
|
||||
from tinygrad.helpers import DEBUG, FUSE_ARANGE, flatten, getenv
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -1270,9 +1270,10 @@ class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Tensor, cnt:int):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
s = xt.schedule()
|
||||
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
|
||||
kernels = [si for si in s if si.ast.op is MetaOps.KERNEL]
|
||||
for si in kernels: verify_lazyop(si.ast)
|
||||
run_schedule(s)
|
||||
if FUSE_ARANGE: self.assertEqual(kernel_cnt, cnt)
|
||||
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
|
||||
|
||||
def test_simple_indexing(self):
|
||||
X = Tensor.randn(10, 10).realize()
|
||||
@@ -1281,11 +1282,10 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule(xt, 2)
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_simple_indexing_alt(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[[1, 2], [1, 2]]
|
||||
self.check_schedule(xt, 5)
|
||||
self.check_schedule(xt, 3)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@@ -1302,11 +1302,10 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule(xt, 6)
|
||||
np.testing.assert_equal(xt.numpy(), 6)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_advanced_simple_indexing_combined(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[1:2, [1, 2]]
|
||||
self.check_schedule(xt, 4)
|
||||
self.check_schedule(xt, 2)
|
||||
np.testing.assert_equal(xt.numpy(), np.arange(16).reshape(4, 4)[1:2, [1, 2]])
|
||||
|
||||
def test_push_through_reshape(self):
|
||||
|
||||
@@ -131,10 +131,18 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
|
||||
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
|
||||
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
|
||||
# unify the kernel dims
|
||||
# push through all movementops between reduceops
|
||||
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
||||
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
|
||||
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
|
||||
# pad all reduceops to the max of each dimension
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
|
||||
for i,dims in enumerate(shape_dims):
|
||||
if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
|
||||
for (r,view),(input_st,axis) in reduce_info.items():
|
||||
if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
|
||||
input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
|
||||
reduce_info[(r, view)] = (input_st, axis)
|
||||
# create the stores
|
||||
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||
|
||||
Reference in New Issue
Block a user