mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
* [FRONTEND] Add input dtypes to autotuning key (#2534) * Fix conflict in 06-fused-attention * Fix get_best_config in FA-transV.py * Fix leftover get_best_config() --------- Co-authored-by: Adnan Akhundov <adnan.akhundov@gmail.com>
This commit is contained in:
@@ -89,17 +89,7 @@ class Autotuner(KernelInterface):
|
||||
except OutOfResources:
|
||||
return [float('inf'), float('inf'), float('inf')]
|
||||
|
||||
def get_best_config(self, *args, **kwargs):
|
||||
if len(args) > 0:
|
||||
key = tuple(args)
|
||||
else:
|
||||
key_names = [self.arg_names[i] for i in self.key_idx]
|
||||
key_values = []
|
||||
for name in key_names:
|
||||
assert (name in kwargs)
|
||||
key_values.append(kwargs[name])
|
||||
key = tuple(key_values)
|
||||
|
||||
def get_best_config(self):
|
||||
return self.best_config
|
||||
|
||||
|
||||
@@ -111,7 +101,11 @@ class Autotuner(KernelInterface):
|
||||
for name in self.arg_names:
|
||||
if name in all_args:
|
||||
_args.append(all_args[name])
|
||||
key = tuple(_args[i] for i in self.key_idx)
|
||||
key = [_args[i] for i in self.key_idx]
|
||||
for arg in _args:
|
||||
if hasattr(arg, "dtype"):
|
||||
key.append(str(arg.dtype))
|
||||
key = tuple(key)
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
|
||||
Reference in New Issue
Block a user