mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Proposal: Better UOps.SWIZZLE (#6309)
* better UOps.SWIZZLE * test_swizzle_rewrite * add it to docs * show a diff * a lil more verbose * two teeny notes * hotfix: sink
This commit is contained in:
@@ -8,13 +8,15 @@ import numpy as np
|
||||
from typing import Dict, List, Optional, Union, cast
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, FUSE_CONV_BW, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.schedule import create_schedule, get_output_st, st_fixup, reduceop_fusor
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from test.helpers import is_dtype_supported, Context
|
||||
from tinygrad.lazy import LazyBuffer, view_supported_devices
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
@@ -1642,5 +1644,31 @@ class TestScheduleRewrite(unittest.TestCase):
|
||||
new_val = st_fixup(val, lambda st:st.reshape((4,)), {}, {})
|
||||
self.assertIs(new_val, val)
|
||||
|
||||
def test_swizzle_rewrite(self):
|
||||
# graph rewrite
|
||||
sink = UOp(UOps.SINK, None, arg=None, src=(
|
||||
UOp(UOps.STORE, None, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x8:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x8,
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
sink = graph_rewrite(sink, reduceop_fusor)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
a = Tensor.randint(32, 32).realize()
|
||||
b = Tensor.empty((), dtype=dtypes.int).realize()
|
||||
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
|
||||
expected_out = (a.numpy() + a.numpy().sum()).sum()
|
||||
np.testing.assert_equal(b.numpy(), expected_out)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user