start tests for putting the tensor graph in a single kernel [pr] (#8542)

* start tests for putting the tensor graph in a single kernel [pr]

* parallel actually

* better view_left test

* test a softmax

* put all that in sym
This commit is contained in:
qazal
2025-01-09 13:33:21 -05:00
committed by GitHub
parent 83a8217cbf
commit f6eb0574f2

View File

@@ -15,7 +15,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic_simple, merge_views
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -1753,73 +1753,65 @@ def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_le
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
# these pattern matchers should move to engine/schedule.py
sym = symbolic_simple+PatternMatcher([
(UPat(Ops.DETACH, name="x"), lambda x:x.src[0]),
])
def _load_buffer(ctx:list[UOp], buf:UOp):
glbl = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(size=buf.size), (), len(ctx))
ctx.append(buf)
return UOp(Ops.LOAD, buf.dtype, (glbl, ShapeTracker.from_shape((buf.size,)).to_uop()))
load_buffers = PatternMatcher([
(UPat(Ops.BUFFER, name="buf"), _load_buffer),
])
# put the entire schedule of the tensor in a single ScheduleItem
@track_rewrites(named=True)
def run_tensor_ast(r:Tensor):
output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype)
glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0)
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
sink = graph_rewrite(sink, remove_movement_ops+sym+load_buffers+view_left, bufs:=[output])
sink = graph_rewrite(sink, remove_movement_ops+sym+view_right)
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
run_schedule([si])
return output.realized.as_buffer().cast(output.dtype.fmt).tolist()
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.ADD, dtypes.int, arg=None, src=(
UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (0, 1)), src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8,
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
sink = swizzle_rewrite(sink)
k = Kernel(sink)
p = k.to_program()
a = Tensor.randint(32, 32).realize()
b = Tensor.empty((), dtype=dtypes.int).realize()
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
expected_out = (a.numpy() + a.numpy().sum()).sum()
np.testing.assert_equal(b.numpy(), expected_out)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(32, 32).realize()
# double reduce collapses to a single reduce
r = (a+a).sum(1).sum(0)
self.assertEqual(run_tensor_ast(r), (a.numpy()+a.numpy()).sum(1).sum(0))
def test_single_swizzle(self):
# ast in tensor style
a = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+1
# LazyBuffer to pre-rewrite AST
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0,)))
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
alu = swizzle_r+1
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
# graph rewrite
sink = swizzle_rewrite(sink)
# verify output
k = Kernel(sink)
p = k.to_program()
b = Tensor.empty((1,), dtype=dtypes.int).realize()
CompiledRunner(p).exec([b.lazydata.buffer, a.lazydata.buffer])
np.testing.assert_equal(b.numpy(), expected_out)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(4, 1).realize()
b = Tensor.ones((1, 1), dtype=a.dtype).contiguous().realize()
# ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD)
r = a.sum(0)+b
self.assertEqual(run_tensor_ast(r), a.numpy().sum(0)+1)
def test_double_swizzle_possible(self):
# ast in tensor style
Tensor.manual_seed(0)
a = Tensor.randint(4,).realize()
b = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
# LazyBuffer to pre-rewrite AST
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (Ops.ADD, (0,)))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (Ops.ADD, (0,)))
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+2,),),)) # noqa: E501
# graph rewrite
sink = swizzle_rewrite(sink)
# verify output
k = Kernel(sink)
p = k.to_program()
c = Tensor.empty((1,), dtype=dtypes.int).realize()
CompiledRunner(p).exec([c.lazydata.buffer, a.lazydata.buffer, b.lazydata.buffer])
np.testing.assert_equal(c.numpy(), expected_out)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
Tensor.manual_seed(0)
a = Tensor.randint(4,).realize()
b = Tensor.randint(4,).realize()
# parallel reduce!
add = a.sum(0)+b.sum(0)
self.assertEqual(run_tensor_ast(add), a.numpy().sum(0)+b.numpy().sum(0))
# TODO: this is failing because it cannot resolve the final shape of two swizzled sources
@unittest.expectedFailure
def test_softmax(self):
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
Tensor.manual_seed(0)
a = Tensor.randn(32, 32).realize()
t = a.softmax()
run_tensor_ast(t)
def test_swizzle_rewrite_alt(self):
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501