[BACKEND] no longer uses shared mem or barriers for single-warp reductions (#1915)

0-bytes shared mem buffers don't materialize empty allocation buffers;
this could lead to unnecessary barriers.

note: reduceop code has become quite messy and will require some cleanup
This commit is contained in:
Philippe Tillet
2023-07-11 00:23:26 -07:00
committed by GitHub
parent 7e3ebbc4c8
commit 8fe5524c75
6 changed files with 104 additions and 55 deletions

View File

@@ -155,30 +155,30 @@ def test_elementwise(N, dtype_str):
flash_attention_data = {
"a100": {
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.424,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.379,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.098,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.201,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.199,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.087,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.240,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.210,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.061,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.135,
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.433,
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.392,
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.106,
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204,
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089,
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.242,
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.248,
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.069,
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136,
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.424,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.378,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.099,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.262,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.254,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.125,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.238,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.211,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.062,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.158,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.134,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.075,
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.432,
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.392,
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.107,
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265,
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257,
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128,
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.242,
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.248,
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.069,
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138,
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076,
}
}

View File

@@ -119,7 +119,7 @@ def check_type_supported(dtype, device):
class MmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = str(warps_per_cta)
self.warps_per_cta = warps_per_cta
def __str__(self):
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
@@ -127,10 +127,10 @@ class MmaLayout:
class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
self.sz_per_thread = str(size_per_thread)
self.threads_per_warp = str(threads_per_warp)
self.warps_per_cta = str(warps_per_cta)
self.order = str(order)
self.sz_per_thread = size_per_thread
self.threads_per_warp = threads_per_warp
self.warps_per_cta = warps_per_cta
self.order = order
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
@@ -2072,7 +2072,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
out_dtype = tl.float16
else:
out_dtype = tl.float32
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
y_tri, y_tri.stride(0), y_tri.stride(1),
w_tri, w_tri.stride(0), w_tri.stride(1),
@@ -2087,6 +2086,14 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
CHAIN_DOT=epilogue == 'chain-dot',
ALLOW_TF32=allow_tf32,
num_warps=num_warps)
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
ptx = pgm.asm["ptx"]
start = ptx.find("shfl.sync")
end = ptx.find("cvt.rn.f16.f32")
red_code = ptx[start:end]
assert len(red_code) > 0
assert "shared" not in red_code
assert "bar.sync" not in red_code
# torch result
if in_dtype == 'int8':
z_ref = np.matmul(x.astype(np.float32),