mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
turn the HuggingFaceConceptsLib into a singleton to prevent redundant downloads
This commit is contained in:
@@ -25,7 +25,7 @@ from invokeai.backend.modules.parameters import parameters_to_command
|
||||
import invokeai.frontend.dist as frontend
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
||||
from ldm.invoke.conditioning import (
|
||||
get_tokens_for_prompt_object,
|
||||
get_prompt_structure,
|
||||
@@ -538,7 +538,7 @@ class InvokeAIWebServer:
|
||||
try:
|
||||
local_triggers = self.generate.model.textual_inversion_manager.get_all_trigger_strings()
|
||||
locals = [{'name': x} for x in sorted(local_triggers, key=str.casefold)]
|
||||
concepts = HuggingFaceConceptsLibrary().list_concepts(minimum_likes=5)
|
||||
concepts = get_hf_concepts_lib().list_concepts(minimum_likes=5)
|
||||
concepts = [{'name': f'<{x}>'} for x in sorted(concepts, key=str.casefold) if f'<{x}>' not in local_triggers]
|
||||
socketio.emit("foundTextualInversionTriggers", {'local_triggers': locals, 'huggingface_concepts': concepts})
|
||||
except Exception as e:
|
||||
|
||||
@@ -12,6 +12,14 @@ from urllib import request, error as ul_error
|
||||
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
singleton = None
|
||||
|
||||
def get_hf_concepts_lib():
|
||||
global singleton
|
||||
if singleton is None:
|
||||
singleton = HuggingFaceConceptsLibrary()
|
||||
return singleton
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
'''
|
||||
|
||||
@@ -13,7 +13,7 @@ import re
|
||||
import atexit
|
||||
from typing import List
|
||||
from ldm.invoke.args import Args
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.modules.lora_manager import LoraManager
|
||||
|
||||
@@ -287,7 +287,7 @@ class Completer(object):
|
||||
def _concept_completions(self, text, state):
|
||||
if self.concepts is None:
|
||||
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||
self.concepts = HuggingFaceConceptsLibrary()
|
||||
self.concepts = get_hf_concepts_lib()
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
else:
|
||||
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||
|
||||
@@ -6,7 +6,7 @@ from torch import nn
|
||||
|
||||
import sys
|
||||
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
||||
from ldm.data.personalized import per_img_token_list
|
||||
from transformers import CLIPTokenizer
|
||||
from functools import partial
|
||||
@@ -39,7 +39,7 @@ class EmbeddingManager(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.embedder = embedder
|
||||
self.concepts_library=HuggingFaceConceptsLibrary()
|
||||
self.concepts_library=get_hf_concepts_lib()
|
||||
|
||||
self.string_to_token_dict = {}
|
||||
self.string_to_param_dict = nn.ParameterDict()
|
||||
|
||||
@@ -9,7 +9,7 @@ from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||
from ldm.invoke.concepts_lib import get_hf_concepts_lib
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||
self.hf_concepts_library = get_hf_concepts_lib()
|
||||
self.trigger_to_sourcefile = dict()
|
||||
default_textual_inversions: list[TextualInversion] = []
|
||||
self.textual_inversions = default_textual_inversions
|
||||
|
||||
Reference in New Issue
Block a user