mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Remove cache key from metadata (#2082)
This commit is contained in:
@@ -519,7 +519,6 @@ def compile(fn, **kwargs):
|
||||
|
||||
# Add device type to meta information
|
||||
metadata["device_type"] = device_type
|
||||
metadata["cache_key"] = fn_cache_manager.key
|
||||
|
||||
first_stage = list(stages.keys()).index(ext)
|
||||
asm = dict()
|
||||
@@ -629,7 +628,6 @@ class CompiledKernel:
|
||||
if "tensormaps_info" in metadata:
|
||||
self.tensormaps_info = metadata["tensormaps_info"]
|
||||
self.constants = metadata["constants"]
|
||||
self.cache_key = metadata["cache_key"]
|
||||
self.device_type = metadata["device_type"]
|
||||
self.device_backend = get_backend(self.device_type) if self.device_type not in ["cuda", "hip"] else None
|
||||
# initialize asm dict
|
||||
@@ -682,7 +680,7 @@ class CompiledKernel:
|
||||
# tuple for hashable
|
||||
args_ptr = tuple([arg.data_ptr() if hasattr(arg, 'data_ptr') else arg for arg in args])
|
||||
for i, e in enumerate(self.tensormaps_info):
|
||||
args_with_tma.append(CompiledKernel.tensormap_manager[(self.cache_key, e, args_ptr)])
|
||||
args_with_tma.append(CompiledKernel.tensormap_manager[(e, args_ptr)])
|
||||
return args_with_tma
|
||||
|
||||
def __getitem__(self, grid):
|
||||
|
||||
@@ -283,7 +283,7 @@ class TensorMapManager:
|
||||
if key in self.tensormaps_device:
|
||||
return int(self.tensormaps_device[key])
|
||||
else:
|
||||
(cache_key, e, args) = key
|
||||
(e, args) = key
|
||||
t_tensormap = e.tensormap(args)
|
||||
TENSORMAP_SIZE_IN_BYTES = 128
|
||||
t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES)
|
||||
|
||||
Reference in New Issue
Block a user