pad reduceops to the max of each dimension (#5889)

* early verify

* pad reduceops to the max of each dim

* remove the function
This commit is contained in:
qazal
2024-08-03 19:03:30 +08:00
committed by GitHub
parent 65fa86901a
commit 56ef9e453e
2 changed files with 15 additions and 8 deletions

View File

@@ -8,7 +8,7 @@ from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps, verify_lazyop
from tinygrad.helpers import DEBUG, FUSE_ARANGE, flatten, getenv
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
@@ -1270,9 +1270,10 @@ class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Tensor, cnt:int):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
s = xt.schedule()
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
kernels = [si for si in s if si.ast.op is MetaOps.KERNEL]
for si in kernels: verify_lazyop(si.ast)
run_schedule(s)
if FUSE_ARANGE: self.assertEqual(kernel_cnt, cnt)
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
def test_simple_indexing(self):
X = Tensor.randn(10, 10).realize()
@@ -1281,11 +1282,10 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
@unittest.expectedFailure
def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [1, 2]]
self.check_schedule(xt, 5)
self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
@unittest.expectedFailure
@@ -1302,11 +1302,10 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(xt, 6)
np.testing.assert_equal(xt.numpy(), 6)
@unittest.expectedFailure
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]]
self.check_schedule(xt, 4)
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), np.arange(16).reshape(4, 4)[1:2, [1, 2]])
def test_push_through_reshape(self):

View File

@@ -131,10 +131,18 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
# unify the kernel dims
# push through all movementops between reduceops
reduce_info: Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]] = {}
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
# pad all reduceops to the max of each dimension
shape_dims = [sorted(dedup(dims)) for dims in zip(*[input_st.shape for input_st,_ in reduce_info.values()])]
for i,dims in enumerate(shape_dims):
if len(dims) == 1 or (len(dims) == 2 and dims[0] == 1): continue
for (r,view),(input_st,axis) in reduce_info.items():
if (dim:=input_st.shape[i]) > 1 and dim != max(dims):
input_st = input_st.pad(((0, 0),)*i+((0, max(dims)-dim),))
reduce_info[(r, view)] = (input_st, axis)
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}