pad reduceops to the max of each dimension (#5889)

* early verify

* pad reduceops to the max of each dim

* remove the function
This commit is contained in:
qazal
2024-08-03 19:03:30 +08:00
committed by GitHub
parent 65fa86901a
commit 56ef9e453e
2 changed files with 15 additions and 8 deletions

View File

@@ -8,7 +8,7 @@ from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps, verify_lazyop
from tinygrad.helpers import DEBUG, FUSE_ARANGE, flatten, getenv
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
@@ -1270,9 +1270,10 @@ class TestIndexing(unittest.TestCase):
def check_schedule(self, xt:Tensor, cnt:int):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
s = xt.schedule()
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
kernels = [si for si in s if si.ast.op is MetaOps.KERNEL]
for si in kernels: verify_lazyop(si.ast)
run_schedule(s)
if FUSE_ARANGE: self.assertEqual(kernel_cnt, cnt)
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
def test_simple_indexing(self):
X = Tensor.randn(10, 10).realize()
@@ -1281,11 +1282,10 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
@unittest.expectedFailure
def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [1, 2]]
self.check_schedule(xt, 5)
self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
@unittest.expectedFailure
@@ -1302,11 +1302,10 @@ class TestIndexing(unittest.TestCase):
self.check_schedule(xt, 6)
np.testing.assert_equal(xt.numpy(), 6)
@unittest.expectedFailure
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]]
self.check_schedule(xt, 4)
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), np.arange(16).reshape(4, 4)[1:2, [1, 2]])
def test_push_through_reshape(self):