Compare commits

..

17 Commits

Author SHA1 Message Date
Brandon Rising
c1dde83abb Clean up erroniously added lines 2023-08-10 14:28:50 -04:00
Brandon Rising
280ac15da2 Go back to 1 lock per table 2023-08-10 14:26:22 -04:00
Brandon Rising
e751f7d815 More testing 2023-08-10 14:09:00 -04:00
Brandon Rising
e26e4740b3 Testing sqlite issues with batch_manager 2023-08-10 11:38:28 -04:00
Brandon
835d76af45 Merge branch 'main' into feat/batch-graphs 2023-08-01 16:44:30 -04:00
Brandon Rising
a3e099bbc0 Instantiate batch managers 2023-08-01 16:44:17 -04:00
Brandon Rising
a61685696f Run black formatting 2023-08-01 16:41:40 -04:00
Brandon Rising
02aa93c67c Cancel batch endpoint 2023-07-31 16:05:27 -04:00
Brandon Rising
55b921818d Create batch manager 2023-07-31 15:45:35 -04:00
Brandon Rising
bb681a8a11 Merge branch 'main' into feat/batch-graphs 2023-07-31 13:22:11 -04:00
Brandon
74e0fbce42 Merge branch 'main' into feat/batch-graphs 2023-07-25 22:25:55 -04:00
Brandon Rising
f080c56771 Testing out generating a new session for each batch_index 2023-07-25 16:50:07 -04:00
Brandon Rising
d2f968b902 Trying different places of applying batches 2023-07-25 10:23:17 -04:00
Brandon Rising
e81601acf3 add todo 2023-07-24 18:12:05 -04:00
Brandon Rising
7073dc0d5d Fix next call in graphexecutionstate.next 2023-07-24 17:45:05 -04:00
Brandon Rising
d090be60e8 Make batch_indices in graph class more clear 2023-07-24 17:43:49 -04:00
Brandon Rising
4bad96d9d6 WIP running graphs as batches 2023-07-24 17:41:54 -04:00
195 changed files with 4204 additions and 5097 deletions

View File

@@ -1,14 +1,13 @@
name: style checks
# just formatting for now
# TODO: add isort and flake8 later
name: Black # TODO: add isort and flake8 later
on:
pull_request:
pull_request: {}
push:
branches: main
branches: master
tags: "*"
jobs:
black:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
@@ -20,7 +19,8 @@ jobs:
- name: Install dependencies with pip
run: |
pip install black
pip install --upgrade pip wheel
pip install .[test]
# - run: isort --check-only .
- run: black --check .

View File

@@ -1,103 +0,0 @@
name: e2e tests
on:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}
cancel-in-progress: true
jobs:
matrix:
if: github.event.pull_request.draft == false
strategy:
matrix:
python-version:
- '3.10'
# - '3.11'
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
github-env: $GITHUB_ENV
# - pytorch: linux-rocm-5_2
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
# github-env: $GITHUB_ENV
# - pytorch: linux-cpu
# os: ubuntu-22.04
# extra-index-url: 'https://download.pytorch.org/whl/cpu'
# github-env: $GITHUB_ENV
# - pytorch: macos-default
# os: macOS-12
# github-env: $GITHUB_ENV
# - pytorch: windows-cpu
# os: windows-2022
# github-env: $env:GITHUB_ENV
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
env:
PIP_USE_PEP517: '1'
steps:
- name: Checkout sources
id: checkout-sources
uses: actions/checkout@v3
- name: set test prompt to main branch validation
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: Free up more disk space on the runner
# https://github.com/actions/runner-images/issues/2840#issuecomment-1284059930
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo swapoff /mnt/swapfile
sudo rm -rf /mnt/swapfile
- name: install invokeai
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
pip3 install
--editable=".[test]"
- name: run invokeai-configure
env:
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
run: >
invokeai-configure
--yes
--default_only
--full-precision
# can't use fp16 weights without a GPU
- name: run invokeai
id: run-invokeai
env:
# Set offline mode to make sure configure preloaded successfully.
HF_HUB_OFFLINE: 1
HF_DATASETS_OFFLINE: 1
TRANSFORMERS_OFFLINE: 1
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
run: >
invokeai
--no-patchmatch
--no-nsfw_checker
--precision=float32
--always_use_cpu
--use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }}
- name: Archive results
env:
INVOKEAI_OUTDIR: ${{ github.workspace }}/results
uses: actions/upload-artifact@v3
with:
name: results
path: ${{ env.INVOKEAI_OUTDIR }}

View File

@@ -0,0 +1,50 @@
name: Test invoke.py pip
# This is a dummy stand-in for the actual tests
# we don't need to run python tests on non-Python changes
# But PRs require passing tests to be mergeable
on:
pull_request:
paths:
- '**'
- '!pyproject.toml'
- '!invokeai/**'
- '!tests/**'
- 'invokeai/frontend/web/**'
merge_group:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
matrix:
if: github.event.pull_request.draft == false
strategy:
matrix:
python-version:
- '3.10'
pytorch:
- linux-cuda-11_7
- linux-rocm-5_2
- linux-cpu
- macos-default
- windows-cpu
include:
- pytorch: linux-cuda-11_7
os: ubuntu-22.04
- pytorch: linux-rocm-5_2
os: ubuntu-22.04
- pytorch: linux-cpu
os: ubuntu-22.04
- pytorch: macos-default
os: macOS-12
- pytorch: windows-cpu
os: windows-2022
name: ${{ matrix.pytorch }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
steps:
- name: skip
run: echo "no build required"

View File

@@ -3,7 +3,16 @@ on:
push:
branches:
- 'main'
paths:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
pull_request:
paths:
- 'pyproject.toml'
- 'invokeai/**'
- 'tests/**'
- '!invokeai/frontend/web/**'
types:
- 'ready_for_review'
- 'opened'
@@ -56,23 +65,10 @@ jobs:
id: checkout-sources
uses: actions/checkout@v3
- name: Check for changed python files
id: changed-files
uses: tj-actions/changed-files@v37
with:
files_yaml: |
python:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'tests/**'
- name: set test prompt to main branch validation
if: steps.changed-files.outputs.python_any_changed == 'true'
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> ${{ matrix.github-env }}
- name: setup python
if: steps.changed-files.outputs.python_any_changed == 'true'
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
@@ -80,7 +76,6 @@ jobs:
cache-dependency-path: pyproject.toml
- name: install invokeai
if: steps.changed-files.outputs.python_any_changed == 'true'
env:
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
run: >
@@ -88,6 +83,41 @@ jobs:
--editable=".[test]"
- name: run pytest
if: steps.changed-files.outputs.python_any_changed == 'true'
id: run-pytest
run: pytest
# - name: run invokeai-configure
# env:
# HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
# run: >
# invokeai-configure
# --yes
# --default_only
# --full-precision
# # can't use fp16 weights without a GPU
# - name: run invokeai
# id: run-invokeai
# env:
# # Set offline mode to make sure configure preloaded successfully.
# HF_HUB_OFFLINE: 1
# HF_DATASETS_OFFLINE: 1
# TRANSFORMERS_OFFLINE: 1
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# run: >
# invokeai
# --no-patchmatch
# --no-nsfw_checker
# --precision=float32
# --always_use_cpu
# --use_memory_db
# --outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
# --from_file ${{ env.TEST_PROMPTS }}
# - name: Archive results
# env:
# INVOKEAI_OUTDIR: ${{ github.workspace }}/results
# uses: actions/upload-artifact@v3
# with:
# name: results
# path: ${{ env.INVOKEAI_OUTDIR }}

View File

@@ -184,9 +184,8 @@ the command `npm install -g yarn` if needed)
6. Configure InvokeAI and install a starting set of image generation models (you only need to do this once):
```terminal
invokeai-configure --root .
invokeai-configure
```
Don't miss the dot at the end!
7. Launch the web server (do it every time you run InvokeAI):
@@ -194,9 +193,15 @@ the command `npm install -g yarn` if needed)
invokeai-web
```
8. Point your browser to http://localhost:9090 to bring up the web interface.
8. Build Node.js assets
9. Type `banana sushi` in the box on the top left and click `Invoke`.
```terminal
cd invokeai/frontend/web/
yarn vite build
```
9. Point your browser to http://localhost:9090 to bring up the web interface.
10. Type `banana sushi` in the box on the top left and click `Invoke`.
Be sure to activate the virtual environment each time before re-launching InvokeAI,
using `source .venv/bin/activate` or `.venv\Scripts\activate`.

View File

@@ -192,10 +192,8 @@ manager, please follow these steps:
your outputs.
```terminal
invokeai-configure --root .
invokeai-configure
```
Don't miss the dot at the end of the command!
The script `invokeai-configure` will interactively guide you through the
process of downloading and installing the weights files needed for InvokeAI.
@@ -227,6 +225,12 @@ manager, please follow these steps:
!!! warning "Make sure that the virtual environment is activated, which should create `(.venv)` in front of your prompt!"
=== "CLI"
```bash
invokeai
```
=== "local Webserver"
```bash
@@ -239,12 +243,6 @@ manager, please follow these steps:
invokeai --web --host 0.0.0.0
```
=== "CLI"
```bash
invokeai
```
If you choose the run the web interface, point your browser at
http://localhost:9090 in order to load the GUI.

View File

@@ -124,7 +124,7 @@ installation. Examples:
invokeai-model-install --list controlnet
# (install the model at the indicated URL)
invokeai-model-install --add https://civitai.com/api/download/models/128713
invokeai-model-install --add http://civitai.com/2860
# (delete the named model)
invokeai-model-install --delete sd-1/main/analog-diffusion
@@ -170,4 +170,4 @@ elsewhere on disk and they will be autoimported. You can also create
subfolders and organize them as you wish.
The location of the autoimport directories are controlled by settings
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).
in `invokeai.yaml`. See [Configuration](../features/CONFIGURATION.md).

View File

@@ -2,6 +2,7 @@
from typing import Optional
from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import (
SqliteBoardImageRecordStorage,
)
@@ -29,7 +30,8 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from ..services.invocation_stats import InvocationStatsService
from ..services.batch_manager import BatchManager
from ..services.batch_manager_storage import SqliteBatchProcessStorage
from .events import FastAPIEventService
@@ -55,7 +57,7 @@ logger = InvokeAILogger.getLogger()
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker
invoker: Optional[Invoker] = None
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
@@ -68,9 +70,8 @@ class ApiDependencies:
output_folder = config.output_path
# TODO: build a file/path manager?
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
db_location = str(db_path)
db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True)
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
@@ -117,11 +118,15 @@ class ApiDependencies:
)
)
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices(
model_manager=ModelManagerService(config, logger),
events=events,
latents=latents,
images=images,
batch_manager=batch_manager,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
@@ -129,7 +134,6 @@ class ApiDependencies:
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
)

View File

@@ -1,30 +1,24 @@
from fastapi import Body, HTTPException
from fastapi import Body, HTTPException, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.services.board_record_storage import BoardRecord, BoardChanges
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.board_record import BoardDTO
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
class AddImagesToBoardResult(BaseModel):
board_id: str = Field(description="The id of the board the images were added to")
added_image_names: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(BaseModel):
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
@board_images_router.post(
"/",
operation_id="add_image_to_board",
operation_id="create_board_image",
responses={
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
)
async def add_image_to_board(
async def create_board_image(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
):
@@ -35,78 +29,26 @@ async def add_image_to_board(
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add image to board")
raise HTTPException(status_code=500, detail="Failed to add to board")
@board_images_router.delete(
"/",
operation_id="remove_image_from_board",
operation_id="remove_board_image",
responses={
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
)
async def remove_image_from_board(
image_name: str = Body(description="The name of the image to remove", embed=True),
async def remove_board_image(
board_id: str = Body(description="The id of the board"),
image_name: str = Body(description="The name of the image to remove"),
):
"""Removes an image from its board, if it had one"""
"""Deletes a board_image"""
try:
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(
board_id=board_id, image_name=image_name
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@board_images_router.post(
"/batch",
operation_id="add_images_to_board",
responses={
201: {"description": "Images were added to board successfully"},
},
status_code=201,
response_model=AddImagesToBoardResult,
)
async def add_images_to_board(
board_id: str = Body(description="The id of the board to add to"),
image_names: list[str] = Body(description="The names of the images to add", embed=True),
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
try:
added_image_names: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name
)
added_image_names.append(image_name)
except:
pass
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@board_images_router.post(
"/batch/delete",
operation_id="remove_images_from_board",
responses={
201: {"description": "Images were removed from board successfully"},
},
status_code=201,
response_model=RemoveImagesFromBoardResult,
)
async def remove_images_from_board(
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
try:
removed_image_names: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_image_names.append(image_name)
except:
pass
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to remove images from board")
raise HTTPException(status_code=500, detail="Failed to update board")

View File

@@ -1,20 +1,21 @@
import io
from typing import Optional
from PIL import Image
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from pydantic import BaseModel
from PIL import Image
from invokeai.app.invocations.metadata import ImageMetadata
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@@ -24,7 +25,7 @@ IMAGE_MAX_AGE = 31536000
@images_router.post(
"/upload",
"/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
@@ -76,7 +77,7 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/i/{image_name}", operation_id="delete_image")
@images_router.delete("/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),
) -> None:
@@ -102,7 +103,7 @@ async def clear_intermediates() -> int:
@images_router.patch(
"/i/{image_name}",
"/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
@@ -119,7 +120,7 @@ async def update_image(
@images_router.get(
"/i/{image_name}",
"/{image_name}",
operation_id="get_image_dto",
response_model=ImageDTO,
)
@@ -135,7 +136,7 @@ async def get_image_dto(
@images_router.get(
"/i/{image_name}/metadata",
"/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageMetadata,
)
@@ -150,9 +151,8 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
@images_router.api_route(
"/i/{image_name}/full",
methods=["GET", "HEAD"],
@images_router.get(
"/{image_name}/full",
operation_id="get_image_full",
response_class=Response,
responses={
@@ -187,7 +187,7 @@ async def get_image_full(
@images_router.get(
"/i/{image_name}/thumbnail",
"/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
@@ -216,7 +216,7 @@ async def get_image_thumbnail(
@images_router.get(
"/i/{image_name}/urls",
"/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
@@ -265,24 +265,3 @@ async def list_image_dtos(
)
return image_dtos
class DeleteImagesFromListResult(BaseModel):
deleted_images: list[str]
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
async def delete_images_from_list(
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesFromListResult:
try:
deleted_images: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.append(image_name)
except:
pass
return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to delete images")

View File

@@ -15,6 +15,7 @@ from ...services.graph import (
GraphExecutionState,
NodeAlreadyExecutedError,
)
from ...services.batch_manager import Batch, BatchProcess
from ...services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
@@ -37,6 +38,37 @@ async def create_session(
return session
@session_router.post(
"/batch",
operation_id="create_batch",
responses={
200: {"model": BatchProcess},
400: {"description": "Invalid json"},
},
)
async def create_batch(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
batches: list[Batch] = Body(description="Batch config to apply to the given graph"),
) -> BatchProcess:
"""Creates and starts a new new batch process"""
batch_id = ApiDependencies.invoker.services.batch_manager.create_batch_process(batches, graph)
ApiDependencies.invoker.services.batch_manager.run_batch_process(batch_id)
return {"batch_id":batch_id}
@session_router.delete(
"{batch_process_id}/batch",
operation_id="cancel_batch",
responses={202: {"description": "The batch is canceled"}},
)
async def cancel_batch(
batch_process_id: str = Path(description="The id of the batch process to cancel"),
) -> Response:
"""Creates and starts a new new batch process"""
ApiDependencies.invoker.services.batch_manager.cancel_batch_process(batch_process_id)
return Response(status_code=202)
@session_router.get(
"/",
operation_id="list_sessions",

View File

@@ -37,7 +37,8 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.batch_manager import BatchManager
from invokeai.app.services.batch_manager_storage import SqliteBatchProcessStorage
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@@ -301,18 +302,21 @@ def invoke_cli():
)
)
batch_manager_storage = SqliteBatchProcessStorage(db_location)
batch_manager = BatchManager(batch_manager_storage)
services = InvocationServices(
model_manager=model_manager,
events=events,
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
images=images,
boards=boards,
batch_manager=batch_manager,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
configuration=config,
)

View File

@@ -109,15 +109,12 @@ class CompelInvocation(BaseInvocation):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model,
)
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
)
except ModelNotFoundException:
# print(e)
@@ -176,7 +173,7 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
@@ -200,15 +197,12 @@ class SDXLPromptInvocationBase:
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model,
)
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
)
except ModelNotFoundException:
# print(e)
@@ -216,8 +210,8 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
@@ -253,7 +247,7 @@ class SDXLPromptInvocationBase:
return c, c_pooled, None
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
@@ -277,15 +271,12 @@ class SDXLPromptInvocationBase:
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model,
)
context.services.model_manager.get_model(
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
)
except ModelNotFoundException:
# print(e)
@@ -293,8 +284,8 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
@@ -366,11 +357,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@@ -424,8 +415,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@@ -477,11 +467,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
else:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@@ -535,8 +525,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)

View File

@@ -3,7 +3,6 @@
from typing import Literal, Optional
import numpy
import cv2
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import Field
from pathlib import Path
@@ -651,143 +650,3 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
width=image_dto.width,
height=image_dto.height,
)
class ImageHueAdjustmentInvocation(BaseInvocation):
"""Adjusts the Hue of an image."""
# fmt: off
type: Literal["img_hue_adjust"] = "img_hue_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert image to HSV color space
hsv_image = numpy.array(pil_image.convert("HSV"))
# Convert hue from 0..360 to 0..256
hue = int(256 * ((self.hue % 360) / 360))
# Increment each hue and wrap around at 255
hsv_image[:, :, 0] = (hsv_image[:, :, 0] + hue) % 256
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(hsv_image, mode="HSV").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageLuminosityAdjustmentInvocation(BaseInvocation):
"""Adjusts the Luminosity (Value) of an image."""
# fmt: off
type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the luminosity (value)
hsv_image[:, :, 2] = numpy.clip(hsv_image[:, :, 2] * self.luminosity, 0, 255)
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageSaturationAdjustmentInvocation(BaseInvocation):
"""Adjusts the Saturation of an image."""
# fmt: off
type: Literal["img_saturation_adjust"] = "img_saturation_adjust"
# Inputs
image: ImageField = Field(default=None, description="The image to adjust")
saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_image = context.services.images.get_pil_image(self.image.image_name)
# Convert PIL image to OpenCV format (numpy array), note color channel
# ordering is changed from RGB to BGR
image = numpy.array(pil_image.convert("RGB"))[:, :, ::-1]
# Convert image to HSV color space
hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# Adjust the saturation
hsv_image[:, :, 1] = numpy.clip(hsv_image[:, :, 1] * self.saturation, 0, 255)
# Convert image back to BGR color space
image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
# Convert back to PIL format and to original color mode
pil_image = Image.fromarray(image[:, :, ::-1], "RGB").convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management import ModelPatcher
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,

View File

@@ -1,6 +1,6 @@
from typing import Literal, Optional, Union
from pydantic import Field
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@@ -10,17 +10,16 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class LoRAMetadataField(BaseModelExcludeNull):
class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model")
class CoreMetadata(BaseModelExcludeNull):
class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(
@@ -71,7 +70,7 @@ class CoreMetadata(BaseModelExcludeNull):
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class ImageMetadata(BaseModelExcludeNull):
class ImageMetadata(BaseModel):
"""An image's generation metadata"""
metadata: Optional[dict] = Field(

View File

@@ -262,103 +262,6 @@ class LoraLoaderInvocation(BaseInvocation):
return output
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
# fmt: on
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unknown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')
output = SDXLLoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
return output
class VAEModelField(BaseModel):
"""Vae model field"""

View File

@@ -65,6 +65,7 @@ class ONNXPromptInvocation(BaseInvocation):
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras
@@ -75,14 +76,18 @@ class ONNXPromptInvocation(BaseInvocation):
name = trigger[1:-1]
try:
ti_list.append(
(
name,
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model,
)
# stack.enter_context(
# context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
# )
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
# print(e)

View File

@@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
from ...backend.model_management import ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
@@ -293,20 +293,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
num_inference_steps = self.steps
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
with unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps
@@ -553,19 +543,9 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
context=context,
)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
do_classifier_free_guidance = True
cross_attention_kwargs = None
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
with unet_info as unet:
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=unet.device)

View File

@@ -0,0 +1,139 @@
import networkx as nx
import copy
from abc import ABC, abstractmethod
from itertools import product
from pydantic import BaseModel, Field
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph, GraphExecutionState
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.batch_manager_storage import (
BatchProcessStorageBase,
BatchSessionNotFoundException,
Batch,
BatchProcess,
BatchSession,
BatchSessionChanges,
)
class BatchManagerBase(ABC):
@abstractmethod
def start(self, invoker: Invoker):
pass
@abstractmethod
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
pass
@abstractmethod
def run_batch_process(self, batch_id: str):
pass
@abstractmethod
def cancel_batch_process(self, batch_process_id: str):
pass
class BatchManager(BatchManagerBase):
"""Responsible for managing currently running and scheduled batch jobs"""
__invoker: Invoker
__batches: list[BatchProcess]
__batch_process_storage: BatchProcessStorageBase
def __init__(self, batch_process_storage: BatchProcessStorageBase) -> None:
super().__init__()
self.__batch_process_storage = batch_process_storage
def start(self, invoker: Invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__invoker = invoker
self.__batches = list()
local_handler.register(event_name=EventServiceBase.session_event, _func=self.on_event)
async def on_event(self, event: Event):
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self.process(event, False)
case "invocation_error":
await self.process(event, True)
return event
async def process(self, event: Event, err: bool):
data = event[1]["data"]
batch_session = self.__batch_process_storage.get_session(data["graph_execution_state_id"])
if not batch_session:
return
updateSession = BatchSessionChanges(
state='error' if err else 'completed'
)
batch_session = self.__batch_process_storage.update_session_state(
batch_session.batch_id,
batch_session.session_id,
updateSession,
)
self.run_batch_process(batch_session.batch_id)
def _create_batch_session(self, batch_process: BatchProcess, batch_indices: list[int]) -> GraphExecutionState:
graph = copy.deepcopy(batch_process.graph)
batches = batch_process.batches
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes:
node = graph.get_node(npath)
(index, batch) = next(((i, b) for i, b in enumerate(batches) if b.node_id in node.id), (None, None))
if batch:
batch_index = batch_indices[index]
datum = batch.data[batch_index]
for key in datum:
node.__dict__[key] = datum[key]
graph.update_node(npath, node)
return GraphExecutionState(graph=graph)
def run_batch_process(self, batch_id: str):
try:
created_session = self.__batch_process_storage.get_created_session(batch_id)
except BatchSessionNotFoundException:
return
ges = self.__invoker.services.graph_execution_manager.get(created_session.session_id)
self.__invoker.invoke(ges, invoke_all=True)
def _valid_batch_config(self, batch_process: BatchProcess) -> bool:
return True
def create_batch_process(self, batches: list[Batch], graph: Graph) -> str:
batch_process = BatchProcess(
batches=batches,
graph=graph,
)
if not self._valid_batch_config(batch_process):
return None
batch_process = self.__batch_process_storage.save(batch_process)
self._create_sessions(batch_process)
return batch_process.batch_id
def _create_sessions(self, batch_process: BatchProcess):
batch_indices = list()
for batch in batch_process.batches:
batch_indices.append(list(range(len(batch.data))))
all_batch_indices = product(*batch_indices)
for bi in all_batch_indices:
ges = self._create_batch_session(batch_process, bi)
self.__invoker.services.graph_execution_manager.set(ges)
batch_session = BatchSession(
batch_id=batch_process.batch_id,
session_id=ges.id,
state="created"
)
self.__batch_process_storage.create_session(batch_session)
def cancel_batch_process(self, batch_process_id: str):
self.__batches = [batch for batch in self.__batches if batch.id != batch_process_id]

View File

@@ -0,0 +1,505 @@
from abc import ABC, abstractmethod
from typing import cast
import uuid
import sqlite3
import threading
from typing import (
Any,
List,
Literal,
Optional,
Union,
)
import json
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
)
from invokeai.app.services.graph import Graph
from invokeai.app.models.image import ImageField
from pydantic import BaseModel, Field, Extra, parse_raw_as
invocations = BaseInvocation.get_invocations()
InvocationsUnion = Union[invocations] # type: ignore
BatchDataType = Union[str, int, float, ImageField]
class Batch(BaseModel):
data: list[dict[str, BatchDataType]] = Field(description="Mapping of node field to data value")
node_id: str = Field(description="ID of the node to batch")
class BatchSession(BaseModel):
batch_id: str = Field(description="Identifier for which batch this Index belongs to")
session_id: str = Field(description="Session ID Created for this Batch Index")
state: Literal["created", "completed", "inprogress", "error"] = Field(
description="Is this session created, completed, in progress, or errored?"
)
def uuid_string():
res = uuid.uuid4()
return str(res)
class BatchProcess(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid_string, description="Identifier for this batch")
batches: List[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
graph: Graph = Field(description="The graph being executed")
class BatchSessionChanges(BaseModel, extra=Extra.forbid):
state: Literal["created", "completed", "inprogress", "error"] = Field(
description="Is this session created, completed, in progress, or errored?"
)
class BatchProcessNotFoundException(Exception):
"""Raised when an Batch Process record is not found."""
def __init__(self, message="BatchProcess record not found"):
super().__init__(message)
class BatchProcessSaveException(Exception):
"""Raised when an Batch Process record cannot be saved."""
def __init__(self, message="BatchProcess record not saved"):
super().__init__(message)
class BatchProcessDeleteException(Exception):
"""Raised when an Batch Process record cannot be deleted."""
def __init__(self, message="BatchProcess record not deleted"):
super().__init__(message)
class BatchSessionNotFoundException(Exception):
"""Raised when an Batch Session record is not found."""
def __init__(self, message="BatchSession record not found"):
super().__init__(message)
class BatchSessionSaveException(Exception):
"""Raised when an Batch Session record cannot be saved."""
def __init__(self, message="BatchSession record not saved"):
super().__init__(message)
class BatchSessionDeleteException(Exception):
"""Raised when an Batch Session record cannot be deleted."""
def __init__(self, message="BatchSession record not deleted"):
super().__init__(message)
class BatchProcessStorageBase(ABC):
"""Low-level service responsible for interfacing with the Batch Process record store."""
@abstractmethod
def delete(self, batch_id: str) -> None:
"""Deletes a Batch Process record."""
pass
@abstractmethod
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
"""Saves a Batch Process record."""
pass
@abstractmethod
def get(
self,
batch_id: str,
) -> BatchProcess:
"""Gets a Batch Process record."""
pass
@abstractmethod
def create_session(
self,
session: BatchSession,
) -> BatchSession:
"""Creates a Batch Session attached to a Batch Process."""
pass
@abstractmethod
def get_session(
self,
session_id: str
) -> BatchSession:
"""Gets session by session_id"""
pass
@abstractmethod
def get_created_session(
self,
batch_id: str
) -> BatchSession:
"""Gets all created Batch Sessions for a given Batch Process id."""
pass
@abstractmethod
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
"""Gets all created Batch Sessions for a given Batch Process id."""
pass
@abstractmethod
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
"""Updates the state of a Batch Session record."""
pass
class SqliteBatchProcessStorage(BatchProcessStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the `batch_process` table and `batch_session` junction table."""
# Create the `batch_process` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_process (
batch_id TEXT NOT NULL PRIMARY KEY,
batches TEXT NOT NULL,
graph TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_process_created_at ON batch_process (created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_process_updated_at
AFTER UPDATE
ON batch_process FOR EACH ROW
BEGIN
UPDATE batch_process SET updated_at = current_timestamp
WHERE batch_id = old.batch_id;
END;
"""
)
# Create the `batch_session` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS batch_session (
batch_id TEXT NOT NULL,
session_id TEXT NOT NULL,
state TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Soft delete, currently unused
deleted_at DATETIME,
-- enforce one-to-many relationship between batch_process and batch_session using PK
-- (we can extend this to many-to-many later)
PRIMARY KEY (batch_id,session_id),
FOREIGN KEY (batch_id) REFERENCES batch_process (batch_id) ON DELETE CASCADE
);
"""
)
# Add index for batch id
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id ON batch_session (batch_id);
"""
)
# Add index for batch id, sorted by created_at
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_batch_session_batch_id_created_at ON batch_session (batch_id,created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_batch_session_updated_at
AFTER UPDATE
ON batch_session FOR EACH ROW
BEGIN
UPDATE batch_session SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE batch_id = old.batch_id AND session_id = old.session_id;
END;
"""
)
def delete(self, batch_id: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
except Exception as e:
self._conn.rollback()
raise BatchProcessDeleteException from e
finally:
self._lock.release()
def save(
self,
batch_process: BatchProcess,
) -> BatchProcess:
try:
self._lock.acquire()
batches = [batch.json() for batch in batch_process.batches]
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_process (batch_id, batches, graph)
VALUES (?, ?, ?);
""",
(batch_process.batch_id, json.dumps(batches), batch_process.graph.json()),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessSaveException from e
finally:
self._lock.release()
return self.get(batch_process.batch_id)
def _deserialize_batch_process(self, session_dict: dict) -> BatchProcess:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
batches_raw = session_dict.get("batches", "unknown")
graph_raw = session_dict.get("graph", "unknown")
batches = json.loads(batches_raw)
batches = [parse_raw_as(Batch, batch) for batch in batches]
return BatchProcess(
batch_id=batch_id,
batches=batches,
graph=parse_raw_as(Graph, graph_raw),
)
def get(
self,
batch_id: str,
) -> BatchProcess:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_process
WHERE batch_id = ?;
""",
(batch_id,)
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchProcessNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchProcessNotFoundException
return self._deserialize_batch_process(dict(result))
def create_session(
self,
session: BatchSession,
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO batch_session (batch_id, session_id, state)
VALUES (?, ?, ?);
""",
(session.batch_id, session.session_id, session.state),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session.session_id)
def get_session(
self,
session_id: str
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE session_id= ?;
""",
(session_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
return self._deserialize_batch_session(dict(result))
def _deserialize_batch_session(self, session_dict: dict) -> BatchSession:
"""Deserializes a batch session."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
batch_id = session_dict.get("batch_id", "unknown")
session_id = session_dict.get("session_id", "unknown")
state = session_dict.get("state", "unknown")
return BatchSession(
batch_id=batch_id,
session_id=session_id,
state=state,
)
def get_created_session(
self,
batch_id: str
) -> BatchSession:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = 'created';
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
session = self._deserialize_batch_session(dict(result))
return session
def get_created_sessions(
self,
batch_id: str
) -> List[BatchSession]:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM batch_session
WHERE batch_id = ? AND state = created;
""",
(batch_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionNotFoundException from e
finally:
self._lock.release()
if result is None:
raise BatchSessionNotFoundException
sessions = list(map(lambda r: self._deserialize_batch_session(dict(r)), result))
return sessions
def update_session_state(
self,
batch_id: str,
session_id: str,
changes: BatchSessionChanges,
) -> BatchSession:
try:
self._lock.acquire()
# Change the state of a batch session
if changes.state is not None:
self._cursor.execute(
f"""--sql
UPDATE batch_session
SET state = ?
WHERE batch_id = ? AND session_id = ?;
""",
(changes.state, batch_id, session_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BatchSessionSaveException from e
finally:
self._lock.release()
return self.get_session(session_id)

View File

@@ -25,6 +25,7 @@ class BoardImageRecordStorageBase(ABC):
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
@@ -153,6 +154,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
try:
@@ -160,9 +162,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self._cursor.execute(
"""--sql
DELETE FROM board_images
WHERE image_name = ?;
WHERE board_id = ? AND image_name = ?;
""",
(image_name,),
(board_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:

View File

@@ -31,6 +31,7 @@ class BoardImagesServiceABC(ABC):
@abstractmethod
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
"""Removes an image from a board."""
@@ -92,9 +93,10 @@ class BoardImagesService(BoardImagesServiceABC):
def remove_image_from_board(
self,
board_id: str,
image_name: str,
) -> None:
self._services.board_image_records.remove_image_from_board(image_name)
self._services.board_image_records.remove_image_from_board(board_id, image_name)
def get_all_board_image_names_for_board(
self,

View File

@@ -24,10 +24,11 @@ InvokeAI:
sequential_guidance: false
precision: float16
max_cache_size: 6
max_vram_cache_size: 0.5
max_vram_cache_size: 2.7
always_use_cpu: false
free_gpu_mem: false
Features:
restore: true
esrgan: true
patchmatch: true
internet_available: true
@@ -164,7 +165,7 @@ import pydoc
import os
import sys
from argparse import ArgumentParser
from omegaconf import OmegaConf, DictConfig, ListConfig
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
@@ -172,7 +173,6 @@ from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_ty
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_MAX_VRAM = 0.5
class InvokeAISettings(BaseSettings):
@@ -189,12 +189,7 @@ class InvokeAISettings(BaseSettings):
opt = parser.parse_args(argv)
for name in self.__fields__:
if name not in self._excluded():
value = getattr(opt, name)
if isinstance(value, ListConfig):
value = list(value)
elif isinstance(value, DictConfig):
value = dict(value)
setattr(self, name, value)
setattr(self, name, getattr(opt, name))
def to_yaml(self) -> str:
"""
@@ -279,7 +274,7 @@ class InvokeAISettings(BaseSettings):
@classmethod
def _excluded(self) -> List[str]:
# internal fields that shouldn't be exposed as command line options
return ["type", "initconf"]
return ["type", "initconf", "cached_root"]
@classmethod
def _excluded_from_yaml(self) -> List[str]:
@@ -287,10 +282,15 @@ class InvokeAISettings(BaseSettings):
return [
"type",
"initconf",
"gpu_mem_reserved",
"max_loaded_models",
"version",
"from_file",
"model",
"restore",
"root",
"nsfw_checker",
"cached_root",
]
class Config:
@@ -356,7 +356,7 @@ class InvokeAISettings(BaseSettings):
def _find_root() -> Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ["INVOKEAI_ROOT"])
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
root = (venv.parent).resolve()
else:
@@ -389,17 +389,21 @@ class InvokeAIAppConfig(InvokeAISettings):
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance')
free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance')
max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED')
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category='Memory/Performance')
root : Path = Field(default=None, description='InvokeAI runtime root directory', category='Paths')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoimport_dir : Path = Field(default='autoimport', description='Path to a directory of models files to be imported on startup.', category='Paths')
lora_dir : Path = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', category='Paths')
embedding_dir : Path = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', category='Paths')
@@ -411,7 +415,8 @@ class InvokeAIAppConfig(InvokeAISettings):
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>"', category="Logging")
# note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues
@@ -419,11 +424,9 @@ class InvokeAIAppConfig(InvokeAISettings):
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
# fmt: on
class Config:
validate_assignment = True
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
"""
Update settings with contents of init file, environment, and
@@ -469,12 +472,15 @@ class InvokeAIAppConfig(InvokeAISettings):
"""
Path to the runtime root directory
"""
if self.root:
# we cache value of root to protect against it being '.' and the cwd changing
if self.cached_root:
root = self.cached_root
elif self.root:
root = Path(self.root).expanduser().absolute()
else:
root = self.find_root().expanduser().absolute()
self.root = root # insulate ourselves from relative paths that may change
return root
root = self.find_root()
self.cached_root = root
return self.cached_root
@property
def root_dir(self) -> Path:

View File

@@ -289,10 +289,9 @@ class ImageService(ImageServiceABC):
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try:
image_record = self._services.image_records.get(image_name)
metadata = self._services.image_records.get_metadata(image_name)
if not image_record.session_id:
return ImageMetadata(metadata=metadata)
return ImageMetadata()
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
graph = None
@@ -304,6 +303,7 @@ class ImageService(ImageServiceABC):
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
metadata = self._services.image_records.get_metadata(image_name)
return ImageMetadata(graph=graph, metadata=metadata)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.batch_manager import BatchManagerBase
from invokeai.app.services.board_images import BoardImagesServiceABC
from invokeai.app.services.boards import BoardServiceABC
from invokeai.app.services.images import ImageServiceABC
@@ -21,6 +22,7 @@ class InvocationServices:
"""Services that can be used by invocations"""
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
batch_manager: "BatchManagerBase"
board_images: "BoardImagesServiceABC"
boards: "BoardServiceABC"
configuration: "InvokeAIAppConfig"
@@ -32,11 +34,11 @@ class InvocationServices:
logger: "Logger"
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
def __init__(
self,
batch_manager: "BatchManagerBase",
board_images: "BoardImagesServiceABC",
boards: "BoardServiceABC",
configuration: "InvokeAIAppConfig",
@@ -48,9 +50,9 @@ class InvocationServices:
logger: "Logger",
model_manager: "ModelManagerServiceBase",
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
):
self.batch_manager = batch_manager
self.board_images = board_images
self.boards = boards
self.boards = boards
@@ -63,5 +65,4 @@ class InvocationServices:
self.logger = logger
self.model_manager = model_manager
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue

View File

@@ -1,223 +0,0 @@
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
"""Utility to collect execution time and GPU usage stats on invocations in flight"""
"""
Usage:
statistics = InvocationStatsService(graph_execution_manager)
with statistics.collect_stats(invocation, graph_execution_state.id):
... execute graphs...
statistics.log_stats()
Typical output:
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
writes to the system log is stored in InvocationServices.performance_statistics.
"""
import time
from abc import ABC, abstractmethod
from contextlib import AbstractContextManager
from dataclasses import dataclass, field
from typing import Dict
import torch
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from .graph import GraphExecutionState
from .item_storage import ItemStorageABC
class InvocationStatsServiceBase(ABC):
"Abstract base class for recording node memory/time performance statistics"
@abstractmethod
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
"""
Initialize the InvocationStatsService and reset counters to zero
:param graph_execution_manager: Graph execution manager for this session
"""
pass
@abstractmethod
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> AbstractContextManager:
"""
Return a context object that will capture the statistics on the execution
of invocaation. Use with: to place around the part of the code that executes the invocation.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
pass
@abstractmethod
def reset_stats(self, graph_execution_state_id: str):
"""
Reset all statistics for the indicated graph
:param graph_execution_state_id
"""
pass
@abstractmethod
def reset_all_stats(self):
"""Zero all statistics"""
pass
@abstractmethod
def update_invocation_stats(
self,
graph_id: str,
invocation_type: str,
time_used: float,
vram_used: float,
):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Time used by node's exection (sec)
:param vram_used: Maximum VRAM used during exection (GB)
"""
pass
@abstractmethod
def log_stats(self):
"""
Write out the accumulated statistics to the log or somewhere else.
"""
pass
@dataclass
class NodeStats:
"""Class for tracking execution stats of an invocation node"""
calls: int = 0
time_used: float = 0.0 # seconds
max_vram: float = 0.0 # GB
@dataclass
class NodeLog:
"""Class for tracking node usage"""
# {node_type => NodeStats}
nodes: Dict[str, NodeStats] = field(default_factory=dict)
class InvocationStatsService(InvocationStatsServiceBase):
"""Accumulate performance information about a running graph. Collects time spent in each node,
as well as the maximum and current VRAM utilisation for CUDA systems"""
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
self.graph_execution_manager = graph_execution_manager
# {graph_id => NodeLog}
self._stats: Dict[str, NodeLog] = {}
class StatsContext:
def __init__(self, invocation: BaseInvocation, graph_id: str, collector: "InvocationStatsServiceBase"):
self.invocation = invocation
self.collector = collector
self.graph_id = graph_id
self.start_time = 0
def __enter__(self):
self.start_time = time.time()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def __exit__(self, *args):
self.collector.update_invocation_stats(
self.graph_id,
self.invocation.type,
time.time() - self.start_time,
torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0,
)
def collect_stats(
self,
invocation: BaseInvocation,
graph_execution_state_id: str,
) -> StatsContext:
"""
Return a context object that will capture the statistics.
:param invocation: BaseInvocation object from the current graph.
:param graph_execution_state: GraphExecutionState object from the current session.
"""
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
self._stats[graph_execution_state_id] = NodeLog()
return self.StatsContext(invocation, graph_execution_state_id, self)
def reset_all_stats(self):
"""Zero all statistics"""
self._stats = {}
def reset_stats(self, graph_execution_id: str):
"""Zero the statistics for the indicated graph."""
try:
self._stats.pop(graph_execution_id)
except KeyError:
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}")
def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float):
"""
Add timing information on execution of a node. Usually
used internally.
:param graph_id: ID of the graph that is currently executing
:param invocation_type: String literal type of the node
:param time_used: Floating point seconds used by node's exection
"""
if not self._stats[graph_id].nodes.get(invocation_type):
self._stats[graph_id].nodes[invocation_type] = NodeStats()
stats = self._stats[graph_id].nodes[invocation_type]
stats.calls += 1
stats.time_used += time_used
stats.max_vram = max(stats.max_vram, vram_used)
def log_stats(self):
"""
Send the statistics to the system logger at the info level.
Stats will only be printed if when the execution of the graph
is complete.
"""
completed = set()
for graph_id, node_log in self._stats.items():
current_graph_state = self.graph_execution_manager.get(graph_id)
if not current_graph_state.is_complete():
continue
total_time = 0
logger.info(f"Graph stats: {graph_id}")
logger.info("Node Calls Seconds VRAM Used")
for node_type, stats in self._stats[graph_id].nodes.items():
logger.info(f"{node_type:<20} {stats.calls:>5} {stats.time_used:7.3f}s {stats.max_vram:4.2f}G")
total_time += stats.time_used
logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s")
if torch.cuda.is_available():
logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9))
completed.add(graph_id)
for graph_id in completed:
del self._stats[graph_id]

View File

@@ -3,10 +3,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from logging import Logger
from pathlib import Path
from pydantic import Field
from typing import Literal, Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING
from types import ModuleType
from invokeai.backend.model_management import (
@@ -194,7 +193,7 @@ class ModelManagerServiceBase(ABC):
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
model_type: Union[ModelType.Main, ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
@@ -293,7 +292,7 @@ class ModelManagerService(ModelManagerServiceBase):
def __init__(
self,
config: InvokeAIAppConfig,
logger: Logger,
logger: ModuleType,
):
"""
Initialize with the path to the models.yaml config file.
@@ -397,7 +396,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type,
)
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Given a model name returns a dict-like (OmegaConf) object describing it.
"""
@@ -417,7 +416,7 @@ class ModelManagerService(ModelManagerServiceBase):
"""
return self.mgr.list_models(base_model, model_type)
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
Return information about the model using the same format as list_models()
"""
@@ -430,7 +429,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> AddModelResult:
) -> None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@@ -479,7 +478,7 @@ class ModelManagerService(ModelManagerServiceBase):
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
model_type: Union[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
@@ -574,9 +573,9 @@ class ModelManagerService(ModelManagerServiceBase):
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: float = 0.5,
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: bool = False,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
@@ -634,8 +633,8 @@ class ModelManagerService(ModelManagerServiceBase):
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
new_name: str = None,
new_base: BaseModelType = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.

View File

@@ -1,8 +0,0 @@
from pydantic import Field
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class BoardImage(BaseModelExcludeNull):
board_id: str = Field(description="The id of the board")
image_name: str = Field(description="The name of the image")

View File

@@ -1,11 +1,10 @@
from typing import Optional, Union
from datetime import datetime
from pydantic import Field
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class BoardRecord(BaseModelExcludeNull):
class BoardRecord(BaseModel):
"""Deserialized board record."""
board_id: str = Field(description="The unique ID of the board.")

View File

@@ -1,14 +1,13 @@
import datetime
from typing import Optional, Union
from pydantic import Extra, Field, StrictBool, StrictStr
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class ImageRecord(BaseModelExcludeNull):
class ImageRecord(BaseModel):
"""Deserialized image record without metadata."""
image_name: str = Field(description="The unique name of the image.")
@@ -41,7 +40,7 @@ class ImageRecord(BaseModelExcludeNull):
"""The node ID that generated this image, if it is a generated image."""
class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""A set of changes to apply to an image record.
Only limited changes are valid:
@@ -61,7 +60,7 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
"""The image's new `is_intermediate` flag."""
class ImageUrlsDTO(BaseModelExcludeNull):
class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail."""
image_name: str = Field(description="The unique name of the image.")
@@ -77,15 +76,11 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
"""The id of the board the image belongs to, if one exists."""
pass
def image_record_to_dto(
image_record: ImageRecord,
image_url: str,
thumbnail_url: str,
board_id: Optional[str],
image_record: ImageRecord, image_url: str, thumbnail_url: str, board_id: Optional[str]
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(

View File

@@ -1,14 +1,13 @@
import time
import traceback
from threading import BoundedSemaphore, Event, Thread
import invokeai.backend.util.logging as logger
from threading import Event, Thread, BoundedSemaphore
from ..invocations.baseinvocation import InvocationContext
from ..models.exceptions import CanceledException
from .invocation_queue import InvocationQueueItem
from .invocation_stats import InvocationStatsServiceBase
from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
import invokeai.backend.util.logging as logger
class DefaultInvocationProcessor(InvocationProcessorABC):
@@ -36,8 +35,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
while not stop_event.is_set():
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
@@ -86,38 +83,35 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Invoke
try:
with statistics.collect_stats(invocation, graph_execution_state.id):
outputs = invocation.invoke(
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
)
)
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Save outputs and history
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
outputs = invocation.invoke(
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
result=outputs.dict(),
)
statistics.log_stats()
)
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
continue
# Save outputs and history
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
result=outputs.dict(),
)
except KeyboardInterrupt:
pass
except CanceledException:
statistics.reset_stats(graph_execution_state.id)
pass
except Exception as e:
@@ -139,7 +133,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error_type=e.__class__.__name__,
error=error,
)
statistics.reset_stats(graph_execution_state.id)
pass
# Check queue to see if this is canceled, and skip if so

View File

@@ -29,6 +29,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._conn.execute('pragma journal_mode=wal')
self._cursor = self._conn.cursor()
self._create_table()

View File

@@ -20,6 +20,6 @@ class LocalUrlService(UrlServiceBase):
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
return f"{self._base_url}/images/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/full"
return f"{self._base_url}/images/{image_basename}/full"

View File

@@ -18,5 +18,5 @@ SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed():
rng = np.random.default_rng(seed=None)
rng = np.random.default_rng(seed=0)
return int(rng.integers(0, SEED_MAX))

View File

@@ -1,23 +0,0 @@
from typing import Any
from pydantic import BaseModel
"""
We want to exclude null values from objects that make their way to the client.
Unfortunately there is no built-in way to do this in pydantic, so we need to override the default
dict method to do this.
From https://github.com/tiangolo/fastapi/discussions/8882#discussioncomment-5154541
"""
class BaseModelExcludeNull(BaseModel):
def dict(self, *args, **kwargs) -> dict[str, Any]:
"""
Override the default dict method to exclude None values in the response
"""
kwargs.pop("exclude_none", None)
return super().dict(*args, exclude_none=True, **kwargs)
pass

View File

@@ -12,17 +12,16 @@ def check_invokeai_root(config: InvokeAIAppConfig):
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
assert config.models_path.exists(), f"{config.models_path} not found"
if not config.ignore_missing_core_models:
for model in [
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
"bert-base-uncased",
"clip-vit-large-patch14",
"sd-vae-ft-mse",
"stable-diffusion-2-clip",
"stable-diffusion-safety-checker",
]:
path = config.models_path / f"core/convert/{model}"
assert path.exists(), f"{path} is missing"
for model in [
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
"bert-base-uncased",
"clip-vit-large-patch14",
"sd-vae-ft-mse",
"stable-diffusion-2-clip",
"stable-diffusion-safety-checker",
]:
path = config.models_path / f"core/convert/{model}"
assert path.exists(), f"{path} is missing"
except Exception as e:
print()
print(f"An exception has occurred: {str(e)}")
@@ -33,10 +32,5 @@ def check_invokeai_root(config: InvokeAIAppConfig):
print(
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
)
print(
'** (To skip this check completely, add "--ignore_missing_core_models" to your CLI args. Not installing '
"these core models will prevent the loading of some or all .safetensors and .ckpt files. However, you can "
"always come back and install these core models in the future.)"
)
input("Press any key to continue...")
sys.exit(0)

View File

@@ -10,15 +10,12 @@ import sys
import argparse
import io
import os
import psutil
import shutil
import textwrap
import torch
import traceback
import yaml
import warnings
from argparse import Namespace
from enum import Enum
from pathlib import Path
from shutil import get_terminal_size
from typing import get_type_hints
@@ -47,8 +44,6 @@ from invokeai.app.services.config import (
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
# TO DO - Move all the frontend code into invokeai.frontend.install
from invokeai.frontend.install.widgets import (
SingleSelectColumns,
CenteredButtonPress,
@@ -58,7 +53,6 @@ from invokeai.frontend.install.widgets import (
CyclingForm,
MIN_COLS,
MIN_LINES,
WindowTooSmallException,
)
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
from invokeai.backend.install.model_install_backend import (
@@ -67,7 +61,6 @@ from invokeai.backend.install.model_install_backend import (
ModelInstall,
)
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
from pydantic.error_wrappers import ValidationError
warnings.filterwarnings("ignore")
transformers.logging.set_verbosity_error()
@@ -83,13 +76,6 @@ Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
PRECISION_CHOICES = ["auto", "float16", "float32"]
GB = 1073741824 # GB in bytes
HAS_CUDA = torch.cuda.is_available()
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
MAX_VRAM /= GB
MAX_RAM = psutil.virtual_memory().total / GB
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# This is the InvokeAI initialization file, which contains command-line default values.
@@ -100,12 +86,6 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
logger = InvokeAILogger.getLogger()
class DummyWidgetValue(Enum):
zero = 0
true = True
false = False
# --------------------------------------------
def postscript(errors: None):
if not any(errors):
@@ -398,35 +378,13 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
)
self.max_cache_size = self.add_widget_intelligent(
IntTitleSlider,
name="RAM cache size (GB). Make this at least large enough to hold a single full model.",
name="Size of the RAM cache used for fast model switching (GB)",
value=old_opts.max_cache_size,
out_of=MAX_RAM,
out_of=20,
lowest=3,
begin_entry_at=6,
scroll_exit=True,
)
if HAS_CUDA:
self.nextrely += 1
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="VRAM cache size (GB). Reserving a small amount of VRAM will modestly speed up the start of image generation.",
begin_entry_at=0,
editable=False,
color="CONTROL",
scroll_exit=True,
)
self.nextrely -= 1
self.max_vram_cache_size = self.add_widget_intelligent(
npyscreen.Slider,
value=old_opts.max_vram_cache_size,
out_of=round(MAX_VRAM * 2) / 2,
lowest=0.0,
relx=8,
step=0.25,
scroll_exit=True,
)
else:
self.max_vram_cache_size = DummyWidgetValue.zero
self.nextrely += 1
self.outdir = self.add_widget_intelligent(
FileBox,
@@ -443,7 +401,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
self.autoimport_dirs = {}
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
FileBox,
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
value=str(config.root_path / config.autoimport_dir),
select_dir=True,
must_exist=False,
@@ -518,7 +476,6 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
"outdir",
"free_gpu_mem",
"max_cache_size",
"max_vram_cache_size",
"xformers_enabled",
"always_use_cpu",
]:
@@ -635,13 +592,13 @@ def maybe_create_models_yaml(root: Path):
# -------------------------------------
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
# parse_args() will read from init file if present
invokeai_opts = default_startup_options(initfile)
invokeai_opts.root = program_opts.root
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
# The third argument is needed in the Windows 11 environment to
# launch a console window running this program.
set_min_terminal_size(MIN_COLS, MIN_LINES)
# the install-models application spawns a subprocess to install
# models, and will crash unless this is set before running.
@@ -697,13 +654,10 @@ def migrate_init_file(legacy_format: Path):
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
new = InvokeAIAppConfig.get_config()
fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"]
fields = list(get_type_hints(InvokeAIAppConfig).keys())
for attr in fields:
if hasattr(old, attr):
try:
setattr(new, attr, getattr(old, attr))
except ValidationError as e:
print(f"* Ignoring incompatible value for field {attr}:\n {str(e)}")
setattr(new, attr, getattr(old, attr))
# a few places where the field names have changed and we have to
# manually add in the new names/values
@@ -823,7 +777,6 @@ def main():
models_to_download = default_user_selections(opt)
new_init_file = config.root_path / "invokeai.yaml"
if opt.yes_to_all:
write_default_options(opt, new_init_file)
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
@@ -849,8 +802,6 @@ def main():
postscript(errors=errors)
if not opt.yes_to_all:
input("Press any key to continue...")
except WindowTooSmallException as e:
logger.error(str(e))
except KeyboardInterrupt:
print("\nGoodbye! Come back soon.")

View File

@@ -13,7 +13,6 @@ import requests
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
import onnx
import torch
from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf
from tqdm import tqdm
@@ -24,7 +23,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore")
@@ -101,9 +99,9 @@ class ModelInstall(object):
def __init__(
self,
config: InvokeAIAppConfig,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
model_manager: Optional[ModelManager] = None,
access_token: Optional[str] = None,
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
model_manager: ModelManager = None,
access_token: str = None,
):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)
@@ -305,7 +303,7 @@ class ModelInstall(object):
with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging)
if "model_index.json" in files:
if "model_index.json" in files and "unet/model.onnx" not in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif "unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)
@@ -418,25 +416,15 @@ class ModelInstall(object):
does a save_pretrained() to the indicated staging area.
"""
_, name = repo_id.split("/")
precision = torch_dtype(choose_torch_device())
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
model = None
for variant in variants:
for revision in revisions:
try:
model = DiffusionPipeline.from_pretrained(
repo_id,
variant=variant,
torch_dtype=precision,
safety_checker=None,
)
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
if "fp16" not in str(e):
print(e)
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
except: # most errors are due to fp16 not being present. Fix this to catch other errors
pass
if model:
break
if not model:
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
return None

View File

@@ -13,4 +13,3 @@ from .models import (
DuplicateModelException,
)
from .model_merge import ModelMerger, MergeInterpolationMethod
from .lora import ModelPatcher

View File

@@ -20,6 +20,424 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
# TODO: rename and split this file
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoRAModel: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> LoRAModel:
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
else:
# TODO: diff/ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
return
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
"""
loras = [
(lora_model1, 0.7),
@@ -98,26 +516,6 @@ class ModelPatcher:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@classmethod
@contextmanager
def apply_lora(
@@ -164,7 +562,7 @@ class ModelPatcher:
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, Any]],
ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
init_tokens_count = None
new_tokens_added = None
@@ -174,27 +572,27 @@ class ModelPatcher:
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
def _get_trigger(ti_name, index):
trigger = ti_name
def _get_trigger(ti, index):
trigger = ti.name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
for ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
for ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i]
trigger = _get_trigger(ti_name, i)
trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
@@ -239,6 +637,7 @@ class ModelPatcher:
class TextualInversionModel:
name: str
embedding: torch.Tensor # [n, 768]|[n, 1280]
@classmethod
@@ -252,6 +651,7 @@ class TextualInversionModel:
file_path = Path(file_path)
result = cls() # TODO:
result.name = file_path.stem # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
@@ -428,7 +828,7 @@ class ONNXModelPatcher:
cls,
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]],
ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel
@@ -441,17 +841,17 @@ class ONNXModelPatcher:
ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index):
trigger = ti_name
def _get_trigger(ti, index):
trigger = ti.name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
for ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
@@ -461,11 +861,11 @@ class ONNXModelPatcher:
axis=0,
)
for ti_name, ti in ti_list:
for ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy()
trigger = _get_trigger(ti_name, i)
trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:

View File

@@ -28,6 +28,8 @@ import torch
import logging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs
@@ -186,7 +188,7 @@ class ModelCache(object):
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
f"Loading model {model_path}, type {base_model.value}:{model_type.value}:{submodel.value if submodel else ''}"
)
# this will remove older cached models until

View File

@@ -228,19 +228,19 @@ the root is the InvokeAI ROOTDIR.
"""
from __future__ import annotations
import hashlib
import os
import hashlib
import textwrap
import types
import yaml
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types
from shutil import rmtree, move
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
import torch
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger
@@ -259,7 +259,6 @@ from .models import (
ModelNotFoundException,
InvalidModelException,
DuplicateModelException,
ModelBase,
)
# We are only starting to number the config file with release 3.
@@ -362,7 +361,7 @@ class ModelManager(object):
if model_key.startswith("_"):
continue
model_name, base_model, model_type = self.parse_key(model_key)
model_class = self._get_implementation(base_model, model_type)
model_class = MODEL_CLASSES[base_model][model_type]
# alias for config file
model_config["model_format"] = model_config.pop("format")
self.models[model_key] = model_class.create_config(**model_config)
@@ -382,24 +381,18 @@ class ModelManager(object):
# causing otherwise unreferenced models to be removed from memory
self._read_models()
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
def model_exists(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> bool:
"""
Given a model name, returns True if it is a valid identifier.
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param rescan: if True, scan_models_directory
Given a model name, returns True if it is a valid
identifier.
"""
model_key = self.create_key(model_name, base_model, model_type)
exists = model_key in self.models
# if model not found try to find it (maybe file just pasted)
if rescan and not exists:
self.scan_models_directory(base_model=base_model, model_type=model_type)
exists = self.model_exists(model_name, base_model, model_type, rescan=False)
return exists
return model_key in self.models
@classmethod
def create_key(
@@ -450,32 +443,39 @@ class ModelManager(object):
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param submodel_type: an ModelType enum indicating the portion of
:param submode_typel: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_class = MODEL_CLASSES[base_model][model_type]
model_key = self.create_key(model_name, base_model, model_type)
if not self.model_exists(model_name, base_model, model_type, rescan=True):
raise ModelNotFoundException(f"Model not found - {model_key}")
# if model not found try to find it (maybe file just pasted)
if model_key not in self.models:
self.scan_models_directory(base_model=base_model, model_type=model_type)
if model_key not in self.models:
raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(base_model, model_type)
model_config = self.models[model_key]
model_path = self.resolve_model_path(model_config.path)
if not model_path.exists():
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
raise Exception(f'Files for model "{model_key}" not found')
else:
self.models.pop(model_key, None)
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.app_config.root_path / override_path
model_type = submodel_type
submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
# TODO: path
# TODO: is it accurate to use path as id
@@ -513,61 +513,12 @@ class ModelManager(object):
_cache=self.cache,
)
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self.resolve_model_path(model_path)
return model_path, is_submodel_override
def _get_model_config(self, base_model: BaseModelType, model_name: str, model_type: ModelType) -> ModelConfigBase:
"""Get a model's config object."""
model_key = self.create_key(model_name, base_model, model_type)
try:
model_config = self.models[model_key]
except KeyError:
raise ModelNotFoundException(f"Model not found - {model_key}")
return model_config
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _instantiate(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelBase:
"""Make a new instance of this model, without loading it."""
model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
# FIXME: do non-overriden submodels get the right class?
constructor = self._get_implementation(base_model, model_type)
instance = constructor(model_path, base_model, model_type)
return instance
def model_info(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> Union[dict, None]:
) -> dict:
"""
Given a model name returns the OmegaConf (dict-like) object describing it.
"""
@@ -589,16 +540,13 @@ class ModelManager(object):
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
) -> Union[dict, None]:
) -> dict:
"""
Returns a dict describing one installed model, using
the combined format of the list_models() method.
"""
models = self.list_models(base_model, model_type, model_name)
if len(models) >= 1:
return models[0]
else:
return None
return models[0] if models else None
def list_models(
self,
@@ -612,7 +560,7 @@ class ModelManager(object):
model_keys = (
[self.create_key(model_name, base_model, model_type)]
if model_name and base_model and model_type
if model_name
else sorted(self.models, key=str.casefold)
)
models = []
@@ -648,7 +596,7 @@ class ModelManager(object):
Print a table of models and their descriptions. This needs to be redone
"""
# TODO: redo
for model_dict in self.list_models():
for model_type, model_dict in self.list_models().items():
for model_name, model_info in model_dict.items():
line = f'{model_info["name"]:25s} {model_info["type"]:10s} {model_info["description"]}'
print(line)
@@ -710,7 +658,7 @@ class ModelManager(object):
if path := model_attributes.get("path"):
model_attributes["path"] = str(self.relative_model_path(Path(path)))
model_class = self._get_implementation(base_model, model_type)
model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type)
@@ -722,7 +670,7 @@ class ModelManager(object):
# TODO: if path changed and old_model.path inside models folder should we delete this too?
# remove conversion cache as config changed
old_model_path = self.resolve_model_path(old_model.path)
old_model_path = self.app_config.root_path / old_model.path
old_model_cache = self._get_model_cache_path(old_model_path)
if old_model_cache.exists():
if old_model_cache.is_dir():
@@ -751,8 +699,8 @@ class ModelManager(object):
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: Optional[str] = None,
new_base: Optional[BaseModelType] = None,
new_name: str = None,
new_base: BaseModelType = None,
):
"""
Rename or rebase a model.
@@ -805,7 +753,7 @@ class ModelManager(object):
self,
model_name: str,
base_model: BaseModelType,
model_type: Literal[ModelType.Main, ModelType.Vae],
model_type: Union[ModelType.Main, ModelType.Vae],
dest_directory: Optional[Path] = None,
) -> AddModelResult:
"""
@@ -819,10 +767,6 @@ class ModelManager(object):
This will raise a ValueError unless the model is a checkpoint.
"""
info = self.model_info(model_name, base_model, model_type)
if info is None:
raise FileNotFoundError(f"model not found: {model_name}")
if info["model_format"] != "checkpoint":
raise ValueError(f"not a checkpoint format model: {model_name}")
@@ -836,7 +780,7 @@ class ModelManager(object):
model_type,
**submodel,
)
checkpoint_path = self.resolve_model_path(info["path"])
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.resolve_model_path(model.location)
new_diffusers_path = (
dest_directory or self.app_config.models_path / base_model.value / model_type.value
@@ -892,7 +836,7 @@ class ModelManager(object):
return search_folder, found_models
def commit(self, conf_file: Optional[Path] = None) -> None:
def commit(self, conf_file: Path = None) -> None:
"""
Write current configuration out to the indicated file.
"""
@@ -901,7 +845,7 @@ class ModelManager(object):
for model_key, model_config in self.models.items():
model_name, base_model, model_type = self.parse_key(model_key)
model_class = self._get_implementation(base_model, model_type)
model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config:
# TODO: or exclude_unset better fits here?
data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"})
@@ -959,7 +903,7 @@ class ModelManager(object):
model_path = self.resolve_model_path(model_config.path).absolute()
if not model_path.exists():
model_class = self._get_implementation(cur_base_model, cur_model_type)
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
if model_class.save_to_config:
model_config.error = ModelError.NotFound
self.models.pop(model_key, None)
@@ -975,7 +919,7 @@ class ModelManager(object):
for cur_model_type in ModelType:
if model_type is not None and cur_model_type != model_type:
continue
model_class = self._get_implementation(cur_base_model, cur_model_type)
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
if not models_dir.exists():
@@ -991,9 +935,7 @@ class ModelManager(object):
raise DuplicateModelException(f"Model with key {model_key} added twice")
model_path = self.relative_model_path(model_path)
model_config: ModelConfigBase = model_class.probe_config(
str(model_path), model_base=cur_base_model
)
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
self.models[model_key] = model_config
new_models_found = True
except DuplicateModelException as e:
@@ -1041,7 +983,7 @@ class ModelManager(object):
# LS: hacky
# Patch in the SD VAE from core so that it is available for use by the UI
try:
self.heuristic_import({str(self.resolve_model_path("core/convert/sd-vae-ft-mse"))})
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
except:
pass
@@ -1050,7 +992,7 @@ class ModelManager(object):
model_manager=self,
prediction_type_helper=ask_user_for_prediction_type,
)
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
known_paths = {config.root_path / x["path"] for x in self.list_models()}
directories = {
config.root_path / x
for x in [
@@ -1069,7 +1011,7 @@ class ModelManager(object):
def heuristic_import(
self,
items_to_import: Set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
) -> Dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.

View File

@@ -33,7 +33,7 @@ class ModelMerger(object):
self,
model_paths: List[Path],
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
interp: MergeInterpolationMethod = None,
force: bool = False,
**kwargs,
) -> DiffusionPipeline:
@@ -73,7 +73,7 @@ class ModelMerger(object):
base_model: Union[BaseModelType, str],
merged_model_name: str,
alpha: float = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
interp: MergeInterpolationMethod = None,
force: bool = False,
merge_dest_directory: Optional[Path] = None,
**kwargs,
@@ -122,7 +122,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
merged_pipe.save_pretrained(dump_path, safe_serialization=True)
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
attributes = dict(
path=str(dump_path),
description=f"Merge of models {', '.join(model_names)}",

View File

@@ -315,38 +315,21 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
# misclassified as SD-1
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
if key in checkpoint and checkpoint[key].shape[0] == 320:
return BaseModelType.StableDiffusion2
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
if key in checkpoint:
return BaseModelType.StableDiffusionXL
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = (
checkpoint[key1].shape[1]
if key1 in checkpoint
else checkpoint[key2].shape[1]
else checkpoint[key2].shape[0]
if key2 in checkpoint
else checkpoint[key3].shape[0]
if key3 in checkpoint
else None
else 768
)
if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2
else:
raise InvalidModelException(f"Unknown LoRA type")
return None
class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@@ -292,9 +292,8 @@ class DiffusersModel(ModelBase):
)
break
except Exception as e:
if not str(e).startswith("Error no file"):
print("====ERR LOAD====")
print(f"{variant}: {e}")
# print("====ERR LOAD====")
# print(f"{variant}: {e}")
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")

View File

@@ -1,9 +1,7 @@
import os
import torch
from enum import Enum
from typing import Optional, Dict, Union, Literal, Any
from pathlib import Path
from safetensors.torch import load_file
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
@@ -15,6 +13,9 @@ from .base import (
ModelNotFoundException,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
@@ -49,7 +50,6 @@ class LoRAModel(ModelBase):
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
base_model=self.base_model,
)
self.model_size = model.calc_size()
@@ -87,582 +87,3 @@ class LoRAModel(ModelBase):
raise NotImplementedError("Diffusers lora not supported")
else:
return model_path
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# weight: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["diff"]
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.rank = None # unscaled
def get_weight(self):
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_compvis_keys(cls, state_dict):
new_state_dict = dict()
for full_key, value in state_dict.items():
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
continue # clip same
if not full_key.startswith("lora_unet_"):
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
src_key = full_key.replace("lora_unet_", "")
try:
dst_key = None
while "_" in src_key:
if src_key in SDXL_UNET_COMPVIS_MAP:
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
break
src_key = "_".join(src_key.split("_")[:-1])
if dst_key is None:
raise Exception(f"Unknown sdxl lora key - {full_key}")
new_key = full_key.replace(src_key, dst_key)
except:
print(SDXL_UNET_COMPVIS_MAP)
raise
new_state_dict[new_key] = value
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
elif "diff" in values:
layer = FullLayer(layer_key, values)
else:
# TODO: ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map():
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
SDXL_UNET_COMPVIS_MAP = {
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@@ -80,10 +80,8 @@ class StableDiffusionXLModel(DiffusersModel):
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
# avoid circular import
from .stable_diffusion import _select_ckpt_config
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
# TO DO: implement picking
pass
return cls.create_config(
path=path,

View File

@@ -4,7 +4,6 @@ from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from .base import (
ModelConfigBase,
BaseModelType,
@@ -264,8 +263,6 @@ def _convert_ckpt_and_cache(
weights = app_config.models_path / model_config.path
config_file = app_config.root_path / model_config.config
output_path = Path(output_path)
variant = model_config.variant
pipeline_class = StableDiffusionInpaintPipeline if variant == "inpaint" else StableDiffusionPipeline
# return cached version if it exists
if output_path.exists():
@@ -292,7 +289,6 @@ def _convert_ckpt_and_cache(
original_config_file=config_file,
extract_ema=True,
scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=weights.suffix == ".safetensors",
precision=torch_dtype(choose_torch_device()),
**kwargs,

View File

@@ -1,14 +1,9 @@
import os
import torch
import safetensors
from enum import Enum
from pathlib import Path
from typing import Optional
import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from typing import Optional, Union, Literal
from .base import (
ModelBase,
ModelConfigBase,
@@ -23,6 +18,9 @@ from .base import (
InvalidModelException,
ModelNotFoundException,
)
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum):
@@ -82,7 +80,7 @@ class VaeModel(ModelBase):
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException(f"Does not exist as local file: {path}")
raise ModelNotFoundException()
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):

View File

@@ -78,9 +78,10 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None
self.sequential_guidance = config.sequential_guidance
@classmethod
@contextmanager
def custom_attention_context(
self,
cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
@@ -90,19 +91,18 @@ class InvokeAIDiffuserComponent:
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
self.cross_attention_control_context = Context(
cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
cross_attention_control_context,
)
try:
yield None
finally:
self.cross_attention_control_context = None
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
from contextlib import nullcontext
from packaging import version
import platform
import torch
from torch import autocast
@@ -32,7 +30,7 @@ def choose_precision(device: torch.device) -> str:
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
return "float16"
elif device.type == "mps" and version.parse(platform.mac_ver()[0]) < version.parse("14.0.0"):
elif device.type == "mps":
return "float16"
return "float32"

View File

@@ -28,6 +28,7 @@ from npyscreen import widget
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import (
ModelInstallList,
InstallSelections,
ModelInstall,
SchedulerPredictionType,
@@ -40,12 +41,12 @@ from invokeai.frontend.install.widgets import (
SingleSelectColumns,
TextBox,
BufferBox,
FileBox,
set_min_terminal_size,
select_stable_diffusion_config_file,
CyclingForm,
MIN_COLS,
MIN_LINES,
WindowTooSmallException,
)
from invokeai.app.services.config import InvokeAIAppConfig
@@ -155,7 +156,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
BufferBox,
name="Log Messages",
editable=False,
max_height=6,
max_height=15,
)
self.nextrely += 1
@@ -692,11 +693,7 @@ def select_and_download_models(opt: Namespace):
# needed to support the probe() method running under a subprocess
torch.multiprocessing.set_start_method("spawn")
if not set_min_terminal_size(MIN_COLS, MIN_LINES):
raise WindowTooSmallException(
"Could not increase terminal size. Try running again with a larger window or smaller font size."
)
set_min_terminal_size(MIN_COLS, MIN_LINES)
installApp = AddModelApplication(opt)
try:
installApp.run()
@@ -790,8 +787,6 @@ def main():
curses.echo()
curses.endwin()
logger.info("Goodbye! Come back soon.")
except WindowTooSmallException as e:
logger.error(str(e))
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
logger.error("Insufficient vertical space for the interface. Please make your window taller and try again")

View File

@@ -21,40 +21,31 @@ MIN_COLS = 130
MIN_LINES = 38
class WindowTooSmallException(Exception):
pass
# -------------------------------------
def set_terminal_size(columns: int, lines: int) -> bool:
def set_terminal_size(columns: int, lines: int):
ts = get_terminal_size()
width = max(columns, ts.columns)
height = max(lines, ts.lines)
OS = platform.uname().system
screen_ok = False
while not screen_ok:
ts = get_terminal_size()
width = max(columns, ts.columns)
height = max(lines, ts.lines)
if OS == "Windows":
pass
# not working reliably - ask user to adjust the window
# _set_terminal_size_powershell(width,height)
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width, height)
if OS == "Windows":
pass
# not working reliably - ask user to adjust the window
# _set_terminal_size_powershell(width,height)
elif OS in ["Darwin", "Linux"]:
_set_terminal_size_unix(width, height)
# check whether it worked....
ts = get_terminal_size()
if ts.columns < columns or ts.lines < lines:
print(
f"\033[1mThis window is too small for the interface. InvokeAI requires {columns}x{lines} (w x h) characters, but window is {ts.columns}x{ts.lines}\033[0m"
)
resp = input(
"Maximize the window and/or decrease the font size then press any key to continue. Type [Q] to give up.."
)
if resp.upper().startswith("Q"):
break
else:
screen_ok = True
return screen_ok
# check whether it worked....
ts = get_terminal_size()
pause = False
if ts.columns < columns:
print("\033[1mThis window is too narrow for the user interface.\033[0m")
pause = True
if ts.lines < lines:
print("\033[1mThis window is too short for the user interface.\033[0m")
pause = True
if pause:
input("Maximize the window then press any key to continue..")
def _set_terminal_size_powershell(width: int, height: int):
@@ -89,14 +80,14 @@ def _set_terminal_size_unix(width: int, height: int):
sys.stdout.flush()
def set_min_terminal_size(min_cols: int, min_lines: int) -> bool:
def set_min_terminal_size(min_cols: int, min_lines: int):
# make sure there's enough room for the ui
term_cols, term_lines = get_terminal_size()
if term_cols >= min_cols and term_lines >= min_lines:
return True
return
cols = max(term_cols, min_cols)
lines = max(term_lines, min_lines)
return set_terminal_size(cols, lines)
set_terminal_size(cols, lines)
class IntSlider(npyscreen.Slider):
@@ -173,7 +164,7 @@ class FloatSlider(npyscreen.Slider):
class FloatTitleSlider(npyscreen.TitleText):
_entry_type = npyscreen.Slider
_entry_type = FloatSlider
class SelectColumnBase:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
import{B as m,g7 as Je,A as y,a5 as Ka,g8 as Xa,af as va,aj as d,g9 as b,ga as t,gb as Ya,gc as h,gd as ua,ge as Ja,gf as Qa,aL as Za,gg as et,ad as rt,gh as at}from"./index-de589048.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./menu-11348abc.js";var za=String.raw,Ca=za`
import{A as m,f$ as Je,z as y,a4 as Ka,g0 as Xa,af as va,aj as d,g1 as b,g2 as t,g3 as Ya,g4 as h,g5 as ua,g6 as Ja,g7 as Qa,aI as Za,g8 as et,ad as rt,g9 as at}from"./index-18f2f740.js";import{s as fa,n as o,t as tt,o as ha,p as ot,q as ma,v as ga,w as ya,x as it,y as Sa,z as pa,A as xr,B as nt,D as lt,E as st,F as xa,G as $a,H as ka,J as dt,K as _a,L as ct,M as bt,N as vt,O as ut,Q as wa,R as ft,S as ht,T as mt,U as gt,V as yt,W as St,e as pt,X as xt}from"./MantineProvider-b20a2267.js";var za=String.raw,Ca=za`
:root,
:host {
--chakra-vh: 100vh;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-de589048.js"></script>
<script type="module" crossorigin src="./assets/index-18f2f740.js"></script>
</head>
<body dir="ltr">

View File

@@ -124,8 +124,7 @@
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images",
"assets": "Assets",
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
"assets": "Assets"
},
"hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts",

View File

@@ -23,7 +23,7 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"typegen": "node scripts/typegen.js",
"typegen": "npx ts-node scripts/typegen.ts",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",

View File

@@ -124,8 +124,7 @@
"deleteImageBin": "Deleted images will be sent to your operating system's Bin.",
"deleteImagePermanent": "Deleted images cannot be restored.",
"images": "Images",
"assets": "Assets",
"autoAssignBoardOnClick": "Auto-Assign Board on Click"
"assets": "Assets"
},
"hotkeys": {
"keyboardShortcuts": "Keyboard Shortcuts",

View File

@@ -4,9 +4,8 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { PartialAppConfig } from 'app/types/invokeai';
import ImageUploader from 'common/components/ImageUploader';
import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import DeleteImageModal from 'features/imageDeletion/components/DeleteImageModal';
import SiteHeader from 'features/system/components/SiteHeader';
import { configChanged } from 'features/system/store/configSlice';
import { languageSelector } from 'features/system/store/systemSelectors';
@@ -17,6 +16,7 @@ import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import i18n from 'i18n';
import { size } from 'lodash-es';
import { ReactNode, memo, useEffect } from 'react';
import UpdateImageBoardModal from '../../features/gallery/components/Boards/UpdateImageBoardModal';
import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
@@ -84,7 +84,7 @@ const App = ({ config = DEFAULT_CONFIG, headerComponent }: Props) => {
</Portal>
</Grid>
<DeleteImageModal />
<ChangeBoardModal />
<UpdateImageBoardModal />
<Toaster />
<GlobalHotkeys />
</>

View File

@@ -58,7 +58,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
);
}
if (props.dragData.payloadType === 'IMAGE_DTOS') {
if (props.dragData.payloadType === 'IMAGE_NAMES') {
return (
<Flex
sx={{
@@ -71,7 +71,7 @@ const DragPreview = (props: OverlayDragImageProps) => {
...STYLES,
}}
>
<Heading>{props.dragData.payload.imageDTOs.length}</Heading>
<Heading>{props.dragData.payload.image_names.length}</Heading>
<Heading size="sm">Images</Heading>
</Flex>
);

View File

@@ -18,32 +18,27 @@ import {
DragStartEvent,
TypesafeDraggableData,
} from './typesafeDnd';
import { logger } from 'app/logging/logger';
type ImageDndContextProps = PropsWithChildren;
const ImageDndContext = (props: ImageDndContextProps) => {
const [activeDragData, setActiveDragData] =
useState<TypesafeDraggableData | null>(null);
const log = logger('images');
const dispatch = useAppDispatch();
const handleDragStart = useCallback(
(event: DragStartEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag started');
const activeData = event.active.data.current;
if (!activeData) {
return;
}
setActiveDragData(activeData);
},
[log]
);
const handleDragStart = useCallback((event: DragStartEvent) => {
console.log('dragStart', event.active.data.current);
const activeData = event.active.data.current;
if (!activeData) {
return;
}
setActiveDragData(activeData);
}, []);
const handleDragEnd = useCallback(
(event: DragEndEvent) => {
log.trace({ dragData: event.active.data.current }, 'Drag ended');
console.log('dragEnd', event.active.data.current);
const overData = event.over?.data.current;
if (!activeDragData || !overData) {
return;
@@ -51,7 +46,7 @@ const ImageDndContext = (props: ImageDndContextProps) => {
dispatch(dndDropped({ overData, activeData: activeDragData }));
setActiveDragData(null);
},
[activeDragData, dispatch, log]
[activeDragData, dispatch]
);
const mouseSensor = useSensor(MouseSensor, {

View File

@@ -11,6 +11,7 @@ import {
useDraggable as useOriginalDraggable,
useDroppable as useOriginalDroppable,
} from '@dnd-kit/core';
import { BoardId } from 'features/gallery/store/types';
import { ImageDTO } from 'services/api/types';
type BaseDropData = {
@@ -53,13 +54,9 @@ export type AddToBatchDropData = BaseDropData & {
actionType: 'ADD_TO_BATCH';
};
export type AddToBoardDropData = BaseDropData & {
actionType: 'ADD_TO_BOARD';
context: { boardId: string };
};
export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
export type MoveBoardDropData = BaseDropData & {
actionType: 'MOVE_BOARD';
context: { boardId: BoardId };
};
export type TypesafeDroppableData =
@@ -70,8 +67,7 @@ export type TypesafeDroppableData =
| NodesImageDropData
| AddToBatchDropData
| NodesMultiImageDropData
| AddToBoardDropData
| RemoveFromBoardDropData;
| MoveBoardDropData;
type BaseDragData = {
id: string;
@@ -82,12 +78,14 @@ export type ImageDraggableData = BaseDragData & {
payload: { imageDTO: ImageDTO };
};
export type ImageDTOsDraggableData = BaseDragData & {
payloadType: 'IMAGE_DTOS';
payload: { imageDTOs: ImageDTO[] };
export type ImageNamesDraggableData = BaseDragData & {
payloadType: 'IMAGE_NAMES';
payload: { image_names: string[] };
};
export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData;
export type TypesafeDraggableData =
| ImageDraggableData
| ImageNamesDraggableData;
interface UseDroppableTypesafeArguments
extends Omit<UseDroppableArguments, 'data'> {
@@ -158,39 +156,14 @@ export const isValidDrop = (
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_MULTI_NODES_IMAGE':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'ADD_TO_BATCH':
return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
case 'ADD_TO_BOARD': {
return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
case 'MOVE_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
if (!isPayloadValid) {
return false;
}
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
if (payloadType === 'IMAGE_DTOS') {
// TODO (multi-select)
return true;
}
return false;
}
case 'REMOVE_FROM_BOARD': {
// If the board is the same, don't allow the drop
// Check the payload types
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS';
const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES';
if (!isPayloadValid) {
return false;
}
@@ -199,16 +172,20 @@ export const isValidDrop = (
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id;
const destinationBoard = overData.context.boardId;
return currentBoard !== 'none';
const isSameBoard = currentBoard === destinationBoard;
const isDestinationValid = !currentBoard ? destinationBoard : true;
return !isSameBoard && isDestinationValid;
}
if (payloadType === 'IMAGE_DTOS') {
if (payloadType === 'IMAGE_NAMES') {
// TODO (multi-select)
return true;
return false;
}
return false;
return true;
}
default:
return false;

View File

@@ -1,6 +1,4 @@
import { Middleware } from '@reduxjs/toolkit';
import { store } from 'app/store/store';
import { PartialAppConfig } from 'app/types/invokeai';
import React, {
lazy,
memo,
@@ -9,11 +7,16 @@ import React, {
useEffect,
} from 'react';
import { Provider } from 'react-redux';
import { PartialAppConfig } from 'app/types/invokeai';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
import { $authToken, $baseUrl, $projectId } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware';
import Loading from '../../common/components/Loading/Loading';
import { Middleware } from '@reduxjs/toolkit';
import { $authToken, $baseUrl } from 'services/api/client';
import { socketMiddleware } from 'services/events/middleware';
import '../../i18n';
import { AddImageToBoardContextProvider } from '../contexts/AddImageToBoardContext';
import ImageDndContext from './ImageDnd/ImageDndContext';
const App = lazy(() => import('./App'));
@@ -34,7 +37,6 @@ const InvokeAIUI = ({
config,
headerComponent,
middleware,
projectId,
}: Props) => {
useEffect(() => {
// configure API client token
@@ -47,11 +49,6 @@ const InvokeAIUI = ({
$baseUrl.set(apiUrl);
}
// configure API client project header
if (projectId) {
$projectId.set(projectId);
}
// reset dynamically added middlewares
resetMiddlewares();
@@ -71,9 +68,8 @@ const InvokeAIUI = ({
// Reset the API client token and base url on unmount
$baseUrl.set(undefined);
$authToken.set(undefined);
$projectId.set(undefined);
};
}, [apiUrl, token, middleware, projectId]);
}, [apiUrl, token, middleware]);
return (
<React.StrictMode>
@@ -81,7 +77,9 @@ const InvokeAIUI = ({
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<ImageDndContext>
<App config={config} headerComponent={headerComponent} />
<AddImageToBoardContextProvider>
<App config={config} headerComponent={headerComponent} />
</AddImageToBoardContextProvider>
</ImageDndContext>
</ThemeLocaleProvider>
</React.Suspense>

View File

@@ -0,0 +1,91 @@
import { useDisclosure } from '@chakra-ui/react';
import { PropsWithChildren, createContext, useCallback, useState } from 'react';
import { ImageDTO } from 'services/api/types';
import { imagesApi } from 'services/api/endpoints/images';
import { useAppDispatch } from '../store/storeHooks';
export type ImageUsage = {
isInitialImage: boolean;
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
};
type AddImageToBoardContextValue = {
/**
* Whether the move image dialog is open.
*/
isOpen: boolean;
/**
* Closes the move image dialog.
*/
onClose: () => void;
/**
* The image pending movement
*/
image?: ImageDTO;
onClickAddToBoard: (image: ImageDTO) => void;
handleAddToBoard: (boardId: string) => void;
};
export const AddImageToBoardContext =
createContext<AddImageToBoardContextValue>({
isOpen: false,
onClose: () => undefined,
onClickAddToBoard: () => undefined,
handleAddToBoard: () => undefined,
});
type Props = PropsWithChildren;
export const AddImageToBoardContextProvider = (props: Props) => {
const [imageToMove, setImageToMove] = useState<ImageDTO>();
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
// Clean up after deleting or dismissing the modal
const closeAndClearImageToDelete = useCallback(() => {
setImageToMove(undefined);
onClose();
}, [onClose]);
const onClickAddToBoard = useCallback(
(image?: ImageDTO) => {
if (!image) {
return;
}
setImageToMove(image);
onOpen();
},
[setImageToMove, onOpen]
);
const handleAddToBoard = useCallback(
(boardId: string) => {
if (imageToMove) {
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO: imageToMove,
board_id: boardId,
})
);
closeAndClearImageToDelete();
}
},
[dispatch, closeAndClearImageToDelete, imageToMove]
);
return (
<AddImageToBoardContext.Provider
value={{
isOpen,
image: imageToMove,
onClose: closeAndClearImageToDelete,
onClickAddToBoard,
handleAddToBoard,
}}
>
{props.children}
</AddImageToBoardContext.Provider>
);
};

View File

@@ -0,0 +1,8 @@
import { createContext } from 'react';
type VoidFunc = () => void;
type ImageUploaderTriggerContextType = VoidFunc | null;
export const ImageUploaderTriggerContext =
createContext<ImageUploaderTriggerContextType>(null);

View File

@@ -23,6 +23,6 @@ const serializationDenylist: {
};
export const serialize: SerializeFunction = (data, key) => {
const result = omit(data, serializationDenylist[key] ?? []);
const result = omit(data, serializationDenylist[key]);
return JSON.stringify(result);
};

View File

@@ -27,8 +27,7 @@ import {
addImageDeletedFulfilledListener,
addImageDeletedPendingListener,
addImageDeletedRejectedListener,
addRequestedSingleImageDeletionListener,
addRequestedMultipleImageDeletionListener,
addRequestedImageDeletionListener,
} from './listeners/imageDeleted';
import { addImageDroppedListener } from './listeners/imageDropped';
import {
@@ -112,8 +111,7 @@ addImageUploadedRejectedListener();
addInitialImageSelectedListener();
// Image deleted
addRequestedSingleImageDeletionListener();
addRequestedMultipleImageDeletionListener();
addRequestedImageDeletionListener();
addImageDeletedPendingListener();
addImageDeletedFulfilledListener();
addImageDeletedRejectedListener();

View File

@@ -1,10 +1,12 @@
import { createAction } from '@reduxjs/toolkit';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import {
ImageCache,
getListImagesUrl,
imagesApi,
} from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { getListImagesUrl, imagesAdapter } from 'services/api/util';
import { ImageCache } from 'services/api/types';
export const appStarted = createAction('app/appStarted');
@@ -32,8 +34,7 @@ export const addFirstListImagesListener = () => {
if (data.ids.length > 0) {
// Select the first image
const firstImage = imagesAdapter.getSelectors().selectAll(data)[0];
dispatch(imageSelected(firstImage ?? null));
dispatch(imageSelected(data.ids[0] as string));
}
},
});

View File

@@ -18,9 +18,7 @@ export const addAppConfigReceivedListener = () => {
const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) {
// if there is no infill method, set it to the first one
// if there is no first one... god help us
dispatch(setInfillMethod(infill_methods[0] as string));
dispatch(setInfillMethod(infill_methods[0]));
}
if (!nsfw_methods.includes('nsfw_checker')) {

View File

@@ -1,14 +1,14 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { getImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { boardsApi } from '../../../../../services/api/endpoints/boards';
export const addDeleteBoardAndImagesFulfilledListener = () => {
startAppListening({
matcher: imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
matcher: boardsApi.endpoints.deleteBoardAndImages.matchFulfilled,
effect: async (action, { dispatch, getState }) => {
const { deleted_images } = action.payload;

View File

@@ -10,7 +10,6 @@ import {
} from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { imagesSelectors } from 'services/api/util';
export const addBoardIdSelectedListener = () => {
startAppListening({
@@ -53,9 +52,8 @@ export const addBoardIdSelectedListener = () => {
queryArgs
)(getState());
if (boardImagesData) {
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
dispatch(imageSelected(firstImage ?? null));
if (boardImagesData?.ids.length) {
dispatch(imageSelected((boardImagesData.ids[0] as string) ?? null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));

View File

@@ -26,8 +26,6 @@ export const addCanvasSavedToGalleryListener = () => {
return;
}
const { autoAddBoardId } = state.gallery;
dispatch(
imagesApi.endpoints.uploadImage.initiate({
file: new File([blob], 'savedCanvas.png', {
@@ -35,7 +33,7 @@ export const addCanvasSavedToGalleryListener = () => {
}),
image_category: 'general',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
board_id: state.gallery.autoAddBoardId,
crop_visible: true,
postUploadAction: {
type: 'TOAST',

View File

@@ -31,20 +31,15 @@ const predicate: AnyListenerPredicate<RootState> = (
// do not process if the user just disabled auto-config
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
?.shouldAutoConfig === true
.shouldAutoConfig === true
) {
return false;
}
}
const cn = state.controlNet.controlNets[action.payload.controlNetId];
const { controlImage, processorType, shouldAutoConfig } =
state.controlNet.controlNets[action.payload.controlNetId];
if (!cn) {
// something is wrong, the controlNet should exist
return false;
}
const { controlImage, processorType, shouldAutoConfig } = cn;
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty
return false;

View File

@@ -17,7 +17,7 @@ export const addControlNetImageProcessedListener = () => {
const { controlNetId } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId];
if (!controlNet?.controlImage) {
if (!controlNet.controlImage) {
log.error('Unable to process ControlNet image');
return;
}

View File

@@ -1,72 +1,57 @@
import { logger } from 'app/logging/logger';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
import { isModalOpenChanged } from 'features/imageDeletion/store/imageDeletionSlice';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { clamp } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..';
export const addRequestedSingleImageDeletionListener = () => {
/**
* Called when the user requests an image deletion
*/
export const addRequestedImageDeletionListener = () => {
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState, condition }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
// handle multiples in separate listener
return;
}
const imageDTO = imageDTOs[0];
const imageUsage = imagesUsage[0];
if (!imageDTO || !imageUsage) {
// satisfy noUncheckedIndexedAccess
return;
}
const { imageDTO, imageUsage } = action.payload;
dispatch(isModalOpenChanged(false));
const { image_name } = imageDTO;
const state = getState();
const lastSelectedImage =
state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const { image_name } = imageDTO;
state.gallery.selection[state.gallery.selection.length - 1];
if (lastSelectedImage === image_name) {
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
const { data } =
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const cachedImageDTOs = data
? imagesAdapter.getSelectors().selectAll(data)
: [];
const ids = data?.ids ?? [];
const deletedImageIndex = cachedImageDTOs.findIndex(
(i) => i.image_name === image_name
const deletedImageIndex = ids.findIndex(
(result) => result.toString() === image_name
);
const filteredImageDTOs = cachedImageDTOs.filter(
(i) => i.image_name !== image_name
);
const filteredIds = ids.filter((id) => id.toString() !== image_name);
const newSelectedImageIndex = clamp(
deletedImageIndex,
0,
filteredImageDTOs.length - 1
filteredIds.length - 1
);
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
const newSelectedImageId = filteredIds[newSelectedImageIndex];
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImageId as string));
} else {
dispatch(imageSelected(null));
}
@@ -112,66 +97,6 @@ export const addRequestedSingleImageDeletionListener = () => {
});
};
/**
* Called when the user requests an image deletion
*/
export const addRequestedMultipleImageDeletionListener = () => {
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length < 1 || imagesUsage.length < 1) {
// handle singles in separate listener
return;
}
try {
// Delete from server
await dispatch(
imagesApi.endpoints.deleteImages.initiate({ imageDTOs })
).unwrap();
const state = getState();
const baseQueryArgs = selectListImagesBaseQueryArgs(state);
const { data } =
imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const newSelectedImageDTO = data
? imagesAdapter.getSelectors().selectAll(data)[0]
: undefined;
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}
dispatch(isModalOpenChanged(false));
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imagesUsage.some((i) => i.isCanvasImage)) {
dispatch(resetCanvas());
}
if (imagesUsage.some((i) => i.isControlNetImage)) {
dispatch(controlNetReset());
}
if (imagesUsage.some((i) => i.isInitialImage)) {
dispatch(clearInitialImage());
}
if (imagesUsage.some((i) => i.isNodesImage)) {
dispatch(nodeEditorReset());
}
} catch {
// no-op
}
},
});
};
/**
* Called when the actual delete request is sent to the server
*/

View File

@@ -6,7 +6,10 @@ import {
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
imagesAddedToBatch,
} from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -24,32 +27,19 @@ export const addImageDroppedListener = () => {
const log = logger('images');
const { activeData, overData } = action.payload;
if (activeData.payloadType === 'IMAGE_DTO') {
log.debug({ activeData, overData }, 'Image dropped');
} else if (activeData.payloadType === 'IMAGE_DTOS') {
log.debug(
{ activeData, overData },
`Images (${activeData.payload.imageDTOs.length}) dropped`
);
} else {
log.debug({ activeData, overData }, `Unknown payload dropped`);
}
log.debug({ activeData, overData }, 'Image or selection dropped');
/**
* Image dropped on current image
*/
// set current image
if (
overData.actionType === 'SET_CURRENT_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO));
dispatch(imageSelected(activeData.payload.imageDTO.image_name));
return;
}
/**
* Image dropped on initial image
*/
// set initial image
if (
overData.actionType === 'SET_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
@@ -59,9 +49,27 @@ export const addImageDroppedListener = () => {
return;
}
/**
* Image dropped on ControlNet
*/
// add image to batch
if (
overData.actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(imagesAddedToBatch([activeData.payload.imageDTO.image_name]));
return;
}
// add multiple images to batch
if (
overData.actionType === 'ADD_TO_BATCH' &&
activeData.payloadType === 'IMAGE_NAMES'
) {
dispatch(imagesAddedToBatch(activeData.payload.image_names));
return;
}
// set control image
if (
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
@@ -77,9 +85,7 @@ export const addImageDroppedListener = () => {
return;
}
/**
* Image dropped on Canvas
*/
// set canvas image
if (
overData.actionType === 'SET_CANVAS_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
@@ -89,9 +95,7 @@ export const addImageDroppedListener = () => {
return;
}
/**
* Image dropped on node image field
*/
// set nodes image
if (
overData.actionType === 'SET_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
@@ -108,36 +112,61 @@ export const addImageDroppedListener = () => {
return;
}
/**
* TODO
* Image selection dropped on node image collection field
*/
// set multiple nodes images (single image handler)
if (
overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { fieldName, nodeId } = overData.context;
dispatch(
fieldValueChanged({
nodeId,
fieldName,
value: [activeData.payload.imageDTO],
})
);
return;
}
// // set multiple nodes images (multiple images handler)
// if (
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
// activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payload.imageDTO
// activeData.payloadType === 'IMAGE_NAMES'
// ) {
// const { fieldName, nodeId } = overData.context;
// dispatch(
// fieldValueChanged({
// imageCollectionFieldValueChanged({
// nodeId,
// fieldName,
// value: [activeData.payload.imageDTO],
// value: activeData.payload.image_names.map((image_name) => ({
// image_name,
// })),
// })
// );
// return;
// }
/**
* Image dropped on user board
*/
// add image to board
if (
overData.actionType === 'ADD_TO_BOARD' &&
overData.actionType === 'MOVE_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
const { boardId } = overData.context;
// image was droppe on the "NoBoardBoard"
if (!boardId) {
dispatch(
imagesApi.endpoints.removeImageFromBoard.initiate({
imageDTO,
})
);
return;
}
// image was dropped on a user board
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
imageDTO,
@@ -147,58 +176,67 @@ export const addImageDroppedListener = () => {
return;
}
/**
* Image dropped on 'none' board
*/
if (
overData.actionType === 'REMOVE_FROM_BOARD' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(
imagesApi.endpoints.removeImageFromBoard.initiate({
imageDTO,
})
);
return;
}
// // add gallery selection to board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId
// ) {
// console.log('adding gallery selection to board');
// const board_id = overData.context.boardId;
// dispatch(
// boardImagesApi.endpoints.addManyBoardImages.initiate({
// board_id,
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
/**
* Multiple images dropped on user board
*/
if (
overData.actionType === 'ADD_TO_BOARD' &&
activeData.payloadType === 'IMAGE_DTOS' &&
activeData.payload.imageDTOs
) {
const { imageDTOs } = activeData.payload;
const { boardId } = overData.context;
dispatch(
imagesApi.endpoints.addImagesToBoard.initiate({
imageDTOs,
board_id: boardId,
})
);
return;
}
// // remove gallery selection from board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId === null
// ) {
// console.log('removing gallery selection to board');
// dispatch(
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
/**
* Multiple images dropped on 'none' board
*/
if (
overData.actionType === 'REMOVE_FROM_BOARD' &&
activeData.payloadType === 'IMAGE_DTOS' &&
activeData.payload.imageDTOs
) {
const { imageDTOs } = activeData.payload;
dispatch(
imagesApi.endpoints.removeImagesFromBoard.initiate({
imageDTOs,
})
);
return;
}
// // add batch selection to board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId
// ) {
// const board_id = overData.context.boardId;
// dispatch(
// boardImagesApi.endpoints.addManyBoardImages.initiate({
// board_id,
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
// // remove batch selection from board
// if (
// overData.actionType === 'MOVE_BOARD' &&
// activeData.payloadType === 'IMAGE_NAMES' &&
// overData.context.boardId === null
// ) {
// dispatch(
// boardImagesApi.endpoints.deleteManyBoardImages.initiate({
// image_names: activeData.payload.image_names,
// })
// );
// return;
// }
},
});
};

View File

@@ -1,32 +1,37 @@
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { selectImageUsage } from 'features/deleteImageModal/store/selectors';
import { imageDeletionConfirmed } from 'features/imageDeletion/store/actions';
import { selectImageUsage } from 'features/imageDeletion/store/imageDeletionSelectors';
import {
imagesToDeleteSelected,
imageToDeleteSelected,
isModalOpenChanged,
} from 'features/deleteImageModal/store/slice';
} from 'features/imageDeletion/store/imageDeletionSlice';
import { startAppListening } from '..';
export const addImageToDeleteSelectedListener = () => {
startAppListening({
actionCreator: imagesToDeleteSelected,
actionCreator: imageToDeleteSelected,
effect: async (action, { dispatch, getState }) => {
const imageDTOs = action.payload;
const imageDTO = action.payload;
const state = getState();
const { shouldConfirmOnDelete } = state.system;
const imagesUsage = selectImageUsage(getState());
const imageUsage = selectImageUsage(getState());
if (!imageUsage) {
// should never happen
return;
}
const isImageInUse =
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isInitialImage) ||
imagesUsage.some((i) => i.isControlNetImage) ||
imagesUsage.some((i) => i.isNodesImage);
imageUsage.isCanvasImage ||
imageUsage.isInitialImage ||
imageUsage.isControlNetImage ||
imageUsage.isNodesImage;
if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true));
return;
}
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage }));
dispatch(imageDeletionConfirmed({ imageDTO, imageUsage }));
},
});
};

View File

@@ -2,13 +2,14 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { imagesAddedToBatch } from 'features/gallery/store/gallerySlice';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { startAppListening } from '..';
import { imagesApi } from '../../../../../services/api/endpoints/images';
import { omit } from 'lodash-es';
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
title: 'Image Uploaded',
@@ -40,7 +41,7 @@ export const addImageUploadedFulfilledListener = () => {
// default action - just upload and alert user
if (postUploadAction?.type === 'TOAST') {
const { toastOptions } = postUploadAction;
if (!autoAddBoardId || autoAddBoardId === 'none') {
if (!autoAddBoardId) {
dispatch(addToast({ ...DEFAULT_UPLOADED_TOAST, ...toastOptions }));
} else {
// Add this image to the board
@@ -120,6 +121,17 @@ export const addImageUploadedFulfilledListener = () => {
);
return;
}
if (postUploadAction?.type === 'ADD_TO_BATCH') {
dispatch(imagesAddedToBatch([imageDTO.image_name]));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: 'Added to batch',
})
);
return;
}
},
});
};

View File

@@ -15,7 +15,7 @@ import {
setShouldUseSDXLRefiner,
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models';
import { modelsApi } from 'services/api/endpoints/models';
import { startAppListening } from '..';
export const addModelsLoadedListener = () => {
@@ -144,9 +144,8 @@ export const addModelsLoadedListener = () => {
return;
}
const firstModel = vaeModelsAdapter
.getSelectors()
.selectAll(action.payload)[0];
const firstModelId = action.payload.ids[0];
const firstModel = action.payload.entities[firstModelId];
if (!firstModel) {
// No custom VAEs loaded at all; use the default

View File

@@ -8,10 +8,9 @@ import {
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { progressImageSet } from 'features/system/store/systemSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter, imagesApi } from 'services/api/endpoints/images';
import { isImageOutput } from 'services/api/guards';
import { sessionCanceled } from 'services/api/thunks/session';
import { imagesAdapter } from 'services/api/util';
import {
appSocketInvocationComplete,
socketInvocationComplete,
@@ -68,7 +67,7 @@ export const addInvocationCompleteEventListener = () => {
*/
const { autoAddBoardId } = gallery;
if (autoAddBoardId && autoAddBoardId !== 'none') {
if (autoAddBoardId) {
dispatch(
imagesApi.endpoints.addImageToBoard.initiate({
board_id: autoAddBoardId,
@@ -84,7 +83,10 @@ export const addInvocationCompleteEventListener = () => {
categories: IMAGE_CATEGORIES,
},
(draft) => {
imagesAdapter.addOne(draft, imageDTO);
const oldTotal = draft.total;
const newState = imagesAdapter.addOne(draft, imageDTO);
const delta = newState.total - oldTotal;
draft.total = draft.total + delta;
}
)
);
@@ -92,8 +94,8 @@ export const addInvocationCompleteEventListener = () => {
dispatch(
imagesApi.util.invalidateTags([
{ type: 'BoardImagesTotal', id: autoAddBoardId },
{ type: 'BoardAssetsTotal', id: autoAddBoardId },
{ type: 'BoardImagesTotal', id: autoAddBoardId ?? 'none' },
{ type: 'BoardAssetsTotal', id: autoAddBoardId ?? 'none' },
])
);
@@ -108,7 +110,7 @@ export const addInvocationCompleteEventListener = () => {
} else if (!autoAddBoardId) {
dispatch(galleryViewChanged('images'));
}
dispatch(imageSelected(imageDTO));
dispatch(imageSelected(imageDTO.image_name));
}
}

View File

@@ -8,9 +8,9 @@ import {
import canvasReducer from 'features/canvas/store/canvasSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import boardsReducer from 'features/gallery/store/boardSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice';
import loraReducer from 'features/lora/store/loraSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import generationReducer from 'features/parameters/store/generationSlice';
@@ -43,9 +43,9 @@ const allReducers = {
ui: uiReducer,
hotkeys: hotkeysReducer,
controlNet: controlNetReducer,
boards: boardsReducer,
dynamicPrompts: dynamicPromptsReducer,
deleteImageModal: deleteImageModalReducer,
changeBoardModal: changeBoardModalReducer,
imageDeletion: imageDeletionReducer,
lora: loraReducer,
modelmanager: modelmanagerReducer,
sdxl: sdxlReducer,

View File

@@ -96,8 +96,7 @@ export type AppFeature =
| 'consoleLogging'
| 'dynamicPrompting'
| 'batches'
| 'syncModels'
| 'multiselect';
| 'syncModels';
/**
* A disable-able Stable Diffusion feature

View File

@@ -1,4 +1,4 @@
import { Box, Flex, useColorMode } from '@chakra-ui/react';
import { Flex, Text, useColorMode } from '@chakra-ui/react';
import { motion } from 'framer-motion';
import { ReactNode, memo, useRef } from 'react';
import { mode } from 'theme/util/mode';
@@ -74,7 +74,7 @@ export const IAIDropOverlay = (props: Props) => {
justifyContent: 'center',
}}
>
<Box
<Text
sx={{
fontSize: '2xl',
fontWeight: 600,
@@ -87,7 +87,7 @@ export const IAIDropOverlay = (props: Props) => {
}}
>
{label}
</Box>
</Text>
</Flex>
</Flex>
</motion.div>

View File

@@ -53,9 +53,7 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => {
// wrap onChange to clear search value on select
const handleChange = useCallback(
(v: string | null) => {
// cannot figure out why we were doing this, but it was causing an issue where if you
// select the currently-selected item, it reset the search value to empty
// setSearchValue('');
setSearchValue('');
if (!onChange) {
return;

View File

@@ -78,7 +78,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
image_category: 'user',
is_intermediate: false,
postUploadAction,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
board_id: autoAddBoardId,
});
},
[autoAddBoardId, postUploadAction, uploadImage]

View File

@@ -49,7 +49,7 @@ export const useImageUploadButton = ({
image_category: 'user',
is_intermediate: false,
postUploadAction: postUploadAction ?? { type: 'TOAST' },
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
board_id: autoAddBoardId,
});
},
[autoAddBoardId, postUploadAction, uploadImage]

View File

@@ -33,10 +33,6 @@ const useColorPicker = () => {
1
).data;
if (!(a && r && g && b)) {
return;
}
dispatch(setColorPickerColor({ r, g, b, a }));
},
commitColorUnderCursor: () => {

View File

@@ -727,13 +727,10 @@ export const canvasSlice = createSlice({
state.pastLayerStates.shift();
}
const imageToCommit = images[selectedImageIndex];
state.layerState.objects.push({
...images[selectedImageIndex],
});
if (imageToCommit) {
state.layerState.objects.push({
...imageToCommit,
});
}
state.layerState.stagingArea = {
...initialLayerState.stagingArea,
};

View File

@@ -1,132 +0,0 @@
import {
AlertDialog,
AlertDialogBody,
AlertDialogContent,
AlertDialogFooter,
AlertDialogHeader,
AlertDialogOverlay,
Flex,
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect';
import { memo, useCallback, useMemo, useRef, useState } from 'react';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import {
useAddImagesToBoardMutation,
useRemoveImagesFromBoardMutation,
} from 'services/api/endpoints/images';
import { changeBoardReset, isModalOpenChanged } from '../store/slice';
const selector = createSelector(
[stateSelector],
({ changeBoardModal }) => {
const { isModalOpen, imagesToChange } = changeBoardModal;
return {
isModalOpen,
imagesToChange,
};
},
defaultSelectorOptions
);
const ChangeBoardModal = () => {
const dispatch = useAppDispatch();
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const { data: boards, isFetching } = useListAllBoardsQuery();
const { imagesToChange, isModalOpen } = useAppSelector(selector);
const [addImagesToBoard] = useAddImagesToBoardMutation();
const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation();
const data = useMemo(() => {
const data: { label: string; value: string }[] = [
{ label: 'Uncategorized', value: 'none' },
];
(boards ?? []).forEach((board) =>
data.push({
label: board.board_name,
value: board.board_id,
})
);
return data;
}, [boards]);
const handleClose = useCallback(() => {
dispatch(changeBoardReset());
dispatch(isModalOpenChanged(false));
}, [dispatch]);
const handleChangeBoard = useCallback(() => {
if (!imagesToChange.length || !selectedBoard) {
return;
}
if (selectedBoard === 'none') {
removeImagesFromBoard({ imageDTOs: imagesToChange });
} else {
addImagesToBoard({
imageDTOs: imagesToChange,
board_id: selectedBoard,
});
}
setSelectedBoard(null);
dispatch(changeBoardReset());
}, [
addImagesToBoard,
dispatch,
imagesToChange,
removeImagesFromBoard,
selectedBoard,
]);
const cancelRef = useRef<HTMLButtonElement>(null);
return (
<AlertDialog
isOpen={isModalOpen}
onClose={handleClose}
leastDestructiveRef={cancelRef}
isCentered
>
<AlertDialogOverlay>
<AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold">
Change Board
</AlertDialogHeader>
<AlertDialogBody>
<Flex sx={{ flexDir: 'column', gap: 4 }}>
<Text>
Moving {`${imagesToChange.length}`} image
{`${imagesToChange.length > 1 ? 's' : ''}`} to board:
</Text>
<IAIMantineSearchableSelect
placeholder={isFetching ? 'Loading...' : 'Select Board'}
disabled={isFetching}
onChange={(v) => setSelectedBoard(v)}
value={selectedBoard}
data={data}
/>
</Flex>
</AlertDialogBody>
<AlertDialogFooter>
<IAIButton ref={cancelRef} onClick={handleClose}>
Cancel
</IAIButton>
<IAIButton colorScheme="accent" onClick={handleChangeBoard} ml={3}>
Move
</IAIButton>
</AlertDialogFooter>
</AlertDialogContent>
</AlertDialogOverlay>
</AlertDialog>
);
};
export default memo(ChangeBoardModal);

View File

@@ -1,6 +0,0 @@
import { ChangeBoardModalState } from './types';
export const initialState: ChangeBoardModalState = {
isModalOpen: false,
imagesToChange: [],
};

View File

@@ -1,25 +0,0 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { ImageDTO } from 'services/api/types';
import { initialState } from './initialState';
const changeBoardModal = createSlice({
name: 'changeBoardModal',
initialState,
reducers: {
isModalOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isModalOpen = action.payload;
},
imagesToChangeSelected: (state, action: PayloadAction<ImageDTO[]>) => {
state.imagesToChange = action.payload;
},
changeBoardReset: (state) => {
state.imagesToChange = [];
state.isModalOpen = false;
},
},
});
export const { isModalOpenChanged, imagesToChangeSelected, changeBoardReset } =
changeBoardModal.actions;
export default changeBoardModal.reducer;

View File

@@ -1,6 +0,0 @@
import { ImageDTO } from 'services/api/types';
export type ChangeBoardModalState = {
isModalOpen: boolean;
imagesToChange: ImageDTO[];
};

Some files were not shown because too many files have changed in this diff Show More