mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
split schedule to view_left and view_right [pr] (#7077)
* split schedule to view_left and view_right [pr] * move valid
This commit is contained in:
@@ -14,7 +14,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, reduceop_fusor, st_fixup
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
|
||||
from test.helpers import ast_const, is_dtype_supported, Context, timeit
|
||||
@@ -1614,7 +1614,7 @@ class TestIndexing(unittest.TestCase):
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||
rsink = graph_rewrite(sink, reduceop_fusor)
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
self.assertEqual(rsink.key, sink.key)
|
||||
|
||||
def test_simple_store_reshape(self):
|
||||
@@ -1624,7 +1624,7 @@ class TestIndexing(unittest.TestCase):
|
||||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, reduceop_fusor)
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
|
||||
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
verify_ast(sink)
|
||||
@@ -1635,7 +1635,7 @@ class TestIndexing(unittest.TestCase):
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, reduceop_fusor)
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
verify_ast(sink)
|
||||
self.assertEqual(sink.key, rsink.key)
|
||||
|
||||
@@ -1646,7 +1646,7 @@ class TestIndexing(unittest.TestCase):
|
||||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink, et = timeit(graph_rewrite, sink, reduceop_fusor)
|
||||
rsink, et = timeit(graph_rewrite, sink, view_right)
|
||||
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
|
||||
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
verify_ast(sink)
|
||||
@@ -1664,7 +1664,7 @@ class TestIndexing(unittest.TestCase):
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink, et = timeit(graph_rewrite, sink, reduceop_fusor)
|
||||
rsink, et = timeit(graph_rewrite, sink, view_right)
|
||||
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
verify_ast(rsink)
|
||||
tms.append(et)
|
||||
@@ -1693,7 +1693,7 @@ class TestIndexing(unittest.TestCase):
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x8,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
@@ -1716,7 +1716,7 @@ class TestIndexing(unittest.TestCase):
|
||||
alu = swizzle_r+const
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
|
||||
# graph rewrite
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
sink = graph_rewrite(sink, view_right)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
@@ -1739,7 +1739,7 @@ class TestIndexing(unittest.TestCase):
|
||||
alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
|
||||
# graph rewrite
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
sink = graph_rewrite(sink, view_right)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
@@ -1755,7 +1755,7 @@ class TestIndexing(unittest.TestCase):
|
||||
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
||||
# there's an EXPAND pushing through the REDUCE_AXIS
|
||||
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
|
||||
ret = graph_rewrite(swizzle, reduceop_fusor)
|
||||
ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right)
|
||||
# EXPAND is rewritten
|
||||
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
|
||||
# and pushed to the LOAD
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import Program
|
||||
from tinygrad.engine.schedule import create_schedule, reduceop_fusor
|
||||
from tinygrad.engine.schedule import create_schedule, enumerate_bufs
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
@@ -441,7 +441,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||
class TestUPatHelpers(unittest.TestCase):
|
||||
def test_location(self):
|
||||
self.assertEqual(sym.patterns[-1][0].location[0].split("/")[-1], "uopgraph.py")
|
||||
self.assertEqual(reduceop_fusor.patterns[0][0].location[0].split("/")[-1], "schedule.py")
|
||||
self.assertEqual(enumerate_bufs.patterns[0][0].location[0].split("/")[-1], "schedule.py")
|
||||
self.assertEqual(spec.patterns[0][0].location[0].split("/")[-1], "ops.py")
|
||||
with self.assertRaises(AssertionError): # TODO: location UPat files created in test/*?
|
||||
test_upat = UPat(UOps.CONST, dtypes.bool)
|
||||
|
||||
Reference in New Issue
Block a user