mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
pad reduceops to the max of each dimension (#5889)
* early verify * pad reduceops to the max of each dim * remove the function
This commit is contained in:
@@ -8,7 +8,7 @@ from typing import List, Optional, Union, cast
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps, verify_lazyop
|
||||
from tinygrad.helpers import DEBUG, FUSE_ARANGE, flatten, getenv
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -1270,9 +1270,10 @@ class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Tensor, cnt:int):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
s = xt.schedule()
|
||||
kernel_cnt = len([si for si in s if si.ast.op is MetaOps.KERNEL])
|
||||
kernels = [si for si in s if si.ast.op is MetaOps.KERNEL]
|
||||
for si in kernels: verify_lazyop(si.ast)
|
||||
run_schedule(s)
|
||||
if FUSE_ARANGE: self.assertEqual(kernel_cnt, cnt)
|
||||
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
|
||||
|
||||
def test_simple_indexing(self):
|
||||
X = Tensor.randn(10, 10).realize()
|
||||
@@ -1281,11 +1282,10 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule(xt, 2)
|
||||
np.testing.assert_equal(xt.numpy(), X.numpy()[idxs.numpy()])
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_simple_indexing_alt(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[[1, 2], [1, 2]]
|
||||
self.check_schedule(xt, 5)
|
||||
self.check_schedule(xt, 3)
|
||||
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
|
||||
|
||||
@unittest.expectedFailure
|
||||
@@ -1302,11 +1302,10 @@ class TestIndexing(unittest.TestCase):
|
||||
self.check_schedule(xt, 6)
|
||||
np.testing.assert_equal(xt.numpy(), 6)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_advanced_simple_indexing_combined(self):
|
||||
X = Tensor.arange(16).reshape(4, 4)
|
||||
xt = X[1:2, [1, 2]]
|
||||
self.check_schedule(xt, 4)
|
||||
self.check_schedule(xt, 2)
|
||||
np.testing.assert_equal(xt.numpy(), np.arange(16).reshape(4, 4)[1:2, [1, 2]])
|
||||
|
||||
def test_push_through_reshape(self):
|
||||
|
||||
Reference in New Issue
Block a user