mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fixing the layouts in the tests for Wave64
This commit is contained in:
@@ -1840,8 +1840,22 @@ class BlockedLayout:
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
|
||||
|
||||
layouts = [
|
||||
if torch.version.hip is not None:
|
||||
layouts = [
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
# MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
BlockedLayout([1, 4], [2, 32], [2, 2], [1, 0]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 64], [2, 2], [1, 0]),
|
||||
BlockedLayout([4, 2], [16, 4], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 2], [8, 8], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 2], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 2], [1, 64], [4, 1], [1, 0])
|
||||
]
|
||||
else:
|
||||
layouts = [
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
@@ -1855,7 +1869,6 @@ layouts = [
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("shape", [(128, 128)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
|
||||
Reference in New Issue
Block a user