mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] allow pre-hook in autotuner configs to access config kwargs (#1814)
This is a very quick change that allows the configs' pre-hooks to see the values in the config itself. This is useful if we'd like to allocate intermediate tensor and the shape depends on tile size.
This commit is contained in:
@@ -70,10 +70,11 @@ class Autotuner(KernelInterface):
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
full_nargs = {**self.nargs, **current}
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
config.pre_hook(full_nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
|
||||
try:
|
||||
@@ -106,7 +107,8 @@ class Autotuner(KernelInterface):
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs}
|
||||
config.pre_hook(full_nargs)
|
||||
return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
|
||||
Reference in New Issue
Block a user