mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
break swizzle into three chunks [pr] (#11153)
* break swizzle into three chunks [pr] * test failed
This commit is contained in:
@@ -35,7 +35,7 @@ class TestBenchLog(unittest.TestCase):
|
||||
self.assertGreater(_events[event]["wall"][0], 0)
|
||||
self.assertGreater(_events[event]["wall"][1], 0)
|
||||
|
||||
@skipIf(CI and Device.DEFAULT == "CUDA", "ci cuda timing is not accurate")
|
||||
@skipIf(CI, "ci timing is not accurate")
|
||||
def test_log_single_kernel_time(self):
|
||||
wall_times = []
|
||||
|
||||
|
||||
@@ -464,11 +464,12 @@ class Kernel:
|
||||
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
|
||||
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
|
||||
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
|
||||
def get_tc_swizzle_st(shape, local_perm, upcast_perm):
|
||||
def get_tc_swizzle_st(shape, local_perm, reduce_perm, upcast_perm):
|
||||
ru_perm = reduce_perm + upcast_perm
|
||||
offset = (tcd - (wd + len(local_perm)))
|
||||
permaxis = list(range(wd)) \
|
||||
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm] + list(range(wd + len(local_perm), tcd)) \
|
||||
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
|
||||
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in ru_perm] + list(range(tcd + len(ru_perm), len(shape)))
|
||||
return ShapeTracker.from_shape(shape).permute(tuple(permaxis))
|
||||
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
|
||||
@@ -11,7 +11,8 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||
dtype_in: DType # dtype for A and B
|
||||
dtype_out: DType # dtype for C and D
|
||||
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifying kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
||||
swizzle: tuple[tuple[tuple[int, ...], tuple[int, ...]], tuple[tuple[int, ...], tuple[int, ...]]] # (local_swizzle, reduce_upcast_swizzle)
|
||||
# (local_swizzle, reduce_swizzle, upcast_swizzle)
|
||||
swizzle: tuple[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]], tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]]
|
||||
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
||||
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
||||
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
|
||||
@@ -25,9 +26,11 @@ class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x
|
||||
assert 2**upcast_axes == self.elements_per_thread[2], \
|
||||
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}"
|
||||
# check swizzle
|
||||
assert len(self.swizzle[0]) == 3 and len(self.swizzle[1]) == 3, "swizzle has wrong part count"
|
||||
assert len(self.swizzle[0][0]) == len(self.swizzle[1][0]) == local_axes, "local swizzle size is wrong"
|
||||
assert len(self.swizzle[0][1]) == len(self.swizzle[1][1]) == reduce_axes + upcast_axes, "reduce/upcast swizzle size is wrong"
|
||||
assert all(sorted(s[0] + s[1]) == list(range(local_axes + reduce_axes + upcast_axes)) for s in self.swizzle), "swizzle missing some dims"
|
||||
assert len(self.swizzle[0][1]) == len(self.swizzle[1][1]) == reduce_axes, "reduce swizzle size is wrong"
|
||||
assert len(self.swizzle[0][2]) == len(self.swizzle[1][2]) == upcast_axes, "reduce/upcast swizzle size is wrong"
|
||||
assert all(sorted(s[0] + s[1] + s[2]) == list(range(local_axes + reduce_axes + upcast_axes)) for s in self.swizzle), "swizzle missing some dims"
|
||||
|
||||
# ***** NVIDIA *****
|
||||
|
||||
@@ -35,12 +38,12 @@ cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with
|
||||
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
|
||||
cuda_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
|
||||
(dtypes.half,dtypes.half)]]
|
||||
swizzle=(((6,7,2,3,4),(0,1,9,5),(10,8)), ((6,7,9,0,1),(2,3,4,10),(5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float),
|
||||
(dtypes.half,dtypes.half)]]
|
||||
cuda_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=di, dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
|
||||
swizzle=(((6,7,2,3,4),(0,1,8),(5,9)), ((6,7,8,0,1),(2,3,4),(9,5)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.half,dtypes.half)]]
|
||||
cuda_8168_tf32 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.float, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
||||
swizzle=(((5,6,2,3,4),(0,1,8,9,7)), ((5,6,8,0,1),(2,3,4,9,7))))]
|
||||
swizzle=(((5,6,2,3,4),(0,1,8),(9,7)), ((5,6,8,0,1),(2,3,4),(9,7))))]
|
||||
|
||||
cuda_sm80: list[TensorCore] = cuda_81616 + cuda_8168_f16
|
||||
if getenv("ALLOW_TF32", 0): cuda_sm80 += cuda_8168_tf32
|
||||
@@ -50,30 +53,30 @@ cuda_sm75: list[TensorCore] = cuda_8168_f16
|
||||
|
||||
# https://gpuopen.com/learn/wmma_on_rdna3/
|
||||
amd_rdna3 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
|
||||
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5),(6,7,8)), ((0,1,2,3,4),(9,10,11,5),(6,7,8))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float)]]
|
||||
amd_rdna4 = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(8,8,8), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","u1","u1","u1","l1"), swizzle=(((9,10,11,4,7),(0,1,2,3,5,6,8)),((0,1,2,3,7),(4,9,10,11,5,6,8))))
|
||||
opts=("l0","l0","l0","l0","u1","u1","u1","l1"), swizzle=(((9,10,11,4,7),(0,1,2,3),(5,6,8)),((0,1,2,3,7),(4,9,10,11),(5,6,8))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
||||
|
||||
# https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme
|
||||
amd_cdna = [TensorCore(dims=(16,16,16), threads=64, elements_per_thread=(4,4,4), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","u1","u1","l1","l1"), swizzle=(((10,11,4,5,8,9),(0,1,2,3,6,7)),((0,1,2,3,8,9),(4,5,10,11,6,7))))
|
||||
opts=("l0","l0","l0","l0","u1","u1","l1","l1"), swizzle=(((10,11,4,5,8,9),(0,1,2,3),(6,7)),((0,1,2,3,8,9),(4,5,10,11),(6,7))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
|
||||
# ***** Apple Metal *****
|
||||
|
||||
metal = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do, opts=("u0","l0","l1","l1","l0","l1"),
|
||||
swizzle=(((6,1,2,7,4),(8,0,3,5)), ((0,5,6,3,7),(1,2,4,8)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
|
||||
swizzle=(((6,1,2,7,4),(8,0,3),(5,)), ((0,5,6,3,7),(1,2,4),(8,)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
|
||||
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
||||
|
||||
# ***** Apple AMX *****
|
||||
|
||||
amx = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
|
||||
swizzle=(((),(0,1,2,3,4,5,6,7)), ((),(4,5,6,7,0,1,2,3))),
|
||||
swizzle=(((),(),(0,1,2,3,4,5,6,7)), ((),(),(4,5,6,7,0,1,2,3))),
|
||||
opts=("u0","u0","u0","u0","u1","u1","u1","u1")) for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
||||
|
||||
# ***** Intel ****
|
||||
|
||||
intel = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
|
||||
opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3,7,8,9)), ((0,1,2),(7,8,9,3,4,5,6))))]
|
||||
opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3),(7,8,9)), ((0,1,2),(7,8,9,3),(4,5,6))))]
|
||||
|
||||
Reference in New Issue
Block a user