mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 06:14:56 -05:00
added hf models import tab and route for getting available hf models
This commit is contained in:
committed by
psychedelicious
parent
efea1a8a7d
commit
f7cd3cf1f4
@@ -20,6 +20,7 @@ from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import RemoteModelFile
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
@@ -405,6 +406,19 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_hugging_face_models(
|
||||
self,
|
||||
source: str,
|
||||
) -> List[AnyHttpUrl]:
|
||||
"""Get the available models in a HuggingFace repo.
|
||||
|
||||
:param source: HuggingFace repo string
|
||||
|
||||
This will get the urls for the available models in the indicated,
|
||||
repo, and return them as a list of AnyHttpUrl strings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
||||
@@ -233,6 +233,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_jobs.append(install_job)
|
||||
return install_job
|
||||
|
||||
def get_hugging_face_models(self, source: str) -> List[AnyHttpUrl]:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
access_token = HfFolder.get_token()
|
||||
if not access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source)
|
||||
self._logger.info(f"metadata is {metadata}")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(
|
||||
session=self._session,
|
||||
)
|
||||
|
||||
# return array of remote_files.url
|
||||
return [x.url for x in remote_files]
|
||||
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user