mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -3,7 +3,7 @@ from tinygrad.uop.ops import UOp, Ops, KernelInfo, graph_rewrite, AxisType
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.schedule.kernelize import merge_views
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
N = 4096
|
||||
@@ -22,9 +22,9 @@ def hl_spec_kernel3():
|
||||
|
||||
# define buffers
|
||||
# TODO: remove these views once the defines have a shape
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N*N,)))
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N*N,)))
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N*N,)))
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N))).permute((1,0))
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N)))
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK*BM,)))
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,)))
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
|
||||
@@ -39,12 +39,19 @@ def hl_spec_kernel3():
|
||||
Bs = Bs.reshape((1, 1, 1, 1, 1, nbIterWaveN, BN//(nbIterWaveN * TN), TN, 1, BK)).expand(full_shape)
|
||||
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
|
||||
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
|
||||
out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
|
||||
axis_types = [
|
||||
|
||||
#out = (a.load() * b.load()).r(Ops.ADD, (8, 9))
|
||||
out = (As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))).r(Ops.ADD, (8, 9))
|
||||
#out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
|
||||
|
||||
axis_types = (
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.REDUCE, AxisType.UNROLL]
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tinygemm", axis_types=tuple(axis_types)))
|
||||
AxisType.REDUCE, AxisType.UNROLL)
|
||||
|
||||
from tinygrad.opt.kernel import axis_colors
|
||||
shape = '_'.join([colored(str(s), axis_colors[at]) for s,at in zip(full_shape, axis_types)])
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+shape, axis_types=axis_types))
|
||||
sink = graph_rewrite(sink, merge_views)
|
||||
return sink
|
||||
|
||||
|
||||
28
test/test_define_reg.py
Normal file
28
test/test_define_reg.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes, Device, Tensor, Context
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.realize import get_program, ExecItem, CompiledRunner
|
||||
|
||||
class TestDefineReg(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
N = 16
|
||||
bout = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N)))
|
||||
a_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(N, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((1,N)))
|
||||
|
||||
out = a_col.load(a_col.store(a.load()))
|
||||
sink = bout.store(out).sink(arg=KernelInfo(name="regcopy", axis_types=(AxisType.LOOP, AxisType.UPCAST)))
|
||||
prg = get_program(sink, Device.default.renderer)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
a = Tensor.randn(N, N).realize()
|
||||
b = Tensor.empty(N, N).realize()
|
||||
hrunner = CompiledRunner(prg)
|
||||
ExecItem(hrunner, [b.uop.buffer, a.uop.buffer]).run(wait=True)
|
||||
with Context(DEBUG=0):
|
||||
self.assertEqual((b-a).mean().item(), 0.0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -56,8 +56,9 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
||||
if not ki.global_dims and not ki.local_dims: return None
|
||||
s_topo = list(s.toposort())
|
||||
if any(x.op is Ops.SPECIAL for x in s_topo): return None
|
||||
ranges = sorted([x for x in s_topo if x.op is Ops.RANGE and x.arg in (ki.global_dims+ki.local_dims)], key=lambda x: x.arg)
|
||||
if not len(ranges): return None
|
||||
all_ranges = {x.arg:x for x in s_topo if x.op is Ops.RANGE}
|
||||
# NOTE: this supports globals/locals in any position
|
||||
ranges = [all_ranges[r] for r in ki.global_dims+ki.local_dims]
|
||||
global_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.global_dims])
|
||||
local_shape = tuple([ssimplify(r.src[0]) for r in ranges if r.arg in ki.local_dims])
|
||||
if ki.dont_use_locals:
|
||||
|
||||
Reference in New Issue
Block a user