mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
permute locals for HL uop matmul (#11412)
* permute locals for HL uop matmul * parens fix that * permutes * 20 TFLOPS
This commit is contained in:
@@ -63,8 +63,8 @@ def hl_spec_kernel3():
|
||||
a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N)))
|
||||
b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N))).permute((1,0))
|
||||
c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N)))
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK*BM,)))
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,)))
|
||||
As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK, BM))).permute((1,0))
|
||||
Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK, BN))).permute((1,0))
|
||||
A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,)))
|
||||
B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1).view(ShapeTracker.from_shape((nbIterWaveN * TN,)))
|
||||
|
||||
@@ -78,14 +78,36 @@ def hl_spec_kernel3():
|
||||
A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape)
|
||||
B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape)
|
||||
|
||||
# U1 L2 L3 L4 L5 U6 U7 U9 L10 L11 L12 L13 U14 U15 U17 U18 U19
|
||||
expanded_shape = (32, 2, 2, 2, 2, 2, 2, 2, 32, 2, 2, 2, 2, 2, 2, 2, 512, 2, 2, 2)
|
||||
assert len(expanded_shape) == 20
|
||||
permute_a = list(range(len(expanded_shape)))
|
||||
permute_b = permute_a[:]
|
||||
|
||||
# this makes all the global loads match
|
||||
# this can also be more simply done by rebinding the RANGEs
|
||||
permute_a[17:20] = [11,12,13]
|
||||
permute_a[11:14] = [17,18,19]
|
||||
permute_a[7], permute_a[10] = permute_a[10], permute_a[7]
|
||||
permute_a[2:7] = [3,4,5,6,2]
|
||||
|
||||
permute_b[2:16] = [19,9,10,11,17,18,8,2,12,13,14,15,3,4]
|
||||
permute_b[17:20] = [5,6,7]
|
||||
|
||||
a_permute = a.reshape(expanded_shape).permute(tuple(permute_a)).reshape(full_shape)
|
||||
As_permute = As.reshape(expanded_shape).permute(tuple(permute_a)).reshape(full_shape)
|
||||
|
||||
b_permute = b.reshape(expanded_shape).permute(tuple(permute_b)).reshape(full_shape)
|
||||
Bs_permute = Bs.reshape(expanded_shape).permute(tuple(permute_b)).reshape(full_shape)
|
||||
|
||||
#out = (a.load() * b.load()).r(Ops.ADD, (8, 9))
|
||||
out = (As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))).r(Ops.ADD, (8, 9))
|
||||
out = (As.load(As_permute.store(a_permute.load())) * Bs.load(Bs_permute.store(b_permute.load()))).r(Ops.ADD, (8, 9))
|
||||
#out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9))
|
||||
|
||||
axis_types = (
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST,
|
||||
AxisType.REDUCE, AxisType.UNROLL)
|
||||
AxisType.REDUCE, AxisType.REDUCE)
|
||||
|
||||
sink = c.store(out).sink(arg=KernelInfo(name="tg_"+to_colored(full_shape, axis_types), axis_types=axis_types))
|
||||
sink = graph_rewrite(sink, merge_views)
|
||||
|
||||
Reference in New Issue
Block a user