mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[RUNTIME] Fix cache dir (#2196)
--------- Co-authored-by: Keren Zhou <kerenzhou@openai.com>
This commit is contained in:
@@ -40,18 +40,20 @@ class FileCacheManager(CacheManager):
|
||||
self.key = key
|
||||
self.lock_path = None
|
||||
# create cache directory if it doesn't exist
|
||||
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
|
||||
self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir()
|
||||
if self.cache_dir:
|
||||
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
||||
self.lock_path = os.path.join(self.cache_dir, "lock")
|
||||
os.makedirs(self.cache_dir, exist_ok=True)
|
||||
else:
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
|
||||
def _make_path(self, filename) -> str:
|
||||
return os.path.join(self.cache_dir, filename)
|
||||
|
||||
def has_file(self, filename):
|
||||
def has_file(self, filename) -> bool:
|
||||
if not self.cache_dir:
|
||||
return False
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
return os.path.exists(self._make_path(filename))
|
||||
|
||||
def get_file(self, filename) -> Optional[str]:
|
||||
@@ -80,16 +82,16 @@ class FileCacheManager(CacheManager):
|
||||
return result
|
||||
|
||||
# Note a group of pushed files as being part of a group
|
||||
def put_group(self, filename: str, group: Dict[str, str]):
|
||||
def put_group(self, filename: str, group: Dict[str, str]) -> str:
|
||||
if not self.cache_dir:
|
||||
return
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
|
||||
grp_filename = f"__grp__{filename}"
|
||||
return self.put(grp_contents, grp_filename, binary=False)
|
||||
|
||||
def put(self, data, filename, binary=True) -> str:
|
||||
if not self.cache_dir:
|
||||
return
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
binary = isinstance(data, bytes)
|
||||
if not binary:
|
||||
data = str(data)
|
||||
|
||||
Reference in New Issue
Block a user