Compare commits
4 Commits
v4.2.9.dev
...
ryan/promp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c09b345ec | ||
|
|
b7a1086325 | ||
|
|
3062fe2752 | ||
|
|
76a65d30cd |
2
.github/workflows/python-checks.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
|||||||
|
|
||||||
- name: install ruff
|
- name: install ruff
|
||||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||||
run: pip install ruff==0.6.0
|
run: pip install ruff
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: ruff check
|
- name: ruff check
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
|||||||
FROM node:20-slim AS web-builder
|
FROM node:20-slim AS web-builder
|
||||||
ENV PNPM_HOME="/pnpm"
|
ENV PNPM_HOME="/pnpm"
|
||||||
ENV PATH="$PNPM_HOME:$PATH"
|
ENV PATH="$PNPM_HOME:$PATH"
|
||||||
RUN corepack use pnpm@8.x
|
|
||||||
RUN corepack enable
|
RUN corepack enable
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|||||||
@@ -1,22 +1,20 @@
|
|||||||
# Invoke in Docker
|
# Invoke in Docker
|
||||||
|
|
||||||
First things first:
|
- Ensure that Docker can use the GPU on your system
|
||||||
|
- This documentation assumes Linux, but should work similarly under Windows with WSL2
|
||||||
- Ensure that Docker can use your [NVIDIA][nvidia docker docs] or [AMD][amd docker docs] GPU.
|
|
||||||
- This document assumes a Linux system, but should work similarly under Windows with WSL2.
|
|
||||||
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
|
- We don't recommend running Invoke in Docker on macOS at this time. It works, but very slowly.
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart :lightning:
|
||||||
|
|
||||||
No `docker compose`, no persistence, single command, using the official images:
|
No `docker compose`, no persistence, just a simple one-liner using the official images:
|
||||||
|
|
||||||
**CUDA (NVIDIA GPU):**
|
**CUDA:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
|
docker run --runtime=nvidia --gpus=all --publish 9090:9090 ghcr.io/invoke-ai/invokeai
|
||||||
```
|
```
|
||||||
|
|
||||||
**ROCm (AMD GPU):**
|
**ROCm:**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
|
docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invoke-ai/invokeai:main-rocm
|
||||||
@@ -24,20 +22,12 @@ docker run --device /dev/kfd --device /dev/dri --publish 9090:9090 ghcr.io/invok
|
|||||||
|
|
||||||
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
|
Open `http://localhost:9090` in your browser once the container finishes booting, install some models, and generate away!
|
||||||
|
|
||||||
### Data persistence
|
> [!TIP]
|
||||||
|
> To persist your data (including downloaded models) outside of the container, add a `--volume/-v` flag to the above command, e.g.: `docker run --volume /some/local/path:/invokeai <...the rest of the command>`
|
||||||
To persist your generated images and downloaded models outside of the container, add a `--volume/-v` flag to the above command, e.g.:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run --volume /some/local/path:/invokeai {...etc...}
|
|
||||||
```
|
|
||||||
|
|
||||||
`/some/local/path/invokeai` will contain all your data.
|
|
||||||
It can *usually* be reused between different installs of Invoke. Tread with caution and read the release notes!
|
|
||||||
|
|
||||||
## Customize the container
|
## Customize the container
|
||||||
|
|
||||||
The included `run.sh` script is a convenience wrapper around `docker compose`. It can be helpful for passing additional build arguments to `docker compose`. Alternatively, the familiar `docker compose` commands work just as well.
|
We ship the `run.sh` script, which is a convenient wrapper around `docker compose` for cases where custom image build args are needed. Alternatively, the familiar `docker compose` commands work just as well.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd docker
|
cd docker
|
||||||
@@ -48,14 +38,11 @@ cp .env.sample .env
|
|||||||
|
|
||||||
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
|
It will take a few minutes to build the image the first time. Once the application starts up, open `http://localhost:9090` in your browser to invoke!
|
||||||
|
|
||||||
>[!TIP]
|
|
||||||
>When using the `run.sh` script, the container will continue running after Ctrl+C. To shut it down, use the `docker compose down` command.
|
|
||||||
|
|
||||||
## Docker setup in detail
|
## Docker setup in detail
|
||||||
|
|
||||||
#### Linux
|
#### Linux
|
||||||
|
|
||||||
1. Ensure buildkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
1. Ensure builkit is enabled in the Docker daemon settings (`/etc/docker/daemon.json`)
|
||||||
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
|
2. Install the `docker compose` plugin using your package manager, or follow a [tutorial](https://docs.docker.com/compose/install/linux/#install-using-the-repository).
|
||||||
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
|
- The deprecated `docker-compose` (hyphenated) CLI probably won't work. Update to a recent version.
|
||||||
3. Ensure docker daemon is able to access the GPU.
|
3. Ensure docker daemon is able to access the GPU.
|
||||||
@@ -111,7 +98,25 @@ GPU_DRIVER=cuda
|
|||||||
|
|
||||||
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
Any environment variables supported by InvokeAI can be set here. See the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
||||||
|
|
||||||
---
|
## Even More Customizing!
|
||||||
|
|
||||||
[nvidia docker docs]: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
|
||||||
[amd docker docs]: https://rocm.docs.amd.com/projects/install-on-linux/en/latest/how-to/docker.html
|
|
||||||
|
### Reconfigure the runtime directory
|
||||||
|
|
||||||
|
Can be used to download additional models from the supported model list
|
||||||
|
|
||||||
|
In conjunction with `INVOKEAI_ROOT` can be also used to initialize a runtime directory
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
command:
|
||||||
|
- invokeai-configure
|
||||||
|
- --yes
|
||||||
|
```
|
||||||
|
|
||||||
|
Or install models:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
command:
|
||||||
|
- invokeai-model-install
|
||||||
|
```
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
set -eu
|
set -eu
|
||||||
|
|
||||||
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
# Ensure we're in the correct folder in case user's CWD is somewhere else
|
||||||
scriptdir=$(dirname $(readlink -f "$0"))
|
scriptdir=$(dirname "$0")
|
||||||
cd "$scriptdir"
|
cd "$scriptdir"
|
||||||
|
|
||||||
. .venv/bin/activate
|
. .venv/bin/activate
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -32,8 +31,6 @@ from invokeai.app.services.session_processor.session_processor_default import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
|
||||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
@@ -66,12 +63,7 @@ class ApiDependencies:
|
|||||||
invoker: Invoker
|
invoker: Invoker
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(
|
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
||||||
config: InvokeAIAppConfig,
|
|
||||||
event_handler_id: int,
|
|
||||||
loop: asyncio.AbstractEventLoop,
|
|
||||||
logger: Logger = logger,
|
|
||||||
) -> None:
|
|
||||||
logger.info(f"InvokeAI version {__version__}")
|
logger.info(f"InvokeAI version {__version__}")
|
||||||
logger.info(f"Root directory = {str(config.root_path)}")
|
logger.info(f"Root directory = {str(config.root_path)}")
|
||||||
|
|
||||||
@@ -82,7 +74,6 @@ class ApiDependencies:
|
|||||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
model_images_folder = config.models_path
|
model_images_folder = config.models_path
|
||||||
style_presets_folder = config.style_presets_path
|
|
||||||
|
|
||||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||||
|
|
||||||
@@ -93,7 +84,7 @@ class ApiDependencies:
|
|||||||
board_images = BoardImagesService()
|
board_images = BoardImagesService()
|
||||||
board_records = SqliteBoardRecordStorage(db=db)
|
board_records = SqliteBoardRecordStorage(db=db)
|
||||||
boards = BoardService()
|
boards = BoardService()
|
||||||
events = FastAPIEventService(event_handler_id, loop=loop)
|
events = FastAPIEventService(event_handler_id)
|
||||||
bulk_download = BulkDownloadService()
|
bulk_download = BulkDownloadService()
|
||||||
image_records = SqliteImageRecordStorage(db=db)
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
images = ImageService()
|
images = ImageService()
|
||||||
@@ -118,8 +109,6 @@ class ApiDependencies:
|
|||||||
session_queue = SqliteSessionQueue(db=db)
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
|
||||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
board_image_records=board_image_records,
|
board_image_records=board_image_records,
|
||||||
@@ -145,8 +134,6 @@ class ApiDependencies:
|
|||||||
workflow_records=workflow_records,
|
workflow_records=workflow_records,
|
||||||
tensors=tensors,
|
tensors=tensors,
|
||||||
conditioning=conditioning,
|
conditioning=conditioning,
|
||||||
style_preset_records=style_preset_records,
|
|
||||||
style_preset_image_files=style_preset_image_files,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
|||||||
@@ -218,8 +218,9 @@ async def get_image_workflow(
|
|||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.api_route(
|
||||||
"/i/{image_name}/full",
|
"/i/{image_name}/full",
|
||||||
|
methods=["GET", "HEAD"],
|
||||||
operation_id="get_image_full",
|
operation_id="get_image_full",
|
||||||
response_class=Response,
|
response_class=Response,
|
||||||
responses={
|
responses={
|
||||||
@@ -230,18 +231,6 @@ async def get_image_workflow(
|
|||||||
404: {"description": "Image not found"},
|
404: {"description": "Image not found"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@images_router.head(
|
|
||||||
"/i/{image_name}/full",
|
|
||||||
operation_id="get_image_full_head",
|
|
||||||
response_class=Response,
|
|
||||||
responses={
|
|
||||||
200: {
|
|
||||||
"description": "Return the full-resolution image",
|
|
||||||
"content": {"image/png": {}},
|
|
||||||
},
|
|
||||||
404: {"description": "Image not found"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def get_image_full(
|
async def get_image_full(
|
||||||
image_name: str = Path(description="The name of full-resolution image file to get"),
|
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
@@ -253,7 +242,6 @@ async def get_image_full(
|
|||||||
content = f.read()
|
content = f.read()
|
||||||
response = Response(content, media_type="image/png")
|
response = Response(content, media_type="image/png")
|
||||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||||
response.headers["Content-Disposition"] = f'inline; filename="{image_name}"'
|
|
||||||
return response
|
return response
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import pathlib
|
|||||||
import traceback
|
import traceback
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import List, Optional, Type
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response, UploadFile
|
from fastapi import Body, Path, Query, Response, UploadFile
|
||||||
from fastapi.responses import FileResponse, HTMLResponse
|
from fastapi.responses import FileResponse, HTMLResponse
|
||||||
@@ -430,11 +430,13 @@ async def delete_model_image(
|
|||||||
async def install_model(
|
async def install_model(
|
||||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||||
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
|
||||||
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
|
# TODO(MM2): Can we type this?
|
||||||
config: ModelRecordChanges = Body(
|
config: Optional[Dict[str, Any]] = Body(
|
||||||
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||||
|
default=None,
|
||||||
example={"name": "string", "description": "string"},
|
example={"name": "string", "description": "string"},
|
||||||
),
|
),
|
||||||
|
access_token: Optional[str] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
"""Install a model using a string identifier.
|
"""Install a model using a string identifier.
|
||||||
|
|
||||||
@@ -449,9 +451,8 @@ async def install_model(
|
|||||||
- model/name:fp16:path/to/model.safetensors
|
- model/name:fp16:path/to/model.safetensors
|
||||||
- model/name::path/to/model.safetensors
|
- model/name::path/to/model.safetensors
|
||||||
|
|
||||||
`config` is a ModelRecordChanges object. Fields in this object will override
|
`config` is an optional dict containing model configuration values that will override
|
||||||
the ones that are probed automatically. Pass an empty object to accept
|
the ones that are probed automatically.
|
||||||
all the defaults.
|
|
||||||
|
|
||||||
`access_token` is an optional access token for use with Urls that require
|
`access_token` is an optional access token for use with Urls that require
|
||||||
authentication.
|
authentication.
|
||||||
@@ -736,7 +737,7 @@ async def convert_model(
|
|||||||
# write the converted file to the convert path
|
# write the converted file to the convert path
|
||||||
raw_model = converted_model.model
|
raw_model = converted_model.model
|
||||||
assert hasattr(raw_model, "save_pretrained")
|
assert hasattr(raw_model, "save_pretrained")
|
||||||
raw_model.save_pretrained(convert_path) # type: ignore
|
raw_model.save_pretrained(convert_path)
|
||||||
assert convert_path.exists()
|
assert convert_path.exists()
|
||||||
|
|
||||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||||
@@ -749,12 +750,12 @@ async def convert_model(
|
|||||||
try:
|
try:
|
||||||
new_key = installer.install_path(
|
new_key = installer.install_path(
|
||||||
convert_path,
|
convert_path,
|
||||||
config=ModelRecordChanges(
|
config={
|
||||||
name=original_name,
|
"name": original_name,
|
||||||
description=model_config.description,
|
"description": model_config.description,
|
||||||
hash=model_config.hash,
|
"hash": model_config.hash,
|
||||||
source=model_config.source,
|
"source": model_config.source,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
CancelByOriginResult,
|
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
PruneResult,
|
PruneResult,
|
||||||
@@ -106,19 +105,6 @@ async def cancel_by_batch_ids(
|
|||||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||||
|
|
||||||
|
|
||||||
@session_queue_router.put(
|
|
||||||
"/{queue_id}/cancel_by_origin",
|
|
||||||
operation_id="cancel_by_origin",
|
|
||||||
responses={200: {"model": CancelByBatchIDsResult}},
|
|
||||||
)
|
|
||||||
async def cancel_by_origin(
|
|
||||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
|
||||||
origin: str = Query(description="The origin to cancel all queue items for"),
|
|
||||||
) -> CancelByOriginResult:
|
|
||||||
"""Immediately cancels all queue items with the given origin"""
|
|
||||||
return ApiDependencies.invoker.services.session_queue.cancel_by_origin(queue_id=queue_id, origin=origin)
|
|
||||||
|
|
||||||
|
|
||||||
@session_queue_router.put(
|
@session_queue_router.put(
|
||||||
"/{queue_id}/clear",
|
"/{queue_id}/clear",
|
||||||
operation_id="clear",
|
operation_id="clear",
|
||||||
|
|||||||
@@ -1,274 +0,0 @@
|
|||||||
import csv
|
|
||||||
import io
|
|
||||||
import json
|
|
||||||
import traceback
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pydantic
|
|
||||||
from fastapi import APIRouter, File, Form, HTTPException, Path, Response, UploadFile
|
|
||||||
from fastapi.responses import FileResponse
|
|
||||||
from PIL import Image
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.api.dependencies import ApiDependencies
|
|
||||||
from invokeai.app.api.routers.model_manager import IMAGE_MAX_AGE
|
|
||||||
from invokeai.app.services.style_preset_images.style_preset_images_common import StylePresetImageFileNotFoundException
|
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
|
||||||
InvalidPresetImportDataError,
|
|
||||||
PresetData,
|
|
||||||
PresetType,
|
|
||||||
StylePresetChanges,
|
|
||||||
StylePresetNotFoundError,
|
|
||||||
StylePresetRecordWithImage,
|
|
||||||
StylePresetWithoutId,
|
|
||||||
UnsupportedFileTypeError,
|
|
||||||
parse_presets_from_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetFormData(BaseModel):
|
|
||||||
name: str = Field(description="Preset name")
|
|
||||||
positive_prompt: str = Field(description="Positive prompt")
|
|
||||||
negative_prompt: str = Field(description="Negative prompt")
|
|
||||||
type: PresetType = Field(description="Preset type")
|
|
||||||
|
|
||||||
|
|
||||||
style_presets_router = APIRouter(prefix="/v1/style_presets", tags=["style_presets"])
|
|
||||||
|
|
||||||
|
|
||||||
@style_presets_router.get(
|
|
||||||
"/i/{style_preset_id}",
|
|
||||||
operation_id="get_style_preset",
|
|
||||||
responses={
|
|
||||||
200: {"model": StylePresetRecordWithImage},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def get_style_preset(
|
|
||||||
style_preset_id: str = Path(description="The style preset to get"),
|
|
||||||
) -> StylePresetRecordWithImage:
|
|
||||||
"""Gets a style preset"""
|
|
||||||
try:
|
|
||||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
|
||||||
style_preset = ApiDependencies.invoker.services.style_preset_records.get(style_preset_id)
|
|
||||||
return StylePresetRecordWithImage(image=image, **style_preset.model_dump())
|
|
||||||
except StylePresetNotFoundError:
|
|
||||||
raise HTTPException(status_code=404, detail="Style preset not found")
|
|
||||||
|
|
||||||
|
|
||||||
@style_presets_router.patch(
|
|
||||||
"/i/{style_preset_id}",
|
|
||||||
operation_id="update_style_preset",
|
|
||||||
responses={
|
|
||||||
200: {"model": StylePresetRecordWithImage},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def update_style_preset(
|
|
||||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
|
||||||
style_preset_id: str = Path(description="The id of the style preset to update"),
|
|
||||||
data: str = Form(description="The data of the style preset to update"),
|
|
||||||
) -> StylePresetRecordWithImage:
|
|
||||||
"""Updates a style preset"""
|
|
||||||
if image is not None:
|
|
||||||
if not image.content_type or not image.content_type.startswith("image"):
|
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
|
||||||
|
|
||||||
contents = await image.read()
|
|
||||||
try:
|
|
||||||
pil_image = Image.open(io.BytesIO(contents))
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
|
||||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
|
||||||
|
|
||||||
try:
|
|
||||||
ApiDependencies.invoker.services.style_preset_image_files.save(style_preset_id, pil_image)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
|
||||||
except StylePresetImageFileNotFoundException:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed_data = json.loads(data)
|
|
||||||
validated_data = StylePresetFormData(**parsed_data)
|
|
||||||
|
|
||||||
name = validated_data.name
|
|
||||||
type = validated_data.type
|
|
||||||
positive_prompt = validated_data.positive_prompt
|
|
||||||
negative_prompt = validated_data.negative_prompt
|
|
||||||
|
|
||||||
except pydantic.ValidationError:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
|
||||||
|
|
||||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
|
||||||
changes = StylePresetChanges(name=name, preset_data=preset_data, type=type)
|
|
||||||
|
|
||||||
style_preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(style_preset_id)
|
|
||||||
style_preset = ApiDependencies.invoker.services.style_preset_records.update(
|
|
||||||
style_preset_id=style_preset_id, changes=changes
|
|
||||||
)
|
|
||||||
return StylePresetRecordWithImage(image=style_preset_image, **style_preset.model_dump())
|
|
||||||
|
|
||||||
|
|
||||||
@style_presets_router.delete(
|
|
||||||
"/i/{style_preset_id}",
|
|
||||||
operation_id="delete_style_preset",
|
|
||||||
)
|
|
||||||
async def delete_style_preset(
|
|
||||||
style_preset_id: str = Path(description="The style preset to delete"),
|
|
||||||
) -> None:
|
|
||||||
"""Deletes a style preset"""
|
|
||||||
try:
|
|
||||||
ApiDependencies.invoker.services.style_preset_image_files.delete(style_preset_id)
|
|
||||||
except StylePresetImageFileNotFoundException:
|
|
||||||
pass
|
|
||||||
|
|
||||||
ApiDependencies.invoker.services.style_preset_records.delete(style_preset_id)
|
|
||||||
|
|
||||||
|
|
||||||
@style_presets_router.post(
|
|
||||||
"/",
|
|
||||||
operation_id="create_style_preset",
|
|
||||||
responses={
|
|
||||||
200: {"model": StylePresetRecordWithImage},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def create_style_preset(
|
|
||||||
image: Optional[UploadFile] = File(description="The image file to upload", default=None),
|
|
||||||
data: str = Form(description="The data of the style preset to create"),
|
|
||||||
) -> StylePresetRecordWithImage:
|
|
||||||
"""Creates a style preset"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed_data = json.loads(data)
|
|
||||||
validated_data = StylePresetFormData(**parsed_data)
|
|
||||||
|
|
||||||
name = validated_data.name
|
|
||||||
type = validated_data.type
|
|
||||||
positive_prompt = validated_data.positive_prompt
|
|
||||||
negative_prompt = validated_data.negative_prompt
|
|
||||||
|
|
||||||
except pydantic.ValidationError:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid preset data")
|
|
||||||
|
|
||||||
preset_data = PresetData(positive_prompt=positive_prompt, negative_prompt=negative_prompt)
|
|
||||||
style_preset = StylePresetWithoutId(name=name, preset_data=preset_data, type=type)
|
|
||||||
new_style_preset = ApiDependencies.invoker.services.style_preset_records.create(style_preset=style_preset)
|
|
||||||
|
|
||||||
if image is not None:
|
|
||||||
if not image.content_type or not image.content_type.startswith("image"):
|
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
|
||||||
|
|
||||||
contents = await image.read()
|
|
||||||
try:
|
|
||||||
pil_image = Image.open(io.BytesIO(contents))
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
|
||||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
|
||||||
|
|
||||||
try:
|
|
||||||
ApiDependencies.invoker.services.style_preset_image_files.save(new_style_preset.id, pil_image)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
|
|
||||||
preset_image = ApiDependencies.invoker.services.style_preset_image_files.get_url(new_style_preset.id)
|
|
||||||
return StylePresetRecordWithImage(image=preset_image, **new_style_preset.model_dump())
|
|
||||||
|
|
||||||
|
|
||||||
@style_presets_router.get(
|
|
||||||
"/",
|
|
||||||
operation_id="list_style_presets",
|
|
||||||
responses={
|
|
||||||
200: {"model": list[StylePresetRecordWithImage]},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def list_style_presets() -> list[StylePresetRecordWithImage]:
|
|
||||||
"""Gets a page of style presets"""
|
|
||||||
style_presets_with_image: list[StylePresetRecordWithImage] = []
|
|
||||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many()
|
|
||||||
for preset in style_presets:
|
|
||||||
image = ApiDependencies.invoker.services.style_preset_image_files.get_url(preset.id)
|
|
||||||
style_preset_with_image = StylePresetRecordWithImage(image=image, **preset.model_dump())
|
|
||||||
style_presets_with_image.append(style_preset_with_image)
|
|
||||||
|
|
||||||
return style_presets_with_image
|
|
||||||
|
|
||||||
|
|
||||||
@style_presets_router.get(
|
|
||||||
"/i/{style_preset_id}/image",
|
|
||||||
operation_id="get_style_preset_image",
|
|
||||||
responses={
|
|
||||||
200: {
|
|
||||||
"description": "The style preset image was fetched successfully",
|
|
||||||
},
|
|
||||||
400: {"description": "Bad request"},
|
|
||||||
404: {"description": "The style preset image could not be found"},
|
|
||||||
},
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def get_style_preset_image(
|
|
||||||
style_preset_id: str = Path(description="The id of the style preset image to get"),
|
|
||||||
) -> FileResponse:
|
|
||||||
"""Gets an image file that previews the model"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
path = ApiDependencies.invoker.services.style_preset_image_files.get_path(style_preset_id)
|
|
||||||
|
|
||||||
response = FileResponse(
|
|
||||||
path,
|
|
||||||
media_type="image/png",
|
|
||||||
filename=style_preset_id + ".png",
|
|
||||||
content_disposition_type="inline",
|
|
||||||
)
|
|
||||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
|
||||||
return response
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(status_code=404)
|
|
||||||
|
|
||||||
|
|
||||||
@style_presets_router.get(
|
|
||||||
"/export",
|
|
||||||
operation_id="export_style_presets",
|
|
||||||
responses={200: {"content": {"text/csv": {}}, "description": "A CSV file with the requested data."}},
|
|
||||||
status_code=200,
|
|
||||||
)
|
|
||||||
async def export_style_presets():
|
|
||||||
# Create an in-memory stream to store the CSV data
|
|
||||||
output = io.StringIO()
|
|
||||||
writer = csv.writer(output)
|
|
||||||
|
|
||||||
# Write the header
|
|
||||||
writer.writerow(["name", "prompt", "negative_prompt"])
|
|
||||||
|
|
||||||
style_presets = ApiDependencies.invoker.services.style_preset_records.get_many(type=PresetType.User)
|
|
||||||
|
|
||||||
for preset in style_presets:
|
|
||||||
writer.writerow([preset.name, preset.preset_data.positive_prompt, preset.preset_data.negative_prompt])
|
|
||||||
|
|
||||||
csv_data = output.getvalue()
|
|
||||||
output.close()
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
content=csv_data,
|
|
||||||
media_type="text/csv",
|
|
||||||
headers={"Content-Disposition": "attachment; filename=prompt_templates.csv"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@style_presets_router.post(
|
|
||||||
"/import",
|
|
||||||
operation_id="import_style_presets",
|
|
||||||
)
|
|
||||||
async def import_style_presets(file: UploadFile = File(description="The file to import")):
|
|
||||||
try:
|
|
||||||
style_presets = await parse_presets_from_file(file)
|
|
||||||
ApiDependencies.invoker.services.style_preset_records.create_many(style_presets)
|
|
||||||
except InvalidPresetImportDataError as e:
|
|
||||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
except UnsupportedFileTypeError as e:
|
|
||||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
|
||||||
raise HTTPException(status_code=415, detail=str(e))
|
|
||||||
@@ -30,7 +30,6 @@ from invokeai.app.api.routers import (
|
|||||||
images,
|
images,
|
||||||
model_manager,
|
model_manager,
|
||||||
session_queue,
|
session_queue,
|
||||||
style_presets,
|
|
||||||
utilities,
|
utilities,
|
||||||
workflows,
|
workflows,
|
||||||
)
|
)
|
||||||
@@ -56,13 +55,11 @@ mimetypes.add_type("text/css", ".css")
|
|||||||
torch_device_name = TorchDevice.get_torch_device_name()
|
torch_device_name = TorchDevice.get_torch_device_name()
|
||||||
logger.info(f"Using torch device: {torch_device_name}")
|
logger.info(f"Using torch device: {torch_device_name}")
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
|
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||||
yield
|
yield
|
||||||
# Shut down threads
|
# Shut down threads
|
||||||
ApiDependencies.shutdown()
|
ApiDependencies.shutdown()
|
||||||
@@ -109,7 +106,6 @@ app.include_router(board_images.board_images_router, prefix="/api")
|
|||||||
app.include_router(app_info.app_router, prefix="/api")
|
app.include_router(app_info.app_router, prefix="/api")
|
||||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||||
app.include_router(workflows.workflows_router, prefix="/api")
|
app.include_router(workflows.workflows_router, prefix="/api")
|
||||||
app.include_router(style_presets.style_presets_router, prefix="/api")
|
|
||||||
|
|
||||||
app.openapi = get_openapi_func(app)
|
app.openapi = get_openapi_func(app)
|
||||||
|
|
||||||
@@ -188,6 +184,8 @@ def invoke_api() -> None:
|
|||||||
|
|
||||||
check_cudnn(logger)
|
check_cudnn(logger)
|
||||||
|
|
||||||
|
# Start our own event loop for eventing usage
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
config = uvicorn.Config(
|
config = uvicorn.Config(
|
||||||
app=app,
|
app=app,
|
||||||
host=app_config.host,
|
host=app_config.host,
|
||||||
|
|||||||
@@ -80,12 +80,12 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(
|
ModelPatcher.apply_lora_text_encoder(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
cached_weights=cached_weights,
|
model_state_dict=model_state_dict,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||||
@@ -175,13 +175,13 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
|
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(
|
ModelPatcher.apply_lora(
|
||||||
text_encoder,
|
text_encoder,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
prefix=lora_prefix,
|
prefix=lora_prefix,
|
||||||
cached_weights=cached_weights,
|
model_state_dict=state_dict,
|
||||||
),
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||||
|
|||||||
@@ -21,8 +21,6 @@ from controlnet_aux import (
|
|||||||
from controlnet_aux.util import HWC3, ade_palette
|
from controlnet_aux.util import HWC3, ade_palette
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from transformers import pipeline
|
|
||||||
from transformers.pipelines import DepthEstimationPipeline
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@@ -46,12 +44,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||||
from invokeai.backend.image_util.canny import get_canny_edges
|
from invokeai.backend.image_util.canny import get_canny_edges
|
||||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||||
from invokeai.backend.image_util.hed import HEDProcessor
|
from invokeai.backend.image_util.hed import HEDProcessor
|
||||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
@@ -593,14 +592,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
return color_map
|
return color_map
|
||||||
|
|
||||||
|
|
||||||
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
|
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||||
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
|
|
||||||
DEPTH_ANYTHING_MODELS = {
|
|
||||||
"large": "LiheYoung/depth-anything-large-hf",
|
|
||||||
"base": "LiheYoung/depth-anything-base-hf",
|
|
||||||
"small": "LiheYoung/depth-anything-small-hf",
|
|
||||||
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@@ -608,33 +600,28 @@ DEPTH_ANYTHING_MODELS = {
|
|||||||
title="Depth Anything Processor",
|
title="Depth Anything Processor",
|
||||||
tags=["controlnet", "depth", "depth anything"],
|
tags=["controlnet", "depth", "depth anything"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.1.3",
|
version="1.1.2",
|
||||||
)
|
)
|
||||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||||
|
|
||||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||||
default="small_v2", description="The size of the depth model to use"
|
default="small", description="The size of the depth model to use"
|
||||||
)
|
)
|
||||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
def load_depth_anything(model_path: Path):
|
def loader(model_path: Path):
|
||||||
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
return DepthAnythingDetector.load_model(
|
||||||
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||||
return DepthAnythingPipeline(depth_anything_pipeline)
|
)
|
||||||
|
|
||||||
with self._context.models.load_remote_model(
|
with self._context.models.load_remote_model(
|
||||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||||
) as depth_anything_detector:
|
) as model:
|
||||||
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||||
depth_map = depth_anything_detector.generate_depth(image)
|
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||||
|
return processed_image
|
||||||
# Resizing to user target specified size
|
|
||||||
new_height = int(image.size[1] * (self.resolution / image.size[0]))
|
|
||||||
depth_map = depth_map.resize((self.resolution, new_height))
|
|
||||||
|
|
||||||
return depth_map
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
|||||||
title="Create Gradient Mask",
|
title="Create Gradient Mask",
|
||||||
tags=["mask", "denoise"],
|
tags=["mask", "denoise"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.2.0",
|
version="1.1.0",
|
||||||
)
|
)
|
||||||
class CreateGradientMaskInvocation(BaseInvocation):
|
class CreateGradientMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
@@ -93,7 +93,6 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
# redistribute blur so that the original edges are 0 and blur outwards to 1
|
||||||
blur_tensor = (blur_tensor - 0.5) * 2
|
blur_tensor = (blur_tensor - 0.5) * 2
|
||||||
blur_tensor[blur_tensor < 0] = 0.0
|
|
||||||
|
|
||||||
threshold = 1 - self.minimum_denoise
|
threshold = 1 - self.minimum_denoise
|
||||||
|
|
||||||
|
|||||||
@@ -37,9 +37,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
@@ -58,15 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
|
||||||
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
|
||||||
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
|
|
||||||
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
|
||||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
@@ -471,65 +463,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return controlnet_data
|
return controlnet_data
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_controlnet_field(
|
|
||||||
exit_stack: ExitStack,
|
|
||||||
context: InvocationContext,
|
|
||||||
control_input: ControlField | list[ControlField] | None,
|
|
||||||
ext_manager: ExtensionsManager,
|
|
||||||
) -> None:
|
|
||||||
# Normalize control_input to a list.
|
|
||||||
control_list: list[ControlField]
|
|
||||||
if isinstance(control_input, ControlField):
|
|
||||||
control_list = [control_input]
|
|
||||||
elif isinstance(control_input, list):
|
|
||||||
control_list = control_input
|
|
||||||
elif control_input is None:
|
|
||||||
control_list = []
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
|
||||||
|
|
||||||
for control_info in control_list:
|
|
||||||
model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
|
||||||
ext_manager.add_extension(
|
|
||||||
ControlNetExt(
|
|
||||||
model=model,
|
|
||||||
image=context.images.get_pil(control_info.image.image_name),
|
|
||||||
weight=control_info.control_weight,
|
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
|
||||||
end_step_percent=control_info.end_step_percent,
|
|
||||||
control_mode=control_info.control_mode,
|
|
||||||
resize_mode=control_info.resize_mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_t2i_adapter_field(
|
|
||||||
exit_stack: ExitStack,
|
|
||||||
context: InvocationContext,
|
|
||||||
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
|
||||||
ext_manager: ExtensionsManager,
|
|
||||||
) -> None:
|
|
||||||
if t2i_adapters is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
|
|
||||||
if isinstance(t2i_adapters, T2IAdapterField):
|
|
||||||
t2i_adapters = [t2i_adapters]
|
|
||||||
|
|
||||||
for t2i_adapter_field in t2i_adapters:
|
|
||||||
ext_manager.add_extension(
|
|
||||||
T2IAdapterExt(
|
|
||||||
node_context=context,
|
|
||||||
model_id=t2i_adapter_field.t2i_adapter_model,
|
|
||||||
image=context.images.get_pil(t2i_adapter_field.image.image_name),
|
|
||||||
weight=t2i_adapter_field.weight,
|
|
||||||
begin_step_percent=t2i_adapter_field.begin_step_percent,
|
|
||||||
end_step_percent=t2i_adapter_field.end_step_percent,
|
|
||||||
resize_mode=t2i_adapter_field.resize_mode,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def prep_ip_adapter_image_prompts(
|
def prep_ip_adapter_image_prompts(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@@ -739,7 +672,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
masked_latents = torch.where(mask < 0.5, 0.0, latents)
|
||||||
|
|
||||||
return mask, masked_latents, self.denoise_mask.gradient
|
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_noise_and_latents(
|
def prepare_noise_and_latents(
|
||||||
@@ -797,6 +730,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
dtype = TorchDevice.choose_torch_dtype()
|
dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
if noise is not None:
|
||||||
|
noise = noise.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
|
||||||
conditioning_data = self.get_conditioning_data(
|
conditioning_data = self.get_conditioning_data(
|
||||||
@@ -829,52 +766,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to sd_step_callback()
|
|
||||||
unet_config = context.models.get_config(self.unet.unet.key)
|
|
||||||
|
|
||||||
### preview
|
|
||||||
def step_callback(state: PipelineIntermediateState) -> None:
|
|
||||||
context.util.sd_step_callback(state, unet_config.base)
|
|
||||||
|
|
||||||
ext_manager.add_extension(PreviewExt(step_callback))
|
|
||||||
|
|
||||||
### cfg rescale
|
|
||||||
if self.cfg_rescale_multiplier > 0:
|
|
||||||
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
|
|
||||||
|
|
||||||
### freeu
|
|
||||||
if self.unet.freeu_config:
|
|
||||||
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
|
|
||||||
|
|
||||||
### lora
|
|
||||||
if self.unet.loras:
|
|
||||||
for lora_field in self.unet.loras:
|
|
||||||
ext_manager.add_extension(
|
|
||||||
LoRAExt(
|
|
||||||
node_context=context,
|
|
||||||
model_id=lora_field.lora,
|
|
||||||
weight=lora_field.weight,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
### seamless
|
|
||||||
if self.unet.seamless_axes:
|
|
||||||
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
|
|
||||||
|
|
||||||
### inpaint
|
|
||||||
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
|
|
||||||
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
|
|
||||||
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
|
|
||||||
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
|
|
||||||
# prevalent, we will have to revisit how we initialize the inpainting extensions.
|
|
||||||
if unet_config.variant == ModelVariantType.Inpaint:
|
|
||||||
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
|
|
||||||
elif mask is not None:
|
|
||||||
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
|
|
||||||
|
|
||||||
# Initialize context for modular denoise
|
|
||||||
latents = latents.to(device=device, dtype=dtype)
|
|
||||||
if noise is not None:
|
|
||||||
noise = noise.to(device=device, dtype=dtype)
|
|
||||||
denoise_ctx = DenoiseContext(
|
denoise_ctx = DenoiseContext(
|
||||||
inputs=DenoiseInputs(
|
inputs=DenoiseInputs(
|
||||||
orig_latents=latents,
|
orig_latents=latents,
|
||||||
@@ -890,31 +781,31 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# context for loading additional models
|
# get the unet's config so that we can pass the base to sd_step_callback()
|
||||||
with ExitStack() as exit_stack:
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
# later should be smth like:
|
|
||||||
# for extension_field in self.extensions:
|
|
||||||
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
|
|
||||||
# ext_manager.add_extension(ext)
|
|
||||||
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
|
|
||||||
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
|
|
||||||
|
|
||||||
# ext: t2i/ip adapter
|
### preview
|
||||||
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
unet_info = context.models.load(self.unet.unet)
|
ext_manager.add_extension(PreviewExt(step_callback))
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
|
||||||
with (
|
# ext: t2i/ip adapter
|
||||||
unet_info.model_on_device() as (cached_weights, unet),
|
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
|
||||||
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
|
||||||
# ext: controlnet
|
unet_info = context.models.load(self.unet.unet)
|
||||||
ext_manager.patch_extensions(denoise_ctx),
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
# ext: freeu, seamless, ip adapter, lora
|
with (
|
||||||
ext_manager.patch_unet(unet, cached_weights),
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
):
|
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
|
||||||
sd_backend = StableDiffusionBackend(unet, scheduler)
|
# ext: controlnet
|
||||||
denoise_ctx.unet = unet
|
ext_manager.patch_extensions(unet),
|
||||||
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
# ext: freeu, seamless, ip adapter, lora
|
||||||
|
ext_manager.patch_unet(model_state_dict, unet),
|
||||||
|
):
|
||||||
|
sd_backend = StableDiffusionBackend(unet, scheduler)
|
||||||
|
denoise_ctx.unet = unet
|
||||||
|
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.detach().to("cpu")
|
result_latents = result_latents.detach().to("cpu")
|
||||||
@@ -929,10 +820,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||||
|
|
||||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||||
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
|
|
||||||
# We invert the mask here for compatibility with the old backend implementation.
|
|
||||||
if mask is not None:
|
|
||||||
mask = 1 - mask
|
|
||||||
|
|
||||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||||
# below. Investigate whether this is appropriate.
|
# below. Investigate whether this is appropriate.
|
||||||
@@ -975,14 +862,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info.model_on_device() as (cached_weights, unet),
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(
|
ModelPatcher.apply_lora_unet(
|
||||||
unet,
|
unet,
|
||||||
loras=_lora_loader(),
|
loras=_lora_loader(),
|
||||||
cached_weights=cached_weights,
|
model_state_dict=model_state_dict,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional, Tuple
|
from typing import Any, Callable, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||||
from pydantic.fields import _Unset
|
from pydantic.fields import _Unset
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
@@ -242,31 +242,6 @@ class ConditioningField(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BoundingBoxField(BaseModel):
|
|
||||||
"""A bounding box primitive value."""
|
|
||||||
|
|
||||||
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
|
|
||||||
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
|
|
||||||
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
|
|
||||||
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
|
|
||||||
|
|
||||||
score: Optional[float] = Field(
|
|
||||||
default=None,
|
|
||||||
ge=0.0,
|
|
||||||
le=1.0,
|
|
||||||
description="The score associated with the bounding box. In the range [0, 1]. This value is typically set "
|
|
||||||
"when the bounding box was produced by a detector and has an associated confidence score.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def check_coords(self):
|
|
||||||
if self.x_min > self.x_max:
|
|
||||||
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
|
|
||||||
if self.y_min > self.y_max:
|
|
||||||
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel[dict[str, Any]]):
|
class MetadataField(RootModel[dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||||
|
|||||||
@@ -1,100 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import pipeline
|
|
||||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|
||||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField
|
|
||||||
from invokeai.app.invocations.primitives import BoundingBoxCollectionOutput
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
||||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
|
||||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
|
||||||
|
|
||||||
GroundingDinoModelKey = Literal["grounding-dino-tiny", "grounding-dino-base"]
|
|
||||||
GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = {
|
|
||||||
"grounding-dino-tiny": "IDEA-Research/grounding-dino-tiny",
|
|
||||||
"grounding-dino-base": "IDEA-Research/grounding-dino-base",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"grounding_dino",
|
|
||||||
title="Grounding DINO (Text Prompt Object Detection)",
|
|
||||||
tags=["prompt", "object detection"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class GroundingDinoInvocation(BaseInvocation):
|
|
||||||
"""Runs a Grounding DINO model. Performs zero-shot bounding-box object detection from a text prompt."""
|
|
||||||
|
|
||||||
# Reference:
|
|
||||||
# - https://arxiv.org/pdf/2303.05499
|
|
||||||
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
|
||||||
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
|
||||||
|
|
||||||
model: GroundingDinoModelKey = InputField(description="The Grounding DINO model to use.")
|
|
||||||
prompt: str = InputField(description="The prompt describing the object to segment.")
|
|
||||||
image: ImageField = InputField(description="The image to segment.")
|
|
||||||
detection_threshold: float = InputField(
|
|
||||||
description="The detection threshold for the Grounding DINO model. All detected bounding boxes with scores above this threshold will be returned.",
|
|
||||||
ge=0.0,
|
|
||||||
le=1.0,
|
|
||||||
default=0.3,
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def invoke(self, context: InvocationContext) -> BoundingBoxCollectionOutput:
|
|
||||||
# The model expects a 3-channel RGB image.
|
|
||||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
|
||||||
|
|
||||||
detections = self._detect(
|
|
||||||
context=context, image=image_pil, labels=[self.prompt], threshold=self.detection_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert detections to BoundingBoxCollectionOutput.
|
|
||||||
bounding_boxes: list[BoundingBoxField] = []
|
|
||||||
for detection in detections:
|
|
||||||
bounding_boxes.append(
|
|
||||||
BoundingBoxField(
|
|
||||||
x_min=detection.box.xmin,
|
|
||||||
x_max=detection.box.xmax,
|
|
||||||
y_min=detection.box.ymin,
|
|
||||||
y_max=detection.box.ymax,
|
|
||||||
score=detection.score,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return BoundingBoxCollectionOutput(collection=bounding_boxes)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_grounding_dino(model_path: Path):
|
|
||||||
grounding_dino_pipeline = pipeline(
|
|
||||||
model=str(model_path),
|
|
||||||
task="zero-shot-object-detection",
|
|
||||||
local_files_only=True,
|
|
||||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
|
||||||
# model, and figure out how to make it work in the pipeline.
|
|
||||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
|
||||||
)
|
|
||||||
assert isinstance(grounding_dino_pipeline, ZeroShotObjectDetectionPipeline)
|
|
||||||
return GroundingDinoPipeline(grounding_dino_pipeline)
|
|
||||||
|
|
||||||
def _detect(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
image: Image.Image,
|
|
||||||
labels: list[str],
|
|
||||||
threshold: float = 0.3,
|
|
||||||
) -> list[DetectionResult]:
|
|
||||||
"""Use Grounding DINO to detect bounding boxes for a set of labels in an image."""
|
|
||||||
# TODO(ryand): I copied this "."-handling logic from the transformers example code. Test it and see if it
|
|
||||||
# actually makes a difference.
|
|
||||||
labels = [label if label.endswith(".") else label + "." for label in labels]
|
|
||||||
|
|
||||||
with context.models.load_remote_model(
|
|
||||||
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
|
|
||||||
) as detector:
|
|
||||||
assert isinstance(detector, GroundingDinoPipeline)
|
|
||||||
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)
|
|
||||||
@@ -6,19 +6,13 @@ import cv2
|
|||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
from PIL import Image, ImageChops, ImageFilter, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
BaseInvocation,
|
|
||||||
Classification,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
ColorField,
|
ColorField,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
ImageField,
|
ImageField,
|
||||||
InputField,
|
InputField,
|
||||||
OutputField,
|
|
||||||
WithBoard,
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
)
|
)
|
||||||
@@ -1013,62 +1007,3 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
image_dto = context.images.save(image=mask, image_category=ImageCategory.MASK)
|
||||||
|
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("canvas_v2_mask_and_crop_output")
|
|
||||||
class CanvasV2MaskAndCropOutput(ImageOutput):
|
|
||||||
offset_x: int = OutputField(description="The x offset of the image, after cropping")
|
|
||||||
offset_y: int = OutputField(description="The y offset of the image, after cropping")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"canvas_v2_mask_and_crop",
|
|
||||||
title="Canvas V2 Mask and Crop",
|
|
||||||
tags=["image", "mask", "id"],
|
|
||||||
category="image",
|
|
||||||
version="1.0.0",
|
|
||||||
classification=Classification.Prototype,
|
|
||||||
)
|
|
||||||
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
||||||
"""Handles Canvas V2 image output masking and cropping"""
|
|
||||||
|
|
||||||
source_image: ImageField | None = InputField(
|
|
||||||
default=None,
|
|
||||||
description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.",
|
|
||||||
)
|
|
||||||
generated_image: ImageField = InputField(description="The image to apply the mask to")
|
|
||||||
mask: ImageField = InputField(description="The mask to apply")
|
|
||||||
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
|
|
||||||
|
|
||||||
def _prepare_mask(self, mask: Image.Image) -> Image.Image:
|
|
||||||
mask_array = numpy.array(mask)
|
|
||||||
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
|
|
||||||
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
|
|
||||||
dilated_mask = Image.fromarray(dilated_mask_array)
|
|
||||||
if self.mask_blur > 0:
|
|
||||||
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
|
|
||||||
return ImageOps.invert(mask.convert("L"))
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput:
|
|
||||||
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
|
|
||||||
|
|
||||||
if self.source_image:
|
|
||||||
generated_image = context.images.get_pil(self.generated_image.image_name)
|
|
||||||
source_image = context.images.get_pil(self.source_image.image_name)
|
|
||||||
source_image.paste(generated_image, (0, 0), mask)
|
|
||||||
image_dto = context.images.save(image=source_image)
|
|
||||||
else:
|
|
||||||
generated_image = context.images.get_pil(self.generated_image.image_name)
|
|
||||||
generated_image.putalpha(mask)
|
|
||||||
image_dto = context.images.save(image=generated_image)
|
|
||||||
|
|
||||||
# bbox = image.getbbox()
|
|
||||||
# image = image.crop(bbox)
|
|
||||||
|
|
||||||
return CanvasV2MaskAndCropOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
offset_x=0,
|
|
||||||
offset_y=0,
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
from invokeai.app.invocations.model import VAEField
|
from invokeai.app.invocations.model import VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
|
from invokeai.backend.stable_diffusion import set_seamless
|
||||||
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
vae_info = context.models.load(self.vae.vae)
|
vae_info = context.models.load(self.vae.vae)
|
||||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
if self.fp32:
|
if self.fp32:
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, InvocationContext, invocation
|
||||||
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithBoard, WithMetadata
|
from invokeai.app.invocations.fields import ImageField, InputField, TensorField, WithMetadata
|
||||||
from invokeai.app.invocations.primitives import ImageOutput, MaskOutput
|
from invokeai.app.invocations.primitives import MaskOutput
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@@ -119,27 +118,3 @@ class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata):
|
|||||||
height=mask.shape[1],
|
height=mask.shape[1],
|
||||||
width=mask.shape[2],
|
width=mask.shape[2],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"tensor_mask_to_image",
|
|
||||||
title="Tensor Mask to Image",
|
|
||||||
tags=["mask"],
|
|
||||||
category="mask",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class MaskTensorToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
||||||
"""Convert a mask tensor to an image."""
|
|
||||||
|
|
||||||
mask: TensorField = InputField(description="The mask tensor to convert.")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
mask = context.tensors.load(self.mask.tensor_name)
|
|
||||||
# Ensure that the mask is binary.
|
|
||||||
if mask.dtype != torch.bool:
|
|
||||||
mask = mask > 0.5
|
|
||||||
mask_np = (mask.float() * 255).byte().cpu().numpy()
|
|
||||||
|
|
||||||
mask_pil = Image.fromarray(mask_np, mode="L")
|
|
||||||
image_dto = context.images.save(image=mask_pil)
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import torch
|
|||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
BoundingBoxField,
|
|
||||||
ColorField,
|
ColorField,
|
||||||
ConditioningField,
|
ConditioningField,
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
@@ -470,42 +469,3 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region BoundingBox
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("bounding_box_output")
|
|
||||||
class BoundingBoxOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for nodes that output a single bounding box"""
|
|
||||||
|
|
||||||
bounding_box: BoundingBoxField = OutputField(description="The output bounding box.")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("bounding_box_collection_output")
|
|
||||||
class BoundingBoxCollectionOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for nodes that output a collection of bounding boxes"""
|
|
||||||
|
|
||||||
collection: list[BoundingBoxField] = OutputField(description="The output bounding boxes.", title="Bounding Boxes")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"bounding_box",
|
|
||||||
title="Bounding Box",
|
|
||||||
tags=["primitives", "segmentation", "collection", "bounding box"],
|
|
||||||
category="primitives",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class BoundingBoxInvocation(BaseInvocation):
|
|
||||||
"""Create a bounding box manually by supplying box coordinates"""
|
|
||||||
|
|
||||||
x_min: int = InputField(default=0, description="x-coordinate of the bounding box's top left vertex")
|
|
||||||
y_min: int = InputField(default=0, description="y-coordinate of the bounding box's top left vertex")
|
|
||||||
x_max: int = InputField(default=0, description="x-coordinate of the bounding box's bottom right vertex")
|
|
||||||
y_max: int = InputField(default=0, description="y-coordinate of the bounding box's bottom right vertex")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> BoundingBoxOutput:
|
|
||||||
bounding_box = BoundingBoxField(x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
|
|
||||||
return BoundingBoxOutput(bounding_box=bounding_box)
|
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
|
||||||
|
|||||||
@@ -1,161 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
|
||||||
from transformers.models.sam import SamModel
|
|
||||||
from transformers.models.sam.processing_sam import SamProcessor
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|
||||||
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
|
|
||||||
from invokeai.app.invocations.primitives import MaskOutput
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
||||||
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
|
|
||||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
|
||||||
|
|
||||||
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
|
|
||||||
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
|
|
||||||
"segment-anything-base": "facebook/sam-vit-base",
|
|
||||||
"segment-anything-large": "facebook/sam-vit-large",
|
|
||||||
"segment-anything-huge": "facebook/sam-vit-huge",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"segment_anything",
|
|
||||||
title="Segment Anything",
|
|
||||||
tags=["prompt", "segmentation"],
|
|
||||||
category="segmentation",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class SegmentAnythingInvocation(BaseInvocation):
|
|
||||||
"""Runs a Segment Anything Model."""
|
|
||||||
|
|
||||||
# Reference:
|
|
||||||
# - https://arxiv.org/pdf/2304.02643
|
|
||||||
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
|
|
||||||
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
|
||||||
|
|
||||||
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
|
|
||||||
image: ImageField = InputField(description="The image to segment.")
|
|
||||||
bounding_boxes: list[BoundingBoxField] = InputField(description="The bounding boxes to prompt the SAM model with.")
|
|
||||||
apply_polygon_refinement: bool = InputField(
|
|
||||||
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
mask_filter: Literal["all", "largest", "highest_box_score"] = InputField(
|
|
||||||
description="The filtering to apply to the detected masks before merging them into a final output.",
|
|
||||||
default="all",
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
|
||||||
# The models expect a 3-channel RGB image.
|
|
||||||
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
|
|
||||||
|
|
||||||
if len(self.bounding_boxes) == 0:
|
|
||||||
combined_mask = torch.zeros(image_pil.size[::-1], dtype=torch.bool)
|
|
||||||
else:
|
|
||||||
masks = self._segment(context=context, image=image_pil)
|
|
||||||
masks = self._filter_masks(masks=masks, bounding_boxes=self.bounding_boxes)
|
|
||||||
|
|
||||||
# masks contains bool values, so we merge them via max-reduce.
|
|
||||||
combined_mask, _ = torch.stack(masks).max(dim=0)
|
|
||||||
|
|
||||||
mask_tensor_name = context.tensors.save(combined_mask)
|
|
||||||
height, width = combined_mask.shape
|
|
||||||
return MaskOutput(mask=TensorField(tensor_name=mask_tensor_name), width=width, height=height)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _load_sam_model(model_path: Path):
|
|
||||||
sam_model = AutoModelForMaskGeneration.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
local_files_only=True,
|
|
||||||
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
|
|
||||||
# model, and figure out how to make it work in the pipeline.
|
|
||||||
# torch_dtype=TorchDevice.choose_torch_dtype(),
|
|
||||||
)
|
|
||||||
assert isinstance(sam_model, SamModel)
|
|
||||||
|
|
||||||
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
|
|
||||||
assert isinstance(sam_processor, SamProcessor)
|
|
||||||
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
|
|
||||||
|
|
||||||
def _segment(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
image: Image.Image,
|
|
||||||
) -> list[torch.Tensor]:
|
|
||||||
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
|
|
||||||
# Convert the bounding boxes to the SAM input format.
|
|
||||||
sam_bounding_boxes = [[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes]
|
|
||||||
|
|
||||||
with (
|
|
||||||
context.models.load_remote_model(
|
|
||||||
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
|
|
||||||
) as sam_pipeline,
|
|
||||||
):
|
|
||||||
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
|
|
||||||
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes)
|
|
||||||
|
|
||||||
masks = self._process_masks(masks)
|
|
||||||
if self.apply_polygon_refinement:
|
|
||||||
masks = self._apply_polygon_refinement(masks)
|
|
||||||
|
|
||||||
return masks
|
|
||||||
|
|
||||||
def _process_masks(self, masks: torch.Tensor) -> list[torch.Tensor]:
|
|
||||||
"""Convert the tensor output from the Segment Anything model from a tensor of shape
|
|
||||||
[num_masks, channels, height, width] to a list of tensors of shape [height, width].
|
|
||||||
"""
|
|
||||||
assert masks.dtype == torch.bool
|
|
||||||
# [num_masks, channels, height, width] -> [num_masks, height, width]
|
|
||||||
masks, _ = masks.max(dim=1)
|
|
||||||
# Split the first dimension into a list of masks.
|
|
||||||
return list(masks.cpu().unbind(dim=0))
|
|
||||||
|
|
||||||
def _apply_polygon_refinement(self, masks: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
||||||
"""Apply polygon refinement to the masks.
|
|
||||||
|
|
||||||
Convert each mask to a polygon, then back to a mask. This has the following effect:
|
|
||||||
- Smooth the edges of the mask slightly.
|
|
||||||
- Ensure that each mask consists of a single closed polygon
|
|
||||||
- Removes small mask pieces.
|
|
||||||
- Removes holes from the mask.
|
|
||||||
"""
|
|
||||||
# Convert tensor masks to np masks.
|
|
||||||
np_masks = [mask.cpu().numpy().astype(np.uint8) for mask in masks]
|
|
||||||
|
|
||||||
# Apply polygon refinement.
|
|
||||||
for idx, mask in enumerate(np_masks):
|
|
||||||
shape = mask.shape
|
|
||||||
assert len(shape) == 2 # Assert length to satisfy type checker.
|
|
||||||
polygon = mask_to_polygon(mask)
|
|
||||||
mask = polygon_to_mask(polygon, shape)
|
|
||||||
np_masks[idx] = mask
|
|
||||||
|
|
||||||
# Convert np masks back to tensor masks.
|
|
||||||
masks = [torch.tensor(mask, dtype=torch.bool) for mask in np_masks]
|
|
||||||
|
|
||||||
return masks
|
|
||||||
|
|
||||||
def _filter_masks(self, masks: list[torch.Tensor], bounding_boxes: list[BoundingBoxField]) -> list[torch.Tensor]:
|
|
||||||
"""Filter the detected masks based on the specified mask filter."""
|
|
||||||
assert len(masks) == len(bounding_boxes)
|
|
||||||
|
|
||||||
if self.mask_filter == "all":
|
|
||||||
return masks
|
|
||||||
elif self.mask_filter == "largest":
|
|
||||||
# Find the largest mask.
|
|
||||||
return [max(masks, key=lambda x: float(x.sum()))]
|
|
||||||
elif self.mask_filter == "highest_box_score":
|
|
||||||
# Find the index of the bounding box with the highest score.
|
|
||||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
|
||||||
# cases the scores should all be non-None when using this filtering mode. That being said, -1.0 is a
|
|
||||||
# reasonable fallback since the expected score range is [0.0, 1.0].
|
|
||||||
max_score_idx = max(range(len(bounding_boxes)), key=lambda i: bounding_boxes[i].score or -1.0)
|
|
||||||
return [masks[max_score_idx]]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid mask filter: {self.mask_filter}")
|
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
from typing import Callable
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -23,7 +21,7 @@ from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
|||||||
from invokeai.backend.tiles.utils import TBLR, Tile
|
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||||
|
|
||||||
|
|
||||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
|
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
|
||||||
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
||||||
|
|
||||||
@@ -37,8 +35,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
|
||||||
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
|
|
||||||
return Tile(
|
return Tile(
|
||||||
coords=TBLR(
|
coords=TBLR(
|
||||||
top=tile.coords.top * scale,
|
top=tile.coords.top * scale,
|
||||||
@@ -54,22 +51,20 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@torch.inference_mode()
|
||||||
def upscale_image(
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
cls,
|
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||||
image: Image.Image,
|
# revisit this.
|
||||||
tile_size: int,
|
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
spandrel_model: SpandrelImageToImageModel,
|
|
||||||
is_canceled: Callable[[], bool],
|
|
||||||
) -> Image.Image:
|
|
||||||
# Compute the image tiles.
|
# Compute the image tiles.
|
||||||
if tile_size > 0:
|
if self.tile_size > 0:
|
||||||
min_overlap = 20
|
min_overlap = 20
|
||||||
tiles = calc_tiles_min_overlap(
|
tiles = calc_tiles_min_overlap(
|
||||||
image_height=image.height,
|
image_height=image.height,
|
||||||
image_width=image.width,
|
image_width=image.width,
|
||||||
tile_height=tile_size,
|
tile_height=self.tile_size,
|
||||||
tile_width=tile_size,
|
tile_width=self.tile_size,
|
||||||
min_overlap=min_overlap,
|
min_overlap=min_overlap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -90,164 +85,60 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
# Prepare input image for inference.
|
# Prepare input image for inference.
|
||||||
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||||
|
|
||||||
# Scale the tiles for re-assembling the final image.
|
# Load the model.
|
||||||
scale = spandrel_model.scale
|
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||||
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
|
|
||||||
|
|
||||||
# Prepare the output tensor.
|
|
||||||
_, channels, height, width = image_tensor.shape
|
|
||||||
output_tensor = torch.zeros(
|
|
||||||
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
|
|
||||||
)
|
|
||||||
|
|
||||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
|
||||||
|
|
||||||
# Run the model on each tile.
|
# Run the model on each tile.
|
||||||
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
with spandrel_model_info as spandrel_model:
|
||||||
# Exit early if the invocation has been canceled.
|
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||||
if is_canceled():
|
|
||||||
raise CanceledException
|
|
||||||
|
|
||||||
# Extract the current tile from the input tensor.
|
# Scale the tiles for re-assembling the final image.
|
||||||
input_tile = image_tensor[
|
scale = spandrel_model.scale
|
||||||
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
|
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
|
||||||
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
|
||||||
|
|
||||||
# Run the model on the tile.
|
# Prepare the output tensor.
|
||||||
output_tile = spandrel_model.run(input_tile)
|
_, channels, height, width = image_tensor.shape
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
# Convert the output tile into the output tensor's format.
|
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||||
# (N, C, H, W) -> (C, H, W)
|
|
||||||
output_tile = output_tile.squeeze(0)
|
|
||||||
# (C, H, W) -> (H, W, C)
|
|
||||||
output_tile = output_tile.permute(1, 2, 0)
|
|
||||||
output_tile = output_tile.clamp(0, 1)
|
|
||||||
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
|
|
||||||
|
|
||||||
# Merge the output tile into the output tensor.
|
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||||
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
|
# Exit early if the invocation has been canceled.
|
||||||
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
|
if context.util.is_canceled():
|
||||||
# it seems unnecessary, but we may find a need in the future.
|
raise CanceledException
|
||||||
top_overlap = scaled_tile.overlap.top // 2
|
|
||||||
left_overlap = scaled_tile.overlap.left // 2
|
# Extract the current tile from the input tensor.
|
||||||
output_tensor[
|
input_tile = image_tensor[
|
||||||
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
|
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
|
||||||
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
|
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||||
:,
|
|
||||||
] = output_tile[top_overlap:, left_overlap:, :]
|
# Run the model on the tile.
|
||||||
|
output_tile = spandrel_model.run(input_tile)
|
||||||
|
|
||||||
|
# Convert the output tile into the output tensor's format.
|
||||||
|
# (N, C, H, W) -> (C, H, W)
|
||||||
|
output_tile = output_tile.squeeze(0)
|
||||||
|
# (C, H, W) -> (H, W, C)
|
||||||
|
output_tile = output_tile.permute(1, 2, 0)
|
||||||
|
output_tile = output_tile.clamp(0, 1)
|
||||||
|
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
|
||||||
|
|
||||||
|
# Merge the output tile into the output tensor.
|
||||||
|
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
|
||||||
|
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
|
||||||
|
# it seems unnecessary, but we may find a need in the future.
|
||||||
|
top_overlap = scaled_tile.overlap.top // 2
|
||||||
|
left_overlap = scaled_tile.overlap.left // 2
|
||||||
|
output_tensor[
|
||||||
|
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
|
||||||
|
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
|
||||||
|
:,
|
||||||
|
] = output_tile[top_overlap:, left_overlap:, :]
|
||||||
|
|
||||||
# Convert the output tensor to a PIL image.
|
# Convert the output tensor to a PIL image.
|
||||||
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
||||||
pil_image = Image.fromarray(np_image)
|
pil_image = Image.fromarray(np_image)
|
||||||
|
|
||||||
return pil_image
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
|
||||||
# revisit this.
|
|
||||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
|
||||||
|
|
||||||
# Load the model.
|
|
||||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
|
||||||
|
|
||||||
# Do the upscaling.
|
|
||||||
with spandrel_model_info as spandrel_model:
|
|
||||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
|
||||||
|
|
||||||
# Upscale the image
|
|
||||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
|
||||||
return ImageOutput.build(image_dto)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
|
||||||
"spandrel_image_to_image_autoscale",
|
|
||||||
title="Image-to-Image (Autoscale)",
|
|
||||||
tags=["upscale"],
|
|
||||||
category="upscale",
|
|
||||||
version="1.0.0",
|
|
||||||
)
|
|
||||||
class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
|
|
||||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached."""
|
|
||||||
|
|
||||||
scale: float = InputField(
|
|
||||||
default=4.0,
|
|
||||||
gt=0.0,
|
|
||||||
le=16.0,
|
|
||||||
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
|
|
||||||
)
|
|
||||||
fit_to_multiple_of_8: bool = InputField(
|
|
||||||
default=False,
|
|
||||||
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
|
|
||||||
)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
|
||||||
# revisit this.
|
|
||||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
|
||||||
|
|
||||||
# Load the model.
|
|
||||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
|
||||||
|
|
||||||
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
|
|
||||||
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
|
|
||||||
target_width = int(image.width * self.scale)
|
|
||||||
target_height = int(image.height * self.scale)
|
|
||||||
|
|
||||||
# Do the upscaling.
|
|
||||||
with spandrel_model_info as spandrel_model:
|
|
||||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
|
||||||
|
|
||||||
# First pass of upscaling. Note: `pil_image` will be mutated.
|
|
||||||
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
|
||||||
|
|
||||||
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
|
|
||||||
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
|
|
||||||
# to be considered an upscale model.
|
|
||||||
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
|
|
||||||
|
|
||||||
if is_upscale_model:
|
|
||||||
# This is an upscale model, so we should keep upscaling until we reach the target size.
|
|
||||||
iterations = 1
|
|
||||||
while pil_image.width < target_width or pil_image.height < target_height:
|
|
||||||
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
|
|
||||||
iterations += 1
|
|
||||||
|
|
||||||
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
|
|
||||||
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
|
|
||||||
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
|
|
||||||
# we should never reach this limit.
|
|
||||||
if iterations >= 5:
|
|
||||||
context.logger.warning(
|
|
||||||
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
|
|
||||||
# to be the same as the processed image size.
|
|
||||||
|
|
||||||
# The output size is now the size of the processed image.
|
|
||||||
target_width = pil_image.width
|
|
||||||
target_height = pil_image.height
|
|
||||||
|
|
||||||
# Warn the user if they requested a scale greater than 1.
|
|
||||||
if self.scale > 1:
|
|
||||||
context.logger.warning(
|
|
||||||
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
|
|
||||||
)
|
|
||||||
|
|
||||||
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
|
|
||||||
# in the final resize
|
|
||||||
if self.fit_to_multiple_of_8:
|
|
||||||
target_width = int(target_width // 8 * 8)
|
|
||||||
target_height = int(target_height // 8 * 8)
|
|
||||||
|
|
||||||
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
|
|
||||||
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
|
|
||||||
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|||||||
@@ -91,7 +91,6 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
db_dir: Path to InvokeAI databases directory.
|
db_dir: Path to InvokeAI databases directory.
|
||||||
outputs_dir: Path to directory for outputs.
|
outputs_dir: Path to directory for outputs.
|
||||||
custom_nodes_dir: Path to directory for custom nodes.
|
custom_nodes_dir: Path to directory for custom nodes.
|
||||||
style_presets_dir: Path to directory for style presets.
|
|
||||||
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
log_handlers: Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".
|
||||||
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
log_format: Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style.<br>Valid values: `plain`, `color`, `syslog`, `legacy`
|
||||||
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
log_level: Emit logging messages at this level or higher.<br>Valid values: `debug`, `info`, `warning`, `error`, `critical`
|
||||||
@@ -154,7 +153,6 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||||
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
custom_nodes_dir: Path = Field(default=Path("nodes"), description="Path to directory for custom nodes.")
|
||||||
style_presets_dir: Path = Field(default=Path("style_presets"), description="Path to directory for style presets.")
|
|
||||||
|
|
||||||
# LOGGING
|
# LOGGING
|
||||||
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
|
log_handlers: list[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=<path>", "syslog=path|address:host:port", "http=<url>".')
|
||||||
@@ -302,11 +300,6 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
"""Path to the models directory, resolved to an absolute path.."""
|
"""Path to the models directory, resolved to an absolute path.."""
|
||||||
return self._resolve(self.models_dir)
|
return self._resolve(self.models_dir)
|
||||||
|
|
||||||
@property
|
|
||||||
def style_presets_path(self) -> Path:
|
|
||||||
"""Path to the style presets directory, resolved to an absolute path.."""
|
|
||||||
return self._resolve(self.style_presets_dir)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def convert_cache_path(self) -> Path:
|
def convert_cache_path(self) -> Path:
|
||||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||||
|
|||||||
@@ -88,7 +88,6 @@ class QueueItemEventBase(QueueEventBase):
|
|||||||
|
|
||||||
item_id: int = Field(description="The ID of the queue item")
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
batch_id: str = Field(description="The ID of the queue batch")
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationEventBase(QueueItemEventBase):
|
class InvocationEventBase(QueueItemEventBase):
|
||||||
@@ -96,6 +95,8 @@ class InvocationEventBase(QueueItemEventBase):
|
|||||||
|
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
queue_id: str = Field(description="The ID of the queue")
|
queue_id: str = Field(description="The ID of the queue")
|
||||||
|
item_id: int = Field(description="The ID of the queue item")
|
||||||
|
batch_id: str = Field(description="The ID of the queue batch")
|
||||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||||
@@ -113,7 +114,6 @@ class InvocationStartedEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
origin=queue_item.origin,
|
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@@ -147,7 +147,6 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
origin=queue_item.origin,
|
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@@ -185,7 +184,6 @@ class InvocationCompleteEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
origin=queue_item.origin,
|
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@@ -218,7 +216,6 @@ class InvocationErrorEvent(InvocationEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
origin=queue_item.origin,
|
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
invocation=invocation,
|
invocation=invocation,
|
||||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
@@ -256,7 +253,6 @@ class QueueItemStatusChangedEvent(QueueItemEventBase):
|
|||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
item_id=queue_item.item_id,
|
item_id=queue_item.item_id,
|
||||||
batch_id=queue_item.batch_id,
|
batch_id=queue_item.batch_id,
|
||||||
origin=queue_item.origin,
|
|
||||||
session_id=queue_item.session_id,
|
session_id=queue_item.session_id,
|
||||||
status=queue_item.status,
|
status=queue_item.status,
|
||||||
error_type=queue_item.error_type,
|
error_type=queue_item.error_type,
|
||||||
@@ -283,14 +279,12 @@ class BatchEnqueuedEvent(QueueEventBase):
|
|||||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||||
)
|
)
|
||||||
priority: int = Field(description="The priority of the batch")
|
priority: int = Field(description="The priority of the batch")
|
||||||
origin: str | None = Field(default=None, description="The origin of the batch")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||||
return cls(
|
return cls(
|
||||||
queue_id=enqueue_result.queue_id,
|
queue_id=enqueue_result.queue_id,
|
||||||
batch_id=enqueue_result.batch.batch_id,
|
batch_id=enqueue_result.batch.batch_id,
|
||||||
origin=enqueue_result.batch.origin,
|
|
||||||
enqueued=enqueue_result.enqueued,
|
enqueued=enqueue_result.enqueued,
|
||||||
requested=enqueue_result.requested,
|
requested=enqueue_result.requested,
|
||||||
priority=enqueue_result.priority,
|
priority=enqueue_result.priority,
|
||||||
|
|||||||
@@ -1,44 +1,46 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
|
from queue import Empty, Queue
|
||||||
|
|
||||||
from fastapi_events.dispatcher import dispatch
|
from fastapi_events.dispatcher import dispatch
|
||||||
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.events.events_common import EventBase
|
from invokeai.app.services.events.events_common import (
|
||||||
|
EventBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FastAPIEventService(EventServiceBase):
|
class FastAPIEventService(EventServiceBase):
|
||||||
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
|
def __init__(self, event_handler_id: int) -> None:
|
||||||
self.event_handler_id = event_handler_id
|
self.event_handler_id = event_handler_id
|
||||||
self._queue = asyncio.Queue[EventBase | None]()
|
self._queue = Queue[EventBase | None]()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._loop = loop
|
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||||
|
|
||||||
# We need to store a reference to the task so it doesn't get GC'd
|
|
||||||
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
|
|
||||||
self._background_tasks: set[asyncio.Task[None]] = set()
|
|
||||||
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
|
||||||
self._background_tasks.add(task)
|
|
||||||
task.add_done_callback(self._background_tasks.remove)
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def stop(self, *args, **kwargs):
|
def stop(self, *args, **kwargs):
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
self._queue.put(None)
|
||||||
|
|
||||||
def dispatch(self, event: EventBase) -> None:
|
def dispatch(self, event: EventBase) -> None:
|
||||||
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
self._queue.put(event)
|
||||||
|
|
||||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
event = await self._queue.get()
|
event = self._queue.get(block=False)
|
||||||
if not event: # Probably stopping
|
if not event: # Probably stopping
|
||||||
continue
|
continue
|
||||||
# Leave the payloads as live pydantic models
|
# Leave the payloads as live pydantic models
|
||||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||||
|
|
||||||
|
except Empty:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
pass
|
||||||
|
|
||||||
except asyncio.CancelledError as e:
|
except asyncio.CancelledError as e:
|
||||||
raise e # Raise a proper error
|
raise e # Raise a proper error
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||||
from invokeai.app.services.image_files.image_files_common import (
|
from invokeai.app.services.image_files.image_files_common import (
|
||||||
@@ -19,12 +20,18 @@ from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
|||||||
class DiskImageFileStorage(ImageFileStorageBase):
|
class DiskImageFileStorage(ImageFileStorageBase):
|
||||||
"""Stores images on disk"""
|
"""Stores images on disk"""
|
||||||
|
|
||||||
|
__output_folder: Path
|
||||||
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
|
__cache: Dict[Path, PILImageType]
|
||||||
|
__max_cache_size: int
|
||||||
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache: dict[Path, PILImageType] = {}
|
self.__cache = {}
|
||||||
self.__cache_ids = Queue[Path]()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||||
# Validate required output folders at launch
|
# Validate required output folders at launch
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
@@ -96,7 +103,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(image_name)
|
||||||
|
|
||||||
if image_path.exists():
|
if image_path.exists():
|
||||||
image_path.unlink()
|
send2trash(image_path)
|
||||||
if image_path in self.__cache:
|
if image_path in self.__cache:
|
||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
@@ -104,7 +111,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
thumbnail_path = self.get_path(thumbnail_name, True)
|
thumbnail_path = self.get_path(thumbnail_name, True)
|
||||||
|
|
||||||
if thumbnail_path.exists():
|
if thumbnail_path.exists():
|
||||||
thumbnail_path.unlink()
|
send2trash(thumbnail_path)
|
||||||
if thumbnail_path in self.__cache:
|
if thumbnail_path in self.__cache:
|
||||||
del self.__cache[thumbnail_path]
|
del self.__cache[thumbnail_path]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ from __future__ import annotations
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
@@ -63,8 +61,6 @@ class InvocationServices:
|
|||||||
workflow_records: "WorkflowRecordsStorageBase",
|
workflow_records: "WorkflowRecordsStorageBase",
|
||||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||||
style_preset_records: "StylePresetRecordsStorageBase",
|
|
||||||
style_preset_image_files: "StylePresetImageFileStorageBase",
|
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.board_image_records = board_image_records
|
self.board_image_records = board_image_records
|
||||||
@@ -89,5 +85,3 @@ class InvocationServices:
|
|||||||
self.workflow_records = workflow_records
|
self.workflow_records = workflow_records
|
||||||
self.tensors = tensors
|
self.tensors = tensors
|
||||||
self.conditioning = conditioning
|
self.conditioning = conditioning
|
||||||
self.style_preset_records = style_preset_records
|
|
||||||
self.style_preset_image_files = style_preset_image_files
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
|
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
|
||||||
@@ -69,7 +70,7 @@ class ModelImageFileStorageDisk(ModelImageFileStorageBase):
|
|||||||
if not self._validate_path(path):
|
if not self._validate_path(path):
|
||||||
raise ModelImageFileNotFoundException
|
raise ModelImageFileNotFoundException
|
||||||
|
|
||||||
path.unlink()
|
send2trash(path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ModelImageFileDeleteException from e
|
raise ModelImageFileDeleteException from e
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
@@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
|
|||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||||
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.backend.model_manager import AnyModelConfig
|
from invokeai.backend.model_manager import AnyModelConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Probe and register the model at model_path.
|
Probe and register the model at model_path.
|
||||||
@@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
This keeps the model in its current location.
|
This keeps the model in its current location.
|
||||||
|
|
||||||
:param model_path: Filesystem Path to the model.
|
:param model_path: Filesystem Path to the model.
|
||||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
:param config: Dict of attributes that will override autoassigned values.
|
||||||
:returns id: The string ID of the registered model.
|
:returns id: The string ID of the registered model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
def install_path(
|
def install_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Probe, register and install the model in the models directory.
|
Probe, register and install the model in the models directory.
|
||||||
@@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
the models directory handled by InvokeAI.
|
the models directory handled by InvokeAI.
|
||||||
|
|
||||||
:param model_path: Filesystem Path to the model.
|
:param model_path: Filesystem Path to the model.
|
||||||
:param config: ModelRecordChanges object that will override autoassigned model record values.
|
:param config: Dict of attributes that will override autoassigned values.
|
||||||
:returns id: The string ID of the registered model.
|
:returns id: The string ID of the registered model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
|
|||||||
def heuristic_import(
|
def heuristic_import(
|
||||||
self,
|
self,
|
||||||
source: str,
|
source: str,
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
inplace: Optional[bool] = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||||
|
|
||||||
:param source: String source
|
:param source: String source
|
||||||
:param config: Optional ModelRecordChanges object. Any fields in this object
|
:param config: Optional dict. Any fields in this dict
|
||||||
will override corresponding autoassigned probe fields in the
|
will override corresponding autoassigned probe fields in the
|
||||||
model's config record as described in `import_model()`.
|
model's config record as described in `import_model()`.
|
||||||
:param access_token: Optional access token for remote sources.
|
:param access_token: Optional access token for remote sources.
|
||||||
@@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
|
|||||||
def import_model(
|
def import_model(
|
||||||
self,
|
self,
|
||||||
source: ModelSource,
|
source: ModelSource,
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
"""Install the indicated model.
|
"""Install the indicated model.
|
||||||
|
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ import re
|
|||||||
import traceback
|
import traceback
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Optional, Set, Union
|
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||||
from invokeai.app.services.model_records import ModelRecordChanges
|
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||||
from invokeai.backend.model_manager.config import ModelSourceType
|
from invokeai.backend.model_manager.config import ModelSourceType
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
@@ -134,9 +133,8 @@ class ModelInstallJob(BaseModel):
|
|||||||
id: int = Field(description="Unique ID for this job")
|
id: int = Field(description="Unique ID for this job")
|
||||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||||
config_in: ModelRecordChanges = Field(
|
config_in: Dict[str, Any] = Field(
|
||||||
default_factory=ModelRecordChanges,
|
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||||
description="Configuration information (e.g. 'description') to apply to model.",
|
|
||||||
)
|
)
|
||||||
config_out: Optional[AnyModelConfig] = Field(
|
config_out: Optional[AnyModelConfig] = Field(
|
||||||
default=None, description="After successful installation, this will hold the configuration object."
|
default=None, description="After successful installation, this will hold the configuration object."
|
||||||
|
|||||||
@@ -163,27 +163,26 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def register_path(
|
def register_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> str: # noqa D102
|
) -> str: # noqa D102
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = config or ModelRecordChanges()
|
config = config or {}
|
||||||
if not config.source:
|
if not config.get("source"):
|
||||||
config.source = model_path.resolve().as_posix()
|
config["source"] = model_path.resolve().as_posix()
|
||||||
config.source_type = ModelSourceType.Path
|
config["source_type"] = ModelSourceType.Path
|
||||||
return self._register(model_path, config)
|
return self._register(model_path, config)
|
||||||
|
|
||||||
def install_path(
|
def install_path(
|
||||||
self,
|
self,
|
||||||
model_path: Union[Path, str],
|
model_path: Union[Path, str],
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> str: # noqa D102
|
) -> str: # noqa D102
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
config = config or ModelRecordChanges()
|
config = config or {}
|
||||||
info: AnyModelConfig = ModelProbe.probe(
|
|
||||||
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
|
|
||||||
) # type: ignore
|
|
||||||
|
|
||||||
if preferred_name := config.name:
|
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
|
||||||
|
|
||||||
|
if preferred_name := config.get("name"):
|
||||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||||
|
|
||||||
dest_path = (
|
dest_path = (
|
||||||
@@ -205,7 +204,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def heuristic_import(
|
def heuristic_import(
|
||||||
self,
|
self,
|
||||||
source: str,
|
source: str,
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
inplace: Optional[bool] = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
@@ -217,7 +216,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source_obj.access_token = access_token
|
source_obj.access_token = access_token
|
||||||
return self.import_model(source_obj, config)
|
return self.import_model(source_obj, config)
|
||||||
|
|
||||||
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
|
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||||
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
||||||
if similar_jobs:
|
if similar_jobs:
|
||||||
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
||||||
@@ -319,17 +318,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
model_path = self._app_config.models_path / model_path
|
model_path = self._app_config.models_path / model_path
|
||||||
model_path = model_path.resolve()
|
model_path = model_path.resolve()
|
||||||
|
|
||||||
config = ModelRecordChanges(
|
config: dict[str, Any] = {}
|
||||||
name=model_name,
|
config["name"] = model_name
|
||||||
description=stanza.get("description"),
|
config["description"] = stanza.get("description")
|
||||||
)
|
|
||||||
legacy_config_path = stanza.get("config")
|
legacy_config_path = stanza.get("config")
|
||||||
if legacy_config_path:
|
if legacy_config_path:
|
||||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||||
config.config_path = str(legacy_config_path)
|
config["config_path"] = str(legacy_config_path)
|
||||||
try:
|
try:
|
||||||
id = self.register_path(model_path=model_path, config=config)
|
id = self.register_path(model_path=model_path, config=config)
|
||||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||||
@@ -502,11 +500,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.total_bytes = self._stat_size(job.local_path)
|
job.total_bytes = self._stat_size(job.local_path)
|
||||||
job.bytes = job.total_bytes
|
job.bytes = job.total_bytes
|
||||||
self._signal_job_running(job)
|
self._signal_job_running(job)
|
||||||
job.config_in.source = str(job.source)
|
job.config_in["source"] = str(job.source)
|
||||||
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||||
# enter the metadata, if there is any
|
# enter the metadata, if there is any
|
||||||
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
|
||||||
job.config_in.source_api_response = job.source_metadata.api_response
|
job.config_in["source_api_response"] = job.source_metadata.api_response
|
||||||
|
|
||||||
if job.inplace:
|
if job.inplace:
|
||||||
key = self.register_path(job.local_path, job.config_in)
|
key = self.register_path(job.local_path, job.config_in)
|
||||||
@@ -641,11 +639,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
return new_path
|
return new_path
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
config = config or ModelRecordChanges()
|
config = config or {}
|
||||||
|
|
||||||
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
|
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
|
||||||
|
|
||||||
model_path = model_path.resolve()
|
model_path = model_path.resolve()
|
||||||
|
|
||||||
@@ -676,13 +674,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
precision = TorchDevice.choose_torch_dtype()
|
precision = TorchDevice.choose_torch_dtype()
|
||||||
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||||
|
|
||||||
def _import_local_model(
|
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
|
|
||||||
) -> ModelInstallJob:
|
|
||||||
return ModelInstallJob(
|
return ModelInstallJob(
|
||||||
id=self._next_id(),
|
id=self._next_id(),
|
||||||
source=source,
|
source=source,
|
||||||
config_in=config or ModelRecordChanges(),
|
config_in=config or {},
|
||||||
local_path=Path(source.path),
|
local_path=Path(source.path),
|
||||||
inplace=source.inplace or False,
|
inplace=source.inplace or False,
|
||||||
)
|
)
|
||||||
@@ -690,7 +686,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def _import_from_hf(
|
def _import_from_hf(
|
||||||
self,
|
self,
|
||||||
source: HFModelSource,
|
source: HFModelSource,
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
# Add user's cached access token to HuggingFace requests
|
# Add user's cached access token to HuggingFace requests
|
||||||
if source.access_token is None:
|
if source.access_token is None:
|
||||||
@@ -706,7 +702,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def _import_from_url(
|
def _import_from_url(
|
||||||
self,
|
self,
|
||||||
source: URLModelSource,
|
source: URLModelSource,
|
||||||
config: Optional[ModelRecordChanges] = None,
|
config: Optional[Dict[str, Any]],
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
remote_files, metadata = self._remote_files_from_source(source)
|
remote_files, metadata = self._remote_files_from_source(source)
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
@@ -721,7 +717,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source: HFModelSource | URLModelSource,
|
source: HFModelSource | URLModelSource,
|
||||||
remote_files: List[RemoteModelFile],
|
remote_files: List[RemoteModelFile],
|
||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[ModelRecordChanges],
|
config: Optional[Dict[str, Any]],
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
if len(remote_files) == 0:
|
if len(remote_files) == 0:
|
||||||
raise ValueError(f"{source}: No downloadable files found")
|
raise ValueError(f"{source}: No downloadable files found")
|
||||||
@@ -734,7 +730,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
install_job = ModelInstallJob(
|
install_job = ModelInstallJob(
|
||||||
id=self._next_id(),
|
id=self._next_id(),
|
||||||
source=source,
|
source=source,
|
||||||
config_in=config or ModelRecordChanges(),
|
config_in=config or {},
|
||||||
source_metadata=metadata,
|
source_metadata=metadata,
|
||||||
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||||
bytes=0,
|
bytes=0,
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ControlAdapterDefaultSettings,
|
ControlAdapterDefaultSettings,
|
||||||
MainModelDefaultSettings,
|
MainModelDefaultSettings,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelSourceType,
|
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
@@ -67,16 +66,10 @@ class ModelRecordChanges(BaseModelExcludeNull):
|
|||||||
"""A set of changes to apply to a model."""
|
"""A set of changes to apply to a model."""
|
||||||
|
|
||||||
# Changes applicable to all models
|
# Changes applicable to all models
|
||||||
source: Optional[str] = Field(description="original source of the model", default=None)
|
|
||||||
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
|
|
||||||
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
|
|
||||||
name: Optional[str] = Field(description="Name of the model.", default=None)
|
name: Optional[str] = Field(description="Name of the model.", default=None)
|
||||||
path: Optional[str] = Field(description="Path to the model.", default=None)
|
path: Optional[str] = Field(description="Path to the model.", default=None)
|
||||||
description: Optional[str] = Field(description="Model description", default=None)
|
description: Optional[str] = Field(description="Model description", default=None)
|
||||||
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
|
||||||
type: Optional[ModelType] = Field(description="Type of model", default=None)
|
|
||||||
key: Optional[str] = Field(description="Database ID for this model", default=None)
|
|
||||||
hash: Optional[str] = Field(description="hash of model file", default=None)
|
|
||||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||||
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
|
||||||
description="Default settings for this model", default=None
|
description="Default settings for this model", default=None
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
CancelByOriginResult,
|
|
||||||
CancelByQueueIDResult,
|
CancelByQueueIDResult,
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@@ -96,11 +95,6 @@ class SessionQueueBase(ABC):
|
|||||||
"""Cancels all queue items with matching batch IDs"""
|
"""Cancels all queue items with matching batch IDs"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
|
||||||
"""Cancels all queue items with the given batch origin"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||||
"""Cancels all queue items with matching queue ID"""
|
"""Cancels all queue items with matching queue ID"""
|
||||||
|
|||||||
@@ -77,7 +77,6 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
|
|||||||
|
|
||||||
class Batch(BaseModel):
|
class Batch(BaseModel):
|
||||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||||
origin: str | None = Field(default=None, description="The origin of this batch.")
|
|
||||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||||
graph: Graph = Field(description="The graph to initialize the session with")
|
graph: Graph = Field(description="The graph to initialize the session with")
|
||||||
workflow: Optional[WorkflowWithoutID] = Field(
|
workflow: Optional[WorkflowWithoutID] = Field(
|
||||||
@@ -196,7 +195,6 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
|
||||||
priority: int = Field(default=0, description="The priority of this queue item")
|
priority: int = Field(default=0, description="The priority of this queue item")
|
||||||
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
batch_id: str = Field(description="The ID of the batch associated with this queue item")
|
||||||
origin: str | None = Field(default=None, description="The origin of this queue item. ")
|
|
||||||
session_id: str = Field(
|
session_id: str = Field(
|
||||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||||
)
|
)
|
||||||
@@ -296,7 +294,6 @@ class SessionQueueStatus(BaseModel):
|
|||||||
class BatchStatus(BaseModel):
|
class BatchStatus(BaseModel):
|
||||||
queue_id: str = Field(..., description="The ID of the queue")
|
queue_id: str = Field(..., description="The ID of the queue")
|
||||||
batch_id: str = Field(..., description="The ID of the batch")
|
batch_id: str = Field(..., description="The ID of the batch")
|
||||||
origin: str | None = Field(..., description="The origin of the batch")
|
|
||||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||||
@@ -331,12 +328,6 @@ class CancelByBatchIDsResult(BaseModel):
|
|||||||
canceled: int = Field(..., description="Number of queue items canceled")
|
canceled: int = Field(..., description="Number of queue items canceled")
|
||||||
|
|
||||||
|
|
||||||
class CancelByOriginResult(BaseModel):
|
|
||||||
"""Result of canceling by list of batch ids"""
|
|
||||||
|
|
||||||
canceled: int = Field(..., description="Number of queue items canceled")
|
|
||||||
|
|
||||||
|
|
||||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||||
"""Result of canceling by queue id"""
|
"""Result of canceling by queue id"""
|
||||||
|
|
||||||
@@ -442,7 +433,6 @@ class SessionQueueValueToInsert(NamedTuple):
|
|||||||
field_values: Optional[str] # field_values json
|
field_values: Optional[str] # field_values json
|
||||||
priority: int # priority
|
priority: int # priority
|
||||||
workflow: Optional[str] # workflow json
|
workflow: Optional[str] # workflow json
|
||||||
origin: str | None
|
|
||||||
|
|
||||||
|
|
||||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||||
@@ -463,7 +453,6 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
|||||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||||
priority, # priority
|
priority, # priority
|
||||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||||
batch.origin, # origin
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return values_to_insert
|
return values_to_insert
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
Batch,
|
Batch,
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
CancelByBatchIDsResult,
|
CancelByBatchIDsResult,
|
||||||
CancelByOriginResult,
|
|
||||||
CancelByQueueIDResult,
|
CancelByQueueIDResult,
|
||||||
ClearResult,
|
ClearResult,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@@ -128,8 +127,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
|
|
||||||
self.__cursor.executemany(
|
self.__cursor.executemany(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin)
|
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
""",
|
""",
|
||||||
values_to_insert,
|
values_to_insert,
|
||||||
)
|
)
|
||||||
@@ -418,7 +417,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
)
|
)
|
||||||
self.__conn.commit()
|
self.__conn.commit()
|
||||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||||
|
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||||
|
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||||
|
current_queue_item, batch_status, queue_status
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.__conn.rollback()
|
self.__conn.rollback()
|
||||||
raise
|
raise
|
||||||
@@ -426,46 +429,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
return CancelByBatchIDsResult(canceled=count)
|
return CancelByBatchIDsResult(canceled=count)
|
||||||
|
|
||||||
def cancel_by_origin(self, queue_id: str, origin: str) -> CancelByOriginResult:
|
|
||||||
try:
|
|
||||||
current_queue_item = self.get_current(queue_id)
|
|
||||||
self.__lock.acquire()
|
|
||||||
where = """--sql
|
|
||||||
WHERE
|
|
||||||
queue_id == ?
|
|
||||||
AND origin == ?
|
|
||||||
AND status != 'canceled'
|
|
||||||
AND status != 'completed'
|
|
||||||
AND status != 'failed'
|
|
||||||
"""
|
|
||||||
params = (queue_id, origin)
|
|
||||||
self.__cursor.execute(
|
|
||||||
f"""--sql
|
|
||||||
SELECT COUNT(*)
|
|
||||||
FROM session_queue
|
|
||||||
{where};
|
|
||||||
""",
|
|
||||||
params,
|
|
||||||
)
|
|
||||||
count = self.__cursor.fetchone()[0]
|
|
||||||
self.__cursor.execute(
|
|
||||||
f"""--sql
|
|
||||||
UPDATE session_queue
|
|
||||||
SET status = 'canceled'
|
|
||||||
{where};
|
|
||||||
""",
|
|
||||||
params,
|
|
||||||
)
|
|
||||||
self.__conn.commit()
|
|
||||||
if current_queue_item is not None and current_queue_item.origin == origin:
|
|
||||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
|
||||||
except Exception:
|
|
||||||
self.__conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self.__lock.release()
|
|
||||||
return CancelByOriginResult(canceled=count)
|
|
||||||
|
|
||||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||||
try:
|
try:
|
||||||
current_queue_item = self.get_current(queue_id)
|
current_queue_item = self.get_current(queue_id)
|
||||||
@@ -578,8 +541,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
started_at,
|
started_at,
|
||||||
session_id,
|
session_id,
|
||||||
batch_id,
|
batch_id,
|
||||||
queue_id,
|
queue_id
|
||||||
origin
|
|
||||||
FROM session_queue
|
FROM session_queue
|
||||||
WHERE queue_id = ?
|
WHERE queue_id = ?
|
||||||
"""
|
"""
|
||||||
@@ -659,7 +621,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.acquire()
|
self.__lock.acquire()
|
||||||
self.__cursor.execute(
|
self.__cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT status, count(*), origin
|
SELECT status, count(*)
|
||||||
FROM session_queue
|
FROM session_queue
|
||||||
WHERE
|
WHERE
|
||||||
queue_id = ?
|
queue_id = ?
|
||||||
@@ -671,7 +633,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||||
total = sum(row[1] for row in result)
|
total = sum(row[1] for row in result)
|
||||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||||
origin = result[0]["origin"] if result else None
|
|
||||||
except Exception:
|
except Exception:
|
||||||
self.__conn.rollback()
|
self.__conn.rollback()
|
||||||
raise
|
raise
|
||||||
@@ -680,7 +641,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
|
|
||||||
return BatchStatus(
|
return BatchStatus(
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
origin=origin,
|
|
||||||
queue_id=queue_id,
|
queue_id=queue_id,
|
||||||
pending=counts.get("pending", 0),
|
pending=counts.get("pending", 0),
|
||||||
in_progress=counts.get("in_progress", 0),
|
in_progress=counts.get("in_progress", 0),
|
||||||
|
|||||||
@@ -16,8 +16,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_14 import build_migration_14
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_15 import build_migration_15
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@@ -51,8 +49,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||||
migrator.register_migration(build_migration_12(app_config=config))
|
migrator.register_migration(build_migration_12(app_config=config))
|
||||||
migrator.register_migration(build_migration_13())
|
migrator.register_migration(build_migration_13())
|
||||||
migrator.register_migration(build_migration_14())
|
|
||||||
migrator.register_migration(build_migration_15())
|
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
import sqlite3
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
|
||||||
|
|
||||||
|
|
||||||
class Migration14Callback:
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
self._create_style_presets(cursor)
|
|
||||||
|
|
||||||
def _create_style_presets(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
"""Create the table used to store style presets."""
|
|
||||||
tables = [
|
|
||||||
"""--sql
|
|
||||||
CREATE TABLE IF NOT EXISTS style_presets (
|
|
||||||
id TEXT NOT NULL PRIMARY KEY,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
preset_data TEXT NOT NULL,
|
|
||||||
type TEXT NOT NULL DEFAULT "user",
|
|
||||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
|
||||||
-- Updated via trigger
|
|
||||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
|
||||||
);
|
|
||||||
"""
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add trigger for `updated_at`.
|
|
||||||
triggers = [
|
|
||||||
"""--sql
|
|
||||||
CREATE TRIGGER IF NOT EXISTS style_presets
|
|
||||||
AFTER UPDATE
|
|
||||||
ON style_presets FOR EACH ROW
|
|
||||||
BEGIN
|
|
||||||
UPDATE style_presets SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
|
||||||
WHERE id = old.id;
|
|
||||||
END;
|
|
||||||
"""
|
|
||||||
]
|
|
||||||
|
|
||||||
# Add indexes for searchable fields
|
|
||||||
indices = [
|
|
||||||
"CREATE INDEX IF NOT EXISTS idx_style_presets_name ON style_presets(name);",
|
|
||||||
]
|
|
||||||
|
|
||||||
for stmt in tables + indices + triggers:
|
|
||||||
cursor.execute(stmt)
|
|
||||||
|
|
||||||
|
|
||||||
def build_migration_14() -> Migration:
|
|
||||||
"""
|
|
||||||
Build the migration from database version 13 to 14..
|
|
||||||
|
|
||||||
This migration does the following:
|
|
||||||
- Create the table used to store style presets.
|
|
||||||
"""
|
|
||||||
migration_14 = Migration(
|
|
||||||
from_version=13,
|
|
||||||
to_version=14,
|
|
||||||
callback=Migration14Callback(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return migration_14
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
import sqlite3
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
|
||||||
|
|
||||||
|
|
||||||
class Migration15Callback:
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
self._add_origin_col(cursor)
|
|
||||||
|
|
||||||
def _add_origin_col(self, cursor: sqlite3.Cursor) -> None:
|
|
||||||
"""
|
|
||||||
- Adds `origin` column to the session queue table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
cursor.execute("ALTER TABLE session_queue ADD COLUMN origin TEXT;")
|
|
||||||
|
|
||||||
|
|
||||||
def build_migration_15() -> Migration:
|
|
||||||
"""
|
|
||||||
Build the migration from database version 14 to 15.
|
|
||||||
|
|
||||||
This migration does the following:
|
|
||||||
- Adds `origin` column to the session queue table.
|
|
||||||
"""
|
|
||||||
migration_15 = Migration(
|
|
||||||
from_version=14,
|
|
||||||
to_version=15,
|
|
||||||
callback=Migration15Callback(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return migration_15
|
|
||||||
|
Before Width: | Height: | Size: 98 KiB |
|
Before Width: | Height: | Size: 138 KiB |
|
Before Width: | Height: | Size: 122 KiB |
|
Before Width: | Height: | Size: 123 KiB |
|
Before Width: | Height: | Size: 160 KiB |
|
Before Width: | Height: | Size: 146 KiB |
|
Before Width: | Height: | Size: 119 KiB |
|
Before Width: | Height: | Size: 117 KiB |
|
Before Width: | Height: | Size: 110 KiB |
|
Before Width: | Height: | Size: 46 KiB |
|
Before Width: | Height: | Size: 79 KiB |
|
Before Width: | Height: | Size: 156 KiB |
|
Before Width: | Height: | Size: 141 KiB |
|
Before Width: | Height: | Size: 96 KiB |
|
Before Width: | Height: | Size: 91 KiB |
|
Before Width: | Height: | Size: 88 KiB |
|
Before Width: | Height: | Size: 107 KiB |
|
Before Width: | Height: | Size: 132 KiB |
@@ -1,33 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetImageFileStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for storing and retrieving image files."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, style_preset_id: str) -> PILImageType:
|
|
||||||
"""Retrieves a style preset image as PIL Image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(self, style_preset_id: str) -> Path:
|
|
||||||
"""Gets the internal path to a style preset image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_url(self, style_preset_id: str) -> str | None:
|
|
||||||
"""Gets the URL to fetch a style preset image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
|
||||||
"""Saves a style preset image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, style_preset_id: str) -> None:
|
|
||||||
"""Deletes a style preset image."""
|
|
||||||
pass
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
class StylePresetImageFileNotFoundException(Exception):
|
|
||||||
"""Raised when an image file is not found in storage."""
|
|
||||||
|
|
||||||
def __init__(self, message: str = "Style preset image file not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetImageFileSaveException(Exception):
|
|
||||||
"""Raised when an image cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message: str = "Style preset image file not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetImageFileDeleteException(Exception):
|
|
||||||
"""Raised when an image cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message: str = "Style preset image file not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.style_preset_images.style_preset_images_base import StylePresetImageFileStorageBase
|
|
||||||
from invokeai.app.services.style_preset_images.style_preset_images_common import (
|
|
||||||
StylePresetImageFileDeleteException,
|
|
||||||
StylePresetImageFileNotFoundException,
|
|
||||||
StylePresetImageFileSaveException,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_common import PresetType
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
from invokeai.app.util.thumbnails import make_thumbnail
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetImageFileStorageDisk(StylePresetImageFileStorageBase):
|
|
||||||
"""Stores images on disk"""
|
|
||||||
|
|
||||||
def __init__(self, style_preset_images_folder: Path):
|
|
||||||
self._style_preset_images_folder = style_preset_images_folder
|
|
||||||
self._validate_storage_folders()
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
|
|
||||||
def get(self, style_preset_id: str) -> PILImageType:
|
|
||||||
try:
|
|
||||||
path = self.get_path(style_preset_id)
|
|
||||||
|
|
||||||
return Image.open(path)
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
raise StylePresetImageFileNotFoundException from e
|
|
||||||
|
|
||||||
def save(self, style_preset_id: str, image: PILImageType) -> None:
|
|
||||||
try:
|
|
||||||
self._validate_storage_folders()
|
|
||||||
image_path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
|
||||||
thumbnail = make_thumbnail(image, 256)
|
|
||||||
thumbnail.save(image_path, format="webp")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise StylePresetImageFileSaveException from e
|
|
||||||
|
|
||||||
def get_path(self, style_preset_id: str) -> Path:
|
|
||||||
style_preset = self._invoker.services.style_preset_records.get(style_preset_id)
|
|
||||||
if style_preset.type is PresetType.Default:
|
|
||||||
default_images_dir = Path(__file__).parent / Path("default_style_preset_images")
|
|
||||||
path = default_images_dir / (style_preset.name + ".png")
|
|
||||||
else:
|
|
||||||
path = self._style_preset_images_folder / (style_preset_id + ".webp")
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def get_url(self, style_preset_id: str) -> str | None:
|
|
||||||
path = self.get_path(style_preset_id)
|
|
||||||
if not self._validate_path(path):
|
|
||||||
return
|
|
||||||
|
|
||||||
url = self._invoker.services.urls.get_style_preset_image_url(style_preset_id)
|
|
||||||
|
|
||||||
# The image URL never changes, so we must add random query string to it to prevent caching
|
|
||||||
url += f"?{uuid_string()}"
|
|
||||||
|
|
||||||
return url
|
|
||||||
|
|
||||||
def delete(self, style_preset_id: str) -> None:
|
|
||||||
try:
|
|
||||||
path = self.get_path(style_preset_id)
|
|
||||||
|
|
||||||
if not self._validate_path(path):
|
|
||||||
raise StylePresetImageFileNotFoundException
|
|
||||||
|
|
||||||
path.unlink()
|
|
||||||
|
|
||||||
except StylePresetImageFileNotFoundException as e:
|
|
||||||
raise StylePresetImageFileNotFoundException from e
|
|
||||||
except Exception as e:
|
|
||||||
raise StylePresetImageFileDeleteException from e
|
|
||||||
|
|
||||||
def _validate_path(self, path: Path) -> bool:
|
|
||||||
"""Validates the path given for an image."""
|
|
||||||
return path.exists()
|
|
||||||
|
|
||||||
def _validate_storage_folders(self) -> None:
|
|
||||||
"""Checks if the required folders exist and create them if they don't"""
|
|
||||||
self._style_preset_images_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
[
|
|
||||||
{
|
|
||||||
"name": "Photography (General)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt}. photography. f/2.8 macro photo, bokeh, photorealism",
|
|
||||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Photography (Studio Lighting)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt}, photography. f/8 photo. centered subject, studio lighting.",
|
|
||||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Photography (Landscape)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt}, landscape photograph, f/12, lifelike, highly detailed.",
|
|
||||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Photography (Portrait)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt}. photography. portraiture. catch light in eyes. one flash. rembrandt lighting. Soft box. dark shadows. High contrast. 80mm lens. F2.8.",
|
|
||||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Photography (Black and White)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} photography. natural light. 80mm lens. F1.4. strong contrast, hard light. dark contrast. blurred background. black and white",
|
|
||||||
"negative_prompt": "painting, digital art. sketch, colour+"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Architectural Visualization",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt}. architectural photography, f/12, luxury, aesthetically pleasing form and function.",
|
|
||||||
"negative_prompt": "painting, digital art. sketch, blurry"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Concept Art (Fantasy)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "concept artwork of a {prompt}. (digital painterly art style)++, mythological, (textured 2d dry media brushpack)++, glazed brushstrokes, otherworldly. painting+, illustration+",
|
|
||||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Concept Art (Sci-Fi)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "(concept art)++, {prompt}, (sleek futurism)++, (textured 2d dry media)++, metallic highlights, digital painting style",
|
|
||||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Concept Art (Character)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "(character concept art)++, stylized painterly digital painting of {prompt}, (painterly, impasto. Dry brush.)++",
|
|
||||||
"negative_prompt": "photo. distorted, blurry, out of focus. sketch. (cgi, 3d.)++"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Concept Art (Painterly)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} oil painting. high contrast. impasto. sfumato. chiaroscuro. Palette knife.",
|
|
||||||
"negative_prompt": "photo. smooth. border. frame"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Environment Art",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} environment artwork, hyper-realistic digital painting style with cinematic composition, atmospheric, depth and detail, voluminous. textured dry brush 2d media",
|
|
||||||
"negative_prompt": "photo, distorted, blurry, out of focus. sketch."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Interior Design (Visualization)",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} interior design photo, gentle shadows, light mid-tones, dimension, mix of smooth and textured surfaces, focus on negative space and clean lines, focus",
|
|
||||||
"negative_prompt": "photo, distorted. sketch."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Product Rendering",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} high quality product photography, 3d rendering with key lighting, shallow depth of field, simple plain background, studio lighting.",
|
|
||||||
"negative_prompt": "blurry, sketch, messy, dirty. unfinished."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Sketch",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} black and white pencil drawing, off-center composition, cross-hatching for shadows, bold strokes, textured paper. sketch+++",
|
|
||||||
"negative_prompt": "blurry, photo, painting, color. messy, dirty. unfinished. frame, borders."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Line Art",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} Line art. bold outline. simplistic. white background. 2d",
|
|
||||||
"negative_prompt": "photo. digital art. greyscale. solid black. painting"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Anime",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} anime++, bold outline, cel-shaded coloring, shounen, seinen",
|
|
||||||
"negative_prompt": "(photo)+++. greyscale. solid black. painting"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Illustration",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "{prompt} illustration, bold linework, illustrative details, vector art style, flat coloring",
|
|
||||||
"negative_prompt": "(photo)+++. greyscale. painting, black and white."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "Vehicles",
|
|
||||||
"type": "default",
|
|
||||||
"preset_data": {
|
|
||||||
"positive_prompt": "A weird futuristic normal auto, {prompt} elegant design, nice color, nice wheels",
|
|
||||||
"negative_prompt": "sketch. digital art. greyscale. painting"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
|
||||||
PresetType,
|
|
||||||
StylePresetChanges,
|
|
||||||
StylePresetRecordDTO,
|
|
||||||
StylePresetWithoutId,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetRecordsStorageBase(ABC):
|
|
||||||
"""Base class for style preset storage services."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
|
||||||
"""Get style preset by id."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
|
||||||
"""Creates a style preset."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
|
||||||
"""Creates many style presets."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
|
||||||
"""Updates a style preset."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, style_preset_id: str) -> None:
|
|
||||||
"""Deletes a style preset."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
|
||||||
"""Gets many workflows."""
|
|
||||||
pass
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
import codecs
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
import pydantic
|
|
||||||
from fastapi import UploadFile
|
|
||||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter
|
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetNotFoundError(Exception):
|
|
||||||
"""Raised when a style preset is not found"""
|
|
||||||
|
|
||||||
|
|
||||||
class PresetData(BaseModel, extra="forbid"):
|
|
||||||
positive_prompt: str = Field(description="Positive prompt")
|
|
||||||
negative_prompt: str = Field(description="Negative prompt")
|
|
||||||
|
|
||||||
|
|
||||||
PresetDataValidator = TypeAdapter(PresetData)
|
|
||||||
|
|
||||||
|
|
||||||
class PresetType(str, Enum, metaclass=MetaEnum):
|
|
||||||
User = "user"
|
|
||||||
Default = "default"
|
|
||||||
Project = "project"
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetChanges(BaseModel, extra="forbid"):
|
|
||||||
name: Optional[str] = Field(default=None, description="The style preset's new name.")
|
|
||||||
preset_data: Optional[PresetData] = Field(default=None, description="The updated data for style preset.")
|
|
||||||
type: Optional[PresetType] = Field(description="The updated type of the style preset")
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetWithoutId(BaseModel):
|
|
||||||
name: str = Field(description="The name of the style preset.")
|
|
||||||
preset_data: PresetData = Field(description="The preset data")
|
|
||||||
type: PresetType = Field(description="The type of style preset")
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetRecordDTO(StylePresetWithoutId):
|
|
||||||
id: str = Field(description="The style preset ID.")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: dict[str, Any]) -> "StylePresetRecordDTO":
|
|
||||||
data["preset_data"] = PresetDataValidator.validate_json(data.get("preset_data", ""))
|
|
||||||
return StylePresetRecordDTOValidator.validate_python(data)
|
|
||||||
|
|
||||||
|
|
||||||
StylePresetRecordDTOValidator = TypeAdapter(StylePresetRecordDTO)
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetRecordWithImage(StylePresetRecordDTO):
|
|
||||||
image: Optional[str] = Field(description="The path for image")
|
|
||||||
|
|
||||||
|
|
||||||
class StylePresetImportRow(BaseModel):
|
|
||||||
name: str = Field(min_length=1, description="The name of the preset.")
|
|
||||||
positive_prompt: str = Field(
|
|
||||||
default="",
|
|
||||||
description="The positive prompt for the preset.",
|
|
||||||
validation_alias=AliasChoices("positive_prompt", "prompt"),
|
|
||||||
)
|
|
||||||
negative_prompt: str = Field(default="", description="The negative prompt for the preset.")
|
|
||||||
|
|
||||||
model_config = ConfigDict(str_strip_whitespace=True, extra="forbid")
|
|
||||||
|
|
||||||
|
|
||||||
StylePresetImportList = list[StylePresetImportRow]
|
|
||||||
StylePresetImportListTypeAdapter = TypeAdapter(StylePresetImportList)
|
|
||||||
|
|
||||||
|
|
||||||
class UnsupportedFileTypeError(ValueError):
|
|
||||||
"""Raised when an unsupported file type is encountered"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidPresetImportDataError(ValueError):
|
|
||||||
"""Raised when invalid preset import data is encountered"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def parse_presets_from_file(file: UploadFile) -> list[StylePresetWithoutId]:
|
|
||||||
"""Parses style presets from a file. The file must be a CSV or JSON file.
|
|
||||||
|
|
||||||
If CSV, the file must have the following columns:
|
|
||||||
- name
|
|
||||||
- prompt (or positive_prompt)
|
|
||||||
- negative_prompt
|
|
||||||
|
|
||||||
If JSON, the file must be a list of objects with the following keys:
|
|
||||||
- name
|
|
||||||
- prompt (or positive_prompt)
|
|
||||||
- negative_prompt
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file (UploadFile): The file to parse.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[StylePresetWithoutId]: The parsed style presets.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
UnsupportedFileTypeError: If the file type is not supported.
|
|
||||||
InvalidPresetImportDataError: If the data in the file is invalid.
|
|
||||||
"""
|
|
||||||
if file.content_type not in ["text/csv", "application/json"]:
|
|
||||||
raise UnsupportedFileTypeError()
|
|
||||||
|
|
||||||
if file.content_type == "text/csv":
|
|
||||||
csv_reader = csv.DictReader(codecs.iterdecode(file.file, "utf-8"))
|
|
||||||
data = list(csv_reader)
|
|
||||||
else: # file.content_type == "application/json":
|
|
||||||
json_data = await file.read()
|
|
||||||
data = json.loads(json_data)
|
|
||||||
|
|
||||||
try:
|
|
||||||
imported_presets = StylePresetImportListTypeAdapter.validate_python(data)
|
|
||||||
|
|
||||||
style_presets: list[StylePresetWithoutId] = []
|
|
||||||
|
|
||||||
for imported in imported_presets:
|
|
||||||
preset_data = PresetData(positive_prompt=imported.positive_prompt, negative_prompt=imported.negative_prompt)
|
|
||||||
style_preset = StylePresetWithoutId(name=imported.name, preset_data=preset_data, type=PresetType.User)
|
|
||||||
style_presets.append(style_preset)
|
|
||||||
except pydantic.ValidationError as e:
|
|
||||||
if file.content_type == "text/csv":
|
|
||||||
msg = "Invalid CSV format: must include columns 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
|
||||||
else: # file.content_type == "application/json":
|
|
||||||
msg = "Invalid JSON format: must be a list of objects with keys 'name', 'prompt', and 'negative_prompt' and name cannot be blank"
|
|
||||||
raise InvalidPresetImportDataError(msg) from e
|
|
||||||
finally:
|
|
||||||
file.file.close()
|
|
||||||
|
|
||||||
return style_presets
|
|
||||||
@@ -1,215 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
|
||||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
|
||||||
PresetType,
|
|
||||||
StylePresetChanges,
|
|
||||||
StylePresetNotFoundError,
|
|
||||||
StylePresetRecordDTO,
|
|
||||||
StylePresetWithoutId,
|
|
||||||
)
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
|
||||||
def __init__(self, db: SqliteDatabase) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self._lock = db.lock
|
|
||||||
self._conn = db.conn
|
|
||||||
self._cursor = self._conn.cursor()
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
self._sync_default_style_presets()
|
|
||||||
|
|
||||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
|
||||||
"""Gets a style preset by ID."""
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
SELECT *
|
|
||||||
FROM style_presets
|
|
||||||
WHERE id = ?;
|
|
||||||
""",
|
|
||||||
(style_preset_id,),
|
|
||||||
)
|
|
||||||
row = self._cursor.fetchone()
|
|
||||||
if row is None:
|
|
||||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
|
||||||
return StylePresetRecordDTO.from_dict(dict(row))
|
|
||||||
except Exception:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
|
||||||
style_preset_id = uuid_string()
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE INTO style_presets (
|
|
||||||
id,
|
|
||||||
name,
|
|
||||||
preset_data,
|
|
||||||
type
|
|
||||||
)
|
|
||||||
VALUES (?, ?, ?, ?);
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
style_preset_id,
|
|
||||||
style_preset.name,
|
|
||||||
style_preset.preset_data.model_dump_json(),
|
|
||||||
style_preset.type,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return self.get(style_preset_id)
|
|
||||||
|
|
||||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
|
||||||
style_preset_ids = []
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
for style_preset in style_presets:
|
|
||||||
style_preset_id = uuid_string()
|
|
||||||
style_preset_ids.append(style_preset_id)
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
INSERT OR IGNORE INTO style_presets (
|
|
||||||
id,
|
|
||||||
name,
|
|
||||||
preset_data,
|
|
||||||
type
|
|
||||||
)
|
|
||||||
VALUES (?, ?, ?, ?);
|
|
||||||
""",
|
|
||||||
(
|
|
||||||
style_preset_id,
|
|
||||||
style_preset.name,
|
|
||||||
style_preset.preset_data.model_dump_json(),
|
|
||||||
style_preset.type,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
# Change the name of a style preset
|
|
||||||
if changes.name is not None:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
UPDATE style_presets
|
|
||||||
SET name = ?
|
|
||||||
WHERE id = ?;
|
|
||||||
""",
|
|
||||||
(changes.name, style_preset_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Change the preset data for a style preset
|
|
||||||
if changes.preset_data is not None:
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
UPDATE style_presets
|
|
||||||
SET preset_data = ?
|
|
||||||
WHERE id = ?;
|
|
||||||
""",
|
|
||||||
(changes.preset_data.model_dump_json(), style_preset_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
self._conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return self.get(style_preset_id)
|
|
||||||
|
|
||||||
def delete(self, style_preset_id: str) -> None:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DELETE from style_presets
|
|
||||||
WHERE id = ?;
|
|
||||||
""",
|
|
||||||
(style_preset_id,),
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
main_query = """
|
|
||||||
SELECT
|
|
||||||
*
|
|
||||||
FROM style_presets
|
|
||||||
"""
|
|
||||||
|
|
||||||
if type is not None:
|
|
||||||
main_query += "WHERE type = ? "
|
|
||||||
|
|
||||||
main_query += "ORDER BY LOWER(name) ASC"
|
|
||||||
|
|
||||||
if type is not None:
|
|
||||||
self._cursor.execute(main_query, (type,))
|
|
||||||
else:
|
|
||||||
self._cursor.execute(main_query)
|
|
||||||
|
|
||||||
rows = self._cursor.fetchall()
|
|
||||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
|
||||||
|
|
||||||
return style_presets
|
|
||||||
except Exception:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
|
|
||||||
def _sync_default_style_presets(self) -> None:
|
|
||||||
"""Syncs default style presets to the database. Internal use only."""
|
|
||||||
|
|
||||||
# First delete all existing default style presets
|
|
||||||
try:
|
|
||||||
self._lock.acquire()
|
|
||||||
self._cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DELETE FROM style_presets
|
|
||||||
WHERE type = "default";
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
self._conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self._conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._lock.release()
|
|
||||||
# Next, parse and create the default style presets
|
|
||||||
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
|
||||||
presets = json.load(file)
|
|
||||||
for preset in presets:
|
|
||||||
style_preset = StylePresetWithoutId.model_validate(preset)
|
|
||||||
self.create(style_preset)
|
|
||||||
@@ -13,8 +13,3 @@ class UrlServiceBase(ABC):
|
|||||||
def get_model_image_url(self, model_key: str) -> str:
|
def get_model_image_url(self, model_key: str) -> str:
|
||||||
"""Gets the URL for a model image"""
|
"""Gets the URL for a model image"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
|
||||||
"""Gets the URL for a style preset image"""
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -19,6 +19,3 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
|
|
||||||
def get_model_image_url(self, model_key: str) -> str:
|
def get_model_image_url(self, model_key: str) -> str:
|
||||||
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
||||||
|
|
||||||
def get_style_preset_image_url(self, style_preset_id: str) -> str:
|
|
||||||
return f"{self._base_url}/style_presets/i/{style_preset_id}/image"
|
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ def get_openapi_func(
|
|||||||
# Add the output map to the schema
|
# Add the output map to the schema
|
||||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": dict(sorted(invocation_output_map_properties.items())),
|
"properties": invocation_output_map_properties,
|
||||||
"required": invocation_output_map_required,
|
"required": invocation_output_map_required,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
90
invokeai/backend/image_util/depth_anything/__init__.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import repeat
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
config = get_config()
|
||||||
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
|
|
||||||
|
DEPTH_ANYTHING_MODELS = {
|
||||||
|
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||||
|
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||||
|
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
transform = Compose(
|
||||||
|
[
|
||||||
|
Resize(
|
||||||
|
width=518,
|
||||||
|
height=518,
|
||||||
|
resize_target=False,
|
||||||
|
keep_aspect_ratio=True,
|
||||||
|
ensure_multiple_of=14,
|
||||||
|
resize_method="lower_bound",
|
||||||
|
image_interpolation_method=cv2.INTER_CUBIC,
|
||||||
|
),
|
||||||
|
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
|
PrepareForNet(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthAnythingDetector:
|
||||||
|
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_model(
|
||||||
|
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||||
|
) -> DPT_DINOv2:
|
||||||
|
match model_size:
|
||||||
|
case "small":
|
||||||
|
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||||
|
case "base":
|
||||||
|
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||||
|
case "large":
|
||||||
|
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||||
|
|
||||||
|
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||||
|
if not self.model:
|
||||||
|
logger.warn("DepthAnything model was not loaded. Returning original image")
|
||||||
|
return image
|
||||||
|
|
||||||
|
np_image = np.array(image, dtype=np.uint8)
|
||||||
|
np_image = np_image[:, :, ::-1] / 255.0
|
||||||
|
|
||||||
|
image_height, image_width = np_image.shape[:2]
|
||||||
|
np_image = transform({"image": np_image})["image"]
|
||||||
|
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
depth = self.model(tensor_image)
|
||||||
|
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||||
|
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||||
|
|
||||||
|
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||||
|
depth_map = Image.fromarray(depth_map)
|
||||||
|
|
||||||
|
new_height = int(image_height * (resolution / image_width))
|
||||||
|
depth_map = depth_map.resize((resolution, new_height))
|
||||||
|
|
||||||
|
return depth_map
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from transformers.pipelines import DepthEstimationPipeline
|
|
||||||
|
|
||||||
from invokeai.backend.raw_model import RawModel
|
|
||||||
|
|
||||||
|
|
||||||
class DepthAnythingPipeline(RawModel):
|
|
||||||
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
|
||||||
for Invoke's Model Management System"""
|
|
||||||
|
|
||||||
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
|
|
||||||
self._pipeline = pipeline
|
|
||||||
|
|
||||||
def generate_depth(self, image: Image.Image) -> Image.Image:
|
|
||||||
depth_map = self._pipeline(image)["depth"]
|
|
||||||
assert isinstance(depth_map, Image.Image)
|
|
||||||
return depth_map
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
|
||||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
|
||||||
device = None
|
|
||||||
self._pipeline.model.to(device=device, dtype=dtype)
|
|
||||||
self._pipeline.device = self._pipeline.model.device
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
|
||||||
|
|
||||||
return calc_module_size(self._pipeline.model)
|
|
||||||
145
invokeai/backend/image_util/depth_anything/model/blocks.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
||||||
|
scratch = nn.Module()
|
||||||
|
|
||||||
|
out_shape1 = out_shape
|
||||||
|
out_shape2 = out_shape
|
||||||
|
out_shape3 = out_shape
|
||||||
|
if len(in_shape) >= 4:
|
||||||
|
out_shape4 = out_shape
|
||||||
|
|
||||||
|
if expand:
|
||||||
|
out_shape1 = out_shape
|
||||||
|
out_shape2 = out_shape * 2
|
||||||
|
out_shape3 = out_shape * 4
|
||||||
|
if len(in_shape) >= 4:
|
||||||
|
out_shape4 = out_shape * 8
|
||||||
|
|
||||||
|
scratch.layer1_rn = nn.Conv2d(
|
||||||
|
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
scratch.layer2_rn = nn.Conv2d(
|
||||||
|
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
scratch.layer3_rn = nn.Conv2d(
|
||||||
|
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
if len(in_shape) >= 4:
|
||||||
|
scratch.layer4_rn = nn.Conv2d(
|
||||||
|
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
||||||
|
)
|
||||||
|
|
||||||
|
return scratch
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualConvUnit(nn.Module):
|
||||||
|
"""Residual convolution module."""
|
||||||
|
|
||||||
|
def __init__(self, features, activation, bn):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (int): number of features
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.bn = bn
|
||||||
|
|
||||||
|
self.groups = 1
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
||||||
|
|
||||||
|
if self.bn:
|
||||||
|
self.bn1 = nn.BatchNorm2d(features)
|
||||||
|
self.bn2 = nn.BatchNorm2d(features)
|
||||||
|
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (tensor): input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: output
|
||||||
|
"""
|
||||||
|
|
||||||
|
out = self.activation(x)
|
||||||
|
out = self.conv1(out)
|
||||||
|
if self.bn:
|
||||||
|
out = self.bn1(out)
|
||||||
|
|
||||||
|
out = self.activation(out)
|
||||||
|
out = self.conv2(out)
|
||||||
|
if self.bn:
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.groups > 1:
|
||||||
|
out = self.conv_merge(out)
|
||||||
|
|
||||||
|
return self.skip_add.add(out, x)
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFusionBlock(nn.Module):
|
||||||
|
"""Feature fusion block."""
|
||||||
|
|
||||||
|
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features (int): number of features
|
||||||
|
"""
|
||||||
|
super(FeatureFusionBlock, self).__init__()
|
||||||
|
|
||||||
|
self.deconv = deconv
|
||||||
|
self.align_corners = align_corners
|
||||||
|
|
||||||
|
self.groups = 1
|
||||||
|
|
||||||
|
self.expand = expand
|
||||||
|
out_features = features
|
||||||
|
if self.expand:
|
||||||
|
out_features = features // 2
|
||||||
|
|
||||||
|
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
||||||
|
|
||||||
|
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
||||||
|
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
||||||
|
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
def forward(self, *xs, size=None):
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor: output
|
||||||
|
"""
|
||||||
|
output = xs[0]
|
||||||
|
|
||||||
|
if len(xs) == 2:
|
||||||
|
res = self.resConfUnit1(xs[1])
|
||||||
|
output = self.skip_add.add(output, res)
|
||||||
|
|
||||||
|
output = self.resConfUnit2(output)
|
||||||
|
|
||||||
|
if (size is None) and (self.size is None):
|
||||||
|
modifier = {"scale_factor": 2}
|
||||||
|
elif size is None:
|
||||||
|
modifier = {"size": self.size}
|
||||||
|
else:
|
||||||
|
modifier = {"size": size}
|
||||||
|
|
||||||
|
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
||||||
|
|
||||||
|
output = self.out_conv(output)
|
||||||
|
|
||||||
|
return output
|
||||||
183
invokeai/backend/image_util/depth_anything/model/dpt.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from invokeai.backend.image_util.depth_anything.model.blocks import FeatureFusionBlock, _make_scratch
|
||||||
|
|
||||||
|
torchhub_path = Path(__file__).parent.parent / "torchhub"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fusion_block(features, use_bn, size=None):
|
||||||
|
return FeatureFusionBlock(
|
||||||
|
features,
|
||||||
|
nn.ReLU(False),
|
||||||
|
deconv=False,
|
||||||
|
bn=use_bn,
|
||||||
|
expand=False,
|
||||||
|
align_corners=True,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DPTHead(nn.Module):
|
||||||
|
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
|
||||||
|
super(DPTHead, self).__init__()
|
||||||
|
|
||||||
|
self.nclass = nclass
|
||||||
|
self.use_clstoken = use_clstoken
|
||||||
|
|
||||||
|
self.projects = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channel,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
for out_channel in out_channels
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resize_layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||||
|
),
|
||||||
|
nn.ConvTranspose2d(
|
||||||
|
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||||
|
),
|
||||||
|
nn.Identity(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_clstoken:
|
||||||
|
self.readout_projects = nn.ModuleList()
|
||||||
|
for _ in range(len(self.projects)):
|
||||||
|
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
||||||
|
|
||||||
|
self.scratch = _make_scratch(
|
||||||
|
out_channels,
|
||||||
|
features,
|
||||||
|
groups=1,
|
||||||
|
expand=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scratch.stem_transpose = None
|
||||||
|
|
||||||
|
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
||||||
|
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
||||||
|
|
||||||
|
head_features_1 = features
|
||||||
|
head_features_2 = 32
|
||||||
|
|
||||||
|
if nclass > 1:
|
||||||
|
self.scratch.output_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.scratch.output_conv1 = nn.Conv2d(
|
||||||
|
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.scratch.output_conv2 = nn.Sequential(
|
||||||
|
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.ReLU(True),
|
||||||
|
nn.Identity(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, out_features, patch_h, patch_w):
|
||||||
|
out = []
|
||||||
|
for i, x in enumerate(out_features):
|
||||||
|
if self.use_clstoken:
|
||||||
|
x, cls_token = x[0], x[1]
|
||||||
|
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||||
|
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||||
|
else:
|
||||||
|
x = x[0]
|
||||||
|
|
||||||
|
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
||||||
|
|
||||||
|
x = self.projects[i](x)
|
||||||
|
x = self.resize_layers[i](x)
|
||||||
|
|
||||||
|
out.append(x)
|
||||||
|
|
||||||
|
layer_1, layer_2, layer_3, layer_4 = out
|
||||||
|
|
||||||
|
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||||
|
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||||
|
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||||
|
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||||
|
|
||||||
|
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
||||||
|
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
||||||
|
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
||||||
|
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||||
|
|
||||||
|
out = self.scratch.output_conv1(path_1)
|
||||||
|
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
||||||
|
out = self.scratch.output_conv2(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DPT_DINOv2(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
features,
|
||||||
|
out_channels,
|
||||||
|
encoder="vitl",
|
||||||
|
use_bn=False,
|
||||||
|
use_clstoken=False,
|
||||||
|
):
|
||||||
|
super(DPT_DINOv2, self).__init__()
|
||||||
|
|
||||||
|
assert encoder in ["vits", "vitb", "vitl"]
|
||||||
|
|
||||||
|
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
||||||
|
# if use_local:
|
||||||
|
# self.pretrained = torch.hub.load(
|
||||||
|
# torchhub_path / "facebookresearch_dinov2_main",
|
||||||
|
# "dinov2_{:}14".format(encoder),
|
||||||
|
# source="local",
|
||||||
|
# pretrained=False,
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# self.pretrained = torch.hub.load(
|
||||||
|
# "facebookresearch/dinov2",
|
||||||
|
# "dinov2_{:}14".format(encoder),
|
||||||
|
# )
|
||||||
|
|
||||||
|
self.pretrained = torch.hub.load(
|
||||||
|
"facebookresearch/dinov2",
|
||||||
|
"dinov2_{:}14".format(encoder),
|
||||||
|
)
|
||||||
|
|
||||||
|
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
||||||
|
|
||||||
|
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
|
||||||
|
features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True)
|
||||||
|
|
||||||
|
patch_h, patch_w = h // 14, w // 14
|
||||||
|
|
||||||
|
depth = self.depth_head(features, patch_h, patch_w)
|
||||||
|
depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True)
|
||||||
|
depth = F.relu(depth)
|
||||||
|
|
||||||
|
return depth.squeeze(1)
|
||||||
227
invokeai/backend/image_util/depth_anything/utilities/util.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
||||||
|
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (dict): sample
|
||||||
|
size (tuple): image size
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: new size
|
||||||
|
"""
|
||||||
|
shape = list(sample["disparity"].shape)
|
||||||
|
|
||||||
|
if shape[0] >= size[0] and shape[1] >= size[1]:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
scale = [0, 0]
|
||||||
|
scale[0] = size[0] / shape[0]
|
||||||
|
scale[1] = size[1] / shape[1]
|
||||||
|
|
||||||
|
scale = max(scale)
|
||||||
|
|
||||||
|
shape[0] = math.ceil(scale * shape[0])
|
||||||
|
shape[1] = math.ceil(scale * shape[1])
|
||||||
|
|
||||||
|
# resize
|
||||||
|
sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method)
|
||||||
|
|
||||||
|
sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST)
|
||||||
|
sample["mask"] = cv2.resize(
|
||||||
|
sample["mask"].astype(np.float32),
|
||||||
|
tuple(shape[::-1]),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
sample["mask"] = sample["mask"].astype(bool)
|
||||||
|
|
||||||
|
return tuple(shape)
|
||||||
|
|
||||||
|
|
||||||
|
class Resize(object):
|
||||||
|
"""Resize sample to given size (width, height)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
resize_target=True,
|
||||||
|
keep_aspect_ratio=False,
|
||||||
|
ensure_multiple_of=1,
|
||||||
|
resize_method="lower_bound",
|
||||||
|
image_interpolation_method=cv2.INTER_AREA,
|
||||||
|
):
|
||||||
|
"""Init.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
width (int): desired output width
|
||||||
|
height (int): desired output height
|
||||||
|
resize_target (bool, optional):
|
||||||
|
True: Resize the full sample (image, mask, target).
|
||||||
|
False: Resize image only.
|
||||||
|
Defaults to True.
|
||||||
|
keep_aspect_ratio (bool, optional):
|
||||||
|
True: Keep the aspect ratio of the input sample.
|
||||||
|
Output sample might not have the given width and height, and
|
||||||
|
resize behaviour depends on the parameter 'resize_method'.
|
||||||
|
Defaults to False.
|
||||||
|
ensure_multiple_of (int, optional):
|
||||||
|
Output width and height is constrained to be multiple of this parameter.
|
||||||
|
Defaults to 1.
|
||||||
|
resize_method (str, optional):
|
||||||
|
"lower_bound": Output will be at least as large as the given size.
|
||||||
|
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller
|
||||||
|
than given size.)
|
||||||
|
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
||||||
|
Defaults to "lower_bound".
|
||||||
|
"""
|
||||||
|
self.__width = width
|
||||||
|
self.__height = height
|
||||||
|
|
||||||
|
self.__resize_target = resize_target
|
||||||
|
self.__keep_aspect_ratio = keep_aspect_ratio
|
||||||
|
self.__multiple_of = ensure_multiple_of
|
||||||
|
self.__resize_method = resize_method
|
||||||
|
self.__image_interpolation_method = image_interpolation_method
|
||||||
|
|
||||||
|
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
||||||
|
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
if max_val is not None and y > max_val:
|
||||||
|
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
if y < min_val:
|
||||||
|
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
def get_size(self, width, height):
|
||||||
|
# determine new height and width
|
||||||
|
scale_height = self.__height / height
|
||||||
|
scale_width = self.__width / width
|
||||||
|
|
||||||
|
if self.__keep_aspect_ratio:
|
||||||
|
if self.__resize_method == "lower_bound":
|
||||||
|
# scale such that output size is lower bound
|
||||||
|
if scale_width > scale_height:
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
elif self.__resize_method == "upper_bound":
|
||||||
|
# scale such that output size is upper bound
|
||||||
|
if scale_width < scale_height:
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
elif self.__resize_method == "minimal":
|
||||||
|
# scale as least as possbile
|
||||||
|
if abs(1 - scale_width) < abs(1 - scale_height):
|
||||||
|
# fit width
|
||||||
|
scale_height = scale_width
|
||||||
|
else:
|
||||||
|
# fit height
|
||||||
|
scale_width = scale_height
|
||||||
|
else:
|
||||||
|
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||||
|
|
||||||
|
if self.__resize_method == "lower_bound":
|
||||||
|
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
||||||
|
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
||||||
|
elif self.__resize_method == "upper_bound":
|
||||||
|
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
||||||
|
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
||||||
|
elif self.__resize_method == "minimal":
|
||||||
|
new_height = self.constrain_to_multiple_of(scale_height * height)
|
||||||
|
new_width = self.constrain_to_multiple_of(scale_width * width)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
||||||
|
|
||||||
|
return (new_width, new_height)
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
||||||
|
|
||||||
|
# resize sample
|
||||||
|
sample["image"] = cv2.resize(
|
||||||
|
sample["image"],
|
||||||
|
(width, height),
|
||||||
|
interpolation=self.__image_interpolation_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.__resize_target:
|
||||||
|
if "disparity" in sample:
|
||||||
|
sample["disparity"] = cv2.resize(
|
||||||
|
sample["disparity"],
|
||||||
|
(width, height),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "depth" in sample:
|
||||||
|
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
if "semseg_mask" in sample:
|
||||||
|
# sample["semseg_mask"] = cv2.resize(
|
||||||
|
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
||||||
|
# )
|
||||||
|
sample["semseg_mask"] = F.interpolate(
|
||||||
|
torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode="nearest"
|
||||||
|
).numpy()[0, 0]
|
||||||
|
|
||||||
|
if "mask" in sample:
|
||||||
|
sample["mask"] = cv2.resize(
|
||||||
|
sample["mask"].astype(np.float32),
|
||||||
|
(width, height),
|
||||||
|
interpolation=cv2.INTER_NEAREST,
|
||||||
|
)
|
||||||
|
# sample["mask"] = sample["mask"].astype(bool)
|
||||||
|
|
||||||
|
# print(sample['image'].shape, sample['depth'].shape)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeImage(object):
|
||||||
|
"""Normlize image by given mean and std."""
|
||||||
|
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.__mean = mean
|
||||||
|
self.__std = std
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareForNet(object):
|
||||||
|
"""Prepare sample for usage as network input."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
image = np.transpose(sample["image"], (2, 0, 1))
|
||||||
|
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
||||||
|
|
||||||
|
if "mask" in sample:
|
||||||
|
sample["mask"] = sample["mask"].astype(np.float32)
|
||||||
|
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
||||||
|
|
||||||
|
if "depth" in sample:
|
||||||
|
depth = sample["depth"].astype(np.float32)
|
||||||
|
sample["depth"] = np.ascontiguousarray(depth)
|
||||||
|
|
||||||
|
if "semseg_mask" in sample:
|
||||||
|
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
||||||
|
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
||||||
|
|
||||||
|
return sample
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
|
|
||||||
class BoundingBox(BaseModel):
|
|
||||||
"""Bounding box helper class."""
|
|
||||||
|
|
||||||
xmin: int
|
|
||||||
ymin: int
|
|
||||||
xmax: int
|
|
||||||
ymax: int
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionResult(BaseModel):
|
|
||||||
"""Detection result from Grounding DINO."""
|
|
||||||
|
|
||||||
score: float
|
|
||||||
label: str
|
|
||||||
box: BoundingBox
|
|
||||||
model_config = ConfigDict(
|
|
||||||
# Allow arbitrary types for mask, since it will be a numpy array.
|
|
||||||
arbitrary_types_allowed=True
|
|
||||||
)
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from transformers.pipelines import ZeroShotObjectDetectionPipeline
|
|
||||||
|
|
||||||
from invokeai.backend.image_util.grounding_dino.detection_result import DetectionResult
|
|
||||||
from invokeai.backend.raw_model import RawModel
|
|
||||||
|
|
||||||
|
|
||||||
class GroundingDinoPipeline(RawModel):
|
|
||||||
"""A wrapper class for a ZeroShotObjectDetectionPipeline that makes it compatible with the model manager's memory
|
|
||||||
management system.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, pipeline: ZeroShotObjectDetectionPipeline):
|
|
||||||
self._pipeline = pipeline
|
|
||||||
|
|
||||||
def detect(self, image: Image.Image, candidate_labels: list[str], threshold: float = 0.1) -> list[DetectionResult]:
|
|
||||||
results = self._pipeline(image=image, candidate_labels=candidate_labels, threshold=threshold)
|
|
||||||
assert results is not None
|
|
||||||
results = [DetectionResult.model_validate(result) for result in results]
|
|
||||||
return results
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
|
||||||
# HACK(ryand): The GroundingDinoPipeline does not work on MPS devices. We only allow it to be moved to CPU or
|
|
||||||
# CUDA.
|
|
||||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
|
||||||
device = None
|
|
||||||
self._pipeline.model.to(device=device, dtype=dtype)
|
|
||||||
self._pipeline.device = self._pipeline.model.device
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
# HACK(ryand): Fix the circular import issue.
|
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
|
||||||
|
|
||||||
return calc_module_size(self._pipeline.model)
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
# This file contains utilities for Grounded-SAM mask refinement based on:
|
|
||||||
# https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
|
|
||||||
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import numpy.typing as npt
|
|
||||||
|
|
||||||
|
|
||||||
def mask_to_polygon(mask: npt.NDArray[np.uint8]) -> list[tuple[int, int]]:
|
|
||||||
"""Convert a binary mask to a polygon.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[list[int]]: List of (x, y) coordinates representing the vertices of the polygon.
|
|
||||||
"""
|
|
||||||
# Find contours in the binary mask.
|
|
||||||
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
||||||
|
|
||||||
# Find the contour with the largest area.
|
|
||||||
largest_contour = max(contours, key=cv2.contourArea)
|
|
||||||
|
|
||||||
# Extract the vertices of the contour.
|
|
||||||
polygon = largest_contour.reshape(-1, 2).tolist()
|
|
||||||
|
|
||||||
return polygon
|
|
||||||
|
|
||||||
|
|
||||||
def polygon_to_mask(
|
|
||||||
polygon: list[tuple[int, int]], image_shape: tuple[int, int], fill_value: int = 1
|
|
||||||
) -> npt.NDArray[np.uint8]:
|
|
||||||
"""Convert a polygon to a segmentation mask.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
polygon (list): List of (x, y) coordinates representing the vertices of the polygon.
|
|
||||||
image_shape (tuple): Shape of the image (height, width) for the mask.
|
|
||||||
fill_value (int): Value to fill the polygon with.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Segmentation mask with the polygon filled (with value 255).
|
|
||||||
"""
|
|
||||||
# Create an empty mask.
|
|
||||||
mask = np.zeros(image_shape, dtype=np.uint8)
|
|
||||||
|
|
||||||
# Convert polygon to an array of points.
|
|
||||||
pts = np.array(polygon, dtype=np.int32)
|
|
||||||
|
|
||||||
# Fill the polygon with white color (255).
|
|
||||||
cv2.fillPoly(mask, [pts], color=(fill_value,))
|
|
||||||
|
|
||||||
return mask
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from transformers.models.sam import SamModel
|
|
||||||
from transformers.models.sam.processing_sam import SamProcessor
|
|
||||||
|
|
||||||
from invokeai.backend.raw_model import RawModel
|
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingPipeline(RawModel):
|
|
||||||
"""A wrapper class for the transformers SAM model and processor that makes it compatible with the model manager."""
|
|
||||||
|
|
||||||
def __init__(self, sam_model: SamModel, sam_processor: SamProcessor):
|
|
||||||
self._sam_model = sam_model
|
|
||||||
self._sam_processor = sam_processor
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
|
||||||
# HACK(ryand): The SAM pipeline does not work on MPS devices. We only allow it to be moved to CPU or CUDA.
|
|
||||||
if device is not None and device.type not in {"cpu", "cuda"}:
|
|
||||||
device = None
|
|
||||||
self._sam_model.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
# HACK(ryand): Fix the circular import issue.
|
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
|
||||||
|
|
||||||
return calc_module_size(self._sam_model)
|
|
||||||
|
|
||||||
def segment(self, image: Image.Image, bounding_boxes: list[list[int]]) -> torch.Tensor:
|
|
||||||
"""Run the SAM model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Image.Image): The image to segment.
|
|
||||||
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
|
|
||||||
[xmin, ymin, xmax, ymax].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
|
|
||||||
"""
|
|
||||||
# Add batch dimension of 1 to the bounding boxes.
|
|
||||||
boxes = [bounding_boxes]
|
|
||||||
inputs = self._sam_processor(images=image, input_boxes=boxes, return_tensors="pt").to(self._sam_model.device)
|
|
||||||
outputs = self._sam_model(**inputs)
|
|
||||||
masks = self._sam_processor.post_process_masks(
|
|
||||||
masks=outputs.pred_masks,
|
|
||||||
original_sizes=inputs.original_sizes,
|
|
||||||
reshaped_input_sizes=inputs.reshaped_input_sizes,
|
|
||||||
)
|
|
||||||
|
|
||||||
# There should be only one batch.
|
|
||||||
assert len(masks) == 1
|
|
||||||
return masks[0]
|
|
||||||
@@ -3,13 +3,12 @@
|
|||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
from invokeai.backend.model_manager import BaseModelType
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
@@ -47,19 +46,9 @@ class LoRALayerBase:
|
|||||||
self.rank = None # set in layer implementation
|
self.rank = None # set in layer implementation
|
||||||
self.layer_key = layer_key
|
self.layer_key = layer_key
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
|
||||||
return self.bias
|
|
||||||
|
|
||||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
|
||||||
params = {"weight": self.get_weight(orig_module.weight)}
|
|
||||||
bias = self.get_bias(orig_module.bias)
|
|
||||||
if bias is not None:
|
|
||||||
params["bias"] = bias
|
|
||||||
return params
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
for val in [self.bias]:
|
for val in [self.bias]:
|
||||||
@@ -71,17 +60,6 @@ class LoRALayerBase:
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
|
||||||
"""Log a warning if values contains unhandled keys."""
|
|
||||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
|
||||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
|
||||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
|
||||||
unknown_keys = set(values.keys()) - all_known_keys
|
|
||||||
if unknown_keys:
|
|
||||||
logger.warning(
|
|
||||||
f"Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
@@ -98,19 +76,14 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
self.up = values["lora_up.weight"]
|
||||||
self.down = values["lora_down.weight"]
|
self.down = values["lora_down.weight"]
|
||||||
self.mid = values.get("lora_mid.weight", None)
|
if "lora_mid.weight" in values:
|
||||||
|
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||||
|
else:
|
||||||
|
self.mid = None
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
self.check_keys(
|
|
||||||
values,
|
|
||||||
{
|
|
||||||
"lora_up.weight",
|
|
||||||
"lora_down.weight",
|
|
||||||
"lora_mid.weight",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
@@ -152,23 +125,20 @@ class LoHALayer(LoRALayerBase):
|
|||||||
self.w1_b = values["hada_w1_b"]
|
self.w1_b = values["hada_w1_b"]
|
||||||
self.w2_a = values["hada_w2_a"]
|
self.w2_a = values["hada_w2_a"]
|
||||||
self.w2_b = values["hada_w2_b"]
|
self.w2_b = values["hada_w2_b"]
|
||||||
self.t1 = values.get("hada_t1", None)
|
|
||||||
self.t2 = values.get("hada_t2", None)
|
if "hada_t1" in values:
|
||||||
|
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||||
|
else:
|
||||||
|
self.t1 = None
|
||||||
|
|
||||||
|
if "hada_t2" in values:
|
||||||
|
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
self.check_keys(
|
|
||||||
values,
|
|
||||||
{
|
|
||||||
"hada_w1_a",
|
|
||||||
"hada_w1_b",
|
|
||||||
"hada_w2_a",
|
|
||||||
"hada_w2_b",
|
|
||||||
"hada_t1",
|
|
||||||
"hada_t2",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
@@ -216,45 +186,37 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
):
|
):
|
||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
self.w1 = values.get("lokr_w1", None)
|
if "lokr_w1" in values:
|
||||||
if self.w1 is None:
|
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||||
|
self.w1_a = None
|
||||||
|
self.w1_b = None
|
||||||
|
else:
|
||||||
|
self.w1 = None
|
||||||
self.w1_a = values["lokr_w1_a"]
|
self.w1_a = values["lokr_w1_a"]
|
||||||
self.w1_b = values["lokr_w1_b"]
|
self.w1_b = values["lokr_w1_b"]
|
||||||
else:
|
|
||||||
self.w1_b = None
|
|
||||||
self.w1_a = None
|
|
||||||
|
|
||||||
self.w2 = values.get("lokr_w2", None)
|
if "lokr_w2" in values:
|
||||||
if self.w2 is None:
|
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||||
self.w2_a = values["lokr_w2_a"]
|
|
||||||
self.w2_b = values["lokr_w2_b"]
|
|
||||||
else:
|
|
||||||
self.w2_a = None
|
self.w2_a = None
|
||||||
self.w2_b = None
|
self.w2_b = None
|
||||||
|
else:
|
||||||
|
self.w2 = None
|
||||||
|
self.w2_a = values["lokr_w2_a"]
|
||||||
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
self.t2 = values.get("lokr_t2", None)
|
if "lokr_t2" in values:
|
||||||
|
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||||
|
else:
|
||||||
|
self.t2 = None
|
||||||
|
|
||||||
if self.w1_b is not None:
|
if "lokr_w1_b" in values:
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = values["lokr_w1_b"].shape[0]
|
||||||
elif self.w2_b is not None:
|
elif "lokr_w2_b" in values:
|
||||||
self.rank = self.w2_b.shape[0]
|
self.rank = values["lokr_w2_b"].shape[0]
|
||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
self.check_keys(
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
values,
|
|
||||||
{
|
|
||||||
"lokr_w1",
|
|
||||||
"lokr_w1_a",
|
|
||||||
"lokr_w1_b",
|
|
||||||
"lokr_w2",
|
|
||||||
"lokr_w2_a",
|
|
||||||
"lokr_w2_b",
|
|
||||||
"lokr_t2",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
w1: Optional[torch.Tensor] = self.w1
|
w1: Optional[torch.Tensor] = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
assert self.w1_a is not None
|
assert self.w1_a is not None
|
||||||
@@ -310,9 +272,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
|
|
||||||
|
|
||||||
class FullLayer(LoRALayerBase):
|
class FullLayer(LoRALayerBase):
|
||||||
# bias handled in LoRALayerBase(calc_size, to)
|
|
||||||
# weight: torch.Tensor
|
# weight: torch.Tensor
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -322,12 +282,15 @@ class FullLayer(LoRALayerBase):
|
|||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
self.weight = values["diff"]
|
self.weight = values["diff"]
|
||||||
self.bias = values.get("diff_b", None)
|
|
||||||
|
if len(values.keys()) > 1:
|
||||||
|
_keys = list(values.keys())
|
||||||
|
_keys.remove("diff")
|
||||||
|
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
self.check_keys(values, {"diff", "diff_b"})
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
return self.weight
|
return self.weight
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@@ -356,9 +319,8 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self.on_input = values["on_input"]
|
self.on_input = values["on_input"]
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
self.check_keys(values, {"weight", "on_input"})
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
if not self.on_input:
|
if not self.on_input:
|
||||||
weight = weight.reshape(-1, 1)
|
weight = weight.reshape(-1, 1)
|
||||||
@@ -378,39 +340,7 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class NormLayer(LoRALayerBase):
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||||
# bias handled in LoRALayerBase(calc_size, to)
|
|
||||||
# weight: torch.Tensor
|
|
||||||
# bias: Optional[torch.Tensor]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layer_key: str,
|
|
||||||
values: Dict[str, torch.Tensor],
|
|
||||||
):
|
|
||||||
super().__init__(layer_key, values)
|
|
||||||
|
|
||||||
self.weight = values["w_norm"]
|
|
||||||
self.bias = values.get("b_norm", None)
|
|
||||||
|
|
||||||
self.rank = None # unscaled
|
|
||||||
self.check_keys(values, {"w_norm", "b_norm"})
|
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.weight
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
|
||||||
model_size = super().calc_size()
|
|
||||||
model_size += self.weight.nelement() * self.weight.element_size()
|
|
||||||
return model_size
|
|
||||||
|
|
||||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
||||||
super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||||
@@ -528,19 +458,16 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
for layer_key, values in state_dict.items():
|
||||||
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
|
|
||||||
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
|
|
||||||
|
|
||||||
# lora and locon
|
# lora and locon
|
||||||
if "lora_up.weight" in values:
|
if "lora_down.weight" in values:
|
||||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||||
|
|
||||||
# loha
|
# loha
|
||||||
elif "hada_w1_a" in values:
|
elif "hada_w1_b" in values:
|
||||||
layer = LoHALayer(layer_key, values)
|
layer = LoHALayer(layer_key, values)
|
||||||
|
|
||||||
# lokr
|
# lokr
|
||||||
elif "lokr_w1" in values or "lokr_w1_a" in values:
|
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||||
layer = LoKRLayer(layer_key, values)
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
# diff
|
# diff
|
||||||
@@ -548,13 +475,9 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
layer = FullLayer(layer_key, values)
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
# ia3
|
# ia3
|
||||||
elif "on_input" in values:
|
elif "weight" in values and "on_input" in values:
|
||||||
layer = IA3Layer(layer_key, values)
|
layer = IA3Layer(layer_key, values)
|
||||||
|
|
||||||
# norms
|
|
||||||
elif "w_norm" in values:
|
|
||||||
layer = NormLayer(layer_key, values)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
raise Exception("Unknown lora format!")
|
raise Exception("Unknown lora format!")
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
|||||||
"""Model config for CLIPVision."""
|
"""Model config for CLIPVision."""
|
||||||
|
|
||||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
@@ -365,7 +365,7 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
|||||||
"""Model config for T2I."""
|
"""Model config for T2I."""
|
||||||
|
|
||||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
|
|||||||
@@ -98,9 +98,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
||||||
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
|
||||||
},
|
},
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
|
||||||
ModelVariantType.Normal: StableDiffusionXLPipeline,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
assert isinstance(config, MainCheckpointConfig)
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -11,9 +11,6 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
|
||||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
|
||||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager.config import AnyModel
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
@@ -37,18 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
elif isinstance(model, CLIPTokenizer):
|
elif isinstance(model, CLIPTokenizer):
|
||||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||||
return 0
|
return 0
|
||||||
elif isinstance(
|
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
||||||
model,
|
|
||||||
(
|
|
||||||
TextualInversionModelRaw,
|
|
||||||
IPAdapter,
|
|
||||||
LoRAModelRaw,
|
|
||||||
SpandrelImageToImageModel,
|
|
||||||
GroundingDinoPipeline,
|
|
||||||
SegmentAnythingPipeline,
|
|
||||||
DepthAnythingPipeline,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
else:
|
else:
|
||||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||||
|
|||||||
@@ -187,171 +187,164 @@ STARTER_MODELS: list[StarterModel] = [
|
|||||||
# endregion
|
# endregion
|
||||||
# region ControlNet
|
# region ControlNet
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="QRCode Monster v2 (SD1.5)",
|
name="QRCode Monster",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="monster-labs/control_v1p_sd15_qrcode_monster::v2",
|
source="monster-labs/control_v1p_sd15_qrcode_monster",
|
||||||
description="ControlNet model that generates scannable creative QR codes",
|
description="Controlnet model that generates scannable creative QR codes",
|
||||||
type=ModelType.ControlNet,
|
|
||||||
),
|
|
||||||
StarterModel(
|
|
||||||
name="QRCode Monster (SDXL)",
|
|
||||||
base=BaseModelType.StableDiffusionXL,
|
|
||||||
source="monster-labs/control_v1p_sdxl_qrcode_monster",
|
|
||||||
description="ControlNet model that generates scannable creative QR codes",
|
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="canny",
|
name="canny",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_canny",
|
source="lllyasviel/control_v11p_sd15_canny",
|
||||||
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
|
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="inpaint",
|
name="inpaint",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_inpaint",
|
source="lllyasviel/control_v11p_sd15_inpaint",
|
||||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="mlsd",
|
name="mlsd",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_mlsd",
|
source="lllyasviel/control_v11p_sd15_mlsd",
|
||||||
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="depth",
|
name="depth",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11f1p_sd15_depth",
|
source="lllyasviel/control_v11f1p_sd15_depth",
|
||||||
description="ControlNet weights trained on sd-1.5 with depth conditioning",
|
description="Controlnet weights trained on sd-1.5 with depth conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="normal_bae",
|
name="normal_bae",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_normalbae",
|
source="lllyasviel/control_v11p_sd15_normalbae",
|
||||||
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
|
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="seg",
|
name="seg",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_seg",
|
source="lllyasviel/control_v11p_sd15_seg",
|
||||||
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
|
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="lineart",
|
name="lineart",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_lineart",
|
source="lllyasviel/control_v11p_sd15_lineart",
|
||||||
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
|
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="lineart_anime",
|
name="lineart_anime",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
|
||||||
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
|
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="openpose",
|
name="openpose",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_openpose",
|
source="lllyasviel/control_v11p_sd15_openpose",
|
||||||
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
|
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="scribble",
|
name="scribble",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_scribble",
|
source="lllyasviel/control_v11p_sd15_scribble",
|
||||||
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
|
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="softedge",
|
name="softedge",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11p_sd15_softedge",
|
source="lllyasviel/control_v11p_sd15_softedge",
|
||||||
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
|
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="shuffle",
|
name="shuffle",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11e_sd15_shuffle",
|
source="lllyasviel/control_v11e_sd15_shuffle",
|
||||||
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
|
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="tile",
|
name="tile",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11f1e_sd15_tile",
|
source="lllyasviel/control_v11f1e_sd15_tile",
|
||||||
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
|
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="ip2p",
|
name="ip2p",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
source="lllyasviel/control_v11e_sd15_ip2p",
|
source="lllyasviel/control_v11e_sd15_ip2p",
|
||||||
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
|
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="canny-sdxl",
|
name="canny-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="xinsir/controlNet-canny-sdxl-1.0",
|
source="xinsir/controlnet-canny-sdxl-1.0",
|
||||||
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="depth-sdxl",
|
name="depth-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="diffusers/controlNet-depth-sdxl-1.0",
|
source="diffusers/controlnet-depth-sdxl-1.0",
|
||||||
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
|
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="softedge-dexined-sdxl",
|
name="softedge-dexined-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
|
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
|
||||||
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="depth-16bit-zoe-sdxl",
|
name="depth-16bit-zoe-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
|
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
|
||||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="depth-zoe-sdxl",
|
name="depth-zoe-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
|
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
|
||||||
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="openpose-sdxl",
|
name="openpose-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="xinsir/controlNet-openpose-sdxl-1.0",
|
source="xinsir/controlnet-openpose-sdxl-1.0",
|
||||||
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="scribble-sdxl",
|
name="scribble-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="xinsir/controlNet-scribble-sdxl-1.0",
|
source="xinsir/controlnet-scribble-sdxl-1.0",
|
||||||
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
StarterModel(
|
StarterModel(
|
||||||
name="tile-sdxl",
|
name="tile-sdxl",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
source="xinsir/controlNet-tile-sdxl-1.0",
|
source="xinsir/controlnet-tile-sdxl-1.0",
|
||||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
description="Controlnet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||||
type=ModelType.ControlNet,
|
type=ModelType.ControlNet,
|
||||||
),
|
),
|
||||||
# endregion
|
# endregion
|
||||||
|
|||||||
@@ -17,9 +17,8 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
|
||||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
@@ -86,13 +85,13 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(
|
with cls.apply_lora(
|
||||||
unet,
|
unet,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
prefix="lora_unet_",
|
prefix="lora_unet_",
|
||||||
cached_weights=cached_weights,
|
model_state_dict=model_state_dict,
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -102,9 +101,9 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -114,7 +113,7 @@ class ModelPatcher:
|
|||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
@@ -122,26 +121,66 @@ class ModelPatcher:
|
|||||||
:param model: The model to patch.
|
:param model: The model to patch.
|
||||||
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
"""
|
"""
|
||||||
original_weights = OriginalWeightsStorage(cached_weights)
|
original_weights = {}
|
||||||
try:
|
try:
|
||||||
for lora_model, lora_weight in loras:
|
with torch.no_grad():
|
||||||
LoRAExt.patch_model(
|
for lora, lora_weight in loras:
|
||||||
model=model,
|
# assert lora.device.type == "cpu"
|
||||||
prefix=prefix,
|
for layer_key, layer in lora.layers.items():
|
||||||
lora=lora_model,
|
if not layer_key.startswith(prefix):
|
||||||
lora_weight=lora_weight,
|
continue
|
||||||
original_weights=original_weights,
|
|
||||||
)
|
|
||||||
del lora_model
|
|
||||||
|
|
||||||
yield
|
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||||
|
# should be improved in the following ways:
|
||||||
|
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||||
|
# LoRA model is applied.
|
||||||
|
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||||
|
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||||
|
# weights to have valid keys.
|
||||||
|
assert isinstance(model, torch.nn.Module)
|
||||||
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||||
|
|
||||||
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
|
# (Performance will be best if this is a CUDA device.)
|
||||||
|
device = module.weight.device
|
||||||
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
|
if module_key not in original_weights:
|
||||||
|
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
||||||
|
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
||||||
|
else:
|
||||||
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
|
||||||
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
layer.to(device=device)
|
||||||
|
layer.to(dtype=torch.float32)
|
||||||
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
|
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||||
|
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||||
|
|
||||||
|
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||||
|
if module.weight.shape != layer_weight.shape:
|
||||||
|
# TODO: debug on lycoris
|
||||||
|
assert hasattr(layer_weight, "reshape")
|
||||||
|
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||||
|
|
||||||
|
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||||
|
module.weight += layer_weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
yield # wait for context manager exit
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for param_key, weight in original_weights.get_changed_weights():
|
for module_key, weight in original_weights.items():
|
||||||
model.get_parameter(param_key).copy_(weight)
|
model.get_submodule(module_key).weight.copy_(weight)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
@@ -7,9 +7,11 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
|
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PipelineIntermediateState",
|
"PipelineIntermediateState",
|
||||||
"StableDiffusionGeneratorPipeline",
|
"StableDiffusionGeneratorPipeline",
|
||||||
"InvokeAIDiffuserComponent",
|
"InvokeAIDiffuserComponent",
|
||||||
|
"set_seamless",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -83,47 +83,47 @@ class DenoiseContext:
|
|||||||
unet: Optional[UNet2DConditionModel] = None
|
unet: Optional[UNet2DConditionModel] = None
|
||||||
|
|
||||||
# Current state of latent-space image in denoising process.
|
# Current state of latent-space image in denoising process.
|
||||||
# None until `PRE_DENOISE_LOOP` callback.
|
# None until `pre_denoise_loop` callback.
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
latents: Optional[torch.Tensor] = None
|
latents: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Current denoising step index.
|
# Current denoising step index.
|
||||||
# None until `PRE_STEP` callback.
|
# None until `pre_step` callback.
|
||||||
step_index: Optional[int] = None
|
step_index: Optional[int] = None
|
||||||
|
|
||||||
# Current denoising step timestep.
|
# Current denoising step timestep.
|
||||||
# None until `PRE_STEP` callback.
|
# None until `pre_step` callback.
|
||||||
timestep: Optional[torch.Tensor] = None
|
timestep: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Arguments which will be passed to UNet model.
|
# Arguments which will be passed to UNet model.
|
||||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||||
unet_kwargs: Optional[UNetKwargs] = None
|
unet_kwargs: Optional[UNetKwargs] = None
|
||||||
|
|
||||||
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
# SchedulerOutput class returned from step function(normally, generated by scheduler).
|
||||||
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
|
# Supposed to be used only in `post_step` callback, otherwise can be None.
|
||||||
step_output: Optional[SchedulerOutput] = None
|
step_output: Optional[SchedulerOutput] = None
|
||||||
|
|
||||||
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
|
||||||
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
|
# Available in events inside step(between `pre_step` and `post_stop`).
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
latent_model_input: Optional[torch.Tensor] = None
|
latent_model_input: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# [TMP] Defines on which conditionings current unet call will be runned.
|
# [TMP] Defines on which conditionings current unet call will be runned.
|
||||||
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
|
# Available in `pre_unet`/`post_unet` callbacks, otherwise will be None.
|
||||||
conditioning_mode: Optional[ConditioningMode] = None
|
conditioning_mode: Optional[ConditioningMode] = None
|
||||||
|
|
||||||
# [TMP] Noise predictions from negative conditioning.
|
# [TMP] Noise predictions from negative conditioning.
|
||||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
negative_noise_pred: Optional[torch.Tensor] = None
|
negative_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# [TMP] Noise predictions from positive conditioning.
|
# [TMP] Noise predictions from positive conditioning.
|
||||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
positive_noise_pred: Optional[torch.Tensor] = None
|
positive_noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Combined noise prediction from passed conditionings.
|
# Combined noise prediction from passed conditionings.
|
||||||
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
|
# Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None.
|
||||||
# Shape: [batch, channels, latent_height, latent_width]
|
# Shape: [batch, channels, latent_height, latent_width]
|
||||||
noise_pred: Optional[torch.Tensor] = None
|
noise_pred: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|||||||
@@ -76,12 +76,12 @@ class StableDiffusionBackend:
|
|||||||
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
|
||||||
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
|
||||||
|
|
||||||
# ext: override combine_noise_preds
|
# ext: override apply_cfg
|
||||||
ctx.noise_pred = self.combine_noise_preds(ctx)
|
ctx.noise_pred = self.apply_cfg(ctx)
|
||||||
|
|
||||||
# ext: cfg_rescale [modify_noise_prediction]
|
# ext: cfg_rescale [modify_noise_prediction]
|
||||||
# TODO: rename
|
# TODO: rename
|
||||||
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
|
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
|
||||||
@@ -95,15 +95,13 @@ class StableDiffusionBackend:
|
|||||||
return step_output
|
return step_output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
|
def apply_cfg(ctx: DenoiseContext) -> torch.Tensor:
|
||||||
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
|
||||||
if isinstance(guidance_scale, list):
|
if isinstance(guidance_scale, list):
|
||||||
guidance_scale = guidance_scale[ctx.step_index]
|
guidance_scale = guidance_scale[ctx.step_index]
|
||||||
|
|
||||||
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
|
return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
||||||
# in slightly different outputs. It is suspected that this is caused by small precision differences.
|
# return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
||||||
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
|
|
||||||
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
|
|
||||||
|
|
||||||
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
|
||||||
sample = ctx.latent_model_input
|
sample = ctx.latent_model_input
|
||||||
|
|||||||
@@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum):
|
|||||||
POST_STEP = "post_step"
|
POST_STEP = "post_step"
|
||||||
PRE_UNET = "pre_unet"
|
PRE_UNET = "pre_unet"
|
||||||
POST_UNET = "post_unet"
|
POST_UNET = "post_unet"
|
||||||
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"
|
POST_APPLY_CFG = "post_apply_cfg"
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -52,21 +52,9 @@ class ExtensionBase:
|
|||||||
return self._callbacks
|
return self._callbacks
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_extension(self, ctx: DenoiseContext):
|
def patch_extension(self, context: DenoiseContext):
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||||
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
yield None
|
||||||
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
|
|
||||||
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
|
|
||||||
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
|
|
||||||
by this context manager.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
|
||||||
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
|
|
||||||
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
|
|
||||||
can access original weights values.
|
|
||||||
"""
|
|
||||||
yield
|
|
||||||
|
|||||||
@@ -1,158 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL.Image import Image
|
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
|
||||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
||||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetExt(ExtensionBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: ControlNetModel,
|
|
||||||
image: Image,
|
|
||||||
weight: Union[float, List[float]],
|
|
||||||
begin_step_percent: float,
|
|
||||||
end_step_percent: float,
|
|
||||||
control_mode: CONTROLNET_MODE_VALUES,
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._model = model
|
|
||||||
self._image = image
|
|
||||||
self._weight = weight
|
|
||||||
self._begin_step_percent = begin_step_percent
|
|
||||||
self._end_step_percent = end_step_percent
|
|
||||||
self._control_mode = control_mode
|
|
||||||
self._resize_mode = resize_mode
|
|
||||||
|
|
||||||
self._image_tensor: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def patch_extension(self, ctx: DenoiseContext):
|
|
||||||
original_processors = self._model.attn_processors
|
|
||||||
try:
|
|
||||||
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
|
|
||||||
|
|
||||||
yield None
|
|
||||||
finally:
|
|
||||||
self._model.set_attn_processor(original_processors)
|
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
|
||||||
def resize_image(self, ctx: DenoiseContext):
|
|
||||||
_, _, latent_height, latent_width = ctx.latents.shape
|
|
||||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
|
||||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
|
||||||
|
|
||||||
self._image_tensor = prepare_control_image(
|
|
||||||
image=self._image,
|
|
||||||
do_classifier_free_guidance=False,
|
|
||||||
width=image_width,
|
|
||||||
height=image_height,
|
|
||||||
device=ctx.latents.device,
|
|
||||||
dtype=ctx.latents.dtype,
|
|
||||||
control_mode=self._control_mode,
|
|
||||||
resize_mode=self._resize_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_UNET)
|
|
||||||
def pre_unet_step(self, ctx: DenoiseContext):
|
|
||||||
# skip if model not active in current step
|
|
||||||
total_steps = len(ctx.inputs.timesteps)
|
|
||||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
|
||||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
|
||||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
|
||||||
return
|
|
||||||
|
|
||||||
# convert mode to internal flags
|
|
||||||
soft_injection = self._control_mode in ["more_prompt", "more_control"]
|
|
||||||
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
|
|
||||||
|
|
||||||
# no negative conditioning in cfg_injection mode
|
|
||||||
if cfg_injection:
|
|
||||||
if ctx.conditioning_mode == ConditioningMode.Negative:
|
|
||||||
return
|
|
||||||
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)
|
|
||||||
|
|
||||||
if ctx.conditioning_mode == ConditioningMode.Both:
|
|
||||||
# add zeros as samples for negative conditioning
|
|
||||||
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
|
|
||||||
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
|
|
||||||
|
|
||||||
else:
|
|
||||||
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)
|
|
||||||
|
|
||||||
if (
|
|
||||||
ctx.unet_kwargs.down_block_additional_residuals is None
|
|
||||||
and ctx.unet_kwargs.mid_block_additional_residual is None
|
|
||||||
):
|
|
||||||
ctx.unet_kwargs.down_block_additional_residuals = down_samples
|
|
||||||
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
|
|
||||||
else:
|
|
||||||
# add controlnet outputs together if have multiple controlnets
|
|
||||||
ctx.unet_kwargs.down_block_additional_residuals = [
|
|
||||||
samples_prev + samples_curr
|
|
||||||
for samples_prev, samples_curr in zip(
|
|
||||||
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
|
|
||||||
)
|
|
||||||
]
|
|
||||||
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
|
|
||||||
|
|
||||||
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
|
|
||||||
total_steps = len(ctx.inputs.timesteps)
|
|
||||||
|
|
||||||
model_input = ctx.latent_model_input
|
|
||||||
image_tensor = self._image_tensor
|
|
||||||
if conditioning_mode == ConditioningMode.Both:
|
|
||||||
model_input = torch.cat([model_input] * 2)
|
|
||||||
image_tensor = torch.cat([image_tensor] * 2)
|
|
||||||
|
|
||||||
cn_unet_kwargs = UNetKwargs(
|
|
||||||
sample=model_input,
|
|
||||||
timestep=ctx.timestep,
|
|
||||||
encoder_hidden_states=None, # set later by conditioning
|
|
||||||
cross_attention_kwargs=dict( # noqa: C408
|
|
||||||
percent_through=ctx.step_index / total_steps,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
|
|
||||||
|
|
||||||
# get static weight, or weight corresponding to current step
|
|
||||||
weight = self._weight
|
|
||||||
if isinstance(weight, list):
|
|
||||||
weight = weight[ctx.step_index]
|
|
||||||
|
|
||||||
tmp_kwargs = vars(cn_unet_kwargs)
|
|
||||||
|
|
||||||
# Remove kwargs not related to ControlNet unet
|
|
||||||
# ControlNet guidance fields
|
|
||||||
del tmp_kwargs["down_block_additional_residuals"]
|
|
||||||
del tmp_kwargs["mid_block_additional_residual"]
|
|
||||||
|
|
||||||
# T2i Adapter guidance fields
|
|
||||||
del tmp_kwargs["down_intrablock_additional_residuals"]
|
|
||||||
|
|
||||||
# controlnet(s) inference
|
|
||||||
down_samples, mid_sample = self._model(
|
|
||||||
controlnet_cond=image_tensor,
|
|
||||||
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
|
|
||||||
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
|
|
||||||
return_dict=False,
|
|
||||||
**vars(cn_unet_kwargs),
|
|
||||||
)
|
|
||||||
|
|
||||||
return down_samples, mid_sample
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
|
||||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
|
||||||
|
|
||||||
|
|
||||||
class FreeUExt(ExtensionBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
freeu_config: FreeUConfig,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._freeu_config = freeu_config
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
|
||||||
unet.enable_freeu(
|
|
||||||
b1=self._freeu_config.b1,
|
|
||||||
b2=self._freeu_config.b2,
|
|
||||||
s1=self._freeu_config.s1,
|
|
||||||
s2=self._freeu_config.s2,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
unet.disable_freeu()
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
import einops
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
||||||
|
|
||||||
|
|
||||||
class InpaintExt(ExtensionBase):
|
|
||||||
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
|
|
||||||
models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mask: torch.Tensor,
|
|
||||||
is_gradient_mask: bool,
|
|
||||||
):
|
|
||||||
"""Initialize InpaintExt.
|
|
||||||
Args:
|
|
||||||
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
|
||||||
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
|
||||||
inpainted.
|
|
||||||
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
|
||||||
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
|
||||||
1.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self._mask = mask
|
|
||||||
self._is_gradient_mask = is_gradient_mask
|
|
||||||
|
|
||||||
# Noise, which used to noisify unmasked part of image
|
|
||||||
# if noise provided to context, then it will be used
|
|
||||||
# if no noise provided, then noise will be generated based on seed
|
|
||||||
self._noise: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_normal_model(unet: UNet2DConditionModel):
|
|
||||||
"""Checks if the provided UNet belongs to a regular model.
|
|
||||||
The `in_channels` of a UNet vary depending on model type:
|
|
||||||
- normal - 4
|
|
||||||
- depth - 5
|
|
||||||
- inpaint - 9
|
|
||||||
"""
|
|
||||||
return unet.conv_in.in_channels == 4
|
|
||||||
|
|
||||||
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
||||||
batch_size = latents.size(0)
|
|
||||||
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
|
||||||
if t.dim() == 0:
|
|
||||||
# some schedulers expect t to be one-dimensional.
|
|
||||||
# TODO: file diffusers bug about inconsistency?
|
|
||||||
t = einops.repeat(t, "-> batch", batch=batch_size)
|
|
||||||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
|
|
||||||
# get very confused about what is happening from step to step when we do that.
|
|
||||||
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
|
|
||||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
|
||||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
|
||||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
|
||||||
if self._is_gradient_mask:
|
|
||||||
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
|
|
||||||
mask_bool = mask < 1 - threshold
|
|
||||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
|
||||||
else:
|
|
||||||
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
|
|
||||||
return masked_input
|
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
|
||||||
def init_tensors(self, ctx: DenoiseContext):
|
|
||||||
if not self._is_normal_model(ctx.unet):
|
|
||||||
raise ValueError(
|
|
||||||
"InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "
|
|
||||||
"inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "
|
|
||||||
"fixed by removing and re-adding the model (so that it gets re-probed)."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
|
||||||
|
|
||||||
self._noise = ctx.inputs.noise
|
|
||||||
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
|
||||||
# We still need noise for inpainting, so we generate it from the seed here.
|
|
||||||
if self._noise is None:
|
|
||||||
self._noise = torch.randn(
|
|
||||||
ctx.latents.shape,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device="cpu",
|
|
||||||
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
|
|
||||||
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
|
||||||
|
|
||||||
# Use negative order to make extensions with default order work with patched latents
|
|
||||||
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
|
|
||||||
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
|
|
||||||
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
|
|
||||||
|
|
||||||
# TODO: redo this with preview events rewrite
|
|
||||||
# Use negative order to make extensions with default order work with patched latents
|
|
||||||
@callback(ExtensionCallbackType.POST_STEP, order=-100)
|
|
||||||
def apply_mask_to_step_output(self, ctx: DenoiseContext):
|
|
||||||
timestep = ctx.scheduler.timesteps[-1]
|
|
||||||
if hasattr(ctx.step_output, "denoised"):
|
|
||||||
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
|
|
||||||
elif hasattr(ctx.step_output, "pred_original_sample"):
|
|
||||||
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
|
|
||||||
else:
|
|
||||||
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
|
|
||||||
|
|
||||||
# Restore unmasked part after the last step is completed
|
|
||||||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
|
||||||
def restore_unmasked(self, ctx: DenoiseContext):
|
|
||||||
if self._is_gradient_mask:
|
|
||||||
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
|
|
||||||
else:
|
|
||||||
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
||||||
|
|
||||||
|
|
||||||
class InpaintModelExt(ExtensionBase):
|
|
||||||
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
|
|
||||||
models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
mask: Optional[torch.Tensor],
|
|
||||||
masked_latents: Optional[torch.Tensor],
|
|
||||||
is_gradient_mask: bool,
|
|
||||||
):
|
|
||||||
"""Initialize InpaintModelExt.
|
|
||||||
Args:
|
|
||||||
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
|
|
||||||
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
|
|
||||||
inpainted.
|
|
||||||
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
|
|
||||||
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
|
|
||||||
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
|
|
||||||
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
|
|
||||||
1.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
if mask is not None and masked_latents is None:
|
|
||||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
|
||||||
|
|
||||||
# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
|
|
||||||
self._mask = None
|
|
||||||
if mask is not None:
|
|
||||||
self._mask = 1 - mask
|
|
||||||
self._masked_latents = masked_latents
|
|
||||||
self._is_gradient_mask = is_gradient_mask
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_inpaint_model(unet: UNet2DConditionModel):
|
|
||||||
"""Checks if the provided UNet belongs to a regular model.
|
|
||||||
The `in_channels` of a UNet vary depending on model type:
|
|
||||||
- normal - 4
|
|
||||||
- depth - 5
|
|
||||||
- inpaint - 9
|
|
||||||
"""
|
|
||||||
return unet.conv_in.in_channels == 9
|
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
|
|
||||||
def init_tensors(self, ctx: DenoiseContext):
|
|
||||||
if not self._is_inpaint_model(ctx.unet):
|
|
||||||
raise ValueError("InpaintModelExt should be used only on inpaint models!")
|
|
||||||
|
|
||||||
if self._mask is None:
|
|
||||||
self._mask = torch.ones_like(ctx.latents[:1, :1])
|
|
||||||
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
|
||||||
|
|
||||||
if self._masked_latents is None:
|
|
||||||
self._masked_latents = torch.zeros_like(ctx.latents[:1])
|
|
||||||
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
|
|
||||||
|
|
||||||
# Do last so that other extensions works with normal latents
|
|
||||||
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
|
|
||||||
def append_inpaint_layers(self, ctx: DenoiseContext):
|
|
||||||
batch_size = ctx.unet_kwargs.sample.shape[0]
|
|
||||||
b_mask = torch.cat([self._mask] * batch_size)
|
|
||||||
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
|
|
||||||
ctx.unet_kwargs.sample = torch.cat(
|
|
||||||
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Restore unmasked part as inpaint model can change unmasked part slightly
|
|
||||||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
|
|
||||||
def restore_unmasked(self, ctx: DenoiseContext):
|
|
||||||
if self._is_gradient_mask:
|
|
||||||
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
|
|
||||||
else:
|
|
||||||
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import TYPE_CHECKING, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
|
||||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAExt(ExtensionBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
node_context: InvocationContext,
|
|
||||||
model_id: ModelIdentifierField,
|
|
||||||
weight: float,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._node_context = node_context
|
|
||||||
self._model_id = model_id
|
|
||||||
self._weight = weight
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
|
||||||
lora_model = self._node_context.models.load(self._model_id).model
|
|
||||||
self.patch_model(
|
|
||||||
model=unet,
|
|
||||||
prefix="lora_unet_",
|
|
||||||
lora=lora_model,
|
|
||||||
lora_weight=self._weight,
|
|
||||||
original_weights=original_weights,
|
|
||||||
)
|
|
||||||
del lora_model
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@torch.no_grad()
|
|
||||||
def patch_model(
|
|
||||||
cls,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
prefix: str,
|
|
||||||
lora: LoRAModelRaw,
|
|
||||||
lora_weight: float,
|
|
||||||
original_weights: OriginalWeightsStorage,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Apply one or more LoRAs to a model.
|
|
||||||
:param model: The model to patch.
|
|
||||||
:param lora: LoRA model to patch in.
|
|
||||||
:param lora_weight: LoRA patch weight.
|
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
|
||||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if lora_weight == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# assert lora.device.type == "cpu"
|
|
||||||
for layer_key, layer in lora.layers.items():
|
|
||||||
if not layer_key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
|
||||||
# should be improved in the following ways:
|
|
||||||
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
|
||||||
# LoRA model is applied.
|
|
||||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
|
||||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
|
||||||
# weights to have valid keys.
|
|
||||||
assert isinstance(model, torch.nn.Module)
|
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
|
||||||
|
|
||||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
|
||||||
# (Performance will be best if this is a CUDA device.)
|
|
||||||
device = module.weight.device
|
|
||||||
dtype = module.weight.dtype
|
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
|
||||||
|
|
||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
|
||||||
# same thing in a single call to '.to(...)'.
|
|
||||||
layer.to(device=device)
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
|
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
|
||||||
for param_name, lora_param_weight in layer.get_parameters(module).items():
|
|
||||||
param_key = module_key + "." + param_name
|
|
||||||
module_param = module.get_parameter(param_name)
|
|
||||||
|
|
||||||
# save original weight
|
|
||||||
original_weights.save(param_key, module_param)
|
|
||||||
|
|
||||||
if module_param.shape != lora_param_weight.shape:
|
|
||||||
# TODO: debug on lycoris
|
|
||||||
lora_param_weight = lora_param_weight.reshape(module_param.shape)
|
|
||||||
|
|
||||||
lora_param_weight *= lora_weight * layer_scale
|
|
||||||
module_param += lora_param_weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
|
||||||
assert "." not in lora_key
|
|
||||||
|
|
||||||
if not lora_key.startswith(prefix):
|
|
||||||
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
|
|
||||||
|
|
||||||
module = model
|
|
||||||
module_key = ""
|
|
||||||
key_parts = lora_key[len(prefix) :].split("_")
|
|
||||||
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
|
|
||||||
while len(key_parts) > 0:
|
|
||||||
try:
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key += "." + submodule_name
|
|
||||||
submodule_name = key_parts.pop(0)
|
|
||||||
except Exception:
|
|
||||||
submodule_name += "_" + key_parts.pop(0)
|
|
||||||
|
|
||||||
module = module.get_submodule(submodule_name)
|
|
||||||
module_key = (module_key + "." + submodule_name).lstrip(".")
|
|
||||||
|
|
||||||
return (module_key, module)
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
||||||
|
|
||||||
|
|
||||||
class RescaleCFGExt(ExtensionBase):
|
|
||||||
def __init__(self, rescale_multiplier: float):
|
|
||||||
super().__init__()
|
|
||||||
self._rescale_multiplier = rescale_multiplier
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
|
|
||||||
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
|
||||||
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
|
||||||
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
|
||||||
|
|
||||||
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
|
||||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
|
||||||
return x_final
|
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
|
|
||||||
def rescale_noise_pred(self, ctx: DenoiseContext):
|
|
||||||
if self._rescale_multiplier > 0:
|
|
||||||
ctx.noise_pred = self._rescale_cfg(
|
|
||||||
ctx.noise_pred,
|
|
||||||
ctx.positive_noise_pred,
|
|
||||||
self._rescale_multiplier,
|
|
||||||
)
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from diffusers import UNet2DConditionModel
|
|
||||||
from diffusers.models.lora import LoRACompatibleConv
|
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
|
||||||
|
|
||||||
|
|
||||||
class SeamlessExt(ExtensionBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
seamless_axes: List[str],
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._seamless_axes = seamless_axes
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
|
||||||
with self.static_patch_model(
|
|
||||||
model=unet,
|
|
||||||
seamless_axes=self._seamless_axes,
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@contextmanager
|
|
||||||
def static_patch_model(
|
|
||||||
model: torch.nn.Module,
|
|
||||||
seamless_axes: List[str],
|
|
||||||
):
|
|
||||||
if not seamless_axes:
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
x_mode = "circular" if "x" in seamless_axes else "constant"
|
|
||||||
y_mode = "circular" if "y" in seamless_axes else "constant"
|
|
||||||
|
|
||||||
# override conv_forward
|
|
||||||
# https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019
|
|
||||||
def _conv_forward_asymmetric(
|
|
||||||
self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
|
||||||
):
|
|
||||||
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
|
||||||
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
|
||||||
working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
|
|
||||||
working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
|
|
||||||
return torch.nn.functional.conv2d(
|
|
||||||
working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups
|
|
||||||
)
|
|
||||||
|
|
||||||
original_layers: List[Tuple[nn.Conv2d, Callable]] = []
|
|
||||||
try:
|
|
||||||
for layer in model.modules():
|
|
||||||
if not isinstance(layer, torch.nn.Conv2d):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
|
|
||||||
layer.lora_layer = lambda *x: 0
|
|
||||||
original_layers.append((layer, layer._conv_forward))
|
|
||||||
layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d)
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
for layer, orig_conv_forward in original_layers:
|
|
||||||
layer._conv_forward = orig_conv_forward
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import T2IAdapter
|
|
||||||
from PIL.Image import Image
|
|
||||||
|
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
|
||||||
from invokeai.backend.model_manager import BaseModelType
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
||||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterExt(ExtensionBase):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
node_context: InvocationContext,
|
|
||||||
model_id: ModelIdentifierField,
|
|
||||||
image: Image,
|
|
||||||
weight: Union[float, List[float]],
|
|
||||||
begin_step_percent: float,
|
|
||||||
end_step_percent: float,
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._node_context = node_context
|
|
||||||
self._model_id = model_id
|
|
||||||
self._image = image
|
|
||||||
self._weight = weight
|
|
||||||
self._resize_mode = resize_mode
|
|
||||||
self._begin_step_percent = begin_step_percent
|
|
||||||
self._end_step_percent = end_step_percent
|
|
||||||
|
|
||||||
self._adapter_state: Optional[List[torch.Tensor]] = None
|
|
||||||
|
|
||||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
|
||||||
model_config = self._node_context.models.get_config(self._model_id.key)
|
|
||||||
if model_config.base == BaseModelType.StableDiffusion1:
|
|
||||||
self._max_unet_downscale = 8
|
|
||||||
elif model_config.base == BaseModelType.StableDiffusionXL:
|
|
||||||
self._max_unet_downscale = 4
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
|
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.SETUP)
|
|
||||||
def setup(self, ctx: DenoiseContext):
|
|
||||||
t2i_model: T2IAdapter
|
|
||||||
with self._node_context.models.load(self._model_id) as t2i_model:
|
|
||||||
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
|
|
||||||
|
|
||||||
self._adapter_state = self._run_model(
|
|
||||||
model=t2i_model,
|
|
||||||
image=self._image,
|
|
||||||
latents_height=latents_height,
|
|
||||||
latents_width=latents_width,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_model(
|
|
||||||
self,
|
|
||||||
model: T2IAdapter,
|
|
||||||
image: Image,
|
|
||||||
latents_height: int,
|
|
||||||
latents_width: int,
|
|
||||||
):
|
|
||||||
# Resize the T2I-Adapter input image.
|
|
||||||
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
|
|
||||||
# result will match the latent image's dimensions after max_unet_downscale is applied.
|
|
||||||
input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor
|
|
||||||
input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor
|
|
||||||
|
|
||||||
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
|
|
||||||
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
|
|
||||||
# T2I-Adapter model.
|
|
||||||
#
|
|
||||||
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
|
|
||||||
# of the same requirements (e.g. preserving binary masks during resize).
|
|
||||||
t2i_image = prepare_control_image(
|
|
||||||
image=image,
|
|
||||||
do_classifier_free_guidance=False,
|
|
||||||
width=input_width,
|
|
||||||
height=input_height,
|
|
||||||
num_channels=model.config["in_channels"],
|
|
||||||
device=model.device,
|
|
||||||
dtype=model.dtype,
|
|
||||||
resize_mode=self._resize_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model(t2i_image)
|
|
||||||
|
|
||||||
@callback(ExtensionCallbackType.PRE_UNET)
|
|
||||||
def pre_unet_step(self, ctx: DenoiseContext):
|
|
||||||
# skip if model not active in current step
|
|
||||||
total_steps = len(ctx.inputs.timesteps)
|
|
||||||
first_step = math.floor(self._begin_step_percent * total_steps)
|
|
||||||
last_step = math.ceil(self._end_step_percent * total_steps)
|
|
||||||
if ctx.step_index < first_step or ctx.step_index > last_step:
|
|
||||||
return
|
|
||||||
|
|
||||||
weight = self._weight
|
|
||||||
if isinstance(weight, list):
|
|
||||||
weight = weight[ctx.step_index]
|
|
||||||
|
|
||||||
adapter_state = self._adapter_state
|
|
||||||
if ctx.conditioning_mode == ConditioningMode.Both:
|
|
||||||
adapter_state = [torch.cat([v] * 2) for v in adapter_state]
|
|
||||||
|
|
||||||
if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
|
|
||||||
ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
|
|
||||||
else:
|
|
||||||
for i, value in enumerate(adapter_state):
|
|
||||||
ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight
|
|
||||||