mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 07:35:16 -05:00
start tests for putting the tensor graph in a single kernel [pr] (#8542)
* start tests for putting the tensor graph in a single kernel [pr] * parallel actually * better view_left test * test a softmax * put all that in sym
This commit is contained in:
@@ -15,7 +15,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
@@ -1753,73 +1753,65 @@ def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_le
|
||||
|
||||
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
|
||||
|
||||
# these pattern matchers should move to engine/schedule.py
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.DETACH, name="x"), lambda x:x.src[0]),
|
||||
])
|
||||
|
||||
def _load_buffer(ctx:list[UOp], buf:UOp):
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(size=buf.size), (), len(ctx))
|
||||
ctx.append(buf)
|
||||
return UOp(Ops.LOAD, buf.dtype, (glbl, ShapeTracker.from_shape((buf.size,)).to_uop()))
|
||||
load_buffers = PatternMatcher([
|
||||
(UPat(Ops.BUFFER, name="buf"), _load_buffer),
|
||||
])
|
||||
|
||||
# put the entire schedule of the tensor in a single ScheduleItem
|
||||
@track_rewrites(named=True)
|
||||
def run_tensor_ast(r:Tensor):
|
||||
output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype)
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0)
|
||||
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
|
||||
sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output])
|
||||
sink = graph_rewrite(sink, remove_movement_ops+sym+view_right)
|
||||
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
|
||||
run_schedule([si])
|
||||
return output.realized.as_buffer().cast(output.dtype.fmt).tolist()
|
||||
|
||||
class TestSwizzle(unittest.TestCase):
|
||||
def test_swizzle_simple(self):
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=(
|
||||
UOp(Ops.ADD, dtypes.int, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=(
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
x8,
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
sink = swizzle_rewrite(sink)
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
a = Tensor.randint(32, 32).realize()
|
||||
b = Tensor.empty((), dtype=dtypes.int).realize()
|
||||
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
|
||||
expected_out = (a.numpy() + a.numpy().sum()).sum()
|
||||
np.testing.assert_equal(b.numpy(), expected_out)
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
||||
a = Tensor.randint(32, 32).realize()
|
||||
# double reduce collapses to a single reduce
|
||||
r = (a+a).sum(1).sum(0)
|
||||
self.assertEqual(run_tensor_ast(r), (a.numpy()+a.numpy()).sum(1).sum(0))
|
||||
|
||||
def test_single_swizzle(self):
|
||||
# ast in tensor style
|
||||
a = Tensor.randint(4,).realize()
|
||||
expected_out = a.numpy().sum(0)+1
|
||||
# LazyBuffer to pre-rewrite AST
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0,)))
|
||||
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
|
||||
alu = swizzle_r+1
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
|
||||
# graph rewrite
|
||||
sink = swizzle_rewrite(sink)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
b = Tensor.empty((1,), dtype=dtypes.int).realize()
|
||||
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
|
||||
np.testing.assert_equal(b.numpy(), expected_out)
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
||||
a = Tensor.randint(4, 1).realize()
|
||||
b = Tensor.ones((1, 1), dtype=a.dtype).contiguous().realize()
|
||||
# ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD)
|
||||
r = a.sum(0)+b
|
||||
self.assertEqual(run_tensor_ast(r), a.numpy().sum(0)+1)
|
||||
|
||||
def test_double_swizzle_possible(self):
|
||||
# ast in tensor style
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randint(4,).realize()
|
||||
b = Tensor.randint(4,).realize()
|
||||
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
|
||||
# LazyBuffer to pre-rewrite AST
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (Ops.ADD, (0,)))
|
||||
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (Ops.ADD, (0,)))
|
||||
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+2,),),)) # noqa: E501
|
||||
# graph rewrite
|
||||
sink = swizzle_rewrite(sink)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
p = k.to_program()
|
||||
c = Tensor.empty((1,), dtype=dtypes.int).realize()
|
||||
CompiledRunner(p).exec([c.lazydata.buffer, a.lazydata.buffer, b.lazydata.buffer])
|
||||
np.testing.assert_equal(c.numpy(), expected_out)
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randint(4,).realize()
|
||||
b = Tensor.randint(4,).realize()
|
||||
# parallel reduce!
|
||||
add = a.sum(0)+b.sum(0)
|
||||
self.assertEqual(run_tensor_ast(add), a.numpy().sum(0)+b.numpy().sum(0))
|
||||
|
||||
# TODO: this is failing because it cannot resolve the final shape of two swizzled sources
|
||||
@unittest.expectedFailure
|
||||
def test_softmax(self):
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
|
||||
Tensor.manual_seed(0)
|
||||
a = Tensor.randn(32, 32).realize()
|
||||
t = a.softmax()
|
||||
run_tensor_ast(t)
|
||||
|
||||
def test_swizzle_rewrite_alt(self):
|
||||
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
|
||||
Reference in New Issue
Block a user