mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 09:08:02 -05:00
Compare commits
142 Commits
v5.9.0rc1
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a8f4c593f | ||
|
|
29c78f0e5e | ||
|
|
501534e2e1 | ||
|
|
50c7318004 | ||
|
|
7f14597012 | ||
|
|
dbe68b364f | ||
|
|
0c7aa85a5c | ||
|
|
703e1c8001 | ||
|
|
b056c93ea3 | ||
|
|
4289241943 | ||
|
|
51f5abf5f9 | ||
|
|
e59fa59ad7 | ||
|
|
2407cb64b3 | ||
|
|
70f704ab44 | ||
|
|
b786032b89 | ||
|
|
e8cc06cc92 | ||
|
|
8e6c56c93d | ||
|
|
69d4ee7f93 | ||
|
|
567fd3e0da | ||
|
|
0b8f88e554 | ||
|
|
60f0c4bf99 | ||
|
|
900ec92ef1 | ||
|
|
2594768479 | ||
|
|
91ab81eca9 | ||
|
|
b20c745c6e | ||
|
|
e41a37bca0 | ||
|
|
9ca44f27a5 | ||
|
|
b9ddf67853 | ||
|
|
afe088045f | ||
|
|
09ca61a962 | ||
|
|
dd69a96c03 | ||
|
|
4a54e594d0 | ||
|
|
936ed1960a | ||
|
|
9fac7986c7 | ||
|
|
e4b603f44e | ||
|
|
7edfe6edcf | ||
|
|
bfb117d0e0 | ||
|
|
b31c1022c3 | ||
|
|
a5851ca31c | ||
|
|
77bf5c15bb | ||
|
|
595133463e | ||
|
|
6155f9ff9e | ||
|
|
7be87c8048 | ||
|
|
9868c3bfe3 | ||
|
|
8b299d0bac | ||
|
|
a44bfb4658 | ||
|
|
96fb5f6881 | ||
|
|
4109ea5324 | ||
|
|
f6c2ee5040 | ||
|
|
965753bf8b | ||
|
|
40c53ab95c | ||
|
|
aaa6211625 | ||
|
|
f6d770eac9 | ||
|
|
47cb61cd62 | ||
|
|
b0fdc8ae1c | ||
|
|
ed9b30efda | ||
|
|
168e5eeff0 | ||
|
|
7acaa86bdf | ||
|
|
96c0393fe7 | ||
|
|
403f795c5e | ||
|
|
c0f88a083e | ||
|
|
542b182899 | ||
|
|
3f58c68c09 | ||
|
|
e50c7e5947 | ||
|
|
4a83700fe4 | ||
|
|
c25f6d1f84 | ||
|
|
a53e1ccf08 | ||
|
|
1af9930951 | ||
|
|
c276c1cbee | ||
|
|
c619348f29 | ||
|
|
c6f96613fc | ||
|
|
258bf736da | ||
|
|
0d75c99476 | ||
|
|
323d409fb6 | ||
|
|
f251722f56 | ||
|
|
7004fde41b | ||
|
|
c9dc27afbb | ||
|
|
efd14ec0e4 | ||
|
|
21ee2b6251 | ||
|
|
82dd2d508f | ||
|
|
ffb5f6c6a6 | ||
|
|
5c5fff9ecb | ||
|
|
9ca071819b | ||
|
|
b14d8e8192 | ||
|
|
5a59f6e3b8 | ||
|
|
60b5aef16a | ||
|
|
35222a8835 | ||
|
|
0e8b5484d5 | ||
|
|
454506c83e | ||
|
|
8f6ab67376 | ||
|
|
5afcc7778f | ||
|
|
325e07d330 | ||
|
|
a016bdc159 | ||
|
|
a14f0b2864 | ||
|
|
721483318a | ||
|
|
be04743649 | ||
|
|
92f0c28d6c | ||
|
|
a6b94e8ca4 | ||
|
|
00b11ef795 | ||
|
|
182580ff69 | ||
|
|
8e9d5c1187 | ||
|
|
99aac5870e | ||
|
|
c1b475c585 | ||
|
|
ec44e68cbf | ||
|
|
73dbebbcc3 | ||
|
|
09f971467d | ||
|
|
2c71b0e873 | ||
|
|
92f69ac463 | ||
|
|
3b154df71a | ||
|
|
64aa965160 | ||
|
|
d715c27d07 | ||
|
|
515084577c | ||
|
|
7596c07a64 | ||
|
|
98fd1d949b | ||
|
|
6312e6aa8f | ||
|
|
6435f11bae | ||
|
|
1c69b9b1fa | ||
|
|
731970ff88 | ||
|
|
038bac1614 | ||
|
|
ed9efe7740 | ||
|
|
ffa0beba7a | ||
|
|
75d793f1c4 | ||
|
|
2b086917e0 | ||
|
|
a9f2738086 | ||
|
|
3a56799ea5 | ||
|
|
3162ce94dc | ||
|
|
c0dc6ac4e1 | ||
|
|
fed1995525 | ||
|
|
5006e23456 | ||
|
|
2f063bddda | ||
|
|
23a26422fd | ||
|
|
434f195a96 | ||
|
|
6a4c2d692c | ||
|
|
5127a07cf9 | ||
|
|
0b4c6f0ab4 | ||
|
|
d8450033ea | ||
|
|
3938736bd8 | ||
|
|
fb2c7b9566 | ||
|
|
29449ec27d | ||
|
|
e38f778d28 | ||
|
|
f5e78436a8 | ||
|
|
6a15b5d9be |
@@ -1,2 +1,5 @@
|
||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||
218b6d0546b990fc449c876fb99f44b50c4daa35
|
||||
182580ff6970caed400be178c5b888514b75d7f2
|
||||
8e9d5c1187b0d36da80571ce4c8ba9b3a37b6c46
|
||||
99aac5870e1092b182e6c5f21abcaab6936a4ad1
|
||||
21
.github/workflows/python-checks.yml
vendored
21
.github/workflows/python-checks.yml
vendored
@@ -34,6 +34,9 @@ on:
|
||||
|
||||
jobs:
|
||||
python-checks:
|
||||
env:
|
||||
# uv requires a venv by default - but for this, we can simply use the system python
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5 # expected run time: <1 min
|
||||
steps:
|
||||
@@ -57,25 +60,19 @@ jobs:
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup python
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install ruff
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip install ruff==0.9.9
|
||||
shell: bash
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
|
||||
- name: ruff check
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff check --output-format=github .
|
||||
run: uv tool run ruff@0.11.2 check --output-format=github .
|
||||
shell: bash
|
||||
|
||||
- name: ruff format
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: ruff format --check .
|
||||
run: uv tool run ruff@0.11.2 format --check .
|
||||
shell: bash
|
||||
|
||||
30
.github/workflows/python-tests.yml
vendored
30
.github/workflows/python-tests.yml
vendored
@@ -39,24 +39,15 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version:
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
- '3.12'
|
||||
platform:
|
||||
- linux-cuda-11_7
|
||||
- linux-rocm-5_2
|
||||
- linux-cpu
|
||||
- macos-default
|
||||
- windows-cpu
|
||||
include:
|
||||
- platform: linux-cuda-11_7
|
||||
os: ubuntu-22.04
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-rocm-5_2
|
||||
os: ubuntu-22.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/rocm5.2'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: linux-cpu
|
||||
os: ubuntu-22.04
|
||||
os: ubuntu-24.04
|
||||
extra-index-url: 'https://download.pytorch.org/whl/cpu'
|
||||
github-env: $GITHUB_ENV
|
||||
- platform: macos-default
|
||||
@@ -70,6 +61,8 @@ jobs:
|
||||
timeout-minutes: 15 # expected run time: 2-6 min, depending on platform
|
||||
env:
|
||||
PIP_USE_PEP517: '1'
|
||||
UV_SYSTEM_PYTHON: 1
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
# https://github.com/nschloe/action-cached-lfs-checkout
|
||||
@@ -92,20 +85,25 @@ jobs:
|
||||
- '!invokeai/frontend/web/**'
|
||||
- 'tests/**'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
env:
|
||||
PIP_EXTRA_INDEX_URL: ${{ matrix.extra-index-url }}
|
||||
run: >
|
||||
pip3 install --editable=".[test]"
|
||||
UV_INDEX: ${{ matrix.extra-index-url }}
|
||||
run: uv pip install --editable ".[test]"
|
||||
|
||||
- name: run pytest
|
||||
if: ${{ steps.changed-files.outputs.python_any_changed == 'true' || inputs.always_run == true }}
|
||||
|
||||
20
.github/workflows/typegen-checks.yml
vendored
20
.github/workflows/typegen-checks.yml
vendored
@@ -54,17 +54,25 @@ jobs:
|
||||
- 'pyproject.toml'
|
||||
- 'invokeai/**'
|
||||
|
||||
- name: setup uv
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: '0.6.10'
|
||||
enable-cache: true
|
||||
python-version: '3.11'
|
||||
|
||||
- name: setup python
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
cache: pip
|
||||
cache-dependency-path: pyproject.toml
|
||||
python-version: '3.11'
|
||||
|
||||
- name: install python dependencies
|
||||
- name: install dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: pip3 install --use-pep517 --editable="."
|
||||
env:
|
||||
UV_INDEX: ${{ matrix.extra-index-url }}
|
||||
run: uv pip install --editable .
|
||||
|
||||
- name: install frontend dependencies
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
@@ -77,7 +85,7 @@ jobs:
|
||||
|
||||
- name: generate schema
|
||||
if: ${{ steps.changed-files.outputs.src_any_changed == 'true' || inputs.always_run == true }}
|
||||
run: make frontend-typegen
|
||||
run: cd invokeai/frontend/web && uv run ../../../scripts/generate_openapi_schema.py | pnpm typegen
|
||||
shell: bash
|
||||
|
||||
- name: compare files
|
||||
|
||||
@@ -12,6 +12,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.util.logging import logging
|
||||
@@ -99,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
|
||||
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
async def get_config_() -> AppConfig:
|
||||
infill_methods = ["lama", "tile", "cv2", "color"] # TODO: add mosaic back
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
@@ -121,6 +122,21 @@ async def get_config() -> AppConfig:
|
||||
)
|
||||
|
||||
|
||||
class InvokeAIAppConfigWithSetFields(BaseModel):
|
||||
"""InvokeAI App Config with model fields set"""
|
||||
|
||||
set_fields: set[str] = Field(description="The set fields")
|
||||
config: InvokeAIAppConfig = Field(description="The InvokeAI App Config")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields
|
||||
)
|
||||
async def get_runtime_config() -> InvokeAIAppConfigWithSetFields:
|
||||
config = get_config()
|
||||
return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
|
||||
@@ -96,6 +96,22 @@ async def upload_image(
|
||||
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||
|
||||
|
||||
class ImageUploadEntry(BaseModel):
|
||||
image_dto: ImageDTO = Body(description="The image DTO")
|
||||
presigned_url: str = Body(description="The URL to get the presigned URL for the image upload")
|
||||
|
||||
|
||||
@images_router.post("/", operation_id="create_image_upload_entry")
|
||||
async def create_image_upload_entry(
|
||||
width: int = Body(description="The width of the image"),
|
||||
height: int = Body(description="The height of the image"),
|
||||
board_id: Optional[str] = Body(default=None, description="The board to add this image to, if any"),
|
||||
) -> ImageUploadEntry:
|
||||
"""Uploads an image from a URL, not implemented"""
|
||||
|
||||
raise HTTPException(status_code=501, detail="Not implemented")
|
||||
|
||||
|
||||
@images_router.delete("/i/{image_name}", operation_id="delete_image")
|
||||
async def delete_image(
|
||||
image_name: str = Path(description="The name of the image to delete"),
|
||||
|
||||
@@ -28,12 +28,10 @@ from invokeai.app.services.model_records import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.fields import BoardField
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
@@ -15,6 +18,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
CancelByDestinationResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
FieldIdentifier,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
@@ -22,6 +26,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.compose_pydantic_model import compose_model_from_fields
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
@@ -34,6 +39,17 @@ class SessionQueueAndProcessorStatus(BaseModel):
|
||||
processor: SessionProcessorStatus
|
||||
|
||||
|
||||
class SimpleModelIdentifer(BaseModel):
|
||||
id: str = Field(description="The model id")
|
||||
|
||||
|
||||
model_field_overrides = {ModelIdentifierField: (SimpleModelIdentifer, Field(description="The model identifier"))}
|
||||
|
||||
|
||||
def model_field_filter(field_type: type[Any]) -> bool:
|
||||
return field_type not in {BoardField, Optional[BoardField]}
|
||||
|
||||
|
||||
@session_queue_router.post(
|
||||
"/{queue_id}/enqueue_batch",
|
||||
operation_id="enqueue_batch",
|
||||
@@ -45,9 +61,52 @@ async def enqueue_batch(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch: Batch = Body(description="Batch to process"),
|
||||
prepend: bool = Body(default=False, description="Whether or not to prepend this batch in the queue"),
|
||||
is_api_validation_run: bool = Body(
|
||||
default=False,
|
||||
description="Whether or not this is a validation run.",
|
||||
),
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Body(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
),
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Body(
|
||||
default=None, description="The fields that were used as output from the API"
|
||||
),
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
if is_api_validation_run:
|
||||
session_count = batch.get_session_count()
|
||||
assert session_count == 1, "API validation run only supports single session batches"
|
||||
|
||||
if api_input_fields:
|
||||
composed_model = compose_model_from_fields(
|
||||
g=batch.graph,
|
||||
field_identifiers=api_input_fields,
|
||||
composed_model_class_name="APIInputModel",
|
||||
model_field_overrides=model_field_overrides,
|
||||
model_field_filter=model_field_filter,
|
||||
)
|
||||
json_schema = composed_model.model_json_schema(mode="validation")
|
||||
print("API Input Model")
|
||||
print(json.dumps(json_schema))
|
||||
|
||||
if api_output_fields:
|
||||
composed_model = compose_model_from_fields(
|
||||
g=batch.graph,
|
||||
field_identifiers=api_output_fields,
|
||||
composed_model_class_name="APIOutputModel",
|
||||
)
|
||||
json_schema = composed_model.model_json_schema(mode="validation")
|
||||
print("API Output Model")
|
||||
print(json.dumps(json_schema))
|
||||
|
||||
print("graph")
|
||||
print(batch.graph.model_dump_json())
|
||||
|
||||
if batch.workflow is not None:
|
||||
print("workflow")
|
||||
print(batch.workflow.model_dump_json())
|
||||
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import io
|
||||
import random
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
@@ -24,6 +25,37 @@ from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_common import
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
|
||||
ids = {
|
||||
"6614752a-0420-4d81-98fc-e110069d4f38": random.choice([True, False]),
|
||||
"default_5e8b008d-c697-45d0-8883-085a954c6ace": random.choice([True, False]),
|
||||
"4b2b297a-0d47-4f43-8113-ebbf3f403089": random.choice([True, False]),
|
||||
"d0ce602a-049e-4368-97ae-977b49eed042": random.choice([True, False]),
|
||||
"f170a187-fd74-40b8-ba9c-00de173ea4b9": random.choice([True, False]),
|
||||
"default_f96e794f-eb3e-4d01-a960-9b4e43402bcf": random.choice([True, False]),
|
||||
"default_cbf0e034-7b54-4b2c-b670-3b1e2e4b4a88": random.choice([True, False]),
|
||||
"default_dec5a2e9-f59c-40d9-8869-a056751d79b8": random.choice([True, False]),
|
||||
"default_dbe46d95-22aa-43fb-9c16-94400d0ce2fd": random.choice([True, False]),
|
||||
"default_d7a1c60f-ca2f-4f90-9e33-75a826ca6d8f": random.choice([True, False]),
|
||||
"default_e71d153c-2089-43c7-bd2c-f61f37d4c1c1": random.choice([True, False]),
|
||||
"default_7dde3e36-d78f-4152-9eea-00ef9c8124ed": random.choice([True, False]),
|
||||
"default_444fe292-896b-44fd-bfc6-c0b5d220fffc": random.choice([True, False]),
|
||||
"default_2d05e719-a6b9-4e64-9310-b875d3b2f9d2": random.choice([True, False]),
|
||||
"acae7e87-070b-4999-9074-c5b593c86618": random.choice([True, False]),
|
||||
"3008fc77-1521-49c7-ba95-94c5a4508d1d": random.choice([True, False]),
|
||||
"default_686bb1d0-d086-4c70-9fa3-2f600b922023": random.choice([True, False]),
|
||||
"36905c46-e768-4dc3-8ecd-e55fe69bf03c": random.choice([True, False]),
|
||||
"7c3e4951-183b-40ef-a890-28eef4d50097": random.choice([True, False]),
|
||||
"7a053b2f-64e4-4152-80e9-296006e77131": random.choice([True, False]),
|
||||
"27d4f1be-4156-46e9-8d22-d0508cd72d4f": random.choice([True, False]),
|
||||
"e881dc06-70d2-438f-b007-6f3e0c3c0e78": random.choice([True, False]),
|
||||
"265d2244-a1d7-495c-a2eb-88217f5eae37": random.choice([True, False]),
|
||||
"caebcbc7-2bf0-41c4-b553-106b585fddda": random.choice([True, False]),
|
||||
"a7998705-474e-417d-bd37-a2a9480beedf": random.choice([True, False]),
|
||||
"554d94b5-94b3-4d8e-8aed-51ebfc9deea5": random.choice([True, False]),
|
||||
"e6898540-c1bc-408b-b944-c1e242cddbcd": random.choice([True, False]),
|
||||
"363b0960-ab2c-4902-8df3-f592d6194bb3": random.choice([True, False]),
|
||||
}
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
"/i/{workflow_id}",
|
||||
@@ -39,6 +71,8 @@ async def get_workflow(
|
||||
try:
|
||||
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
|
||||
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
workflow.is_published = ids.get(workflow_id, False)
|
||||
workflow.workflow.is_published = ids.get(workflow_id, False)
|
||||
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
@@ -106,10 +140,11 @@ async def list_workflows(
|
||||
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
|
||||
is_published: Optional[bool] = Query(default=None, description="Whether to include/exclude published workflows"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
|
||||
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
workflow_record_list_items = ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
order_by=order_by,
|
||||
direction=direction,
|
||||
page=page,
|
||||
@@ -118,20 +153,23 @@ async def list_workflows(
|
||||
categories=categories,
|
||||
tags=tags,
|
||||
has_been_opened=has_been_opened,
|
||||
is_published=is_published,
|
||||
)
|
||||
for workflow in workflows.items:
|
||||
for item in workflow_record_list_items.items:
|
||||
data = item.model_dump()
|
||||
data["is_published"] = ids.get(item.workflow_id, False)
|
||||
workflows_with_thumbnails.append(
|
||||
WorkflowRecordListItemWithThumbnailDTO(
|
||||
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow.workflow_id),
|
||||
**workflow.model_dump(),
|
||||
thumbnail_url=ApiDependencies.invoker.services.workflow_thumbnails.get_url(item.workflow_id),
|
||||
**data,
|
||||
)
|
||||
)
|
||||
return PaginatedResults[WorkflowRecordListItemWithThumbnailDTO](
|
||||
items=workflows_with_thumbnails,
|
||||
total=workflows.total,
|
||||
page=workflows.page,
|
||||
pages=workflows.pages,
|
||||
per_page=workflows.per_page,
|
||||
total=workflow_record_list_items.total,
|
||||
page=workflow_record_list_items.page,
|
||||
pages=workflow_record_list_items.pages,
|
||||
per_page=workflow_record_list_items.per_page,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -27,7 +28,6 @@ import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeAliasType
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
@@ -100,37 +100,6 @@ class BaseInvocationOutput(BaseModel):
|
||||
All invocation outputs must use the `@invocation_output` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocationOutput = TypeAliasType(
|
||||
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
@@ -173,76 +142,16 @@ class BaseInvocation(ABC, BaseModel):
|
||||
All invocations must use the `@invocation` decorator to provide their unique type.
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls.get_invocations())], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
def invalidate_typeadapter(cls) -> None:
|
||||
"""Invalidates the typeadapter, forcing it to be rebuilt on next access. If the invocation allowlist or
|
||||
denylist is changed, this should be called to ensure the typeadapter is updated and validation respects
|
||||
the updated allowlist and denylist."""
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
is_in_denylist = (
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_annotation(cls) -> BaseInvocationOutput:
|
||||
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@classmethod
|
||||
def get_invocation_for_type(cls, invocation_type: str) -> BaseInvocation | None:
|
||||
"""Gets the invocation class for a given invocation type."""
|
||||
return cls.get_invocations_map().get(invocation_type)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
@@ -340,6 +249,105 @@ class BaseInvocation(ABC, BaseModel):
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
class InvocationRegistry:
|
||||
_invocation_classes: ClassVar[set[type[BaseInvocation]]] = set()
|
||||
_output_classes: ClassVar[set[type[BaseInvocationOutput]]] = set()
|
||||
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: type[BaseInvocation]) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls.invalidate_invocation_typeadapter()
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_invocation_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation types.
|
||||
|
||||
This is used to parse serialized invocations into the correct invocation class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls.get_invocation_classes())], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_invocation_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation type adapter."""
|
||||
cls.get_invocation_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_invocation_classes(cls) -> Iterable[type[BaseInvocation]]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = get_config()
|
||||
allowed_invocations: set[type[BaseInvocation]] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
is_in_denylist = (
|
||||
invocation_type in app_config.deny_nodes if isinstance(app_config.deny_nodes, list) else False
|
||||
)
|
||||
if is_in_allowlist and not is_in_denylist:
|
||||
allowed_invocations.add(sc)
|
||||
return allowed_invocations
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, type[BaseInvocation]]:
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in cls.get_invocation_classes()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in cls.get_invocation_classes())
|
||||
|
||||
@classmethod
|
||||
def get_invocation_for_type(cls, invocation_type: str) -> type[BaseInvocation] | None:
|
||||
"""Gets the invocation class for a given invocation type."""
|
||||
return cls.get_invocations_map().get(invocation_type)
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: "type[TBaseInvocationOutput]") -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls.invalidate_output_typeadapter()
|
||||
|
||||
@classmethod
|
||||
def get_output_classes(cls) -> Iterable[type[BaseInvocationOutput]]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=1)
|
||||
def get_output_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantic TypeAdapter for the union of all invocation output types.
|
||||
|
||||
This is used to parse serialized invocation outputs into the correct invocation output class.
|
||||
|
||||
This method is cached to avoid rebuilding the TypeAdapter on every access. If the invocation allowlist or
|
||||
denylist is changed, the cache should be cleared to ensure the TypeAdapter is updated and validation respects
|
||||
the updated allowlist and denylist.
|
||||
|
||||
@see https://docs.pydantic.dev/latest/concepts/type_adapter/
|
||||
"""
|
||||
return TypeAdapter(Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")])
|
||||
|
||||
@classmethod
|
||||
def invalidate_output_typeadapter(cls) -> None:
|
||||
"""Invalidates the cached invocation output type adapter."""
|
||||
cls.get_output_typeadapter.cache_clear()
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in cls.get_output_classes())
|
||||
|
||||
|
||||
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"id",
|
||||
"is_intermediate",
|
||||
@@ -453,8 +461,8 @@ def invocation(
|
||||
node_pack = cls.__module__.split(".")[0]
|
||||
|
||||
# Handle the case where an existing node is being clobbered by the one we are registering
|
||||
if invocation_type in BaseInvocation.get_invocation_types():
|
||||
clobbered_invocation = BaseInvocation.get_invocation_for_type(invocation_type)
|
||||
if invocation_type in InvocationRegistry.get_invocation_types():
|
||||
clobbered_invocation = InvocationRegistry.get_invocation_for_type(invocation_type)
|
||||
# This should always be true - we just checked if the invocation type was in the set
|
||||
assert clobbered_invocation is not None
|
||||
|
||||
@@ -539,8 +547,7 @@ def invocation(
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
# TODO: how to type this correctly? it's typed as ModelMetaclass, a private class in pydantic
|
||||
BaseInvocation.register_invocation(cls) # type: ignore
|
||||
InvocationRegistry.register_invocation(cls)
|
||||
|
||||
return cls
|
||||
|
||||
@@ -565,7 +572,7 @@ def invocation_output(
|
||||
if re.compile(r"^\S+$").match(output_type) is None:
|
||||
raise ValueError(f'"output_type" must consist of non-whitespace characters, got "{output_type}"')
|
||||
|
||||
if output_type in BaseInvocationOutput.get_output_types():
|
||||
if output_type in InvocationRegistry.get_output_types():
|
||||
raise ValueError(f'Invocation type "{output_type}" already exists')
|
||||
|
||||
validate_fields(cls.model_fields, output_type)
|
||||
@@ -586,7 +593,7 @@ def invocation_output(
|
||||
)
|
||||
cls.__doc__ = docstring
|
||||
|
||||
BaseInvocationOutput.register_output(cls) # type: ignore # TODO: how to type this correctly?
|
||||
InvocationRegistry.register_output(cls)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
@@ -19,7 +19,8 @@ from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.model import UNetField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import LoadedModel
|
||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import MainConfigBase
|
||||
from invokeai.backend.model_manager.taxonomy import ModelVariantType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
|
||||
|
||||
@@ -39,8 +39,8 @@ from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -25,7 +24,6 @@ class FluxControlLoRALoaderOutput(BaseInvocationOutput):
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlLoRALoaderInvocation(BaseInvocation):
|
||||
"""LoRA model and Image to use with FLUX transformer generation."""
|
||||
|
||||
@@ -3,7 +3,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -52,7 +51,6 @@ class FluxControlNetOutput(BaseInvocationOutput):
|
||||
tags=["controlnet", "flux"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlNetInvocation(BaseInvocation):
|
||||
"""Collect FLUX ControlNet info to pass to other nodes."""
|
||||
|
||||
@@ -10,7 +10,7 @@ from PIL import Image
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
@@ -49,7 +49,7 @@ from invokeai.backend.flux.sampling_utils import (
|
||||
unpack,
|
||||
)
|
||||
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
|
||||
from invokeai.backend.model_manager.config import ModelFormat, ModelVariantType
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -64,7 +64,6 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.3.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a FLUX transformer model."""
|
||||
|
||||
@@ -31,7 +31,7 @@ class FluxFillOutput(BaseInvocationOutput):
|
||||
tags=["inpaint"],
|
||||
category="inpaint",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class FluxFillInvocation(BaseInvocation):
|
||||
"""Prepare the FLUX Fill conditioning data."""
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Literal, Union
|
||||
from pydantic import field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import InputField, UIType
|
||||
from invokeai.app.invocations.ip_adapter import (
|
||||
CLIP_VISION_MODEL_MAP,
|
||||
@@ -28,7 +28,6 @@ from invokeai.backend.model_manager.config import (
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxIPAdapterInvocation(BaseInvocation):
|
||||
"""Collects FLUX IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@@ -3,14 +3,13 @@ from typing import Optional
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, T5EncoderField, TransformerField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
|
||||
|
||||
@invocation_output("flux_lora_loader_output")
|
||||
@@ -28,11 +27,10 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"flux_lora_loader",
|
||||
title="FLUX LoRA",
|
||||
title="Apply LoRA - FLUX",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.2.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.2.1",
|
||||
)
|
||||
class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply a LoRA model to a FLUX transformer and/or text encoder."""
|
||||
@@ -107,11 +105,10 @@ class FluxLoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"flux_lora_collection_loader",
|
||||
title="FLUX LoRA Collection Loader",
|
||||
title="Apply LoRA Collection - FLUX",
|
||||
tags=["lora", "model", "flux"],
|
||||
category="model",
|
||||
version="1.3.0",
|
||||
classification=Classification.Prototype,
|
||||
version="1.3.1",
|
||||
)
|
||||
class FLUXLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to a FLUX transformer."""
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Literal
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -17,8 +16,8 @@ from invokeai.app.util.t5_model_identifier import (
|
||||
from invokeai.backend.flux.util import max_seq_lengths
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
@@ -41,7 +40,6 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
tags=["model", "flux"],
|
||||
category="model",
|
||||
version="1.0.6",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a flux base model, outputting its submodels."""
|
||||
|
||||
@@ -23,7 +23,8 @@ from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.starter_models import siglip
|
||||
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -44,7 +45,7 @@ class FluxReduxOutput(BaseInvocationOutput):
|
||||
tags=["ip_adapter", "control"],
|
||||
category="ip_adapter",
|
||||
version="2.0.0",
|
||||
classification=Classification.Prototype,
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class FluxReduxInvocation(BaseInvocation):
|
||||
"""Runs a FLUX Redux model to generate a conditioning tensor."""
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Iterator, Literal, Optional, Tuple
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
@@ -17,7 +17,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.model_manager import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -30,7 +30,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
||||
tags=["prompt", "conditioning", "flux"],
|
||||
category="conditioning",
|
||||
version="1.1.2",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxTextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a flux image."""
|
||||
|
||||
@@ -6,7 +6,7 @@ from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
|
||||
|
||||
@invocation_output("ideal_size_output")
|
||||
|
||||
@@ -355,7 +355,6 @@ class ImageBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tags=["image", "unsharp_mask"],
|
||||
category="image",
|
||||
version="1.2.2",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class UnsharpMaskInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Applies an unsharp mask filter to an image"""
|
||||
@@ -1090,12 +1089,13 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
|
||||
@invocation(
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0"
|
||||
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1"
|
||||
)
|
||||
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
|
||||
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
|
||||
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
|
||||
If the fade size is 0, the mask is returned as-is.
|
||||
"""
|
||||
|
||||
mask: ImageField = InputField(description="The mask to expand")
|
||||
@@ -1105,6 +1105,11 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
|
||||
if self.fade_size_px == 0:
|
||||
# If the fade size is 0, just return the mask as-is.
|
||||
image_dto = context.images.save(image=pil_mask, image_category=ImageCategory.MASK)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
np_mask = numpy.array(pil_mask)
|
||||
|
||||
# Threshold the mask to create a binary mask - 0 for black, 255 for white
|
||||
@@ -1142,8 +1147,21 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
coeffs = numpy.polyfit(x_control, y_control, 3)
|
||||
poly = numpy.poly1d(coeffs)
|
||||
|
||||
# Evaluate and clip the smooth mapping
|
||||
feather = numpy.clip(poly(d_norm), 0, 1)
|
||||
# Evaluate the polynomial
|
||||
feather = poly(d_norm)
|
||||
|
||||
# The polynomial fit isn't perfect. Points beyond the fade distance are likely to be slightly less than 1.0,
|
||||
# even though the control points indicate that they should be exactly 1.0. This is due to the nature of the
|
||||
# polynomial fit, which is a best approximation of the control points but not an exact match.
|
||||
|
||||
# When this occurs, the area outside the mask and fade-out will not be 100% transparent. For example, it may
|
||||
# have an alpha value of 1 instead of 0. So we must force pixels at or beyond the fade distance to exactly 1.0.
|
||||
|
||||
# Force pixels at or beyond the fade distance to exactly 1.0
|
||||
feather = numpy.where(d_norm >= 1.0, 1.0, feather)
|
||||
|
||||
# Clip any other values to ensure they're in the valid range [0,1]
|
||||
feather = numpy.clip(feather, 0, 1)
|
||||
|
||||
# Build final image.
|
||||
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))
|
||||
@@ -1265,7 +1283,6 @@ class ImageNoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
tags=["image", "crop"],
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Crop an image to the given bounding box. If the bounding box is omitted, the image is cropped to the non-transparent pixels."""
|
||||
@@ -1292,7 +1309,6 @@ class CropImageToBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
tags=["image", "crop"],
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Paste the source image into the target image at the given bounding box.
|
||||
|
||||
@@ -13,10 +13,8 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordCh
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
IPAdapterCheckpointConfig,
|
||||
IPAdapterInvokeAIConfig,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.starter_models import (
|
||||
StarterModel,
|
||||
@@ -24,6 +22,7 @@ from invokeai.backend.model_manager.starter_models import (
|
||||
ip_adapter_sd_image_encoder,
|
||||
ip_adapter_sdxl_image_encoder,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import StringOutput
|
||||
@@ -13,7 +13,14 @@ from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0")
|
||||
@invocation(
|
||||
"llava_onevision_vllm",
|
||||
title="LLaVA OneVision VLLM",
|
||||
tags=["vllm"],
|
||||
category="vllm",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class LlavaOnevisionVllmInvocation(BaseInvocation):
|
||||
"""Run a LLaVA OneVision VLLM model."""
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
@@ -58,7 +57,6 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
"""Convert a mask image to a tensor. Opaque regions are 1 and transparent regions are 0."""
|
||||
@@ -87,7 +85,6 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class InvertTensorMaskInvocation(BaseInvocation):
|
||||
"""Inverts a tensor mask."""
|
||||
@@ -234,7 +231,6 @@ WHITE = ColorField(r=255, g=255, b=255, a=255)
|
||||
tags=["mask"],
|
||||
category="mask",
|
||||
version="1.0.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class GetMaskBoundingBoxInvocation(BaseInvocation):
|
||||
"""Gets the bounding box of the given mask image."""
|
||||
|
||||
@@ -43,7 +43,7 @@ from invokeai.app.invocations.primitives import BooleanOutput, FloatOutput, Inte
|
||||
from invokeai.app.invocations.scheduler import SchedulerOutput
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField, T2IAdapterInvocation
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.version import __version__
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -15,10 +14,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@@ -126,7 +123,6 @@ class ModelIdentifierOutput(BaseInvocationOutput):
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class ModelIdentifierInvocation(BaseInvocation):
|
||||
"""Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as
|
||||
@@ -181,7 +177,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
|
||||
@invocation("lora_loader", title="Apply LoRA - SD1.5", tags=["model"], category="model", version="1.0.4")
|
||||
class LoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@@ -244,7 +240,7 @@ class LoRASelectorOutput(BaseInvocationOutput):
|
||||
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
|
||||
|
||||
|
||||
@invocation("lora_selector", title="LoRA Model - SD1.5", tags=["model"], category="model", version="1.0.2")
|
||||
@invocation("lora_selector", title="Select LoRA", tags=["model"], category="model", version="1.0.3")
|
||||
class LoRASelectorInvocation(BaseInvocation):
|
||||
"""Selects a LoRA model and weight."""
|
||||
|
||||
@@ -258,7 +254,7 @@ class LoRASelectorInvocation(BaseInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"lora_collection_loader", title="LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.1"
|
||||
"lora_collection_loader", title="Apply LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.2"
|
||||
)
|
||||
class LoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
|
||||
@@ -322,10 +318,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
|
||||
|
||||
@invocation(
|
||||
"sdxl_lora_loader",
|
||||
title="LoRA Model - SDXL",
|
||||
title="Apply LoRA - SDXL",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.4",
|
||||
version="1.0.5",
|
||||
)
|
||||
class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@@ -402,10 +398,10 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
|
||||
|
||||
@invocation(
|
||||
"sdxl_lora_collection_loader",
|
||||
title="LoRA Collection - SDXL",
|
||||
title="Apply LoRA Collection - SDXL",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.1.1",
|
||||
version="1.1.2",
|
||||
)
|
||||
class SDXLLoRACollectionLoader(BaseInvocation):
|
||||
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""
|
||||
|
||||
@@ -6,7 +6,7 @@ from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
@@ -23,7 +23,7 @@ from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
||||
@@ -36,7 +36,6 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
version="1.1.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a SD3 model."""
|
||||
|
||||
@@ -2,7 +2,7 @@ import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
@@ -25,7 +25,6 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
tags=["image", "latents", "vae", "i2l", "sd3"],
|
||||
category="image",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image."""
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Optional
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -14,7 +13,7 @@ from invokeai.app.util.t5_model_identifier import (
|
||||
preprocess_t5_encoder_model_identifier,
|
||||
preprocess_t5_tokenizer_model_identifier,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
@@ -34,7 +33,6 @@ class Sd3ModelLoaderOutput(BaseInvocationOutput):
|
||||
tags=["model", "sd3"],
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a SD3 base model, outputting its submodels."""
|
||||
|
||||
@@ -11,12 +11,12 @@ from transformers import (
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import SD3ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import ModelFormat
|
||||
from invokeai.backend.model_manager.taxonomy import ModelFormat
|
||||
from invokeai.backend.patches.layer_patcher import LayerPatcher
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
@@ -33,7 +33,6 @@ SD3_T5_MAX_SEQ_LEN = 256
|
||||
tags=["prompt", "conditioning", "sd3"],
|
||||
category="conditioning",
|
||||
version="1.0.1",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
"""Encodes and preps a prompt for a SD3 image."""
|
||||
|
||||
@@ -2,7 +2,7 @@ from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocati
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, UNetField, VAEField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import SubModelType
|
||||
|
||||
|
||||
@invocation_output("sdxl_model_loader_output")
|
||||
|
||||
@@ -7,7 +7,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
@@ -56,7 +56,6 @@ def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> C
|
||||
title="Tiled Multi-Diffusion Denoise - SD1.5, SDXL",
|
||||
tags=["upscale", "denoise"],
|
||||
category="latents",
|
||||
classification=Classification.Beta,
|
||||
version="1.0.1",
|
||||
)
|
||||
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -40,7 +39,6 @@ class CalculateImageTilesOutput(BaseInvocationOutput):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
@@ -74,7 +72,6 @@ class CalculateImageTilesInvocation(BaseInvocation):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.1.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
@@ -117,7 +114,6 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
@@ -168,7 +164,6 @@ class TileToPropertiesOutput(BaseInvocationOutput):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class TileToPropertiesInvocation(BaseInvocation):
|
||||
"""Split a Tile into its individual properties."""
|
||||
@@ -201,7 +196,6 @@ class PairTileImageOutput(BaseInvocationOutput):
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.0.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class PairTileImageInvocation(BaseInvocation):
|
||||
"""Pair an image with its tile properties."""
|
||||
@@ -230,7 +224,6 @@ BLEND_MODES = Literal["Linear", "Seam"]
|
||||
tags=["tiles"],
|
||||
category="tiles",
|
||||
version="1.1.1",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Merge multiple tile images into a single image."""
|
||||
|
||||
@@ -41,16 +41,15 @@ def run_app() -> None:
|
||||
)
|
||||
|
||||
# Find an open port, and modify the config accordingly.
|
||||
orig_config_port = app_config.port
|
||||
app_config.port = find_open_port(app_config.port)
|
||||
if orig_config_port != app_config.port:
|
||||
first_open_port = find_open_port(app_config.port)
|
||||
if app_config.port != first_open_port:
|
||||
orig_config_port = app_config.port
|
||||
app_config.port = first_open_port
|
||||
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
|
||||
|
||||
# Miscellaneous startup tasks.
|
||||
apply_monkeypatches()
|
||||
register_mime_types()
|
||||
if app_config.dev_reload:
|
||||
enable_dev_reload()
|
||||
check_cudnn(logger)
|
||||
|
||||
# Initialize the app and event loop.
|
||||
@@ -61,6 +60,11 @@ def run_app() -> None:
|
||||
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
|
||||
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
|
||||
|
||||
if app_config.dev_reload:
|
||||
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
|
||||
# imported.
|
||||
enable_dev_reload(custom_nodes_path=app_config.custom_nodes_path)
|
||||
|
||||
# Start the server.
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
|
||||
@@ -44,7 +44,8 @@ if TYPE_CHECKING:
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
|
||||
@@ -16,7 +16,8 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
)
|
||||
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
|
||||
@@ -10,9 +10,9 @@ from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||
from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
|
||||
@@ -39,8 +39,6 @@ from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigBase,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
@@ -52,6 +50,7 @@ from invokeai.backend.model_manager.metadata import (
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant, ModelSourceType
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -5,9 +5,10 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
|
||||
@@ -11,7 +11,7 @@ from torch import load as torch_load
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
LoadedModelWithoutConfig,
|
||||
@@ -20,6 +20,7 @@ from invokeai.backend.model_manager.load import (
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
"""Initialization file for model manager service."""
|
||||
|
||||
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
__all__ = [
|
||||
"ModelManagerServiceBase",
|
||||
"ModelManagerService",
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"LoadedModel",
|
||||
]
|
||||
|
||||
@@ -14,10 +14,12 @@ from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
|
||||
@@ -60,11 +60,9 @@ from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
@@ -304,7 +302,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
# We catch this error so that the app can still run if there are invalid model configs in the database.
|
||||
# One reason that an invalid model config might be in the database is if someone had to rollback from a
|
||||
# newer version of the app that added a new model type.
|
||||
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
|
||||
row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0]
|
||||
self._logger.warning(
|
||||
f"Found an invalid model config in the database. Ignoring this model. ({row_data})"
|
||||
)
|
||||
else:
|
||||
results.append(model_config)
|
||||
|
||||
|
||||
@@ -33,7 +33,12 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
|
||||
def enqueue_batch(
|
||||
self,
|
||||
queue_id: str,
|
||||
batch: Batch,
|
||||
prepend: bool,
|
||||
) -> Coroutine[Any, Any, EnqueueBatchResult]:
|
||||
"""Enqueues all permutations of a batch for execution."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -157,6 +157,28 @@ class Batch(BaseModel):
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
|
||||
creating them, as is done in `create_session_nfv_tuples()`.
|
||||
|
||||
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
|
||||
many were _actually_ created (which may be less due to the maximum number of sessions).
|
||||
|
||||
If the session count has already been calculated, return the cached value.
|
||||
"""
|
||||
if not self.data:
|
||||
return self.runs
|
||||
data = []
|
||||
for batch_datum_list in self.data:
|
||||
to_zip = []
|
||||
for batch_datum in batch_datum_list:
|
||||
batch_data_items = range(len(batch_datum.items))
|
||||
to_zip.append(batch_data_items)
|
||||
data.append(list(zip(*to_zip, strict=True)))
|
||||
data_product = list(product(*data))
|
||||
return len(data_product) * self.runs
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
@@ -201,6 +223,12 @@ def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
||||
return None
|
||||
|
||||
|
||||
class FieldIdentifier(BaseModel):
|
||||
kind: Literal["input", "output"] = Field(description="The kind of field")
|
||||
node_id: str = Field(description="The ID of the node")
|
||||
field_name: str = Field(description="The name of the field")
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
@@ -237,6 +265,16 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
retried_from_item_id: Optional[int] = Field(
|
||||
default=None, description="The item_id of the queue item that this item was retried from"
|
||||
)
|
||||
is_api_validation_run: bool = Field(
|
||||
default=False,
|
||||
description="Whether this queue item is an API validation run.",
|
||||
)
|
||||
api_input_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The fields that were used as input to the API"
|
||||
)
|
||||
api_output_fields: Optional[list[FieldIdentifier]] = Field(
|
||||
default=None, description="The nodes that were used as output from the API"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
|
||||
@@ -536,28 +574,6 @@ def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str
|
||||
count += 1
|
||||
|
||||
|
||||
def calc_session_count(batch: Batch) -> int:
|
||||
"""
|
||||
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
|
||||
creating them, as is done in `create_session_nfv_tuples()`.
|
||||
|
||||
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
|
||||
many were _actually_ created (which may be less due to the maximum number of sessions).
|
||||
"""
|
||||
# TODO: Should this be a class method on Batch?
|
||||
if not batch.data:
|
||||
return batch.runs
|
||||
data = []
|
||||
for batch_datum_list in batch.data:
|
||||
to_zip = []
|
||||
for batch_datum in batch_datum_list:
|
||||
batch_data_items = range(len(batch_datum.items))
|
||||
to_zip.append(batch_data_items)
|
||||
data.append(list(zip(*to_zip, strict=True)))
|
||||
data_product = list(product(*data))
|
||||
return len(data_product) * batch.runs
|
||||
|
||||
|
||||
ValueToInsertTuple: TypeAlias = tuple[
|
||||
str, # queue_id
|
||||
str, # session (as stringified JSON)
|
||||
|
||||
@@ -28,7 +28,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
ValueToInsertTuple,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
@@ -118,7 +117,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = calc_session_count(batch)
|
||||
requested_count = batch.get_session_count()
|
||||
|
||||
values_to_insert = prepare_values_to_insert(
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
|
||||
204
invokeai/app/services/shared/compose_pydantic_model.py
Normal file
204
invokeai/app/services/shared/compose_pydantic_model.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, TypeAlias, get_args
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from invokeai.app.services.session_queue.session_queue_common import FieldIdentifier
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
|
||||
DictOfFieldsMetadata: TypeAlias = dict[str, tuple[type[Any], FieldInfo]]
|
||||
|
||||
|
||||
class ComposedFieldMetadata(BaseModel):
|
||||
node_id: str
|
||||
field_name: str
|
||||
field_type_class_name: str
|
||||
|
||||
|
||||
def dedupe_field_name(field_metadata: DictOfFieldsMetadata, field_name: str) -> str:
|
||||
"""Given a field name, return a name that is not already in the field metadata.
|
||||
If the field name is not in the field metadata, return the field name.
|
||||
If the field name is in the field metadata, generate a new name by appending an underscore and integer to the field name, starting with 2.
|
||||
"""
|
||||
|
||||
if field_name not in field_metadata:
|
||||
return field_name
|
||||
|
||||
i = 2
|
||||
while True:
|
||||
new_field_name = f"{field_name}_{i}"
|
||||
if new_field_name not in field_metadata:
|
||||
return new_field_name
|
||||
i += 1
|
||||
|
||||
|
||||
def compose_model_from_fields(
|
||||
g: Graph,
|
||||
field_identifiers: list[FieldIdentifier],
|
||||
composed_model_class_name: str = "ComposedModel",
|
||||
model_field_overrides: dict[type[Any], tuple[type[Any], FieldInfo]] | None = None,
|
||||
model_field_filter: Callable[[type[Any]], bool] | None = None,
|
||||
) -> type[BaseModel]:
|
||||
"""Given a graph and a list of field identifiers, create a new pydantic model composed of the fields of the nodes in the graph.
|
||||
|
||||
The resultant model can be used to validate a JSON payload that contains the fields of the nodes in the graph, or generate an
|
||||
OpenAPI schema for the model.
|
||||
|
||||
Args:
|
||||
g: The graph containing the nodes whose fields will be composed into the new model.
|
||||
field_identifiers: A list of FieldIdentifier instances, each representing a field on a node in the graph.
|
||||
model_name: The name of the composed model.
|
||||
kind: The kind of model to create. Must be "input" or "output". Defaults to "input".
|
||||
model_field_overrides: A dictionary mapping type annotations to tuples of (new_type_annotation, new_field_info).
|
||||
This can be used to override the type annotation and field info of a field in the composed model. For example,
|
||||
if `ModelIdentifierField` should be replaced by a string, the dictionary would look like this:
|
||||
```python
|
||||
{ModelIdentifierField: (str, Field(description="The model id."))}
|
||||
```
|
||||
model_field_filter: A function that takes a type annotation and returns True if the field should be included in the composed model.
|
||||
If None, all fields will be included. For example, to omit `BoardField` fields, the filter would look like this:
|
||||
```python
|
||||
def model_field_filter(field_type: type[Any]) -> bool:
|
||||
return field_type not in {BoardField}
|
||||
```
|
||||
Optional fields - or any other complex field types like unions - must be explicitly included in the filter. For example,
|
||||
to omit `BoardField` _and_ `Optional[BoardField]`:
|
||||
```python
|
||||
def model_field_filter(field_type: type[Any]) -> bool:
|
||||
return field_type not in {BoardField, Optional[BoardField]}
|
||||
```
|
||||
Note that the filter is applied to the type annotation of the field, not the field itself.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
# Create some nodes.
|
||||
add_node = AddInvocation()
|
||||
sub_node = SubtractInvocation()
|
||||
color_node = ColorInvocation()
|
||||
|
||||
# Create a graph with the nodes.
|
||||
g = Graph(
|
||||
nodes={
|
||||
add_node.id: add_node,
|
||||
sub_node.id: sub_node,
|
||||
color_node.id: color_node,
|
||||
}
|
||||
)
|
||||
|
||||
# Select the fields to compose.
|
||||
fields_to_compose = [
|
||||
FieldIdentifier(node_id=add_node.id, field_name="a"),
|
||||
FieldIdentifier(node_id=sub_node.id, field_name="a"), # this will be deduped to "a_2"
|
||||
FieldIdentifier(node_id=add_node.id, field_name="b"),
|
||||
FieldIdentifier(node_id=color_node.id, field_name="color"),
|
||||
]
|
||||
|
||||
# Compose the model from the fields.
|
||||
composed_model = compose_model_from_fields(g, fields_to_compose, model_name="ComposedModel")
|
||||
|
||||
# Generate the OpenAPI schema for the model.
|
||||
json_schema = composed_model.model_json_schema(mode="validation")
|
||||
```
|
||||
"""
|
||||
|
||||
# Temp storage for the composed fields. Pydantic needs a type annotation and instance of FieldInfo to create a model.
|
||||
field_metadata: DictOfFieldsMetadata = {}
|
||||
model_field_overrides = model_field_overrides or {}
|
||||
|
||||
# The list of required fields. This is used to ensure the composed model's fields retain their required state.
|
||||
required: list[str] = []
|
||||
|
||||
for field_identifier in field_identifiers:
|
||||
node_id = field_identifier.node_id
|
||||
field_name = field_identifier.field_name
|
||||
|
||||
# Pull the node instance from the graph so we can introspect it.
|
||||
node_instance = g.nodes[node_id]
|
||||
|
||||
if field_identifier.kind == "input":
|
||||
# Get the class of the node. This will be a BaseInvocation subclass, e.g. AddInvocation, DenoiseLatentsInvocation, etc.
|
||||
pydantic_model = type(node_instance)
|
||||
else:
|
||||
# Otherwise the the type of the node's output class. This will be a BaseInvocationOutput subclass, e.g. IntegerOutput, ImageOutput, etc.
|
||||
pydantic_model = type(node_instance).get_output_annotation()
|
||||
|
||||
# Get the FieldInfo instance for the field. For example:
|
||||
# a: int = Field(..., description="The first number to add.")
|
||||
# ^^^^^ The return value of this Field call is the FieldInfo instance (Field is a function).
|
||||
og_field_info = pydantic_model.model_fields[field_name]
|
||||
|
||||
# Get the type annotation of the field. For example:
|
||||
# a: int = Field(..., description="The first number to add.")
|
||||
# ^^^ this is the type annotation
|
||||
og_field_type = og_field_info.annotation
|
||||
|
||||
# Apparently pydantic allows fields without type annotations. We don't support that.
|
||||
assert og_field_type is not None, (
|
||||
f"{field_identifier.kind.capitalize()} field {field_name} on node {node_id} has no type annotation."
|
||||
)
|
||||
|
||||
# Now that we have the type annotation, we can apply the filter to see if we should include the field in the composed model.
|
||||
if model_field_filter and not model_field_filter(og_field_type):
|
||||
continue
|
||||
|
||||
# Ok, we want this type of field. Retrieve any overrides for the field type. This is a dictionary mapping
|
||||
# type annotations to tuples of (override_type_annotation, override_field_info).
|
||||
(override_field_type, override_field_info) = model_field_overrides.get(og_field_type, (None, None))
|
||||
|
||||
# The override tuple's first element is the new type annotation, if it exists.
|
||||
composed_field_type = override_field_type if override_field_type is not None else og_field_type
|
||||
|
||||
# Create a deep copy of the FieldInfo instance (or override it if it exists) so we can modify it without
|
||||
# affecting the original. This is important because we are going to modify the FieldInfo instance and
|
||||
# don't want to affect the original model's schema.
|
||||
composed_field_info = deepcopy(override_field_info if override_field_info is not None else og_field_info)
|
||||
|
||||
json_schema_extra = og_field_info.json_schema_extra if isinstance(og_field_info.json_schema_extra, dict) else {}
|
||||
|
||||
# The field's original required state is stored in the json_schema_extra dict. For more information about why,
|
||||
# see the definition of `InputField` in invokeai/app/invocations/fields.py.
|
||||
#
|
||||
# Add the field to the required list if it is required, which we will use when creating the composed model.
|
||||
if json_schema_extra.get("orig_required", False):
|
||||
required.append(field_name)
|
||||
|
||||
# Invocation fields have some extra metadata, used by the UI to render the field in the frontend. This data is
|
||||
# included in the OpenAPI schema for each field. For example, we add a "ui_order" field, which the UI uses to
|
||||
# sort fields when rendering them.
|
||||
#
|
||||
# The composed model's OpenAPI schema should not have this information. It should only have a standard OpenAPI
|
||||
# schema for the field. We need to strip out the UI-specific metadata from the FieldInfo instance before adding
|
||||
# it to the composed model.
|
||||
#
|
||||
# We will replace this metadata with some custom metadata:
|
||||
# - node_id: The id of the node that this field belongs to.
|
||||
# - field_name: The name of the field on the node.
|
||||
# - original_data_type: The original data type of the field.
|
||||
|
||||
field_type_class = get_args(og_field_type)[0] if hasattr(og_field_type, "__args__") else og_field_type
|
||||
field_type_class_name = field_type_class.__name__
|
||||
|
||||
composed_field_metadata = ComposedFieldMetadata(
|
||||
node_id=node_id,
|
||||
field_name=field_name,
|
||||
field_type_class_name=field_type_class_name,
|
||||
)
|
||||
|
||||
composed_field_info.json_schema_extra = {
|
||||
"composed_field_extra": composed_field_metadata.model_dump(),
|
||||
}
|
||||
|
||||
# Override the name, title and description if overrides are provided. Dedupe the field name if necessary.
|
||||
final_field_name = dedupe_field_name(field_metadata, field_name)
|
||||
|
||||
# Store the field metadata.
|
||||
field_metadata.update({final_field_name: (composed_field_type, composed_field_info)})
|
||||
|
||||
# Splat in the composed fields to create the new model. There are type errors here because create_model's kwargs are not typed,
|
||||
# and for some reason pydantic's ConfigDict doesn't like lists in `json_schema_extra`. Anyways, the inputs here are correct.
|
||||
return create_model(
|
||||
composed_model_class_name,
|
||||
**field_metadata,
|
||||
__config__=ConfigDict(json_schema_extra={"required": required}),
|
||||
)
|
||||
@@ -21,6 +21,7 @@ from invokeai.app.invocations import * # noqa: F401 F403
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationRegistry,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -283,7 +284,7 @@ class AnyInvocation(BaseInvocation):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||
return InvocationRegistry.get_invocation_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||
|
||||
@@ -294,7 +295,7 @@ class AnyInvocation(BaseInvocation):
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||
names = [i.__name__ for i in InvocationRegistry.get_invocation_classes()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
@@ -304,7 +305,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||
return InvocationRegistry.get_output_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||
|
||||
@@ -316,7 +317,7 @@ class AnyInvocationOutput(BaseInvocationOutput):
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||
names = [i.__name__ for i in InvocationRegistry.get_output_classes()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
@@ -20,14 +20,10 @@ from invokeai.app.services.session_processor.session_processor_common import Pro
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
query: Optional[str],
|
||||
tags: Optional[list[str]],
|
||||
has_been_opened: Optional[bool],
|
||||
is_published: Optional[bool],
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
@@ -56,6 +57,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
"""Gets a dictionary of counts for each of the provided categories."""
|
||||
pass
|
||||
@@ -66,6 +68,7 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
"""Gets a dictionary of counts for each of the provided tags."""
|
||||
pass
|
||||
|
||||
@@ -67,6 +67,7 @@ class WorkflowWithoutID(BaseModel):
|
||||
# This is typed as optional to prevent errors when pulling workflows from the DB. The frontend adds a default form if
|
||||
# it is None.
|
||||
form: dict[str, JsonValue] | None = Field(default=None, description="The form of the workflow.")
|
||||
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
@@ -101,6 +102,7 @@ class WorkflowRecordDTOBase(BaseModel):
|
||||
opened_at: Optional[Union[datetime.datetime, str]] = Field(
|
||||
default=None, description="The opened timestamp of the workflow."
|
||||
)
|
||||
is_published: bool | None = Field(default=None, description="Whether the workflow is published or not.")
|
||||
|
||||
|
||||
class WorkflowRecordDTO(WorkflowRecordDTOBase):
|
||||
|
||||
@@ -119,6 +119,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
query: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
@@ -241,6 +242,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
if not tags:
|
||||
return {}
|
||||
@@ -292,6 +294,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
is_published: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
cursor = self._conn.cursor()
|
||||
result: dict[str, int] = {}
|
||||
|
||||
@@ -4,7 +4,10 @@ from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
InvocationRegistry,
|
||||
UIConfigBase,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
@@ -56,14 +59,14 @@ def get_openapi_func(
|
||||
invocation_output_map_required: list[str] = []
|
||||
|
||||
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||
for output in BaseInvocationOutput.get_outputs():
|
||||
for output in InvocationRegistry.get_output_classes():
|
||||
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||
|
||||
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
|
||||
# property, so we'll just do it all manually.
|
||||
for invocation in BaseInvocation.get_invocations():
|
||||
for invocation in InvocationRegistry.get_invocation_classes():
|
||||
json_schema = invocation.model_json_schema(
|
||||
mode="serialization", ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
@@ -33,7 +34,16 @@ def check_cudnn(logger: logging.Logger) -> None:
|
||||
)
|
||||
|
||||
|
||||
def enable_dev_reload() -> None:
|
||||
def invokeai_source_dir() -> Path:
|
||||
# `invokeai.__file__` doesn't always work for editable installs
|
||||
this_module_path = Path(__file__).resolve()
|
||||
# https://youtrack.jetbrains.com/issue/PY-38382/Unresolved-reference-spec-but-this-is-standard-builtin
|
||||
# noinspection PyUnresolvedReferences
|
||||
depth = len(__spec__.parent.split("."))
|
||||
return this_module_path.parents[depth - 1]
|
||||
|
||||
|
||||
def enable_dev_reload(custom_nodes_path=None) -> None:
|
||||
"""Enable hot reloading on python file changes during development."""
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@@ -44,7 +54,10 @@ def enable_dev_reload() -> None:
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
|
||||
) from e
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
paths = [str(invokeai_source_dir() / "*.py")]
|
||||
if custom_nodes_path:
|
||||
paths.append(str(custom_nodes_path / "*.py"))
|
||||
jurigged.watch(pattern=paths, logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
|
||||
def apply_monkeypatches() -> None:
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
# fast latents preview matrix for sdxl
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.backend.model_manager.config import BaseModelType, SubModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, SubModelType
|
||||
|
||||
|
||||
def preprocess_t5_encoder_model_identifier(model_identifier: ModelIdentifierField) -> ModelIdentifierField:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Tuple
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
|
||||
23
invokeai/backend/flux/flux_state_dict_utils.py
Normal file
23
invokeai/backend/flux/flux_state_dict_utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.model_manager.legacy_probe import CkptType
|
||||
|
||||
|
||||
def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None:
|
||||
"""Gets the in channels from the state dict."""
|
||||
|
||||
# "Standard" FLUX models use "img_in.weight", but some community fine tunes use
|
||||
# "model.diffusion_model.img_in.weight". Known models that use the latter key:
|
||||
# - https://civitai.com/models/885098?modelVersionId=990775
|
||||
# - https://civitai.com/models/1018060?modelVersionId=1596255
|
||||
# - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133
|
||||
|
||||
keys = {"img_in.weight", "model.diffusion_model.img_in.weight"}
|
||||
|
||||
for key in keys:
|
||||
val = state_dict.get(key)
|
||||
if val is not None:
|
||||
return val.shape[1]
|
||||
|
||||
return None
|
||||
@@ -6,8 +6,8 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
|
||||
@@ -16,7 +16,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .config import *
|
||||
from .config import is_exportable, is_scriptable
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
|
||||
@@ -5,8 +5,8 @@ Copyright 2020 Ross Wightman
|
||||
import re
|
||||
from copy import deepcopy
|
||||
|
||||
from .conv2d_layers import *
|
||||
from geffnet.activations import *
|
||||
from .conv2d_layers import CondConv2d, get_condconv_initializer, math, partial, select_conv2d
|
||||
from geffnet.activations import F, get_act_layer, nn, sigmoid, torch
|
||||
|
||||
__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible',
|
||||
'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv',
|
||||
|
||||
@@ -32,7 +32,9 @@ import torch.nn.functional as F
|
||||
from .config import layer_config_kwargs, is_scriptable
|
||||
from .conv2d_layers import select_conv2d
|
||||
from .helpers import load_pretrained
|
||||
from .efficientnet_builder import *
|
||||
from .efficientnet_builder import (BN_EPS_TF_DEFAULT, EfficientNetBuilder, decode_arch_def,
|
||||
initialize_weight_default, initialize_weight_goog,
|
||||
resolve_act_layer, resolve_bn_args, round_channels)
|
||||
|
||||
__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140',
|
||||
'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small',
|
||||
|
||||
@@ -13,7 +13,9 @@ from .activations import get_act_fn, get_act_layer, HardSwish
|
||||
from .config import layer_config_kwargs
|
||||
from .conv2d_layers import select_conv2d
|
||||
from .helpers import load_pretrained
|
||||
from .efficientnet_builder import *
|
||||
from .efficientnet_builder import (BN_EPS_TF_DEFAULT, EfficientNetBuilder, decode_arch_def,
|
||||
initialize_weight_default, initialize_weight_goog,
|
||||
resolve_act_layer, resolve_bn_args, round_channels)
|
||||
|
||||
__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100',
|
||||
'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100',
|
||||
|
||||
@@ -10,7 +10,7 @@ from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
"""
|
||||
|
||||
@@ -47,3 +47,10 @@ class LlavaOnevisionModel(RawModel):
|
||||
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
self._vllm_model.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get size of the model in memory in bytes."""
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._vllm_model)
|
||||
|
||||
@@ -1,37 +1,45 @@
|
||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.legacy_probe import ModelProbe
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
__all__ = [
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelRepoVariant",
|
||||
"InvalidModelConfigException",
|
||||
"LoadedModel",
|
||||
"ModelConfigFactory",
|
||||
"ModelFormat",
|
||||
"ModelProbe",
|
||||
"ModelSearch",
|
||||
"ModelConfigBase",
|
||||
"AnyModel",
|
||||
"AnyVariant",
|
||||
"BaseModelType",
|
||||
"ClipVariantType",
|
||||
"ModelFormat",
|
||||
"ModelRepoVariant",
|
||||
"ModelSourceType",
|
||||
"ModelType",
|
||||
"ModelVariantType",
|
||||
"SchedulerPredictionType",
|
||||
"SubModelType",
|
||||
"ModelConfigBase",
|
||||
]
|
||||
|
||||
@@ -21,6 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
"""
|
||||
|
||||
# pyright: reportIncompatibleVariableOverride=false
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -29,153 +30,41 @@ from inspect import isabstract
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Literal, Optional, TypeAlias, Union
|
||||
|
||||
import diffusers
|
||||
import onnxruntime as ort
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ClipVariantType,
|
||||
FluxLoRAFormat,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
|
||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[
|
||||
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession
|
||||
]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognize this combination of model type and format."""
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusion3 = "sd-3"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""Model type."""
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlLoRa = "control_lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
CLIPEmbed = "clip_embed"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
T5Encoder = "t5_encoder"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
SigLIP = "siglip"
|
||||
FluxRedux = "flux_redux"
|
||||
LlavaOnevision = "llava_onevision"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
UNet = "unet"
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ClipVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
L = "large"
|
||||
G = "gigantic"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Storage format of model."""
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
LyCORIS = "lycoris"
|
||||
ONNX = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
T5Encoder = "t5_encoder"
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
"""Scheduler prediction type."""
|
||||
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
Default = "" # model files without "fp16" or other qualifier
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OpenVINO = "openvino"
|
||||
Flax = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""Model source type."""
|
||||
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
pass
|
||||
|
||||
|
||||
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
|
||||
|
||||
class SubmodelDefinition(BaseModel):
|
||||
path_or_prefix: str
|
||||
model_type: ModelType
|
||||
@@ -206,51 +95,6 @@ class ControlAdapterDefaultSettings(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class ModelOnDisk:
|
||||
"""A utility class representing a model stored on disk."""
|
||||
|
||||
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
|
||||
self.path = path
|
||||
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
|
||||
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
self.name = path.stem
|
||||
else:
|
||||
self.name = path.name
|
||||
self.hash_algo = hash_algo
|
||||
|
||||
def hash(self):
|
||||
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
||||
|
||||
def size(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return self.path.stat().st_size
|
||||
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
||||
|
||||
def component_paths(self):
|
||||
if self.format_type == ModelFormat.Checkpoint:
|
||||
return {self.path}
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
@staticmethod
|
||||
def load_state_dict(path: Path):
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
|
||||
checkpoint = torch.load(path, map_location="cpu")
|
||||
elif path.suffix.endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
|
||||
elif path.suffix.endswith(".safetensors"):
|
||||
checkpoint = safetensors.torch.load_file(path)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized model extension: {path.suffix}")
|
||||
|
||||
state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
return state_dict
|
||||
|
||||
|
||||
class MatchSpeed(int, Enum):
|
||||
"""Represents the estimated runtime speed of a config's 'matches' method."""
|
||||
|
||||
@@ -325,16 +169,18 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
Created to deprecate ModelProbe.probe
|
||||
"""
|
||||
candidates = ModelConfigBase._USING_CLASSIFY_API
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
|
||||
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
|
||||
mod = ModelOnDisk(model_path, hash_algo)
|
||||
|
||||
for config_cls in sorted_by_match_speed:
|
||||
try:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
except InvalidModelConfigException:
|
||||
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
|
||||
if not config_cls.matches(mod):
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
|
||||
logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}")
|
||||
continue
|
||||
else:
|
||||
return config_cls.from_model_on_disk(mod, **overrides)
|
||||
|
||||
raise InvalidModelConfigException("No valid config found")
|
||||
|
||||
@@ -359,21 +205,43 @@ class ModelConfigBase(ABC, BaseModel):
|
||||
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def cast_overrides(overrides: dict[str, Any]):
|
||||
"""Casts user overrides from str to Enum"""
|
||||
if "type" in overrides:
|
||||
overrides["type"] = ModelType(overrides["type"])
|
||||
|
||||
if "format" in overrides:
|
||||
overrides["format"] = ModelFormat(overrides["format"])
|
||||
|
||||
if "base" in overrides:
|
||||
overrides["base"] = BaseModelType(overrides["base"])
|
||||
|
||||
if "source_type" in overrides:
|
||||
overrides["source_type"] = ModelSourceType(overrides["source_type"])
|
||||
|
||||
if "variant" in overrides:
|
||||
overrides["variant"] = ModelVariantType(overrides["variant"])
|
||||
|
||||
@classmethod
|
||||
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
|
||||
"""Creates an instance of this config or raises InvalidModelConfigException."""
|
||||
if not cls.matches(mod):
|
||||
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
|
||||
|
||||
fields = cls.parse(mod)
|
||||
cls.cast_overrides(overrides)
|
||||
fields.update(overrides)
|
||||
|
||||
type = fields.get("type") or cls.model_fields["type"].default
|
||||
base = fields.get("base") or cls.model_fields["base"].default
|
||||
|
||||
fields["path"] = mod.path.as_posix()
|
||||
fields["source"] = fields.get("source") or fields["path"]
|
||||
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
|
||||
fields["name"] = mod.name
|
||||
fields["name"] = name = fields.get("name") or mod.name
|
||||
fields["hash"] = fields.get("hash") or mod.hash()
|
||||
fields["key"] = fields.get("key") or uuid_string()
|
||||
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
|
||||
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
|
||||
|
||||
fields.update(overrides)
|
||||
return cls(**fields)
|
||||
|
||||
|
||||
@@ -414,6 +282,38 @@ class LoRAConfigBase(ABC, BaseModel):
|
||||
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
||||
|
||||
@classmethod
|
||||
def flux_lora_format(cls, mod: ModelOnDisk):
|
||||
key = "FLUX_LORA_FORMAT"
|
||||
if key in mod.cache:
|
||||
return mod.cache[key]
|
||||
|
||||
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
|
||||
|
||||
sd = mod.load_state_dict(mod.path)
|
||||
value = flux_format_from_state_dict(sd)
|
||||
mod.cache[key] = value
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def base_model(cls, mod: ModelOnDisk) -> BaseModelType:
|
||||
if cls.flux_lora_format(mod):
|
||||
return BaseModelType.Flux
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
# If we've gotten here, we assume that the model is a Stable Diffusion model
|
||||
token_vector_length = lora_token_vector_length(state_dict)
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelConfigException("Unknown LoRA type")
|
||||
|
||||
|
||||
class T5EncoderConfigBase(ABC, BaseModel):
|
||||
"""Base class for diffusers-style models."""
|
||||
@@ -429,11 +329,40 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.path.is_dir():
|
||||
return False
|
||||
|
||||
# Avoid false positive match against ControlLoRA and Diffusers
|
||||
if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]:
|
||||
return False
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
for key in state_dict.keys():
|
||||
if type(key) is int:
|
||||
continue
|
||||
|
||||
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
|
||||
return True
|
||||
# "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": cls.base_model(mod),
|
||||
}
|
||||
|
||||
|
||||
class ControlAdapterConfigBase(ABC, BaseModel):
|
||||
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
||||
@@ -457,11 +386,26 @@ class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, Mod
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
|
||||
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase):
|
||||
"""Model config for LoRA/Diffusers models."""
|
||||
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.path.is_file():
|
||||
return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers
|
||||
|
||||
suffixes = ["bin", "safetensors"]
|
||||
weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes]
|
||||
return any(wf.exists() for wf in weight_files)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": cls.base_model(mod),
|
||||
}
|
||||
|
||||
|
||||
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
"""Model config for standalone VAE models."""
|
||||
@@ -625,12 +569,34 @@ class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
|
||||
class LlavaOnevisionConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
|
||||
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
|
||||
"""Model config for Llava Onevision models."""
|
||||
|
||||
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
|
||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.path.is_file():
|
||||
return False
|
||||
|
||||
config_path = mod.path / "config.json"
|
||||
try:
|
||||
with open(config_path, "r") as file:
|
||||
config = json.load(file)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
architectures = config.get("architectures")
|
||||
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
return {
|
||||
"base": BaseModelType.Any,
|
||||
"variant": ModelVariantType.Normal,
|
||||
}
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
|
||||
@@ -14,27 +14,30 @@ from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict
|
||||
from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter
|
||||
from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ControlAdapterDefaultSettings,
|
||||
InvalidModelConfigException,
|
||||
MainModelDefaultSettings,
|
||||
ModelConfigFactory,
|
||||
SubmodelDefinition,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyVariant,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SubmodelDefinition,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
get_clip_variant_type,
|
||||
lora_token_vector_length,
|
||||
@@ -562,15 +565,28 @@ class CheckpointProbeBase(ProbeBase):
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
|
||||
if base_type == BaseModelType.Flux:
|
||||
in_channels = state_dict["img_in.weight"].shape[1]
|
||||
if in_channels == 64:
|
||||
in_channels = get_flux_in_channels_from_state_dict(state_dict)
|
||||
|
||||
if in_channels is None:
|
||||
# If we cannot find the in_channels, we assume that this is a normal variant. Log a warning.
|
||||
logger.warning(
|
||||
f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant."
|
||||
)
|
||||
return ModelVariantType.Normal
|
||||
elif in_channels == 384:
|
||||
|
||||
# FLUX Model variant types are distinguished by input channels:
|
||||
# - Unquantized Dev and Schnell have in_channels=64
|
||||
# - BNB-NF4 Dev and Schnell have in_channels=1
|
||||
# - FLUX Fill has in_channels=384
|
||||
# - Unsure of quantized FLUX Fill models
|
||||
# - Unsure of GGUF-quantized models
|
||||
if in_channels == 384:
|
||||
# This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant
|
||||
# type is used to determine whether to use the fill model or the base model.
|
||||
return ModelVariantType.Inpaint
|
||||
else:
|
||||
raise InvalidModelConfigException(
|
||||
f"Unexpected in_channels (in_channels={in_channels}) for FLUX model at {self.model_path}."
|
||||
)
|
||||
# Fall back on "normal" variant type for all other FLUX models.
|
||||
return ModelVariantType.Normal
|
||||
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
|
||||
@@ -13,12 +13,11 @@ import torch
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
|
||||
|
||||
|
||||
class LoadedModelWithoutConfig:
|
||||
|
||||
@@ -6,18 +6,16 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord
|
||||
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
|
||||
@@ -23,6 +22,7 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter
|
||||
|
||||
@@ -20,13 +20,10 @@ from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelLoaderBase
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelLoaderRegistryBase(ABC):
|
||||
|
||||
@@ -4,16 +4,12 @@ from typing import Optional
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
DiffusersConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
|
||||
@@ -5,19 +5,19 @@ from typing import Optional
|
||||
|
||||
from diffusers import ControlNetModel
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
AnyModelConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(
|
||||
|
||||
@@ -27,15 +27,8 @@ from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
@@ -51,6 +44,13 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import (
|
||||
convert_bundle_to_flux_transformer_checkpoint,
|
||||
)
|
||||
|
||||
@@ -8,18 +8,16 @@ from typing import Any, Optional
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
|
||||
@@ -7,8 +7,9 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
|
||||
@@ -3,15 +3,11 @@ from typing import Optional
|
||||
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers)
|
||||
|
||||
@@ -9,17 +9,17 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
|
||||
@@ -5,16 +5,16 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
|
||||
|
||||
@@ -2,15 +2,11 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
|
||||
|
||||
|
||||
|
||||
@@ -4,15 +4,11 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
|
||||
|
||||
@@ -11,16 +11,8 @@ from diffusers import (
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
CheckpointConfigBase,
|
||||
DiffusersConfigBase,
|
||||
MainCheckpointConfig,
|
||||
@@ -28,6 +20,14 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
|
||||
@@ -4,16 +4,16 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
|
||||
@@ -5,15 +5,16 @@ from typing import Optional
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModel, SubModelType, VAECheckpointConfig
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
|
||||
@@ -15,7 +15,8 @@ from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import D
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
|
||||
from invokeai.backend.model_manager.taxonomy import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
|
||||
@@ -50,6 +51,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
SegmentAnythingPipeline,
|
||||
DepthAnythingPipeline,
|
||||
SigLipPipeline,
|
||||
LlavaOnevisionModel,
|
||||
),
|
||||
):
|
||||
return model.calc_size()
|
||||
|
||||
@@ -17,12 +17,12 @@ from typing import Optional
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
AnyModelRepoMetadataValidator,
|
||||
BaseMetadata,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
|
||||
|
||||
class ModelMetadataFetchBase(ABC):
|
||||
|
||||
@@ -24,7 +24,6 @@ from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundErro
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata.fetch.fetch_base import ModelMetadataFetchBase
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import (
|
||||
AnyModelRepoMetadata,
|
||||
@@ -32,6 +31,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import (
|
||||
RemoteModelFile,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
|
||||
HF_MODEL_RE = r"https?://huggingface.co/([\w\-.]+/[\w\-.]+)"
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.util.select_hf_files import filter_files
|
||||
|
||||
|
||||
|
||||
96
invokeai/backend/model_manager/model_on_disk.py
Normal file
96
invokeai/backend/model_manager/model_on_disk.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, TypeAlias
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
StateDict: TypeAlias = dict[str | int, Any] # When are the keys int?
|
||||
|
||||
|
||||
class ModelOnDisk:
|
||||
"""A utility class representing a model stored on disk."""
|
||||
|
||||
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
|
||||
self.path = path
|
||||
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
self.name = path.stem
|
||||
else:
|
||||
self.name = path.name
|
||||
self.hash_algo = hash_algo
|
||||
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
|
||||
# This prevents redundant computations during matching and parsing
|
||||
self.cache = {"_CACHED_STATE_DICTS": {}}
|
||||
|
||||
def hash(self) -> str:
|
||||
return ModelHash(algorithm=self.hash_algo).hash(self.path)
|
||||
|
||||
def size(self) -> int:
|
||||
if self.path.is_file():
|
||||
return self.path.stat().st_size
|
||||
return sum(file.stat().st_size for file in self.path.rglob("*"))
|
||||
|
||||
def component_paths(self) -> set[Path]:
|
||||
if self.path.is_file():
|
||||
return {self.path}
|
||||
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
|
||||
return {f for f in self.path.rglob("*") if f.suffix in extensions}
|
||||
|
||||
def repo_variant(self) -> Optional[ModelRepoVariant]:
|
||||
if self.path.is_file():
|
||||
return None
|
||||
|
||||
weight_files = list(self.path.glob("**/*.safetensors"))
|
||||
weight_files.extend(list(self.path.glob("**/*.bin")))
|
||||
for x in weight_files:
|
||||
if ".fp16" in x.suffixes:
|
||||
return ModelRepoVariant.FP16
|
||||
if "openvino_model" in x.name:
|
||||
return ModelRepoVariant.OpenVINO
|
||||
if "flax_model" in x.name:
|
||||
return ModelRepoVariant.Flax
|
||||
if x.suffix == ".onnx":
|
||||
return ModelRepoVariant.ONNX
|
||||
return ModelRepoVariant.Default
|
||||
|
||||
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
|
||||
sd_cache = self.cache["_CACHED_STATE_DICTS"]
|
||||
|
||||
if path in sd_cache:
|
||||
return sd_cache[path]
|
||||
|
||||
if not path:
|
||||
components = list(self.component_paths())
|
||||
match components:
|
||||
case []:
|
||||
raise ValueError("No weight files found for this model")
|
||||
case [p]:
|
||||
path = p
|
||||
case ps if len(ps) >= 2:
|
||||
raise ValueError(
|
||||
f"Multiple weight files found for this model: {ps}. "
|
||||
f"Please specify the intended file using the 'path' argument"
|
||||
)
|
||||
|
||||
with SilenceWarnings():
|
||||
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
|
||||
scan_result = scan_file_path(path)
|
||||
if scan_result.infected_files != 0 or scan_result.scan_err:
|
||||
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
|
||||
checkpoint = torch.load(path, map_location="cpu")
|
||||
assert isinstance(checkpoint, dict)
|
||||
elif path.suffix.endswith(".gguf"):
|
||||
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
|
||||
elif path.suffix.endswith(".safetensors"):
|
||||
checkpoint = safetensors.torch.load_file(path)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized model extension: {path.suffix}")
|
||||
|
||||
state_dict = checkpoint.get("state_dict", checkpoint)
|
||||
sd_cache[path] = state_dict
|
||||
return state_dict
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class StarterModelWithoutDependencies(BaseModel):
|
||||
|
||||
138
invokeai/backend/model_manager/taxonomy.py
Normal file
138
invokeai/backend/model_manager/taxonomy.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, TypeAlias, Union
|
||||
|
||||
import diffusers
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[
|
||||
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession
|
||||
]
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
"""Base model type."""
|
||||
|
||||
Any = "any"
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusion3 = "sd-3"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
Flux = "flux"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
"""Model type."""
|
||||
|
||||
ONNX = "onnx"
|
||||
Main = "main"
|
||||
VAE = "vae"
|
||||
LoRA = "lora"
|
||||
ControlLoRa = "control_lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
CLIPEmbed = "clip_embed"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
T5Encoder = "t5_encoder"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
SigLIP = "siglip"
|
||||
FluxRedux = "flux_redux"
|
||||
LlavaOnevision = "llava_onevision"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
"""Submodel type."""
|
||||
|
||||
UNet = "unet"
|
||||
Transformer = "transformer"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
|
||||
|
||||
class ClipVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
L = "large"
|
||||
G = "gigantic"
|
||||
|
||||
|
||||
class ModelVariantType(str, Enum):
|
||||
"""Variant type."""
|
||||
|
||||
Normal = "normal"
|
||||
Inpaint = "inpaint"
|
||||
Depth = "depth"
|
||||
|
||||
|
||||
class ModelFormat(str, Enum):
|
||||
"""Storage format of model."""
|
||||
|
||||
Diffusers = "diffusers"
|
||||
Checkpoint = "checkpoint"
|
||||
LyCORIS = "lycoris"
|
||||
ONNX = "onnx"
|
||||
Olive = "olive"
|
||||
EmbeddingFile = "embedding_file"
|
||||
EmbeddingFolder = "embedding_folder"
|
||||
InvokeAI = "invokeai"
|
||||
T5Encoder = "t5_encoder"
|
||||
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
||||
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
||||
GGUFQuantized = "gguf_quantized"
|
||||
|
||||
|
||||
class SchedulerPredictionType(str, Enum):
|
||||
"""Scheduler prediction type."""
|
||||
|
||||
Epsilon = "epsilon"
|
||||
VPrediction = "v_prediction"
|
||||
Sample = "sample"
|
||||
|
||||
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
Default = "" # model files without "fp16" or other qualifier
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
OpenVINO = "openvino"
|
||||
Flax = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""Model source type."""
|
||||
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
|
||||
|
||||
class FluxLoRAFormat(str, Enum):
|
||||
"""Flux LoRA formats."""
|
||||
|
||||
Diffusers = "flux.diffusers"
|
||||
Kohya = "flux.kohya"
|
||||
OneTrainer = "flux.onetrainer"
|
||||
Control = "flux.control"
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
@@ -8,7 +8,7 @@ import picklescan.scanner as pscan
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import ClipVariantType
|
||||
from invokeai.backend.model_manager.taxonomy import ClipVariantType
|
||||
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
|
||||
|
||||
|
||||
def filter_files(
|
||||
|
||||
24
invokeai/backend/patches/lora_conversions/formats.py
Normal file
24
invokeai/backend/patches/lora_conversions/formats.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_kohya_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_onetrainer_format,
|
||||
)
|
||||
|
||||
|
||||
def flux_format_from_state_dict(state_dict):
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict):
|
||||
return FluxLoRAFormat.Kohya
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
|
||||
return FluxLoRAFormat.OneTrainer
|
||||
elif is_state_dict_likely_in_flux_diffusers_format(state_dict):
|
||||
return FluxLoRAFormat.Diffusers
|
||||
elif is_state_dict_likely_flux_control(state_dict):
|
||||
return FluxLoRAFormat.Control
|
||||
else:
|
||||
return None
|
||||
@@ -8,7 +8,7 @@ from diffusers import T2IAdapter
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
|
||||
|
||||
@@ -62,7 +62,7 @@
|
||||
"@nanostores/react": "^0.7.3",
|
||||
"@reduxjs/toolkit": "2.6.1",
|
||||
"@roarr/browser-log-writer": "^1.3.0",
|
||||
"@xyflow/react": "^12.4.2",
|
||||
"@xyflow/react": "^12.5.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"chakra-react-select": "^4.9.2",
|
||||
"cmdk": "^1.0.0",
|
||||
@@ -150,7 +150,7 @@
|
||||
"prettier": "^3.3.3",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"storybook": "^8.3.4",
|
||||
"tsafe": "^1.7.5",
|
||||
"tsafe": "^1.8.5",
|
||||
"type-fest": "^4.26.1",
|
||||
"typescript": "^5.6.2",
|
||||
"vite": "^6.1.0",
|
||||
|
||||
56
invokeai/frontend/web/pnpm-lock.yaml
generated
56
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -36,8 +36,8 @@ dependencies:
|
||||
specifier: ^1.3.0
|
||||
version: 1.3.0
|
||||
'@xyflow/react':
|
||||
specifier: ^12.4.2
|
||||
version: 12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
|
||||
specifier: ^12.5.1
|
||||
version: 12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1)
|
||||
async-mutex:
|
||||
specifier: ^0.5.0
|
||||
version: 0.5.0
|
||||
@@ -284,8 +284,8 @@ devDependencies:
|
||||
specifier: ^8.3.4
|
||||
version: 8.3.4
|
||||
tsafe:
|
||||
specifier: ^1.7.5
|
||||
version: 1.7.5
|
||||
specifier: ^1.8.5
|
||||
version: 1.8.5
|
||||
type-fest:
|
||||
specifier: ^4.26.1
|
||||
version: 4.26.1
|
||||
@@ -3323,7 +3323,7 @@ packages:
|
||||
/@types/d3-drag@3.0.7:
|
||||
resolution: {integrity: sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==}
|
||||
dependencies:
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/d3-interpolate@3.0.4:
|
||||
@@ -3332,21 +3332,21 @@ packages:
|
||||
'@types/d3-color': 3.1.3
|
||||
dev: false
|
||||
|
||||
/@types/d3-selection@3.0.10:
|
||||
resolution: {integrity: sha512-cuHoUgS/V3hLdjJOLTT691+G2QoqAjCVLmr4kJXR4ha56w1Zdu8UUQ5TxLRqudgNjwXeQxKMq4j+lyf9sWuslg==}
|
||||
/@types/d3-selection@3.0.11:
|
||||
resolution: {integrity: sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==}
|
||||
dev: false
|
||||
|
||||
/@types/d3-transition@3.0.8:
|
||||
resolution: {integrity: sha512-ew63aJfQ/ms7QQ4X7pk5NxQ9fZH/z+i24ZfJ6tJSfqxJMrYLiK01EAs2/Rtw/JreGUsS3pLPNV644qXFGnoZNQ==}
|
||||
/@types/d3-transition@3.0.9:
|
||||
resolution: {integrity: sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==}
|
||||
dependencies:
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/d3-zoom@3.0.8:
|
||||
resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==}
|
||||
dependencies:
|
||||
'@types/d3-interpolate': 3.0.4
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-selection': 3.0.11
|
||||
dev: false
|
||||
|
||||
/@types/diff-match-patch@1.0.36:
|
||||
@@ -3951,28 +3951,28 @@ packages:
|
||||
resolution: {integrity: sha512-N8tkAACJx2ww8vFMneJmaAgmjAG1tnVBZJRLRcx061tmsLRZHSEZSLuGWnwPtunsSLvSqXQ2wfp7Mgqg1I+2dQ==}
|
||||
dev: false
|
||||
|
||||
/@xyflow/react@12.4.2(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-AFJKVc/fCPtgSOnRst3xdYJwiEcUN9lDY7EO/YiRvFHYCJGgfzg+jpvZjkTOnBLGyrMJre9378pRxAc3fsR06A==}
|
||||
/@xyflow/react@12.5.1(@types/react@18.3.11)(react-dom@18.3.1)(react@18.3.1):
|
||||
resolution: {integrity: sha512-jMKQVqGwCz0x6pUyvxTIuCMbyehfua7CfEEWDj29zQSHigQpCy0/5d8aOmZrqK4cwur/pVHLQomT6Rm10gXfHg==}
|
||||
peerDependencies:
|
||||
react: '>=17'
|
||||
react-dom: '>=17'
|
||||
dependencies:
|
||||
'@xyflow/system': 0.0.50
|
||||
'@xyflow/system': 0.0.53
|
||||
classcat: 5.0.5
|
||||
react: 18.3.1
|
||||
react-dom: 18.3.1(react@18.3.1)
|
||||
zustand: 4.5.5(@types/react@18.3.11)(react@18.3.1)
|
||||
zustand: 4.5.6(@types/react@18.3.11)(react@18.3.1)
|
||||
transitivePeerDependencies:
|
||||
- '@types/react'
|
||||
- immer
|
||||
dev: false
|
||||
|
||||
/@xyflow/system@0.0.50:
|
||||
resolution: {integrity: sha512-HVUZd4LlY88XAaldFh2nwVxDOcdIBxGpQ5txzwfJPf+CAjj2BfYug1fHs2p4yS7YO8H6A3EFJQovBE8YuHkAdg==}
|
||||
/@xyflow/system@0.0.53:
|
||||
resolution: {integrity: sha512-QTWieiTtvNYyQAz1fxpzgtUGXNpnhfh6vvZa7dFWpWS2KOz6bEHODo/DTK3s07lDu0Bq0Db5lx/5M5mNjb9VDQ==}
|
||||
dependencies:
|
||||
'@types/d3-drag': 3.0.7
|
||||
'@types/d3-selection': 3.0.10
|
||||
'@types/d3-transition': 3.0.8
|
||||
'@types/d3-selection': 3.0.11
|
||||
'@types/d3-transition': 3.0.9
|
||||
'@types/d3-zoom': 3.0.8
|
||||
d3-drag: 3.0.0
|
||||
d3-selection: 3.0.0
|
||||
@@ -8791,8 +8791,8 @@ packages:
|
||||
resolution: {integrity: sha512-tLJxacIQUM82IR7JO1UUkKlYuUTmoY9HBJAmNWFzheSlDS5SPMcNIepejHJa4BpPQLAcbRhRf3GDJzyj6rbKvA==}
|
||||
dev: false
|
||||
|
||||
/tsafe@1.7.5:
|
||||
resolution: {integrity: sha512-tbNyyBSbwfbilFfiuXkSOj82a6++ovgANwcoqBAcO9/REPoZMEQoE8kWPeO0dy5A2D/2Lajr8Ohue5T0ifIvLQ==}
|
||||
/tsafe@1.8.5:
|
||||
resolution: {integrity: sha512-LFWTWQrW6rwSY+IBNFl2ridGfUzVsPwrZ26T4KUJww/py8rzaQ/SY+MIz6YROozpUCaRcuISqagmlwub9YT9kw==}
|
||||
dev: true
|
||||
|
||||
/tsconfck@3.1.5(typescript@5.6.2):
|
||||
@@ -9123,6 +9123,14 @@ packages:
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/use-sync-external-store@1.4.0(react@18.3.1):
|
||||
resolution: {integrity: sha512-9WXSPC5fMv61vaupRkCKCxsPxBocVnwakBEkMIHHpkTTg6icbJtg6jzgtLDm4bl3cSHAca52rYWih0k4K3PfHw==}
|
||||
peerDependencies:
|
||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||
dependencies:
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/util-deprecate@1.0.2:
|
||||
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
|
||||
dev: true
|
||||
@@ -9567,8 +9575,8 @@ packages:
|
||||
/zod@3.23.8:
|
||||
resolution: {integrity: sha512-XBx9AXhXktjUqnepgTiE5flcKIYWi/rme0Eaj+5Y0lftuGBq+jyRu/md4WnuxqgP1ubdpNCsYEYPxrzVHD8d6g==}
|
||||
|
||||
/zustand@4.5.5(@types/react@18.3.11)(react@18.3.1):
|
||||
resolution: {integrity: sha512-+0PALYNJNgK6hldkgDq2vLrw5f6g/jCInz52n9RTpropGgeAf/ioFUCdtsjCqu4gNhW9D01rUQBROoRjdzyn2Q==}
|
||||
/zustand@4.5.6(@types/react@18.3.11)(react@18.3.1):
|
||||
resolution: {integrity: sha512-ibr/n1hBzLLj5Y+yUcU7dYw8p6WnIVzdJbnX+1YpaScvZVF2ziugqHs+LAmHw4lWO9c/zRj+K1ncgWDQuthEdQ==}
|
||||
engines: {node: '>=12.7.0'}
|
||||
peerDependencies:
|
||||
'@types/react': '>=16.8'
|
||||
@@ -9584,5 +9592,5 @@ packages:
|
||||
dependencies:
|
||||
'@types/react': 18.3.11
|
||||
react: 18.3.1
|
||||
use-sync-external-store: 1.2.2(react@18.3.1)
|
||||
use-sync-external-store: 1.4.0(react@18.3.1)
|
||||
dev: false
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user