mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -1,10 +1,13 @@
|
||||
from tinygrad import Tensor, Device, Context, GlobalCounters, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType, PatternMatcher, UPat
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.schedule.kernelize import merge_views
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.schedule.kernelize import merge_views, view_left
|
||||
from tinygrad.helpers import getenv, colored, prod, unwrap
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View, strides_for_shape
|
||||
from tinygrad.opt.kernel import axis_colors
|
||||
|
||||
def to_colored(full_shape, axis_types): return '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
|
||||
|
||||
N = 4096
|
||||
run_count = 5
|
||||
@@ -16,15 +19,50 @@ BK = 8
|
||||
TN = 4
|
||||
TM = 4
|
||||
|
||||
# NOTE: this is from testgrad
|
||||
# change reduceop axes and input ShapeTrackers, view gets replaced with a reshape.
|
||||
# src->r->view --> src->view->r
|
||||
def swizzle_reduceop(src:UOp, r:UOp, view:UOp):
|
||||
if r.tag is not None: return None
|
||||
# confirm the input is in order
|
||||
# TODO: replace this with a UOp that allows for nothing else then remove this
|
||||
permute = tuple(i for i in range(len(src.shape)) if i not in r.axis_arg)+r.axis_arg
|
||||
assert permute == tuple(range(len(permute))), f"reduce axis must already be in order, {permute} isn't"
|
||||
|
||||
# append the reduce shape to each of the views
|
||||
prshape = prod(rshape:=src.shape[-len(r.axis_arg):])
|
||||
rstrides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+rstrides, v.offset*prshape,
|
||||
v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in unwrap(view.st).views]
|
||||
|
||||
# no reshape required with shrinking REDUCE_AXIS
|
||||
return UOp(Ops.REDUCE_AXIS, r.dtype, (src.view(ShapeTracker(tuple(nv))),),
|
||||
(r.arg[0], tuple(range(len(view.shape), len(view.shape) + len(r.axis_arg)))))
|
||||
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop),
|
||||
])
|
||||
|
||||
def top_spec_kernel3():
|
||||
a = Tensor.empty(N,N)
|
||||
b = Tensor.empty(N,N)
|
||||
c = a@b
|
||||
sink = c.schedule()[-1].ast
|
||||
L = 16
|
||||
sink = sink.reshape((N//L, L, N//L, L)) #.lift({0:UOp.range(dtypes.int, N//BM, 0), 2:UOp.range(dtypes.int, N//BN, 1)})
|
||||
sink = graph_rewrite(sink, view_left+pm)
|
||||
axis_types = (AxisType.GLOBAL, AxisType.LOCAL, AxisType.GLOBAL, AxisType.LOCAL, AxisType.REDUCE)
|
||||
return sink.replace(arg=KernelInfo(name="top_"+to_colored(sink.full_shape, axis_types), axis_types=axis_types))
|
||||
|
||||
def hl_spec_kernel3():
|
||||
nbIterWaveM = 2
|
||||
nbIterWaveN = 2
|
||||
|
||||
# define buffers
|
||||
# TODO: remove these views once the defines have a shape
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N))).permute((1,0))
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N)))
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N)))
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N))).permute((1,0))
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK*BM,)))
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,)))
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
|
||||
@@ -49,9 +87,7 @@ def hl_spec_kernel3():
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.REDUCE, AxisType.UNROLL)
|
||||
|
||||
from tinygrad.opt.kernel import axis_colors
|
||||
shape = '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+shape, axis_types=axis_types))
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+to_colored(full_shape, axis_types), axis_types=axis_types))
|
||||
sink = graph_rewrite(sink, merge_views)
|
||||
return sink
|
||||
|
||||
@@ -98,9 +134,9 @@ def hand_spec_kernel3():
|
||||
blockIdx_x = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx0", N//BN))
|
||||
blockIdx_y = UOp(Ops.SPECIAL, dtypes.int, arg=("gidx1", N//BM))
|
||||
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1)
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2)
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0)
|
||||
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0)
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1)
|
||||
@@ -167,9 +203,13 @@ def hand_spec_kernel3():
|
||||
return sink.sink(arg=KernelInfo(name="tinygemm"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
hprg = hl_spec_kernel3() if getenv("HL") else hand_spec_kernel3()
|
||||
HL = getenv("HL")
|
||||
if HL == 2: hprg = top_spec_kernel3()
|
||||
elif HL == 1: hprg = hl_spec_kernel3()
|
||||
else: hprg = hand_spec_kernel3()
|
||||
prg = get_program(hprg, Device.default.renderer)
|
||||
print(prg.src)
|
||||
if getenv("SRC"): exit(0)
|
||||
hrunner = CompiledRunner(prg)
|
||||
|
||||
a = Tensor.randn(N, N).realize()
|
||||
@@ -181,7 +221,8 @@ if __name__ == "__main__":
|
||||
for _ in range(run_count): tc = (a@b).realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
ei = ExecItem(hrunner, [a.uop.buffer, b.uop.buffer, hc.uop.buffer])
|
||||
buffers = [hc.uop.buffer, a.uop.buffer, b.uop.buffer]
|
||||
ei = ExecItem(hrunner, buffers)
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
err = (hc-tc).square().mean().item()
|
||||
|
||||
@@ -52,7 +52,8 @@ def lower_load(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
return UOp(Ops.LOAD, x.dtype, (buf.index(idx, valid),) + barrier)
|
||||
|
||||
def lower_store(ctx: IndexContext, x: UOp, buf: UOp):
|
||||
assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
|
||||
# TODO: reenable after REDUCE_AXIS is fixed
|
||||
#assert x.src[1].shape == x.src[0].shape, f"shape mismatch on store {x.src[1].shape} != {x.src[0].shape}"
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.idxs)
|
||||
if cast(PtrDType, buf.dtype).addrspace == AddrSpace.GLOBAL:
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
|
||||
@@ -188,7 +188,7 @@ def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
|
||||
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# view before elementwise and buffer ops
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID}, name="e"),), name="view"),
|
||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.BIND, Ops.LOAD, Ops.STORE, Ops.VALID, Ops.SINK}, name="e"),), name="view"),
|
||||
lambda e,view: e.replace(src=tuple(s.view(view.st) for s in e.src))),
|
||||
# if there's ones added after reduce, put this before the reduce
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), reduce_push_add_ones),
|
||||
|
||||
Reference in New Issue
Block a user