[FRONTEND] Add input dtypes to autotuning key (#2534) (#374)

* [FRONTEND] Add input dtypes to autotuning key (#2534)

* Fix conflict in 06-fused-attention

* Fix get_best_config in FA-transV.py

* Fix leftover get_best_config()

---------

Co-authored-by: Adnan Akhundov <adnan.akhundov@gmail.com>
This commit is contained in:
Lixun Zhang
2023-11-07 19:36:57 -06:00
committed by GitHub
parent 3c1fe617c1
commit 1af893d8a2
5 changed files with 10 additions and 20 deletions

View File

@@ -89,17 +89,7 @@ class Autotuner(KernelInterface):
except OutOfResources:
return [float('inf'), float('inf'), float('inf')]
def get_best_config(self, *args, **kwargs):
if len(args) > 0:
key = tuple(args)
else:
key_names = [self.arg_names[i] for i in self.key_idx]
key_values = []
for name in key_names:
assert (name in kwargs)
key_values.append(kwargs[name])
key = tuple(key_values)
def get_best_config(self):
return self.best_config
@@ -111,7 +101,11 @@ class Autotuner(KernelInterface):
for name in self.arg_names:
if name in all_args:
_args.append(all_args[name])
key = tuple(_args[i] for i in self.key_idx)
key = [_args[i] for i in self.key_idx]
for arg in _args:
if hasattr(arg, "dtype"):
key.append(str(arg.dtype))
key = tuple(key)
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)