mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] no longer uses shared mem or barriers for single-warp reductions (#1915)
0-bytes shared mem buffers don't materialize empty allocation buffers; this could lead to unnecessary barriers. note: reduceop code has become quite messy and will require some cleanup
This commit is contained in:
@@ -155,30 +155,30 @@ def test_elementwise(N, dtype_str):
|
||||
|
||||
flash_attention_data = {
|
||||
"a100": {
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.424,
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.379,
|
||||
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.098,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.201,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.199,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.087,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.240,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.210,
|
||||
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.061,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.135,
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.433,
|
||||
(4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.392,
|
||||
(4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.106,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204,
|
||||
(4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202,
|
||||
(4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.242,
|
||||
(4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.248,
|
||||
(4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.069,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136,
|
||||
(4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135,
|
||||
(4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.424,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.378,
|
||||
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.099,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.262,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.254,
|
||||
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.125,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.238,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.211,
|
||||
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.062,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.158,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.134,
|
||||
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.075,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.432,
|
||||
(4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.392,
|
||||
(4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.107,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265,
|
||||
(4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257,
|
||||
(4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.242,
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.248,
|
||||
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.069,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138,
|
||||
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ def check_type_supported(dtype, device):
|
||||
class MmaLayout:
|
||||
def __init__(self, version, warps_per_cta):
|
||||
self.version = version
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.warps_per_cta = warps_per_cta
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
|
||||
@@ -127,10 +127,10 @@ class MmaLayout:
|
||||
|
||||
class BlockedLayout:
|
||||
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
|
||||
self.sz_per_thread = str(size_per_thread)
|
||||
self.threads_per_warp = str(threads_per_warp)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.order = str(order)
|
||||
self.sz_per_thread = size_per_thread
|
||||
self.threads_per_warp = threads_per_warp
|
||||
self.warps_per_cta = warps_per_cta
|
||||
self.order = order
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
@@ -2072,7 +2072,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
out_dtype = tl.float16
|
||||
else:
|
||||
out_dtype = tl.float32
|
||||
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
@@ -2087,6 +2086,14 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
|
||||
ptx = pgm.asm["ptx"]
|
||||
start = ptx.find("shfl.sync")
|
||||
end = ptx.find("cvt.rn.f16.f32")
|
||||
red_code = ptx[start:end]
|
||||
assert len(red_code) > 0
|
||||
assert "shared" not in red_code
|
||||
assert "bar.sync" not in red_code
|
||||
# torch result
|
||||
if in_dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
|
||||
Reference in New Issue
Block a user