mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] no longer uses shared mem or barriers for single-warp reductions (#1915)
0-bytes shared mem buffers don't materialize empty allocation buffers; this could lead to unnecessary barriers. note: reduceop code has become quite messy and will require some cleanup
This commit is contained in:
committed by
Ognjen Plavsic
parent
398d2c7dd0
commit
4215086931
@@ -113,7 +113,7 @@ def check_type_supported(dtype):
|
||||
class MmaLayout:
|
||||
def __init__(self, version, warps_per_cta):
|
||||
self.version = version
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.warps_per_cta = warps_per_cta
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
|
||||
@@ -121,10 +121,10 @@ class MmaLayout:
|
||||
|
||||
class BlockedLayout:
|
||||
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
|
||||
self.sz_per_thread = str(size_per_thread)
|
||||
self.threads_per_warp = str(threads_per_warp)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.order = str(order)
|
||||
self.sz_per_thread = size_per_thread
|
||||
self.threads_per_warp = threads_per_warp
|
||||
self.warps_per_cta = warps_per_cta
|
||||
self.order = order
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
@@ -1959,7 +1959,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
out_dtype = tl.float16
|
||||
else:
|
||||
out_dtype = tl.float32
|
||||
|
||||
pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1),
|
||||
y_tri, y_tri.stride(0), y_tri.stride(1),
|
||||
w_tri, w_tri.stride(0), w_tri.stride(1),
|
||||
@@ -1974,6 +1973,14 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
CHAIN_DOT=epilogue == 'chain-dot',
|
||||
ALLOW_TF32=allow_tf32,
|
||||
num_warps=num_warps)
|
||||
if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32):
|
||||
ptx = pgm.asm["ptx"]
|
||||
start = ptx.find("shfl.sync")
|
||||
end = ptx.find("cvt.rn.f16.f32")
|
||||
red_code = ptx[start:end]
|
||||
assert len(red_code) > 0
|
||||
assert "shared" not in red_code
|
||||
assert "bar.sync" not in red_code
|
||||
# torch result
|
||||
if in_dtype == 'int8':
|
||||
z_ref = np.matmul(x.astype(np.float32),
|
||||
|
||||
Reference in New Issue
Block a user