replace hardcoded ast with tensors in TestSwizzle [pr] (#9401)

This commit is contained in:
qazal
2025-03-10 20:33:57 +02:00
committed by GitHub
parent 796c3bbb23
commit 59dfb234eb

View File

@@ -1,7 +1,6 @@
# this will be the new test_ops for the next level
# schedule confirms the right things are capable of fusing
# NOTE: this has overlap with external_test_opt.py
# ruff: noqa: E501
import unittest
import numpy as np
@@ -12,7 +11,6 @@ from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.spec import type_verify, shape_spec
@@ -1958,47 +1956,14 @@ class TestSwizzle(unittest.TestCase):
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3)
@unittest.skip("this swizzle can't be decided after the ADD")
@unittest.skip("TODO: this swizzle isn't resolvable when there's a mask")
def test_swizzle_failure_permute(self):
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(20, 65), src=(UOp(Ops.DEVICE, arg="METAL"),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(8, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),
x15:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12,
UOp(Ops.CONST, dtypes.float, arg=0.0003418803389649838, src=()),
x15,)),)),
x6,)),)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12,
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
x15,)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(2, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
x10,)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(4, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
a = Tensor.empty(45,65).T.reshape(65,1,45).pad((None,None,(0,45))).expand(65,45,90)
b = Tensor.empty(45,65)
a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,))
b_reduce = b.sum(axis=(0,))
t = a_reduce+b_reduce
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
zero_pm = UPat(Ops.CONST, arg=0)