diff --git a/test/test_schedule.py b/test/test_schedule.py index bb84827dc0..e8a684cef5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -8,13 +8,15 @@ import numpy as np from typing import Dict, List, Optional, Union, cast from tinygrad import nn, dtypes from tinygrad.device import Device +from tinygrad.dtype import PtrDType from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.view import View from tinygrad.tensor import Tensor -from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps +from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP from tinygrad.codegen.kernel import Kernel, verify_ast -from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup -from tinygrad.engine.realize import run_schedule +from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup, reduceop_fusor +from tinygrad.engine.realize import CompiledRunner, run_schedule from test.helpers import is_dtype_supported, Context from tinygrad.lazy import LazyBuffer, view_supported_devices from extra.models.llama import precompute_freqs_cis @@ -1642,5 +1644,31 @@ class TestScheduleRewrite(unittest.TestCase): new_val = st_fixup(val, lambda st:st.reshape((4,)), {}, {}) self.assertIs(new_val, val) + def test_swizzle_rewrite(self): + # graph rewrite + sink = UOp(UOps.SINK, None, arg=None, src=( + UOp(UOps.STORE, None, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), + UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501 + UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(UOps.LOAD, dtypes.int, arg=None, src=( + x8:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501 + UOp(UOps.LOAD, dtypes.int, arg=None, src=( + x8, + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501 + sink = graph_rewrite(sink, reduceop_fusor) + # verify output + k = Kernel(sink) + p = k.to_program() + a = Tensor.randint(32, 32).realize() + b = Tensor.empty((), dtype=dtypes.int).realize() + CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer]) + expected_out = (a.numpy() + a.numpy().sum()).sum() + np.testing.assert_equal(b.numpy(), expected_out) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 2e1eb96ae3..9d8f4ea202 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -82,7 +82,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. # reduce ops change ShapeTracker if buf.op in ReduceOps: - swizzle = (UOp(UOps.SWIZZLE, src=(st.to_uop(),)),) if not st.contiguous and AST_REWRITE else () + swizzle_arg = st if not st.contiguous and AST_REWRITE else None rinfo: Optional[Tuple[ShapeTracker, Tuple[int, ...]]] = (ShapeTracker.from_shape(buf.srcs[0].shape), buf.arg) \ if AST_REWRITE else reduce_info.get((buf, st)) rsrc = _recursive_uop(buf.srcs[0], st:=(rinfo[0] if rinfo else st), outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) @@ -91,7 +91,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. if rinfo is None: assert rsrc.op is UOps.REDUCE_AXIS and rsrc.arg[0] is alu_op, f"can't merge reduceop {buf.op} with {rsrc}\n{st}" return rsrc - return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,)+swizzle, (alu_op, rinfo[1]))) + ret = UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (alu_op, rinfo[1])) + if swizzle_arg is not None: ret = UOp(UOps.SWIZZLE, dtype, (ret,), swizzle_arg) + return cache.setdefault((buf, st), ret) # elementwise ops pass shapetracker in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, reduce_info, cache) for x in buf.srcs) @@ -172,10 +174,11 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int # ***** reduceop fusor ***** -def apply_swizzle(root:UOp, rsrc:UOp, swizzle:UOp) -> UOp: +def push_swizzle_through_reduce(swizzle:UOp, reduceop:UOp) -> UOp: uop_sts: Dict[UOp, ShapeTracker] = {} - new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, root.arg[1]) - return replace(root, src=(st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), arg=(root.arg[0], new_axis)) + rsrc = reduceop.src[0] + new_input_st, new_axis = swizzle_reduceop(unwrap(get_output_st(rsrc, uop_sts)), swizzle.arg, reduceop.arg[1]) + return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda _:new_input_st, uop_sts, {}),), (reduceop.arg[0], new_axis)) def push_reduceop_shape(root:UOp) -> Optional[UOp]: reduceops = [x for x in root.parents if x.op is UOps.REDUCE_AXIS] @@ -186,7 +189,7 @@ def push_reduceop_shape(root:UOp) -> Optional[UOp]: return st_fixup(root, lambda st:st.reshape(rshape), uop_sts, {}) reduceop_fusor = PatternMatcher([ - (UPat(UOps.REDUCE_AXIS, src=(UPat(name="rsrc"), UPat(UOps.SWIZZLE, src=(UPat(name="swizzle"),))), name="root"), apply_swizzle), + (UPat(UOps.SWIZZLE, src=(UPat(UOps.REDUCE_AXIS, name="reduceop"),), name="swizzle"), push_swizzle_through_reduce), (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE}, name="root"), push_reduceop_shape), ]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 3d3b00fd53..5b616569fc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -46,7 +46,7 @@ class UOps(Enum): Holds `UOps.STORE`. SINK defines the AST for a Kernel. - **`dtype`**: `None` - - **`src`**: `Tuple[UOp, ...]`, Only local STOREs are allowed. + - **`src`**: `Tuple[UOp, ...]`, Only global STOREs are allowed. - **`arg`**: `Optional[KernelInfo]` NOTE: `ScheduleItem` ASTs do not have the `KernelInfo` arg, `Kernel` inserts this to the SINK later. @@ -70,6 +70,59 @@ class UOps(Enum): - **`arg`**: `ShapeTracker` """ SWIZZLE = auto() + """ + Swizzle inserts a movement op between a UOp and its children. Because movement ops (reshape, expand, shrink, permute, pad) are not allowed in an AST, + the scheduler rewrites SWIZZLE by pushing its ShapeTracker through reduceops or elementwise ops to the edges of the graph. + + Example: + ```python + a = Tensor.empty(32, 32) + first_reduce = a.sum() + output = (a + first_reduce).sum() + ``` + `first_reduce` must broadcast to `(32, 32)` before ADD. We UOp this as: + + ``` + UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( + UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(UOps.LOAD, dtypes.int, arg=None, src=( + x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), + UOp(UOps.LOAD, dtypes.int, arg=None, src=( + x3, + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)) + ``` + + The scheduler rewrites this by pushing the expand in SWIZZLE through the reduce, to the LOAD: + + ```diff + UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( + - UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( + - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( + - UOp(UOps.LOAD, dtypes.int, arg=None, src=( + - x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), + - UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), + + UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2, 3)), src=( + + UOp(UOps.LOAD, dtypes.int, arg=None, src=( + + x2:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), + + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), + UOp(UOps.LOAD, dtypes.int, arg=None, src=( + - x3, + - UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)) + + x2, + + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)) + + ``` + + NOTE: Pushing a SWIZZLE through a reduce changes the axis. + + NOTE: Pushing a SWIZZLE changes the output shape of that UOp. We have to reshape every other adjacent node. eg. reshape of the second LOAD to `(32, 32, 1, 1)` above. + + - **`dtype`**: Output DType + - **`src`**: `Tuple[UOp]`, a single UOp to swizzle. + - **`arg`**: ShapeTracker + """ # noqa E501 DEFINE_GLOBAL = auto() DEFINE_VAR = auto() DEFINE_LOCAL = auto()