mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Initial commit to resolve merge conflicts
rename tl.float8e4 to tl.float8e4nv to align with upstream ROCM IFU: Fix python arch issues ROCM IFU: Fix kernel launcher ROCM IFU: Fix merge conflicts fix debug build Set correct threadsPerCTA
This commit is contained in:
@@ -1005,7 +1005,7 @@ def deserialize_fp8(np_data, in_dtype):
|
||||
# return np_data
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4b15x4, tl.float8e4nv, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
"""
|
||||
@@ -1056,9 +1056,9 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device):
|
||||
[32, 32, 128],
|
||||
[128, 128, 64],
|
||||
[64, 128, 128]]
|
||||
for ab_type in [[tl.float8e4, tl.float16],
|
||||
for ab_type in [[tl.float8e4nv, tl.float16],
|
||||
[tl.float8e5, tl.float16],
|
||||
[tl.float16, tl.float8e4],
|
||||
[tl.float16, tl.float8e4nv],
|
||||
[tl.float16, tl.float8e5]]
|
||||
for out_dtype in [torch.float16, torch.float32]
|
||||
])
|
||||
|
||||
@@ -53,14 +53,8 @@ def compile_fn(config, device_type, cc):
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
<<<<<<< HEAD
|
||||
cc, device_type = get_device_type()
|
||||
config = instance_descriptor(tuple(range(4)), ())
|
||||
=======
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(4)), (), (), ())
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
multiprocessing.set_start_method('fork')
|
||||
proc = multiprocessing.Process(
|
||||
@@ -92,14 +86,8 @@ def compile_fn_dot(config, device_type, cc):
|
||||
|
||||
def test_compile_in_forked_subproc() -> None:
|
||||
reset_tmp_dir()
|
||||
<<<<<<< HEAD
|
||||
cc, device_type = get_device_type()
|
||||
config = instance_descriptor(tuple(range(1)), ())
|
||||
=======
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(1)), (), (), ())
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
|
||||
assert multiprocessing.get_start_method() == 'fork'
|
||||
proc = multiprocessing.Process(
|
||||
|
||||
Reference in New Issue
Block a user