[FRONTEND] removed unnecessary comprehension (#1085)

This commit is contained in:
Nishant Sikarwar
2023-01-31 01:12:14 +05:30
committed by GitHub
parent bc8a26d56f
commit e5dbe35cc1
4 changed files with 7 additions and 7 deletions

View File

@@ -110,7 +110,7 @@ def check_type_supported(dtype):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
@pytest.mark.parametrize("dtype_x", [dtype_x for dtype_x in dtypes] + ["bfloat16"])
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
def test_empty_kernel(dtype_x, device='cuda'):
SIZE = 128
@@ -773,7 +773,7 @@ def test_cast(dtype_x, dtype_z, bitcast, device='cuda'):
assert to_numpy(z_tri) == z_ref
@pytest.mark.parametrize("dtype_str", [dtype_str for dtype_str in torch_dtypes])
@pytest.mark.parametrize("dtype_str", list(torch_dtypes))
def test_store_constant(dtype_str):
check_type_supported(dtype_str)

View File

@@ -502,7 +502,7 @@ def view(input: tl.tensor,
def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
dst_shape = [s for s in input.type.shape]
dst_shape = list(input.type.shape)
dst_shape.insert(axis, 1)
ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)

View File

@@ -69,7 +69,7 @@ class Autotuner(KernelInterface):
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple([args[i] for i in self.key_idx])
key = tuple(args[i] for i in self.key_idx)
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)

View File

@@ -195,7 +195,7 @@ class JITFunction(KernelInterface[T]):
return signature
def _make_constants(self, constexpr_key):
constants = {i: k for i, k in zip(self.constexprs, constexpr_key)}
constants = dict(zip(self.constexprs, constexpr_key))
return constants
def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
@@ -298,10 +298,10 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
# function signature information
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
self.has_defaults = any([v.default != inspect._empty for v in signature.parameters.values()])
self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values())
# specialization hints
self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize
self.do_not_specialize = set([self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize])
self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize}
# function source code (without decorators)
self.src = textwrap.dedent(inspect.getsource(fn))
self.src = self.src[self.src.find("def"):]