move access token regex matching into download queue

This commit is contained in:
Lincoln Stein
2024-05-05 21:00:31 -04:00
parent 8e5e9b53d6
commit f211c95dbc
7 changed files with 69 additions and 29 deletions

View File

@@ -75,8 +75,6 @@ class ModelManagerServiceBase(ABC):
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
@@ -94,9 +92,6 @@ class ModelManagerServiceBase(ABC):
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:

View File

@@ -106,8 +106,6 @@ class ModelManagerService(ModelManagerServiceBase):
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
@@ -125,13 +123,10 @@ class ModelManagerService(ModelManagerServiceBase):
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout)
model_path = self.install.download_and_cache_ckpt(source=source)
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)