Proposal: Better UOps.SWIZZLE (#6309)

* better UOps.SWIZZLE

* test_swizzle_rewrite

* add it to docs

* show a diff

* a lil more verbose

* two teeny notes

* hotfix: sink
This commit is contained in:
qazal
2024-08-29 20:39:48 +08:00
committed by GitHub
parent 8c50ef8b7c
commit 07942ef361
3 changed files with 94 additions and 10 deletions

View File

@@ -8,13 +8,15 @@ import numpy as np
from typing import Dict, List, Optional, Union, cast
from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.dtype import PtrDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup, reduceop_fusor
from tinygrad.engine.realize import CompiledRunner, run_schedule
from test.helpers import is_dtype_supported, Context
from tinygrad.lazy import LazyBuffer, view_supported_devices
from extra.models.llama import precompute_freqs_cis
@@ -1642,5 +1644,31 @@ class TestScheduleRewrite(unittest.TestCase):
new_val = st_fixup(val, lambda st:st.reshape((4,)), {}, {})
self.assertIs(new_val, val)
def test_swizzle_rewrite(self):
# graph rewrite
sink = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8,
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
sink = graph_rewrite(sink, reduceop_fusor)
# verify output
k = Kernel(sink)
p = k.to_program()
a = Tensor.randint(32, 32).realize()
b = Tensor.empty((), dtype=dtypes.int).realize()
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
expected_out = (a.numpy() + a.numpy().sum()).sum()
np.testing.assert_equal(b.numpy(), expected_out)
if __name__ == '__main__':
unittest.main(verbosity=2)