Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-03-09 19:24:40 +00:00
6 changed files with 109 additions and 181 deletions

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, all_same, temp
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -178,6 +178,13 @@ class TestSchedule(unittest.TestCase):
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
with self.assertRaises(KernelCountException): check_schedule(c, 1)
def test_allow_push_permutes(self):
a = Tensor.randn(10,10,10).realize()
b = Tensor.randn(10,10,1).realize()
c = a.sum(axis=0, keepdim=True).permute(2,1,0) + b
with Context(DONT_GROUP_REDUCES=1): run_schedule(check_schedule(c, 1))
np.testing.assert_allclose(c.numpy(), np.sum(a.numpy(), axis=0, keepdims=True).transpose(2,1,0)+b.numpy())
def test_binop_early_reshape_reduce_fusion(self):
a = Tensor.empty(100)
b = Tensor.empty(100)
@@ -247,20 +254,14 @@ class TestSchedule(unittest.TestCase):
def test_div_collapse_buffer(self):
a = Tensor.full((4,), 4.0).contiguous().realize()
b = Tensor.full((4,), 2.0).contiguous().realize()
GlobalCounters.reset()
expr = (a*b)/b
expr.realize()
self.assertEqual(GlobalCounters.kernel_count, 0) # the scheduler can fold divs now!
self.assertEqual(GlobalCounters.global_ops, 0)
check_schedule(expr, 0)
np.testing.assert_allclose(expr.numpy(), np.full((4,), 4.0))
def test_div_collapse_const(self):
a = Tensor.full((4,), 4.0).contiguous().realize()
GlobalCounters.reset()
expr = a/a
expr.realize()
self.assertEqual(GlobalCounters.kernel_count, 0)
self.assertEqual(GlobalCounters.global_ops, 0)
check_schedule(expr, 0)
np.testing.assert_allclose(expr.numpy(), np.full((4,), 1.0))
def test_div_collapse(self):
@@ -316,7 +317,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_double_unary(self):
y = Tensor.empty(2)
out = y.sum(keepdim=True).sqrt().__neg__()
out = y.sum(keepdim=True).sqrt().neg()
check_schedule(out, 1)
#@unittest.skip("may want to reconsider this")
@@ -1871,7 +1872,7 @@ class TestIndexing(unittest.TestCase):
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + 2
r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# this AST first needs to swizzle, but it doesn't have implicit movementops
@@ -1889,126 +1890,73 @@ class TestIndexing(unittest.TestCase):
@track_rewrites(named=True)
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
# these pattern matchers should move to engine/schedule.py
ops_folding = symbolic_simple+PatternMatcher([
(UPat(Ops.DETACH, name="x"), lambda x:x.src[0]),
])
def _load_buffer(ctx:list[UOp], buf:UOp):
glbl = UOp(Ops.DEFINE_GLOBAL, buf.dtype.ptr(size=buf.size), (), len(ctx))
ctx.append(buf)
return UOp(Ops.LOAD, buf.dtype, (glbl, ShapeTracker.from_shape((buf.size,)).to_uop()))
load_buffers = PatternMatcher([
(UPat(Ops.BUFFER, name="buf"), _load_buffer),
])
# put the entire schedule of the tensor in a single ScheduleItem
@track_rewrites(named=True)
def run_tensor_ast(r:Tensor):
output = UOp.new_buffer(r.device, r.lazydata.size, r.dtype)
glbl = UOp(Ops.DEFINE_GLOBAL, output.dtype.ptr(size=output.size), (), 0)
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output])
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right)
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), ())
run_schedule([si])
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(32, 32).realize()
# double reduce collapses to a single reduce
r = (a+a).sum(1).sum(0)
self.assertEqual(run_tensor_ast(r), (a.numpy()+a.numpy()).sum(1).sum(0))
# double reduce collapses to a single reduce
with Context(DONT_GROUP_REDUCES=1):
run_schedule(check_schedule(r, 1))
self.assertEqual(r.numpy(), (a.numpy()+a.numpy()).sum(1).sum(0))
def test_single_swizzle(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randint(4, 1).realize()
b = Tensor.ones((1, 1), dtype=a.dtype).contiguous().realize()
# ADD(REDUCE(RESHAPE(LOAD)), LOAD) to ADD(REDUCE(RESHAPE(LOAD))), RESHAPE(LOAD)
r = a.sum(0)+b
self.assertEqual(run_tensor_ast(r), a.numpy().sum(0)+1)
run_schedule(check_schedule(r, 1))
self.assertEqual(r.numpy(), a.numpy().sum(0)+1)
def test_double_swizzle_possible(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
Tensor.manual_seed(0)
a = Tensor.randint(4,).realize()
b = Tensor.randint(4,).realize()
# parallel reduce!
add = a.sum(0)+b.sum(0)
self.assertEqual(run_tensor_ast(add), a.numpy().sum(0)+b.numpy().sum(0))
with Context(DONT_GROUP_REDUCES=1):
run_schedule(check_schedule(add, 1))
self.assertEqual(add.numpy(), a.numpy().sum(0)+b.numpy().sum(0))
# TODO: this is failing because it cannot resolve the final shape of two swizzled sources
@unittest.expectedFailure
def test_softmax(self):
@unittest.skip("TODO: how do we express the norm")
def test_softmax_one_kernel(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
Tensor.manual_seed(0)
a = Tensor.randn(32, 32).realize()
t = a.softmax()
run_tensor_ast(t)
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1):
check_schedule(t, 1)
def test_swizzle_rewrite_alt(self):
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
# there's an UNROLL pushing through the REDUCE_AXIS
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
ret = swizzle_rewrite(swizzle)
# UNROLL is rewritten
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
# and pushed to the LOAD
new_load_st = unwrap([x for x in ret.toposort if x.op is Ops.VIEW][0].st)
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
def test_argmax_one_kernel(self):
Tensor.manual_seed(0)
with Context(DEBUG=0, TRACK_MATCH_STATS=0):
a = Tensor.randn(10, 20).realize()
t = a.argmax(0)
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): t.realize()
def test_swizzle_reduceop(self):
Tensor.manual_seed(0)
x = Tensor.randn(4,4).realize()
y = Tensor.randn(4,4,4).realize()
out = x.reshape(4,4,1).expand(4,4,4).sum(axis=(1,))+y
with Context(DONT_REALIZE_EXPAND=1, DONT_GROUP_REDUCES=1):
run_schedule(check_schedule(out, 1))
np.testing.assert_allclose(out.numpy(), np.tile(x.numpy().reshape(4,4,1), (1,1,4)).sum(axis=1)+y.numpy())
def test_permute_rewrite(self):
x = Tensor.randn(4, 4, 16).realize()
y = Tensor.randn(4, 1, 16).realize()
z = Tensor.randn(4, 4, 1).realize()
t = (x*y).sum(axis=(0, 2)).reshape(1, 4, 1).permute(0, 2, 1)+z
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
t_np = (x.numpy()*y.numpy()).sum(axis=(0, 2)).reshape(1, 4, 1).transpose(0, 2, 1)+z.numpy()
np.testing.assert_allclose(run_tensor_ast(t), t_np, atol=1e-6, rtol=1e-3)
@unittest.expectedFailure
def test_fuse_conv2_relu_bw(self):
# fuse (relu bw, conv2d, conv2d bw, relu)
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(10, ('METAL', 128, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.CAST, dtypes.float, arg=None, src=(
UOp(Ops.CMPLT, dtypes.bool, arg=None, src=(
x6:=UOp(Ops.WHERE, dtypes.float, arg=None, src=(
UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
x9:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),
x9,)),
UOp(Ops.MAX, dtypes.float, arg=None, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(9, ('METAL', 96, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(48, 0, 0, 4, 1, 16, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(16, ('METAL', 432, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 16, 2, 2, 3, 3, 3), strides=(0, 0, 27, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),
x6,)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 2), strides=(64, 4, 2, 1), offset=0, mask=None, contiguous=True),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(18, ('METAL', 128, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 16, 2, 3, 2, 3), strides=(64, 4, 2, 0, 1, 0), offset=0, mask=((0, 2), (0, 16), (0, 2), (0, 1), (0, 2), (0, 1)), contiguous=False), View(shape=(1, 2, 1, 16, 3, 2, 3, 2), strides=(0, 576, 0, 36, 12, 6, 2, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),))
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
np.testing.assert_allclose(t.numpy(), t_np, atol=1e-6, rtol=1e-3)
@unittest.skip("this swizzle can't be decided after the ADD")
def test_swizzle_failure_permute(self):
@@ -2052,39 +2000,6 @@ class TestSwizzle(unittest.TestCase):
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
def test_non_contiguous_view_simplify(self):
st = ShapeTracker(views=(View(shape=(2048, 2048), strides=(1, 2048), offset=0, mask=None, contiguous=False),))
a = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, 4194304, dtypes.char), st.to_uop()))
ret = swizzle_rewrite(a.view(st))
self.assertEqual(ret.st_arg, st+st)
def test_contiguous_view_simplify(self):
base = ShapeTracker.from_shape((32, 32))
a = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
swizzle = a.reshape((64, 16))
swizzle = graph_rewrite(swizzle, remove_movement_ops)
self.assertEqual(swizzle_cnt(swizzle), 1)
ret = swizzle_rewrite(swizzle)
self.assertEqual(ret.st_arg, base.reshape((64, 16))) # late rewrite
reswizzle = a.reshape((64, 16)).reshape((32, 32))
self.assertEqual(swizzle_cnt(reswizzle), 0) # instant rule
ret = swizzle_rewrite(reswizzle)
self.assertEqual(ret.st, reswizzle.st)
def test_late_fusion_post_permute_simpler(self):
base = ShapeTracker.from_shape((32, 16, 1))
start = UOp(Ops.LOAD, dtypes.char, (UOp.new_buffer(Device.DEFAULT, base.size, dtypes.char), base.to_uop()))
r = start.expand((32, 16, 16)).r(Ops.ADD, (2,))
add = r.reshape((16, 32, 1)) + UOp.const(r.dtype, 0)
self.assertEqual(add.st, ShapeTracker.from_shape((16, 32, 1)))
to_store = add.permute((1, 0, 2)).contiguous()
to_store = graph_rewrite(to_store, remove_movement_ops)
self.assertEqual(to_store.st, ShapeTracker.from_shape((32, 16, 1)))
self.assertEqual(to_store.src[0].st, add.st.permute((1, 0, 2)))
self.assertIs(to_store.src[0].op, Ops.VIEW)
ret = graph_rewrite(to_store, view_left)
self.assertEqual(swizzle_cnt(ret), 1)
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
zero_pm = UPat(Ops.CONST, arg=0)
class TestView(unittest.TestCase):
@@ -2159,7 +2074,7 @@ class TestView(unittest.TestCase):
self.assertEqual(other_child.tolist(), [2, 3, 4])
def tensor_rewrite(t) -> UOp: return graph_rewrite(t.lazydata.base, remove_movement_ops+symbolic_simple)
class TestBigGraph(unittest.TestCase):
class TestSimplifier(unittest.TestCase):
def test_sink_childless_const(self):
x = Tensor(0)
check_schedule(x, 0)
@@ -2242,12 +2157,11 @@ class TestConst(unittest.TestCase):
a = Tensor.ones((4,)).pad((1, 1)).contiguous()
sched = a.schedule()
print(sched[0].ast)
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(Ops.WHERE, src=(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0)))),))
const_ast_pattern = UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat.where(UPat(Ops.VALID), UPat.cvar("x"), UPat(Ops.CONST, arg=0))),))
self.assertEqual(len(const_ast_pattern.match(sched[0].ast, {})), 1)
run_schedule(sched)
self.assertListEqual(a.tolist(), [0, 1, 1, 1, 1, 0])
# TOOD: currently even unmasked constants are VALID until codegen
def test_unmasked_const_ast(self):
a = Tensor.ones((4,)).contiguous()
sched = a.schedule()
@@ -2641,5 +2555,12 @@ class TestUOpBecome(unittest.TestCase):
assert b.lazydata is c.lazydata
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {})
def test_setitem_becomes_view_of_base(self):
a = Tensor.full((4,), 2.).contiguous().realize()
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
b.realize()
assert b.lazydata.is_realized
assert b.lazydata.base.buffer._base is None
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -5,7 +5,7 @@ from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, grap
from tinygrad.ops import can_pad, identity_element, resolve, view_left, merge_views
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, unwrap, flatten, getenv, pluralize
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, SPLIT_REDUCEOP
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
from tinygrad.dtype import ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
@@ -165,6 +165,7 @@ create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"),
def group_realizes(sink:UOp) -> dict[UOp, None]:
# start by adding uops that always realize
sink = graph_rewrite(sink, do_realize+create_ctx, ctx:=GrouperContext({}, {}, defaultdict(dict)))
if DONT_GROUP_REDUCES: return ctx.realizes
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: dict[UOp, UOp] = {}
double_reduces: list[UOp] = []
@@ -230,6 +231,11 @@ class KernelContext:
realizes: dict[UOp, None]
ops_metadata: dict[UOp, Metadata]
def create_kernel(x:UOp, b:UOp):
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x))
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
def append_to_kernel(ctx:KernelContext, x:UOp):
new_srcs: list[UOp] = []
@@ -244,11 +250,14 @@ def append_to_kernel(ctx:KernelContext, x:UOp):
return x.replace(arg=Kernel(x.arg.ast, new_metadata)) if (new_metadata:=tuple(metadata)) != x.arg.metadata else None
create_kernels = merge_views+PatternMatcher([
# always give assign a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), lambda x,b: b.assign(UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))),
# otherwise check if need to assign this UOp to a new buffer
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: UOp(Ops.ASSIGN, x.dtype, (b:=UOp.new_buffer(x.device, x.size, x.dtype).view(x.st),\
UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x)))) if x in ctx.realizes else None),
# always give assign/contiguous a kernel
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
(UPat(Ops.CONTIGUOUS, name="x"), lambda x: create_kernel(x, UOp.new_buffer(x.device, x.size, x.dtype))),
# create a buffer for COPY on the new device
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="d"), UPat()), name="x"), lambda d,x: create_kernel(x, UOp.new_buffer(d.arg, x.size, x.dtype))),
# otherwise check the context if we're realizing this UOp
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"),
lambda ctx,x: create_kernel(x, UOp.new_buffer(x.device, x.size, x.dtype)) if x in ctx.realizes else None),
# walk back the local graph until we reach a buffer/assign parent
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
# remove CONST/BIND from SINK
@@ -260,13 +269,9 @@ create_kernels = merge_views+PatternMatcher([
# ** create buffer ops + enumerate buffers
def load_buf(ctx:list[UOp], x:UOp):
if x not in ctx: ctx.append(x)
return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop()))
add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), load_buf),
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
# STORE (except for COPY/BUFFER_VIEW)
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
@@ -278,8 +283,9 @@ add_buffer_ops = PatternMatcher([
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
def swizzle_reduceop(r:UOp, src:UOp, view:UOp):
if (st:=unwrap(view.st)).contiguous: return None
input_st = ShapeTracker.from_shape(src.shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
@@ -290,20 +296,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
output_shape = swizzle_st.reduce(r.axis_arg)
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
def elementwise_view_right(root:UOp) -> UOp|None:
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW]): return None
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# push the swizzle from src to root
output_swizzle = swizzles[0]
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
# place view after applying the elementwise op
new_shape = swizzles[0].base.shape
ret = root.replace(src=tuple(x.base if x.base.shape == new_shape else apply_swizzle(x.view(ShapeTracker.from_shape(new_shape))) for x in root.src))
# reshape to match downstream shapes
return ret.reshape(root.shape)
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
@@ -317,12 +321,12 @@ view_right = merge_views+PatternMatcher([
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
# STORE is the last child, so we just merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
# REDUCE(src.view()) -> REDUCE(src).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
# ALU(src.view()) -> ALU(src).view()
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
# push a non contiguous ShapeTracker through reduceop
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
# apply view after reduceops
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat.var("src"),), name="v"),), name="r"), reduceop_view_right),
# apply view after elementwise ops
(UPat(GroupOp.All-GroupOp.Buffer, name="root"), elementwise_view_right),
# double reduce op collapses to a single reduce op
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
@@ -369,10 +373,10 @@ fix_kernel_ops = PatternMatcher([
def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
# substitute kernel sources for the target buffer
ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink()
# substitute kernel sources for the target buffer + apply reshapes
ast = k.arg.ast.substitute({(ast:=s.src[1].arg.ast):s.src[0].view(unwrap(ast.st)) for s in k.src if s.op is Ops.ASSIGN}).sink()
# add buffer ops
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True)
ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# unbind_vars + push views to edges
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
@@ -417,9 +421,11 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# map tensors to buffer/const, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
# if we created a KERNEL for this tensor, map it to the assigned buffer
if (a:=kernel_map.get(v.base)) is not None and a.op is Ops.ASSIGN:
becomes_map[k] = a.src[0] if v is v.base else a.src[0].view(unwrap(v.st))
# ASSIGN always becomes the target buffer
if v.op is Ops.ASSIGN: becomes_map[k] = v.src[0]
# if we created a new buffer for this tensor, map it to the assigned buffer
elif (a:=kernel_map.get(v.base)) is not None and (a:=a.base).op is Ops.ASSIGN:
becomes_map[k] = a.src[0] if a.src[0].st == v.st else a.src[0].view(unwrap(v.st))
# tensors can also simplify to an existing buffer/const
else:
if k is v: continue
@@ -463,8 +469,8 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
# TODO: move this to create_kernels
k = fix_kernel_ast(u.src[1], var_vals)
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
# increment the refcount of the target buf (this is required by the JIT and memory planner)
u.buf_uop.buffer.ref(1)
# increment the refcount of the target buf (this is required by the JIT and memory planner) TODO: this does not belong here
k.src[0].buffer.ref(1)
for x in children.get(u, []):
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(x)

View File

@@ -112,7 +112,7 @@ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
DONT_REALIZE_EXPAND = ContextVar("DONT_REALIZE_EXPAND", 0)
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
@dataclass(frozen=True)
class Metadata:

View File

@@ -292,7 +292,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ShapeTracker.from_shape(
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
if self.op in {Ops.BUFFER, Ops.BUFFER_VIEW}: return ShapeTracker.from_shape((self.size,))
if self.op is Ops.KERNEL: return ShapeTracker.from_shape(self.arg.ast.shape)
if self.op is Ops.KERNEL: return ShapeTracker.from_shape((self.arg.ast.size,))
# these ops define a ShapeTracker from the arg
if self.op is Ops.VIEW: return self.arg
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)

View File

@@ -584,7 +584,8 @@ class AMDDevice(HCQCompiled):
sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000
vgpr_size_per_cu = 0x60000 if self.target in {110000, 110001, 120000, 120001} else 0x40000
wg_data_size = round_up((vgpr_size_per_cu + sgrp_size_per_cu + lds_size_per_cu + hwreg_size_per_cu) * (self.max_cu_id + 1), mmap.PAGESIZE)
ctl_stack_size = round_up(12 * (self.max_cu_id + 1) * (self.max_wave_id + 1) + 8 + 40, mmap.PAGESIZE) if self.target//10000 != 10 else 0x7000
ctl_stack_size = round_up(12 * (self.max_cu_id + 1) * (self.max_wave_id + 1) + 8 + 40, mmap.PAGESIZE)
if self.target//10000 == 10: ctl_stack_size = min(ctl_stack_size, 0x7000)
debug_memory_size = round_up((self.max_cu_id + 1) * (self.max_wave_id + 1) * 32, 64)
self.compute_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_COMPUTE, 0x800000, ctx_save_restore_size=wg_data_size + ctl_stack_size,

View File

@@ -1,7 +1,7 @@
from typing import cast
from tinygrad.ops import PatternMatcher, UPat, GroupOp, Ops, UOp, print_uops
from tinygrad.dtype import DType, ImageDType, dtypes, PtrDType
from tinygrad.helpers import all_same, all_int, dedup, prod
from tinygrad.helpers import all_same, dedup, prod
buffer_spec = PatternMatcher([
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
@@ -9,7 +9,7 @@ buffer_spec = PatternMatcher([
(UPat(Ops.BUFFER, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="buf"),
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all_int(buf_view.arg)),
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
])
# *** this is the spec of a Tensor in UOp ***
@@ -126,10 +126,10 @@ spec = PatternMatcher([
# *** this is the spec of a Kernel in UOp ***
kernel_spec = buffer_spec+PatternMatcher([
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.ASSIGN))), lambda: True),
# assign has a buffer view and kernel source, it can optionally depend on other assigns
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}), lambda: False),
(UPat(Ops.KERNEL, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.ASSIGN))), lambda: True),
# assign has a buffer and kernel source, it can optionally depend on other assigns
(UPat(Ops.ASSIGN, src=UPat((Ops.BUFFER, Ops.BUFFER_VIEW, Ops.KERNEL, Ops.ASSIGN))), lambda: True),
(UPat(GroupOp.All-{Ops.SINK}), lambda: False),
])
# *** this is the UOp shape spec ***