split schedule to view_left and view_right [pr] (#7077)

* split schedule to view_left and view_right [pr]

* move valid
This commit is contained in:
qazal
2024-10-16 03:39:38 +03:00
committed by GitHub
parent 8601115976
commit fb29de6cc3
3 changed files with 19 additions and 17 deletions

View File

@@ -14,7 +14,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, reduceop_fusor, st_fixup
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
from tinygrad.engine.realize import CompiledRunner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from test.helpers import ast_const, is_dtype_supported, Context, timeit
@@ -1614,7 +1614,7 @@ class TestIndexing(unittest.TestCase):
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
rsink = graph_rewrite(sink, reduceop_fusor)
rsink = graph_rewrite(sink, view_right)
self.assertEqual(rsink.key, sink.key)
def test_simple_store_reshape(self):
@@ -1624,7 +1624,7 @@ class TestIndexing(unittest.TestCase):
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, reduceop_fusor)
rsink = graph_rewrite(sink, view_right)
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(sink)
@@ -1635,7 +1635,7 @@ class TestIndexing(unittest.TestCase):
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, reduceop_fusor)
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
self.assertEqual(sink.key, rsink.key)
@@ -1646,7 +1646,7 @@ class TestIndexing(unittest.TestCase):
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, reduceop_fusor)
rsink, et = timeit(graph_rewrite, sink, view_right)
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(sink)
@@ -1664,7 +1664,7 @@ class TestIndexing(unittest.TestCase):
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, reduceop_fusor)
rsink, et = timeit(graph_rewrite, sink, view_right)
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(rsink)
tms.append(et)
@@ -1693,7 +1693,7 @@ class TestIndexing(unittest.TestCase):
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
sink = graph_rewrite(sink, reduceop_fusor)
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# verify output
k = Kernel(sink)
p = k.to_program()
@@ -1716,7 +1716,7 @@ class TestIndexing(unittest.TestCase):
alu = swizzle_r+const
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
# graph rewrite
sink = graph_rewrite(sink, reduceop_fusor)
sink = graph_rewrite(sink, view_right)
# verify output
k = Kernel(sink)
p = k.to_program()
@@ -1739,7 +1739,7 @@ class TestIndexing(unittest.TestCase):
alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
# graph rewrite
sink = graph_rewrite(sink, reduceop_fusor)
sink = graph_rewrite(sink, view_right)
# verify output
k = Kernel(sink)
p = k.to_program()
@@ -1755,7 +1755,7 @@ class TestIndexing(unittest.TestCase):
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
# there's an EXPAND pushing through the REDUCE_AXIS
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
ret = graph_rewrite(swizzle, reduceop_fusor)
ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right)
# EXPAND is rewritten
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
# and pushed to the LOAD

View File

@@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes, DType, PtrDType
from tinygrad.device import Buffer, Device
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule, reduceop_fusor
from tinygrad.engine.schedule import create_schedule, enumerate_bufs
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
@@ -441,7 +441,7 @@ class TestIndexingOrdering(unittest.TestCase):
class TestUPatHelpers(unittest.TestCase):
def test_location(self):
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
self.assertEqual(reduceop_fusor.patterns[0][0].location[0].split("/")[-1], "schedule.py")
self.assertEqual(enumerate_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py")
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
test_upat = UPat(UOps.CONST, dtypes.bool)