[FRONTEND] Remove cache key from metadata (#2082)

This commit is contained in:
Zahi Moudallal
2023-08-11 08:58:00 -07:00
committed by GitHub
parent 4828f61894
commit b62b6d6a71
2 changed files with 2 additions and 4 deletions

View File

@@ -519,7 +519,6 @@ def compile(fn, **kwargs):
# Add device type to meta information
metadata["device_type"] = device_type
metadata["cache_key"] = fn_cache_manager.key
first_stage = list(stages.keys()).index(ext)
asm = dict()
@@ -629,7 +628,6 @@ class CompiledKernel:
if "tensormaps_info" in metadata:
self.tensormaps_info = metadata["tensormaps_info"]
self.constants = metadata["constants"]
self.cache_key = metadata["cache_key"]
self.device_type = metadata["device_type"]
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None
# initialize asm dict
@@ -682,7 +680,7 @@ class CompiledKernel:
# tuple for hashable
args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args])
for i, e in enumerate(self.tensormaps_info):
args_with_tma.append(CompiledKernel.tensormap_manager[(self.cache_key, e, args_ptr)])
args_with_tma.append(CompiledKernel.tensormap_manager[(e, args_ptr)])
return args_with_tma
def __getitem__(self, grid):

View File

@@ -283,7 +283,7 @@ class TensorMapManager:
if key in self.tensormaps_device:
return int(self.tensormaps_device[key])
else:
(cache_key, e, args) = key
(e, args) = key
t_tensormap = e.tensormap(args)
TENSORMAP_SIZE_IN_BYTES = 128
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)