Fixing the layouts in the tests for Wave64

This commit is contained in:
B1tway
2023-03-06 18:35:18 +00:00
parent 742295e8a5
commit 625a99aa78

View File

@@ -1840,8 +1840,22 @@ class BlockedLayout:
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
layouts = [
if torch.version.hip is not None:
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
# MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
# MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 4], [2, 32], [2, 2], [1, 0]),
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 64], [2, 2], [1, 0]),
BlockedLayout([4, 2], [16, 4], [1, 4], [0, 1]),
BlockedLayout([4, 2], [8, 8], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 2], [2, 2], [0, 1]),
BlockedLayout([4, 2], [1, 64], [4, 1], [1, 0])
]
else:
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
@@ -1855,7 +1869,6 @@ layouts = [
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)