diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 22246c31f4..f970680072 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -63,8 +63,8 @@ def hl_spec_kernel3(): a = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=1).view(ShapeTracker.from_shape((N,N))) b = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=2).view(ShapeTracker.from_shape((N,N))).permute((1,0)) c = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(N*N), arg=0).view(ShapeTracker.from_shape((N,N))) - As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK*BM,))) - Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK*BN,))) + As = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BM, AddrSpace.LOCAL), arg=0).view(ShapeTracker.from_shape((BK, BM))).permute((1,0)) + Bs = UOp(Ops.DEFINE_LOCAL, dtypes.float.ptr(BK*BN, AddrSpace.LOCAL), arg=1).view(ShapeTracker.from_shape((BK, BN))).permute((1,0)) A_col = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveM * TM, AddrSpace.REG), arg=0).view(ShapeTracker.from_shape((nbIterWaveM * TM,))) B_row = UOp(Ops.DEFINE_REG, dtypes.float.ptr(nbIterWaveN * TN, AddrSpace.REG), arg=1).view(ShapeTracker.from_shape((nbIterWaveN * TN,))) @@ -78,14 +78,36 @@ def hl_spec_kernel3(): A_col = A_col.reshape((1, nbIterWaveM, 1, TM, 1, 1, 1, 1, 1, 1)).expand(full_shape) B_row = B_row.reshape((1, 1, 1, 1, 1, nbIterWaveN, 1, TN, 1, 1)).expand(full_shape) + # U1 L2 L3 L4 L5 U6 U7 U9 L10 L11 L12 L13 U14 U15 U17 U18 U19 + expanded_shape = (32, 2, 2, 2, 2, 2, 2, 2, 32, 2, 2, 2, 2, 2, 2, 2, 512, 2, 2, 2) + assert len(expanded_shape) == 20 + permute_a = list(range(len(expanded_shape))) + permute_b = permute_a[:] + + # this makes all the global loads match + # this can also be more simply done by rebinding the RANGEs + permute_a[17:20] = [11,12,13] + permute_a[11:14] = [17,18,19] + permute_a[7], permute_a[10] = permute_a[10], permute_a[7] + permute_a[2:7] = [3,4,5,6,2] + + permute_b[2:16] = [19,9,10,11,17,18,8,2,12,13,14,15,3,4] + permute_b[17:20] = [5,6,7] + + a_permute = a.reshape(expanded_shape).permute(tuple(permute_a)).reshape(full_shape) + As_permute = As.reshape(expanded_shape).permute(tuple(permute_a)).reshape(full_shape) + + b_permute = b.reshape(expanded_shape).permute(tuple(permute_b)).reshape(full_shape) + Bs_permute = Bs.reshape(expanded_shape).permute(tuple(permute_b)).reshape(full_shape) + #out = (a.load() * b.load()).r(Ops.ADD, (8, 9)) - out = (As.load(As.store(a.load())) * Bs.load(Bs.store(b.load()))).r(Ops.ADD, (8, 9)) + out = (As.load(As_permute.store(a_permute.load())) * Bs.load(Bs_permute.store(b_permute.load()))).r(Ops.ADD, (8, 9)) #out = (A_col.load(A_col.store(As.load(As.store(a.load())))) * B_row.load(B_row.store(Bs.load(Bs.store(b.load()))))).r(Ops.ADD, (8, 9)) axis_types = ( AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST, AxisType.GLOBAL, AxisType.UPCAST, AxisType.LOCAL, AxisType.UPCAST, - AxisType.REDUCE, AxisType.UNROLL) + AxisType.REDUCE, AxisType.REDUCE) sink = c.store(out).sink(arg=KernelInfo(name="tg_"+to_colored(full_shape, axis_types), axis_types=axis_types)) sink = graph_rewrite(sink, merge_views) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0b14577552..be6472951a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -47,7 +47,7 @@ base_rewrite = PatternMatcher([ lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"), (UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat.var("gate"))).or_casted("bidx"), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), - (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"), + (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"(*{ctx[bidx]})"), (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), # alu/gep # TODO: look for left-associative @@ -170,6 +170,7 @@ class CStyleLanguage(Renderer): if u.op in {Ops.ENDIF, Ops.ENDRANGE}: depth -= 1 if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ + (u.op is Ops.LOAD and cast(PtrDType, u.src[0].dtype).addrspace == AddrSpace.REG) or \ (u.op in {Ops.VECTORIZE, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): r[u] = l else: