Initial commit to resolve merge conflicts

rename tl.float8e4 to tl.float8e4nv to align with upstream

ROCM IFU: Fix python arch issues

ROCM IFU: Fix kernel launcher

ROCM IFU: Fix merge conflicts

fix debug build

Set correct threadsPerCTA
This commit is contained in:
Jason Furmanek
2023-09-12 20:43:59 +00:00
parent 74fd8e9754
commit e5d7bb4fae
36 changed files with 414 additions and 1005 deletions

View File

@@ -1005,7 +1005,7 @@ def deserialize_fp8(np_data, in_dtype):
# return np_data
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5])
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
"""
@@ -1056,9 +1056,9 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
[32, 32, 128],
[128, 128, 64],
[64, 128, 128]]
for ab_type in [[tl.float8e4, tl.float16],
for ab_type in [[tl.float8e4nv, tl.float16],
[tl.float8e5, tl.float16],
[tl.float16, tl.float8e4],
[tl.float16, tl.float8e4nv],
[tl.float16, tl.float8e5]]
for out_dtype in [torch.float16, torch.float32]
])

View File

@@ -53,14 +53,8 @@ def compile_fn(config, device_type, cc):
def test_compile_in_subproc() -> None:
<<<<<<< HEAD
cc, device_type = get_device_type()
config = instance_descriptor(tuple(range(4)), ())
=======
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), (), (), ())
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(
@@ -92,14 +86,8 @@ def compile_fn_dot(config, device_type, cc):
def test_compile_in_forked_subproc() -> None:
reset_tmp_dir()
<<<<<<< HEAD
cc, device_type = get_device_type()
config = instance_descriptor(tuple(range(1)), ())
=======
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(1)), (), (), ())
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(