turn the HuggingFaceConceptsLib into a singleton to prevent redundant downloads

This commit is contained in:
Lincoln Stein
2023-05-03 15:12:34 -04:00
parent 1dcac3929b
commit fa886ee9e0
5 changed files with 16 additions and 8 deletions

View File

@@ -25,7 +25,7 @@ from invokeai.backend.modules.parameters import parameters_to_command
import invokeai.frontend.dist as frontend
from ldm.generate import Generate
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.concepts_lib import get_hf_concepts_lib
from ldm.invoke.conditioning import (
get_tokens_for_prompt_object,
get_prompt_structure,
@@ -538,7 +538,7 @@ class InvokeAIWebServer:
try:
local_triggers = self.generate.model.textual_inversion_manager.get_all_trigger_strings()
locals = [{'name': x} for x in sorted(local_triggers, key=str.casefold)]
concepts = HuggingFaceConceptsLibrary().list_concepts(minimum_likes=5)
concepts = get_hf_concepts_lib().list_concepts(minimum_likes=5)
concepts = [{'name': f'<{x}>'} for x in sorted(concepts, key=str.casefold) if f'<{x}>' not in local_triggers]
socketio.emit("foundTextualInversionTriggers", {'local_triggers': locals, 'huggingface_concepts': concepts})
except Exception as e:

View File

@@ -12,6 +12,14 @@ from urllib import request, error as ul_error
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
from ldm.invoke.globals import Globals
singleton = None
def get_hf_concepts_lib():
global singleton
if singleton is None:
singleton = HuggingFaceConceptsLibrary()
return singleton
class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None):
'''

View File

@@ -13,7 +13,7 @@ import re
import atexit
from typing import List
from ldm.invoke.args import Args
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.concepts_lib import get_hf_concepts_lib
from ldm.invoke.globals import Globals
from ldm.modules.lora_manager import LoraManager
@@ -287,7 +287,7 @@ class Completer(object):
def _concept_completions(self, text, state):
if self.concepts is None:
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
self.concepts = HuggingFaceConceptsLibrary()
self.concepts = get_hf_concepts_lib()
self.embedding_terms.update(set(self.concepts.list_concepts()))
else:
self.embedding_terms.update(set(self.concepts.list_concepts()))

View File

@@ -6,7 +6,7 @@ from torch import nn
import sys
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.concepts_lib import get_hf_concepts_lib
from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer
from functools import partial
@@ -39,7 +39,7 @@ class EmbeddingManager(nn.Module):
super().__init__()
self.embedder = embedder
self.concepts_library=HuggingFaceConceptsLibrary()
self.concepts_library=get_hf_concepts_lib()
self.string_to_token_dict = {}
self.string_to_param_dict = nn.ParameterDict()

View File

@@ -9,7 +9,7 @@ from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
from compel.embeddings_provider import BaseTextualInversionManager
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.concepts_lib import get_hf_concepts_lib
@dataclass
@@ -34,7 +34,7 @@ class TextualInversionManager(BaseTextualInversionManager):
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.full_precision = full_precision
self.hf_concepts_library = HuggingFaceConceptsLibrary()
self.hf_concepts_library = get_hf_concepts_lib()
self.trigger_to_sourcefile = dict()
default_textual_inversions: list[TextualInversion] = []
self.textual_inversions = default_textual_inversions