mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 08:07:59 -05:00
Compare commits
17 Commits
experiment
...
feat/batch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c1dde83abb | ||
|
|
280ac15da2 | ||
|
|
e751f7d815 | ||
|
|
e26e4740b3 | ||
|
|
835d76af45 | ||
|
|
a3e099bbc0 | ||
|
|
a61685696f | ||
|
|
02aa93c67c | ||
|
|
55b921818d | ||
|
|
bb681a8a11 | ||
|
|
74e0fbce42 | ||
|
|
f080c56771 | ||
|
|
d2f968b902 | ||
|
|
e81601acf3 | ||
|
|
7073dc0d5d | ||
|
|
d090be60e8 | ||
|
|
4bad96d9d6 |
@@ -30,6 +30,8 @@ from ..services.invoker import Invoker
|
|||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.model_manager_service import ModelManagerService
|
||||||
|
from ..services.batch_manager import BatchManager
|
||||||
|
from ..services.batch_manager_storage import SqliteBatchProcessStorage
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@@ -116,11 +118,15 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
||||||
|
batch_manager = BatchManager(batch_manager_storage)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=ModelManagerService(config, logger),
|
model_manager=ModelManagerService(config, logger),
|
||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
|
batch_manager=batch_manager,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from ...services.graph import (
|
|||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
NodeAlreadyExecutedError,
|
NodeAlreadyExecutedError,
|
||||||
)
|
)
|
||||||
|
from ...services.batch_manager import Batch, BatchProcess
|
||||||
from ...services.item_storage import PaginatedResults
|
from ...services.item_storage import PaginatedResults
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@@ -37,6 +38,37 @@ async def create_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.post(
|
||||||
|
"/batch",
|
||||||
|
operation_id="create_batch",
|
||||||
|
responses={
|
||||||
|
200: {"model": BatchProcess},
|
||||||
|
400: {"description": "Invalid json"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def create_batch(
|
||||||
|
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||||
|
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
|
||||||
|
) -> BatchProcess:
|
||||||
|
"""Creates and starts a new new batch process"""
|
||||||
|
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
|
||||||
|
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
|
||||||
|
return {"batch_id":batch_id}
|
||||||
|
|
||||||
|
|
||||||
|
@session_router.delete(
|
||||||
|
"{batch_process_id}/batch",
|
||||||
|
operation_id="cancel_batch",
|
||||||
|
responses={202: {"description": "The batch is canceled"}},
|
||||||
|
)
|
||||||
|
async def cancel_batch(
|
||||||
|
batch_process_id: str = Path(description="The id of the batch process to cancel"),
|
||||||
|
) -> Response:
|
||||||
|
"""Creates and starts a new new batch process"""
|
||||||
|
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
|
||||||
|
return Response(status_code=202)
|
||||||
|
|
||||||
|
|
||||||
@session_router.get(
|
@session_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_sessions",
|
operation_id="list_sessions",
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
from invokeai.app.services.batch_manager import BatchManager
|
||||||
|
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@@ -300,12 +302,16 @@ def invoke_cli():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_manager_storage = SqliteBatchProcessStorage(db_location)
|
||||||
|
batch_manager = BatchManager(batch_manager_storage)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
||||||
images=images,
|
images=images,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
|
batch_manager=batch_manager,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||||
|
|||||||
@@ -108,15 +108,14 @@ class CompelInvocation(BaseInvocation):
|
|||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append((
|
ti_list.append(
|
||||||
name,
|
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
))
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
# import traceback
|
||||||
@@ -197,15 +196,14 @@ class SDXLPromptInvocationBase:
|
|||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append((
|
ti_list.append(
|
||||||
name,
|
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
))
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
# import traceback
|
||||||
@@ -272,15 +270,14 @@ class SDXLPromptInvocationBase:
|
|||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append((
|
ti_list.append(
|
||||||
name,
|
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
context=context,
|
||||||
).context.model
|
).context.model
|
||||||
))
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
# import traceback
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(),
|
||||||
)
|
)
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
||||||
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||||
loras = [
|
loras = [
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
||||||
for lora in self.clip.loras
|
for lora in self.clip.loras
|
||||||
@@ -74,14 +75,20 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append((
|
ti_list.append(
|
||||||
name,
|
# stack.enter_context(
|
||||||
|
# context.services.model_manager.get_model(
|
||||||
|
# model_name=name,
|
||||||
|
# base_model=self.clip.text_encoder.base_model,
|
||||||
|
# model_type=ModelType.TextualInversion,
|
||||||
|
# )
|
||||||
|
# )
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
).context.model
|
).context.model
|
||||||
))
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
# import traceback
|
||||||
|
|||||||
139
invokeai/app/services/batch_manager.py
Normal file
139
invokeai/app/services/batch_manager.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
import networkx as nx
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from itertools import product
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from fastapi_events.handlers.local import local_handler
|
||||||
|
from fastapi_events.typing import Event
|
||||||
|
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.batch_manager_storage import (
|
||||||
|
BatchProcessStorageBase,
|
||||||
|
BatchSessionNotFoundException,
|
||||||
|
Batch,
|
||||||
|
BatchProcess,
|
||||||
|
BatchSession,
|
||||||
|
BatchSessionChanges,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchManagerBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def start(self, invoker: Invoker):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_batch_process(self, batch_id: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel_batch_process(self, batch_process_id: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BatchManager(BatchManagerBase):
|
||||||
|
"""Responsible for managing currently running and scheduled batch jobs"""
|
||||||
|
|
||||||
|
__invoker: Invoker
|
||||||
|
__batches: list[BatchProcess]
|
||||||
|
__batch_process_storage: BatchProcessStorageBase
|
||||||
|
|
||||||
|
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.__batch_process_storage = batch_process_storage
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
# if we do want multithreading at some point, we could make this configurable
|
||||||
|
self.__invoker = invoker
|
||||||
|
self.__batches = list()
|
||||||
|
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
|
||||||
|
|
||||||
|
async def on_event(self, event: Event):
|
||||||
|
event_name = event[1]["event"]
|
||||||
|
|
||||||
|
match event_name:
|
||||||
|
case "graph_execution_state_complete":
|
||||||
|
await self.process(event, False)
|
||||||
|
case "invocation_error":
|
||||||
|
await self.process(event, True)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
async def process(self, event: Event, err: bool):
|
||||||
|
data = event[1]["data"]
|
||||||
|
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
|
||||||
|
if not batch_session:
|
||||||
|
return
|
||||||
|
updateSession = BatchSessionChanges(
|
||||||
|
state='error' if err else 'completed'
|
||||||
|
)
|
||||||
|
batch_session = self.__batch_process_storage.update_session_state(
|
||||||
|
batch_session.batch_id,
|
||||||
|
batch_session.session_id,
|
||||||
|
updateSession,
|
||||||
|
)
|
||||||
|
self.run_batch_process(batch_session.batch_id)
|
||||||
|
|
||||||
|
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
|
||||||
|
graph = copy.deepcopy(batch_process.graph)
|
||||||
|
batches = batch_process.batches
|
||||||
|
g = graph.nx_graph_flat()
|
||||||
|
sorted_nodes = nx.topological_sort(g)
|
||||||
|
for npath in sorted_nodes:
|
||||||
|
node = graph.get_node(npath)
|
||||||
|
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
|
||||||
|
if batch:
|
||||||
|
batch_index = batch_indices[index]
|
||||||
|
datum = batch.data[batch_index]
|
||||||
|
for key in datum:
|
||||||
|
node.__dict__[key] = datum[key]
|
||||||
|
graph.update_node(npath, node)
|
||||||
|
|
||||||
|
return GraphExecutionState(graph=graph)
|
||||||
|
|
||||||
|
def run_batch_process(self, batch_id: str):
|
||||||
|
try:
|
||||||
|
created_session = self.__batch_process_storage.get_created_session(batch_id)
|
||||||
|
except BatchSessionNotFoundException:
|
||||||
|
return
|
||||||
|
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
|
||||||
|
self.__invoker.invoke(ges, invoke_all=True)
|
||||||
|
|
||||||
|
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
|
||||||
|
batch_process = BatchProcess(
|
||||||
|
batches=batches,
|
||||||
|
graph=graph,
|
||||||
|
)
|
||||||
|
if not self._valid_batch_config(batch_process):
|
||||||
|
return None
|
||||||
|
batch_process = self.__batch_process_storage.save(batch_process)
|
||||||
|
self._create_sessions(batch_process)
|
||||||
|
return batch_process.batch_id
|
||||||
|
|
||||||
|
def _create_sessions(self, batch_process: BatchProcess):
|
||||||
|
batch_indices = list()
|
||||||
|
for batch in batch_process.batches:
|
||||||
|
batch_indices.append(list(range(len(batch.data))))
|
||||||
|
all_batch_indices = product(*batch_indices)
|
||||||
|
for bi in all_batch_indices:
|
||||||
|
ges = self._create_batch_session(batch_process, bi)
|
||||||
|
self.__invoker.services.graph_execution_manager.set(ges)
|
||||||
|
batch_session = BatchSession(
|
||||||
|
batch_id=batch_process.batch_id,
|
||||||
|
session_id=ges.id,
|
||||||
|
state="created"
|
||||||
|
)
|
||||||
|
self.__batch_process_storage.create_session(batch_session)
|
||||||
|
|
||||||
|
def cancel_batch_process(self, batch_process_id: str):
|
||||||
|
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]
|
||||||
505
invokeai/app/services/batch_manager_storage.py
Normal file
505
invokeai/app/services/batch_manager_storage.py
Normal file
@@ -0,0 +1,505 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import cast
|
||||||
|
import uuid
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
import json
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
|
from invokeai.app.models.image import ImageField
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, Extra, parse_raw_as
|
||||||
|
|
||||||
|
invocations = BaseInvocation.get_invocations()
|
||||||
|
InvocationsUnion = Union[invocations] # type: ignore
|
||||||
|
|
||||||
|
BatchDataType = Union[str, int, float, ImageField]
|
||||||
|
|
||||||
|
class Batch(BaseModel):
|
||||||
|
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
|
||||||
|
node_id: str = Field(description="ID of the node to batch")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSession(BaseModel):
|
||||||
|
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
|
||||||
|
session_id: str = Field(description="Session ID Created for this Batch Index")
|
||||||
|
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
||||||
|
description="Is this session created, completed, in progress, or errored?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def uuid_string():
|
||||||
|
res = uuid.uuid4()
|
||||||
|
return str(res)
|
||||||
|
|
||||||
|
class BatchProcess(BaseModel):
|
||||||
|
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
|
||||||
|
batches: List[Batch] = Field(
|
||||||
|
description="List of batch configs to apply to this session",
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
graph: Graph = Field(description="The graph being executed")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
|
||||||
|
state: Literal["created", "completed", "inprogress", "error"] = Field(
|
||||||
|
description="Is this session created, completed, in progress, or errored?"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessNotFoundException(Exception):
|
||||||
|
"""Raised when an Batch Process record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchProcess record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessSaveException(Exception):
|
||||||
|
"""Raised when an Batch Process record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchProcess record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessDeleteException(Exception):
|
||||||
|
"""Raised when an Batch Process record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchProcess record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSessionNotFoundException(Exception):
|
||||||
|
"""Raised when an Batch Session record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchSession record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSessionSaveException(Exception):
|
||||||
|
"""Raised when an Batch Session record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchSession record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchSessionDeleteException(Exception):
|
||||||
|
"""Raised when an Batch Session record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="BatchSession record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the Batch Process record store."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, batch_id: str) -> None:
|
||||||
|
"""Deletes a Batch Process record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
batch_process: BatchProcess,
|
||||||
|
) -> BatchProcess:
|
||||||
|
"""Saves a Batch Process record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
) -> BatchProcess:
|
||||||
|
"""Gets a Batch Process record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_session(
|
||||||
|
self,
|
||||||
|
session: BatchSession,
|
||||||
|
) -> BatchSession:
|
||||||
|
"""Creates a Batch Session attached to a Batch Process."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_session(
|
||||||
|
self,
|
||||||
|
session_id: str
|
||||||
|
) -> BatchSession:
|
||||||
|
"""Gets session by session_id"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_created_session(
|
||||||
|
self,
|
||||||
|
batch_id: str
|
||||||
|
) -> BatchSession:
|
||||||
|
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_created_sessions(
|
||||||
|
self,
|
||||||
|
batch_id: str
|
||||||
|
) -> List[BatchSession]:
|
||||||
|
"""Gets all created Batch Sessions for a given Batch Process id."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_session_state(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
session_id: str,
|
||||||
|
changes: BatchSessionChanges,
|
||||||
|
) -> BatchSession:
|
||||||
|
"""Updates the state of a Batch Session record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteBatchProcessStorage(BatchProcessStorageBase):
|
||||||
|
_filename: str
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: threading.Lock
|
||||||
|
|
||||||
|
def __init__(self, filename: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._filename = filename
|
||||||
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# Enable foreign keys
|
||||||
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
self._create_tables()
|
||||||
|
self._conn.commit()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
"""Creates the `batch_process` table and `batch_session` junction table."""
|
||||||
|
|
||||||
|
# Create the `batch_process` table.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS batch_process (
|
||||||
|
batch_id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
batches TEXT NOT NULL,
|
||||||
|
graph TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON batch_process FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE batch_process SET updated_at = current_timestamp
|
||||||
|
WHERE batch_id = old.batch_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the `batch_session` junction table.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS batch_session (
|
||||||
|
batch_id TEXT NOT NULL,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
state TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME,
|
||||||
|
-- enforce one-to-many relationship between batch_process and batch_session using PK
|
||||||
|
-- (we can extend this to many-to-many later)
|
||||||
|
PRIMARY KEY (batch_id,session_id),
|
||||||
|
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add index for batch id
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add index for batch id, sorted by created_at
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON batch_session FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE batch_id = old.batch_id AND session_id = old.session_id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, batch_id: str) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM batch_process
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessDeleteException from e
|
||||||
|
except Exception as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessDeleteException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
batch_process: BatchProcess,
|
||||||
|
) -> BatchProcess:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
batches = [batch.json() for batch in batch_process.batches]
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
|
||||||
|
VALUES (?, ?, ?);
|
||||||
|
""",
|
||||||
|
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get(batch_process.batch_id)
|
||||||
|
|
||||||
|
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
|
||||||
|
"""Deserializes a batch session."""
|
||||||
|
|
||||||
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
|
batch_id = session_dict.get("batch_id", "unknown")
|
||||||
|
batches_raw = session_dict.get("batches", "unknown")
|
||||||
|
graph_raw = session_dict.get("graph", "unknown")
|
||||||
|
batches = json.loads(batches_raw)
|
||||||
|
batches = [parse_raw_as(Batch, batch) for batch in batches]
|
||||||
|
return BatchProcess(
|
||||||
|
batch_id=batch_id,
|
||||||
|
batches=batches,
|
||||||
|
graph=parse_raw_as(Graph, graph_raw),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
) -> BatchProcess:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_process
|
||||||
|
WHERE batch_id = ?;
|
||||||
|
""",
|
||||||
|
(batch_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchProcessNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchProcessNotFoundException
|
||||||
|
return self._deserialize_batch_process(dict(result))
|
||||||
|
|
||||||
|
def create_session(
|
||||||
|
self,
|
||||||
|
session: BatchSession,
|
||||||
|
) -> BatchSession:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
|
||||||
|
VALUES (?, ?, ?);
|
||||||
|
""",
|
||||||
|
(session.batch_id, session.session_id, session.state),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get_session(session.session_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_session(
|
||||||
|
self,
|
||||||
|
session_id: str
|
||||||
|
) -> BatchSession:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_session
|
||||||
|
WHERE session_id= ?;
|
||||||
|
""",
|
||||||
|
(session_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchSessionNotFoundException
|
||||||
|
return self._deserialize_batch_session(dict(result))
|
||||||
|
|
||||||
|
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
|
||||||
|
"""Deserializes a batch session."""
|
||||||
|
|
||||||
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
|
batch_id = session_dict.get("batch_id", "unknown")
|
||||||
|
session_id = session_dict.get("session_id", "unknown")
|
||||||
|
state = session_dict.get("state", "unknown")
|
||||||
|
|
||||||
|
return BatchSession(
|
||||||
|
batch_id=batch_id,
|
||||||
|
session_id=session_id,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_created_session(
|
||||||
|
self,
|
||||||
|
batch_id: str
|
||||||
|
) -> BatchSession:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_session
|
||||||
|
WHERE batch_id = ? AND state = 'created';
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchone())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchSessionNotFoundException
|
||||||
|
session = self._deserialize_batch_session(dict(result))
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def get_created_sessions(
|
||||||
|
self,
|
||||||
|
batch_id: str
|
||||||
|
) -> List[BatchSession]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT *
|
||||||
|
FROM batch_session
|
||||||
|
WHERE batch_id = ? AND state = created;
|
||||||
|
""",
|
||||||
|
(batch_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
if result is None:
|
||||||
|
raise BatchSessionNotFoundException
|
||||||
|
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
|
||||||
|
return sessions
|
||||||
|
|
||||||
|
|
||||||
|
def update_session_state(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
session_id: str,
|
||||||
|
changes: BatchSessionChanges,
|
||||||
|
) -> BatchSession:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
# Change the state of a batch session
|
||||||
|
if changes.state is not None:
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
UPDATE batch_session
|
||||||
|
SET state = ?
|
||||||
|
WHERE batch_id = ? AND session_id = ?;
|
||||||
|
""",
|
||||||
|
(changes.state, batch_id, session_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise BatchSessionSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
return self.get_session(session_id)
|
||||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
from invokeai.app.services.batch_manager import BatchManagerBase
|
||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
@@ -21,6 +22,7 @@ class InvocationServices:
|
|||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
|
batch_manager: "BatchManagerBase"
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
configuration: "InvokeAIAppConfig"
|
configuration: "InvokeAIAppConfig"
|
||||||
@@ -36,6 +38,7 @@ class InvocationServices:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
batch_manager: "BatchManagerBase",
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
configuration: "InvokeAIAppConfig",
|
configuration: "InvokeAIAppConfig",
|
||||||
@@ -49,6 +52,7 @@ class InvocationServices:
|
|||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
queue: "InvocationQueueABC",
|
queue: "InvocationQueueABC",
|
||||||
):
|
):
|
||||||
|
self.batch_manager = batch_manager
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._conn = sqlite3.connect(
|
self._conn = sqlite3.connect(
|
||||||
self._filename, check_same_thread=False
|
self._filename, check_same_thread=False
|
||||||
) # TODO: figure out a better threading solution
|
) # TODO: figure out a better threading solution
|
||||||
|
self._conn.execute('pragma journal_mode=wal')
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
|
|||||||
@@ -562,7 +562,7 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
ti_list: List[Tuple[str, Any]],
|
ti_list: List[Any],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
init_tokens_count = None
|
init_tokens_count = None
|
||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
@@ -572,27 +572,27 @@ class ModelPatcher:
|
|||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
def _get_trigger(ti_name, index):
|
def _get_trigger(ti, index):
|
||||||
trigger = ti_name
|
trigger = ti.name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti_name, ti in ti_list:
|
for ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
|
||||||
model_embeddings = text_encoder.get_input_embeddings()
|
model_embeddings = text_encoder.get_input_embeddings()
|
||||||
|
|
||||||
for ti_name, ti in ti_list:
|
for ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i]
|
embedding = ti.embedding[i]
|
||||||
trigger = _get_trigger(ti_name, i)
|
trigger = _get_trigger(ti, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
@@ -637,6 +637,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
|
name: str
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -650,6 +651,7 @@ class TextualInversionModel:
|
|||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
result = cls() # TODO:
|
result = cls() # TODO:
|
||||||
|
result.name = file_path.stem # TODO:
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
@@ -826,7 +828,7 @@ class ONNXModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
text_encoder: IAIOnnxRuntimeModel,
|
text_encoder: IAIOnnxRuntimeModel,
|
||||||
ti_list: List[Tuple[str, Any]],
|
ti_list: List[Any],
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
@@ -839,17 +841,17 @@ class ONNXModelPatcher:
|
|||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
|
|
||||||
def _get_trigger(ti_name, index):
|
def _get_trigger(ti, index):
|
||||||
trigger = ti_name
|
trigger = ti.name
|
||||||
if index > 0:
|
if index > 0:
|
||||||
trigger += f"-!pad-{i}"
|
trigger += f"-!pad-{i}"
|
||||||
return f"<{trigger}>"
|
return f"<{trigger}>"
|
||||||
|
|
||||||
# modify tokenizer
|
# modify tokenizer
|
||||||
new_tokens_added = 0
|
new_tokens_added = 0
|
||||||
for ti_name, ti in ti_list:
|
for ti in ti_list:
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||||
@@ -859,11 +861,11 @@ class ONNXModelPatcher:
|
|||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
for ti_name, ti in ti_list:
|
for ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
for i in range(ti.embedding.shape[0]):
|
for i in range(ti.embedding.shape[0]):
|
||||||
embedding = ti.embedding[i].detach().numpy()
|
embedding = ti.embedding[i].detach().numpy()
|
||||||
trigger = _get_trigger(ti_name, i)
|
trigger = _get_trigger(ti, i)
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
|
|||||||
@@ -210,31 +210,6 @@ class ModelCache(object):
|
|||||||
|
|
||||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
||||||
|
|
||||||
def clear_one_model(self) -> bool:
|
|
||||||
reserved = self.max_vram_cache_size * GIG
|
|
||||||
vram_in_use = torch.cuda.memory_allocated()
|
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
|
||||||
smallest_key = None
|
|
||||||
smallest_size = float("inf")
|
|
||||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
|
||||||
if not cache_entry.locked and cache_entry.loaded:
|
|
||||||
if cache_entry.size > 0 and cache_entry.size < smallest_size:
|
|
||||||
smallest_key = model_key
|
|
||||||
smallest_size = cache_entry.size
|
|
||||||
|
|
||||||
if smallest_key is not None:
|
|
||||||
cache_entry = self._cached_models[smallest_key]
|
|
||||||
self.logger.debug(f"!!!!!!!!!!!Offloading {smallest_key} from {self.execution_device} into {self.storage_device}")
|
|
||||||
with VRAMUsage() as mem:
|
|
||||||
cache_entry.model.to(self.storage_device)
|
|
||||||
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
|
|
||||||
vram_in_use += mem.vram_used # note vram_used is negative
|
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
return smallest_key is not None
|
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||||
"""
|
"""
|
||||||
@@ -261,48 +236,17 @@ class ModelCache(object):
|
|||||||
self.cache_entry.lock()
|
self.cache_entry.lock()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
|
if self.cache.lazy_offloading:
|
||||||
while True:
|
self.cache._offload_unlocked_models(self.size_needed)
|
||||||
try:
|
|
||||||
with VRAMUsage() as mem:
|
|
||||||
self.model.to(self.cache.execution_device) # move into GPU
|
|
||||||
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
|
|
||||||
|
|
||||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
if self.model.device != self.cache.execution_device:
|
||||||
self.cache._print_cuda_stats()
|
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
|
||||||
|
with VRAMUsage() as mem:
|
||||||
|
self.model.to(self.cache.execution_device) # move into GPU
|
||||||
|
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
|
||||||
|
|
||||||
def my_forward(module, cache, *args, **kwargs):
|
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||||
while True:
|
self.cache._print_cuda_stats()
|
||||||
try:
|
|
||||||
return module._orig_forward(*args, **kwargs)
|
|
||||||
except:
|
|
||||||
if not cache.clear_one_model():
|
|
||||||
raise
|
|
||||||
|
|
||||||
import functools
|
|
||||||
from diffusers.models.unet_2d_blocks import DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D
|
|
||||||
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
|
|
||||||
from diffusers.models.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D
|
|
||||||
|
|
||||||
for module_name, module in self.model.named_modules():
|
|
||||||
if type(module) not in [
|
|
||||||
DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D, # unet blocks
|
|
||||||
CLIPEncoderLayer, # CLIPTextTransformer clip
|
|
||||||
DownEncoderBlock2D, UpDecoderBlock2D, # vae
|
|
||||||
]:
|
|
||||||
continue
|
|
||||||
# better here filter to only specific model modules
|
|
||||||
module._orig_forward = module.forward
|
|
||||||
module.forward = functools.partial(my_forward, module, self.cache)
|
|
||||||
|
|
||||||
self.model._orig_forward = self.model.forward
|
|
||||||
self.model.forward = functools.partial(my_forward, self.model, self.cache)
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
except:
|
|
||||||
if not self.cache.clear_one_model():
|
|
||||||
raise
|
|
||||||
|
|
||||||
except:
|
except:
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
@@ -320,19 +264,10 @@ class ModelCache(object):
|
|||||||
if not hasattr(self.model, "to"):
|
if not hasattr(self.model, "to"):
|
||||||
return
|
return
|
||||||
|
|
||||||
if hasattr(self.model, "_orig_forward"):
|
|
||||||
self.model.forward = self.model._orig_forward
|
|
||||||
delattr(self.model, "_orig_forward")
|
|
||||||
|
|
||||||
for module_name, module in self.model.named_modules():
|
|
||||||
if hasattr(module, "_orig_forward"):
|
|
||||||
module.forward = module._orig_forward
|
|
||||||
delattr(module, "_orig_forward")
|
|
||||||
|
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
#if not self.cache.lazy_offloading:
|
if not self.cache.lazy_offloading:
|
||||||
# self.cache._offload_unlocked_models()
|
self.cache._offload_unlocked_models()
|
||||||
# self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
# TODO: should it be called untrack_model?
|
# TODO: should it be called untrack_model?
|
||||||
def uncache_model(self, cache_id: str):
|
def uncache_model(self, cache_id: str):
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ def mock_services() -> InvocationServices:
|
|||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
|
batch_manager=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"),
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ def mock_services() -> InvocationServices:
|
|||||||
logger=None, # type: ignore
|
logger=None, # type: ignore
|
||||||
images=None, # type: ignore
|
images=None, # type: ignore
|
||||||
latents=None, # type: ignore
|
latents=None, # type: ignore
|
||||||
|
batch_manager=None, # type: ignore
|
||||||
boards=None, # type: ignore
|
boards=None, # type: ignore
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
|||||||
Reference in New Issue
Block a user