mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
replace hardcoded ast with tensors in TestSwizzle [pr] (#9401)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# this will be the new test_ops for the next level
|
||||
# schedule confirms the right things are capable of fusing
|
||||
# NOTE: this has overlap with external_test_opt.py
|
||||
# ruff: noqa: E501
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
@@ -12,7 +11,6 @@ from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
@@ -1958,47 +1956,14 @@ class TestSwizzle(unittest.TestCase):
|
||||
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
|
||||
np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3)
|
||||
|
||||
@unittest.skip("this swizzle can't be decided after the ADD")
|
||||
@unittest.skip("TODO: this swizzle isn't resolvable when there's a mask")
|
||||
def test_swizzle_failure_permute(self):
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(20, 65), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(8, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),
|
||||
x15:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12,
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.0003418803389649838, src=()),
|
||||
x15,)),)),
|
||||
x6,)),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12,
|
||||
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
|
||||
x15,)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(2, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
x10,)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(4, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
a = Tensor.empty(45,65).T.reshape(65,1,45).pad((None,None,(0,45))).expand(65,45,90)
|
||||
b = Tensor.empty(45,65)
|
||||
a_reduce = a.sum(axis=(2,), keepdim=True).sum(axis=(1,))
|
||||
b_reduce = b.sum(axis=(0,))
|
||||
t = a_reduce+b_reduce
|
||||
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
|
||||
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
|
||||
zero_pm = UPat(Ops.CONST, arg=0)
|
||||
|
||||
Reference in New Issue
Block a user