disable PADTO on upcasted axis (#4444)

fixed test_failure_31. PADTO upcasted is at best a no-op, and might fail at edge cases.
This commit is contained in:
chenyu
2024-05-05 21:52:03 -04:00
committed by GitHub
parent 709410071c
commit afe020710d
3 changed files with 25 additions and 5 deletions

View File

@@ -11,7 +11,7 @@ from tinygrad.shape.symbolic import MulNode, Variable, NumNode, Node
from tinygrad.tensor import Tensor
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule
from tinygrad.helpers import prod, Context, getenv
from tinygrad.helpers import prod, Context, getenv, CI
from tinygrad.dtype import DType, dtypes
from tinygrad.codegen.uops import UOpGraph
@@ -787,7 +787,7 @@ class TestKernelOpts(unittest.TestCase):
], apply_tc=True, atol=atol, rtol=rtol)
def test_padto_matmul(self):
if Device.DEFAULT in ["CUDA", "RHIP"]: self.skipTest("super slow on CUDA and RHIP because of the big grid dims")
if CI and Device.DEFAULT in ["CUDA", "RHIP"]: self.skipTest("super slow on CUDA and RHIP because of the big grid dims")
N = 17 * 17
Tensor.manual_seed(289)
a = Tensor.rand(N, N)
@@ -802,6 +802,25 @@ class TestKernelOpts(unittest.TestCase):
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.PADTO, 1, 32), Opt(OptOps.UPCAST, 0, 2), Opt(OptOps.UPCAST, 1, 2),],
])
def test_padto_upcasted_not_ok(self):
N = 4
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
helper_linearizer_opt(a@b, [
[Opt(OptOps.UPCAST, 0, 0)],
[Opt(OptOps.UPCAST, 1, 0)],
[Opt(OptOps.UNROLL, 0, 0)],
[Opt(OptOps.PADTO, 0, 8)],
[Opt(OptOps.PADTO, 1, 8)],
[Opt(OptOps.PADTO, 2, 8)],
])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a@b, [[Opt(OptOps.UPCAST, 1, 0), Opt(OptOps.PADTO, 2, 8)]])
with self.assertRaises(KernelOptError):
helper_linearizer_opt(a@b, [[Opt(OptOps.UNROLL, 0, 0), Opt(OptOps.PADTO, 2, 8)]])
def test_padto_sum_ok(self):
N = 18 * 18
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension

View File

@@ -22,7 +22,7 @@ def helper_test_lin(lin: Linearizer, opts, failed_platforms, rtol=1e-2, atol=1e-
lin.apply_opt(opt)
except KernelOptError:
# it's considered fixed if we invalidated the opts
assert Device.DEFAULT not in failed_platforms
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
return
compare_result = compare_linearizer(lin, rtol=rtol, atol=atol)
@@ -234,7 +234,7 @@ class TestLinearizerFailures(unittest.TestCase):
def test_failure_31(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.EXP2, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.SUB, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1.4426950408889634, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None),), arg=((3,), dtypes.float)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),))))
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32)]
helper_test_lin(Linearizer(ast), opts=opts, failed_platforms=["METAL", "GPU", "HSA", "CUDA", "CLANG", "LLVM"])
helper_test_lin(Linearizer(ast), opts=opts, failed_platforms=[])
if __name__ == '__main__':
unittest.main()

View File

@@ -493,8 +493,9 @@ class Kernel:
self.dont_use_locals = True
elif opt.op is OptOps.PADTO:
check(not self.vars, "does not work with symbolic shape")
check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
# ok to pad SUM if all parent ops have f(0) = 0
if self.first_reduce <= axis < self.shape_len - self.upcasted:
if self.first_reduce <= axis:
check(self.reduceop.op is ReduceOps.SUM and all(op.op not in UNSAFE_PAD_OPS for ops in self.reduceop.src for op in ops.lazyops), "cannot pad")
padded = False
for i,st in enumerate(self.sts):