mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] removed unnecessary comprehension (#1085)
This commit is contained in:
@@ -110,7 +110,7 @@ def check_type_supported(dtype):
|
||||
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
|
||||
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
|
||||
def test_empty_kernel(dtype_x, device='cuda'):
|
||||
SIZE = 128
|
||||
|
||||
@@ -773,7 +773,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
|
||||
assert to_numpy(z_tri) == z_ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", [dtype_str for dtype_str in torch_dtypes])
|
||||
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
|
||||
def test_store_constant(dtype_str):
|
||||
check_type_supported(dtype_str)
|
||||
|
||||
|
||||
@@ -502,7 +502,7 @@ def view(input: tl.tensor,
|
||||
|
||||
|
||||
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
|
||||
dst_shape = [s for s in input.type.shape]
|
||||
dst_shape = list(input.type.shape)
|
||||
dst_shape.insert(axis, 1)
|
||||
ret_ty = tl.block_type(input.type.scalar, dst_shape)
|
||||
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)
|
||||
|
||||
@@ -69,7 +69,7 @@ class Autotuner(KernelInterface):
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
key = tuple([args[i] for i in self.key_idx])
|
||||
key = tuple(args[i] for i in self.key_idx)
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
|
||||
@@ -195,7 +195,7 @@ class JITFunction(KernelInterface[T]):
|
||||
return signature
|
||||
|
||||
def _make_constants(self, constexpr_key):
|
||||
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
|
||||
constants = dict(zip(self.constexprs, constexpr_key))
|
||||
return constants
|
||||
|
||||
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
@@ -298,10 +298,10 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# function signature information
|
||||
signature = inspect.signature(fn)
|
||||
self.arg_names = [v.name for v in signature.parameters.values()]
|
||||
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
|
||||
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
|
||||
# specialization hints
|
||||
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
|
||||
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
|
||||
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
|
||||
# function source code (without decorators)
|
||||
self.src = textwrap.dedent(inspect.getsource(fn))
|
||||
self.src = self.src[self.src.find("def"):]
|
||||
|
||||
Reference in New Issue
Block a user