mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 05:38:15 -05:00
Compare commits
85 Commits
v2.3.3-rc7
...
pre-nodes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdad62e88b | ||
|
|
955c81acef | ||
|
|
e1058f3416 | ||
|
|
edf16a253d | ||
|
|
46f5ef4100 | ||
|
|
b843255236 | ||
|
|
3a968e5072 | ||
|
|
fd80e84ea6 | ||
|
|
4824237a98 | ||
|
|
2c9a05eb59 | ||
|
|
ecb5bdaf7e | ||
|
|
f6cdff2c5b | ||
|
|
63d10027a4 | ||
|
|
ef0773b8a3 | ||
|
|
3daaddf15b | ||
|
|
570c3fe690 | ||
|
|
cbd1a7263a | ||
|
|
7fc5fbd4ce | ||
|
|
6f6de402ad | ||
|
|
281662a6e1 | ||
|
|
50eb02f68b | ||
|
|
d73f3adc43 | ||
|
|
116107f464 | ||
|
|
da44bb1707 | ||
|
|
f43aed677e | ||
|
|
0d051aaae2 | ||
|
|
e4e48ff995 | ||
|
|
442a6bffa4 | ||
|
|
23d65e7162 | ||
|
|
024fd54d0b | ||
|
|
c44c19e911 | ||
|
|
d923d1d66b | ||
|
|
1f2c1e14db | ||
|
|
07e3a0ec15 | ||
|
|
427db7c7e2 | ||
|
|
dad3a7f263 | ||
|
|
5bd0bb637f | ||
|
|
f05095770c | ||
|
|
de189f2db6 | ||
|
|
4463124bdd | ||
|
|
34402cc46a | ||
|
|
54d9833db0 | ||
|
|
5fe8cb56fc | ||
|
|
7919d81fb1 | ||
|
|
9d80b28a4f | ||
|
|
1fcd91bcc5 | ||
|
|
e456e2e63a | ||
|
|
ee41b99049 | ||
|
|
111d674e71 | ||
|
|
8f048cfbd9 | ||
|
|
7103ac6a32 | ||
|
|
f6b131e706 | ||
|
|
d1b2b99226 | ||
|
|
e356f2511b | ||
|
|
e5f8b22a43 | ||
|
|
45b84fb4bb | ||
|
|
f022c89249 | ||
|
|
ab05144716 | ||
|
|
aeb4914e67 | ||
|
|
76bcd4d44f | ||
|
|
50f5e1bc83 | ||
|
|
4c339dd4b0 | ||
|
|
7268131f57 | ||
|
|
85b020f76c | ||
|
|
a7833cc9a9 | ||
|
|
919294e977 | ||
|
|
d44151d6ff | ||
|
|
7640acfb1f | ||
|
|
aed9ecef2a | ||
|
|
18cddd7972 | ||
|
|
e6b25f4ae3 | ||
|
|
d1c0050e65 | ||
|
|
ecdfa136a0 | ||
|
|
5cd513ee63 | ||
|
|
ab45086546 | ||
|
|
77ba7359f4 | ||
|
|
8cbe2e14d9 | ||
|
|
ee86eedf01 | ||
|
|
1f89cf3343 | ||
|
|
c4e6511a59 | ||
|
|
44843be4c8 | ||
|
|
b9df9e26f2 | ||
|
|
e11c1d66ab | ||
|
|
cdb3616dca | ||
|
|
abe4dc8ac1 |
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@@ -1,16 +1,16 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||
/mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||
/docs/ @lstein @tildebyte @blessedcoolant
|
||||
/mkdocs.yml @lstein @blessedcoolant
|
||||
|
||||
# nodes
|
||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||
/docker/ @mauwii @lstein @blessedcoolant
|
||||
/pyproject.toml @lstein @blessedcoolant
|
||||
/docker/ @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @lstein @ebr
|
||||
/invokeai/assets @lstein @ebr
|
||||
@@ -22,11 +22,11 @@
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
|
||||
# generation, model management, postprocessing
|
||||
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
|
||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
||||
|
||||
# front ends
|
||||
/invokeai/frontend/CLI @lstein
|
||||
/invokeai/frontend/install @lstein @ebr @mauwii
|
||||
/invokeai/frontend/install @lstein @ebr
|
||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||
|
||||
19
.github/stale.yaml
vendored
Normal file
19
.github/stale.yaml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
# Number of days of inactivity before an issue becomes stale
|
||||
daysUntilStale: 28
|
||||
# Number of days of inactivity before a stale issue is closed
|
||||
daysUntilClose: 14
|
||||
# Issues with these labels will never be considered stale
|
||||
exemptLabels:
|
||||
- pinned
|
||||
- security
|
||||
# Label to use when marking an issue as stale
|
||||
staleLabel: stale
|
||||
# Comment to post when marking an issue as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Please
|
||||
update the ticket if this is still a problem on the latest release.
|
||||
# Comment to post when closing a stale issue. Set to `false` to disable
|
||||
closeComment: >
|
||||
Due to inactivity, this issue has been automatically closed. If this is
|
||||
still a problem on the latest release, please recreate the issue.
|
||||
1
.github/workflows/build-container.yml
vendored
1
.github/workflows/build-container.yml
vendored
@@ -18,6 +18,7 @@ on:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -9,6 +9,8 @@ models/ldm/stable-diffusion-v1/model.ckpt
|
||||
configs/models.user.yaml
|
||||
config/models.user.yml
|
||||
invokeai.init
|
||||
.version
|
||||
.last_model
|
||||
|
||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||
anaconda.sh
|
||||
|
||||
@@ -84,7 +84,7 @@ installing lots of models.
|
||||
|
||||
6. Wait while the installer does its thing. After installing the software,
|
||||
the installer will launch a script that lets you configure InvokeAI and
|
||||
select a set of starting image generaiton models.
|
||||
select a set of starting image generation models.
|
||||
|
||||
7. Find the folder that InvokeAI was installed into (it is not the
|
||||
same as the unpacked zip file directory!) The default location of this
|
||||
@@ -148,6 +148,11 @@ not supported.
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
```
|
||||
|
||||
_For non-GPU systems:_
|
||||
```terminal
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
_For Macintoshes, either Intel or M1/M2:_
|
||||
|
||||
```sh
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
# Invocations
|
||||
|
||||
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
|
||||
Invocations represent a single operation, its inputs, and its outputs. These
|
||||
operations and their outputs can be chained together to generate and modify
|
||||
images.
|
||||
|
||||
## Creating a new invocation
|
||||
|
||||
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
|
||||
To create a new invocation, either find the appropriate module file in
|
||||
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
|
||||
that folder. All invocations in that folder will be discovered and made
|
||||
available to the CLI and API automatically. Invocations make use of
|
||||
[typing](https://docs.python.org/3/library/typing.html) and
|
||||
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
|
||||
into the CLI and API.
|
||||
|
||||
An invocation looks like this:
|
||||
|
||||
@@ -41,34 +49,54 @@ class UpscaleInvocation(BaseInvocation):
|
||||
Each portion is important to implement correctly.
|
||||
|
||||
### Class definition and type
|
||||
|
||||
```py
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
"""Upscales an image."""
|
||||
type: Literal['upscale'] = 'upscale'
|
||||
```
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
|
||||
|
||||
All invocations must derive from `BaseInvocation`. They should have a docstring
|
||||
that declares what they do in a single, short line. They should also have a
|
||||
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
|
||||
is what the user will type on the CLI or use in the API to create this
|
||||
invocation. The `command_name` must be unique. The `type` must be assigned to
|
||||
the value of the literal in the type hint.
|
||||
|
||||
### Inputs
|
||||
|
||||
```py
|
||||
# Inputs
|
||||
image: Union[ImageField,None] = Field(description="The input image")
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||
level: Literal[2,4] = Field(default=2, description="The upscale level")
|
||||
```
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
|
||||
| Part | Value | Description |
|
||||
| ---- | ----- | ----------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
|
||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
|
||||
description, and validation information. For example:
|
||||
|
||||
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
|
||||
| Part | Value | Description |
|
||||
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Name | `strength` | This field is referred to as `strength` |
|
||||
| Type Hint | `float` | This field must be of type `float` |
|
||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||
|
||||
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
|
||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
|
||||
field to be parsed with `None` as a value, which enables linking to previous
|
||||
invocations. All fields should either provide a default value or allow `None` as
|
||||
a value, so that they can be overwritten with a linked output from another
|
||||
invocation.
|
||||
|
||||
The special type `ImageField` is also used here. All images are passed as
|
||||
`ImageField`, which protects them from pydantic validation errors (since images
|
||||
only ever come from links).
|
||||
|
||||
Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
If the `name` also matches, then the field can be **automatically linked** to a
|
||||
previous invocation by name and matching.
|
||||
|
||||
### Invoke Function
|
||||
|
||||
```py
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||
@@ -88,13 +116,22 @@ Finally, note that for all linking, the `type` of the linked fields must match.
|
||||
image = ImageField(image_type = image_type, image_name = image_name)
|
||||
)
|
||||
```
|
||||
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
|
||||
|
||||
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
|
||||
The `invoke` function is the last portion of an invocation. It is provided an
|
||||
`InvocationContext` which contains services to perform work as well as a
|
||||
`session_id` for use as needed. It should return a class with output values that
|
||||
derives from `BaseInvocationOutput`.
|
||||
|
||||
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
|
||||
Before being called, the invocation will have all of its fields set from
|
||||
defaults, inputs, and finally links (overriding in that order).
|
||||
|
||||
Assume that this invocation may be running simultaneously with other
|
||||
invocations, may be running on another machine, or in other interesting
|
||||
scenarios. If you need functionality, please provide it as a service in the
|
||||
`InvocationServices` class, and make sure it can be overridden.
|
||||
|
||||
### Outputs
|
||||
|
||||
```py
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
@@ -102,4 +139,64 @@ class ImageOutput(BaseInvocationOutput):
|
||||
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
```
|
||||
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.
|
||||
|
||||
Output classes look like an invocation class without the invoke method. Prefer
|
||||
to use an existing output class if available, and prefer to name inputs the same
|
||||
as outputs when possible, to promote automatic invocation linking.
|
||||
|
||||
## Schema Generation
|
||||
|
||||
Invocation, output and related classes are used to generate an OpenAPI schema.
|
||||
|
||||
### Required Properties
|
||||
|
||||
The schema generation treat all properties with default values as optional. This
|
||||
makes sense internally, but when when using these classes via the generated
|
||||
schema, we end up with e.g. the `ImageOutput` class having its `image` property
|
||||
marked as optional.
|
||||
|
||||
We know that this property will always be present, so the additional logic
|
||||
needed to always check if the property exists adds a lot of extraneous cruft.
|
||||
|
||||
To fix this, we can leverage `pydantic`'s
|
||||
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
|
||||
to mark properties that we know will always be present as required.
|
||||
|
||||
Here's that `ImageOutput` class, without the needed schema customisation:
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
```
|
||||
|
||||
The generated OpenAPI schema, and all clients/types generated from it, will have
|
||||
the `type` and `image` properties marked as optional, even though we know they
|
||||
will always have a value by the time we can interact with them via the API.
|
||||
|
||||
Here's the same class, but with the schema customisation added:
|
||||
|
||||
```python
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
type: Literal["image"] = "image"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
'type',
|
||||
'image',
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The resultant schema (and any API client or types generated from it) will now
|
||||
have see `type` as string literal `"image"` and `image` as an `ImageField`
|
||||
object.
|
||||
|
||||
See this `pydantic` issue for discussion on this solution:
|
||||
<https://github.com/pydantic/pydantic/discussions/4577>
|
||||
|
||||
@@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
|
||||
At installation time, InvokeAI will ask whether the checker should be
|
||||
activated by default (neither argument given on the command line). The
|
||||
response is stored in the InvokeAI initialization file (usually
|
||||
`.invokeai` in your home directory). You can change the default at any
|
||||
`invokeai.init` in your home directory). You can change the default at any
|
||||
time by opening this file in a text editor and commenting or
|
||||
uncommenting the line `--nsfw_checker`.
|
||||
|
||||
|
||||
@@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
|
||||
masking option:
|
||||
|
||||
```bash
|
||||
invoke> a fluffy cat eating a hotdot
|
||||
invoke> a fluffy cat eating a hotdog
|
||||
Outputs:
|
||||
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
|
||||
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
|
||||
|
||||
@@ -50,7 +50,7 @@ subset that are currently installed are found in
|
||||
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|
||||
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|
||||
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|
||||
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|
||||
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|
||||
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |
|
||||
|
||||
@@ -456,7 +456,7 @@ def get_torch_source() -> (Union[str, None],str):
|
||||
optional_modules = None
|
||||
if OS == "Linux":
|
||||
if device == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.2"
|
||||
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||
elif device == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
|
||||
@@ -3,10 +3,14 @@
|
||||
import os
|
||||
from argparse import Namespace
|
||||
|
||||
from ..services.default_graphs import create_system_graphs
|
||||
|
||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.model_manager_initializer import get_model_manager
|
||||
from ..services.restoration_services import RestorationServices
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@@ -54,7 +58,9 @@ class ApiDependencies:
|
||||
os.path.join(os.path.dirname(__file__), "../../../../outputs")
|
||||
)
|
||||
|
||||
images = DiskImageStorage(output_folder)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||
|
||||
images = DiskImageStorage(f'{output_folder}/images')
|
||||
|
||||
# TODO: build a file/path manager?
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
@@ -62,8 +68,12 @@ class ApiDependencies:
|
||||
services = InvocationServices(
|
||||
model_manager=get_model_manager(config),
|
||||
events=events,
|
||||
latents=latents,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@@ -71,6 +81,8 @@ class ApiDependencies:
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@staticmethod
|
||||
|
||||
14
invokeai/app/api/models/images.py
Normal file
14
invokeai/app/api/models/images.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.models.image import ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
|
||||
|
||||
class ImageResponse(BaseModel):
|
||||
"""The response type for images"""
|
||||
|
||||
image_type: ImageType = Field(description="The type of the image")
|
||||
image_name: str = Field(description="The name of the image")
|
||||
image_url: str = Field(description="The url of the image")
|
||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||
metadata: ImageMetadata = Field(description="The image's metadata")
|
||||
@@ -1,18 +1,20 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from fastapi import Path, Request, UploadFile
|
||||
from fastapi import Path, Query, Request, UploadFile
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
from ...services.image_storage import ImageType
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||
|
||||
|
||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
||||
async def get_image(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
@@ -23,6 +25,16 @@ async def get_image(
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
|
||||
async def get_thumbnail(
|
||||
image_type: ImageType = Path(description="The type of image to get"),
|
||||
image_name: str = Path(description="The name of the image to get"),
|
||||
):
|
||||
"""Gets a thumbnail"""
|
||||
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
|
||||
return FileResponse(filename)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/uploads/",
|
||||
@@ -43,14 +55,30 @@ async def upload_image(file: UploadFile, request: Request):
|
||||
# Error opening the image
|
||||
return Response(status_code=415)
|
||||
|
||||
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||
|
||||
return Response(
|
||||
status_code=201,
|
||||
headers={
|
||||
"Location": request.url_for(
|
||||
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
||||
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images",
|
||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||
)
|
||||
async def list_images(
|
||||
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
"""Gets a list of images"""
|
||||
result = ApiDependencies.invoker.services.images.list(
|
||||
image_type, page, per_page
|
||||
)
|
||||
return result
|
||||
|
||||
251
invokeai/app/api/routers/models.py
Normal file
251
invokeai/app/api/routers/models.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from pathlib import Path
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
|
||||
from invokeai.backend.args import Args
|
||||
|
||||
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
|
||||
class VaeRepo(BaseModel):
|
||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
||||
path: Optional[str] = Field(description="The path to the VAE")
|
||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
description: Optional[str] = Field(description="A description of the model")
|
||||
|
||||
class CkptModelInfo(ModelInfo):
|
||||
format: Literal['ckpt'] = 'ckpt'
|
||||
|
||||
config: str = Field(description="The path to the model config")
|
||||
weights: str = Field(description="The path to the model weights")
|
||||
vae: str = Field(description="The path to the model VAE")
|
||||
width: Optional[int] = Field(description="The width of the model")
|
||||
height: Optional[int] = Field(description="The height of the model")
|
||||
|
||||
class DiffusersModelInfo(ModelInfo):
|
||||
format: Literal['diffusers'] = 'diffusers'
|
||||
|
||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||
path: Optional[str] = Field(description="The path to the model")
|
||||
|
||||
class CreateModelRequest(BaseModel):
|
||||
name: str = Field(description="The name of the model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: CkptModelInfo = Field(description="The converted model info")
|
||||
save_location: str = Field(description="The path to save the converted model weights")
|
||||
|
||||
|
||||
class ConvertedModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList }},
|
||||
)
|
||||
async def list_models() -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
operation_id="update_model",
|
||||
responses={200: {"status": "success"}},
|
||||
)
|
||||
async def update_model(
|
||||
model_request: CreateModelRequest
|
||||
) -> CreateModelResponse:
|
||||
""" Add Model """
|
||||
model_request_info = model_request.info
|
||||
info_dict = model_request_info.dict()
|
||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
model_name=model_request.name,
|
||||
model_attributes=info_dict,
|
||||
clobber=True,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {
|
||||
"description": "Model deleted successfully"
|
||||
},
|
||||
404: {
|
||||
"description": "Model not found"
|
||||
}
|
||||
},
|
||||
)
|
||||
async def delete_model(model_name: str) -> None:
|
||||
"""Delete Model"""
|
||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
||||
model_exists = model_name in model_names
|
||||
|
||||
# check if model exists
|
||||
print(f">> Checking for model {model_name}...")
|
||||
|
||||
if model_exists:
|
||||
print(f">> Deleting Model: {model_name}")
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
||||
print(f">> Model Deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
else:
|
||||
print(f">> Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
# if model_info := self.generate.model_manager.model_info(
|
||||
# model_name=model_to_convert["model_name"]
|
||||
# ):
|
||||
# if "weights" in model_info:
|
||||
# ckpt_path = Path(model_info["weights"])
|
||||
# original_config_file = Path(model_info["config"])
|
||||
# model_name = model_to_convert["model_name"]
|
||||
# model_description = model_info["description"]
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
||||
# )
|
||||
# else:
|
||||
# self.socketio.emit(
|
||||
# "error", {"message": "Could not retrieve model info."}
|
||||
# )
|
||||
|
||||
# if not ckpt_path.is_absolute():
|
||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
||||
|
||||
# if original_config_file and not original_config_file.is_absolute():
|
||||
# original_config_file = Path(Globals.root, original_config_file)
|
||||
|
||||
# diffusers_path = Path(
|
||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if model_to_convert["save_location"] == "root":
|
||||
# diffusers_path = Path(
|
||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if (
|
||||
# model_to_convert["save_location"] == "custom"
|
||||
# and model_to_convert["custom_location"] is not None
|
||||
# ):
|
||||
# diffusers_path = Path(
|
||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
||||
# )
|
||||
|
||||
# if diffusers_path.exists():
|
||||
# shutil.rmtree(diffusers_path)
|
||||
|
||||
# self.generate.model_manager.convert_and_import(
|
||||
# ckpt_path,
|
||||
# diffusers_path,
|
||||
# model_name=model_name,
|
||||
# model_description=model_description,
|
||||
# vae=None,
|
||||
# original_config_file=original_config_file,
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
# socketio.emit(
|
||||
# "modelConverted",
|
||||
# {
|
||||
# "new_model_name": model_name,
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Model Converted: {model_name}")
|
||||
# except Exception as e:
|
||||
# self.handle_exceptions(e)
|
||||
|
||||
# @socketio.on("mergeDiffusersModels")
|
||||
# def merge_diffusers_models(model_merge_info: dict):
|
||||
# try:
|
||||
# models_to_merge = model_merge_info["models_to_merge"]
|
||||
# model_ids_or_paths = [
|
||||
# self.generate.model_manager.model_name_or_path(x)
|
||||
# for x in models_to_merge
|
||||
# ]
|
||||
# merged_pipe = merge_diffusion_models(
|
||||
# model_ids_or_paths,
|
||||
# model_merge_info["alpha"],
|
||||
# model_merge_info["interp"],
|
||||
# model_merge_info["force"],
|
||||
# )
|
||||
|
||||
# dump_path = global_models_dir() / "merged_models"
|
||||
# if model_merge_info["model_merge_save_path"] is not None:
|
||||
# dump_path = Path(model_merge_info["model_merge_save_path"])
|
||||
|
||||
# os.makedirs(dump_path, exist_ok=True)
|
||||
# dump_path = dump_path / model_merge_info["merged_model_name"]
|
||||
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
|
||||
# merged_model_config = dict(
|
||||
# model_name=model_merge_info["merged_model_name"],
|
||||
# description=f'Merge of models {", ".join(models_to_merge)}',
|
||||
# commit_to_conf=opt.conf,
|
||||
# )
|
||||
|
||||
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||
# "vae", None
|
||||
# ):
|
||||
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||
# merged_model_config.update(vae=vae)
|
||||
|
||||
# self.generate.model_manager.import_diffuser_model(
|
||||
# dump_path, **merged_model_config
|
||||
# )
|
||||
# new_model_list = self.generate.model_manager.list_models()
|
||||
|
||||
# socketio.emit(
|
||||
# "modelsMerged",
|
||||
# {
|
||||
# "merged_models": models_to_merge,
|
||||
# "merged_model_name": model_merge_info["merged_model_name"],
|
||||
# "model_list": new_model_list,
|
||||
# "update": True,
|
||||
# },
|
||||
# )
|
||||
# print(f">> Models Merged: {models_to_merge}")
|
||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||
# except Exception as e:
|
||||
@@ -51,7 +51,7 @@ async def list_sessions(
|
||||
query: str = Query(default="", description="The query string to search for"),
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if filter == "":
|
||||
if query == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||
page, per_page
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ from pydantic.schema import schema
|
||||
|
||||
from ..backend import Args
|
||||
from .api.dependencies import ApiDependencies
|
||||
from .api.routers import images, sessions
|
||||
from .api.routers import images, sessions, models
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
@@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
|
||||
@@ -4,12 +4,43 @@ from abc import ABC, abstractmethod
|
||||
import argparse
|
||||
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||
from pydantic import BaseModel, Field
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ..invocations.baseinvocation import BaseInvocation
|
||||
from ..invocations.image import ImageField
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.graph import GraphExecutionState, LibraryGraph, GraphInvocation, Edge
|
||||
from ..services.invoker import Invoker
|
||||
|
||||
|
||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
|
||||
def add_parsers(
|
||||
subparsers,
|
||||
commands: list[type],
|
||||
@@ -34,30 +65,26 @@ def add_parsers(
|
||||
if name in exclude_fields:
|
||||
continue
|
||||
|
||||
if get_origin(field.type_) == Literal:
|
||||
allowed_values = get_args(field.type_)
|
||||
allowed_types = set()
|
||||
for val in allowed_values:
|
||||
allowed_types.add(type(val))
|
||||
allowed_types_list = list(allowed_types)
|
||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||
add_field_argument(command_parser, name, field)
|
||||
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field_type,
|
||||
default=field.default,
|
||||
choices=allowed_values,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
else:
|
||||
command_parser.add_argument(
|
||||
f"--{name}",
|
||||
dest=name,
|
||||
type=field.type_,
|
||||
default=field.default,
|
||||
help=field.field_info.description,
|
||||
)
|
||||
|
||||
def add_graph_parsers(
|
||||
subparsers,
|
||||
graphs: list[LibraryGraph],
|
||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
||||
):
|
||||
for graph in graphs:
|
||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||
|
||||
if add_arguments is not None:
|
||||
add_arguments(command_parser)
|
||||
|
||||
# Add arguments for inputs
|
||||
for exposed_input in graph.exposed_inputs:
|
||||
node = graph.graph.get_node(exposed_input.node_path)
|
||||
field = node.__fields__[exposed_input.field]
|
||||
default_override = getattr(node, exposed_input.field)
|
||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
||||
|
||||
|
||||
class CliContext:
|
||||
@@ -65,17 +92,38 @@ class CliContext:
|
||||
session: GraphExecutionState
|
||||
parser: argparse.ArgumentParser
|
||||
defaults: dict[str, Any]
|
||||
graph_nodes: dict[str, str]
|
||||
nodes_added: list[str]
|
||||
|
||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||
self.invoker = invoker
|
||||
self.session = session
|
||||
self.parser = parser
|
||||
self.defaults = dict()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
|
||||
def get_session(self):
|
||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||
return self.session
|
||||
|
||||
def reset(self):
|
||||
self.session = self.invoker.create_execution_state()
|
||||
self.graph_nodes = dict()
|
||||
self.nodes_added = list()
|
||||
# Leave defaults unchanged
|
||||
|
||||
def add_node(self, node: BaseInvocation):
|
||||
self.get_session()
|
||||
self.session.graph.add_node(node)
|
||||
self.nodes_added.append(node.id)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
def add_edge(self, edge: Edge):
|
||||
self.get_session()
|
||||
self.session.add_edge(edge)
|
||||
self.invoker.services.graph_execution_manager.set(self.session)
|
||||
|
||||
|
||||
class ExitCli(Exception):
|
||||
"""Exception to exit the CLI"""
|
||||
@@ -200,3 +248,39 @@ class SetDefaultCommand(BaseCommand):
|
||||
del context.defaults[self.field]
|
||||
else:
|
||||
context.defaults[self.field] = self.value
|
||||
|
||||
|
||||
class DrawGraphCommand(BaseCommand):
|
||||
"""Debugs a graph"""
|
||||
type: Literal['draw_graph'] = 'draw_graph'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
|
||||
class DrawExecutionGraphCommand(BaseCommand):
|
||||
"""Debugs an execution graph"""
|
||||
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||
|
||||
def run(self, context: CliContext) -> None:
|
||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||
nxgraph = session.execution_graph.nx_graph_flat()
|
||||
|
||||
# Draw the networkx graph
|
||||
plt.figure(figsize=(20, 20))
|
||||
pos = nx.spectral_layout(nxgraph)
|
||||
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
|
||||
nx.draw_networkx_edges(nxgraph, pos, width=2)
|
||||
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
|
||||
plt.axis("off")
|
||||
plt.show()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import time
|
||||
from typing import (
|
||||
@@ -12,15 +13,20 @@ from typing import (
|
||||
from pydantic import BaseModel
|
||||
from pydantic.fields import Field
|
||||
|
||||
from .services.default_graphs import create_system_graphs
|
||||
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from ..backend import Args
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers, get_graph_execution_history
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.model_manager_initializer import get_model_manager
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
||||
from .services.graph import Edge, EdgeConnection, ExposedNodeInput, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||
from .services.default_graphs import default_text_to_image_graph_id
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
@@ -44,7 +50,7 @@ def add_invocation_args(command_parser):
|
||||
"-l",
|
||||
action="append",
|
||||
nargs=3,
|
||||
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
|
||||
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
|
||||
)
|
||||
|
||||
command_parser.add_argument(
|
||||
@@ -55,7 +61,7 @@ def add_invocation_args(command_parser):
|
||||
)
|
||||
|
||||
|
||||
def get_command_parser() -> argparse.ArgumentParser:
|
||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
||||
# Create invocation parser
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -73,20 +79,72 @@ def get_command_parser() -> argparse.ArgumentParser:
|
||||
commands = BaseCommand.get_all_subclasses()
|
||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||
|
||||
# Create subparsers for exposed CLI graphs
|
||||
# TODO: add a way to identify these graphs
|
||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
class NodeField():
|
||||
alias: str
|
||||
node_path: str
|
||||
field: str
|
||||
field_type: type
|
||||
|
||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
||||
self.alias = alias
|
||||
self.node_path = node_path
|
||||
self.field = field
|
||||
self.field_type = field_type
|
||||
|
||||
|
||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||
|
||||
|
||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||
|
||||
|
||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||
"""Gets the node field for the specified field alias"""
|
||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||
node_output_type = node_type.get_output_type()
|
||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||
|
||||
|
||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the inputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
||||
|
||||
|
||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||
"""Gets the outputs for the specified invocation from the context"""
|
||||
node_type = type(invocation)
|
||||
if node_type is not GraphInvocation:
|
||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
||||
else:
|
||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||
|
||||
|
||||
def generate_matching_edges(
|
||||
a: BaseInvocation, b: BaseInvocation
|
||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||
) -> list[Edge]:
|
||||
"""Generates all possible edges between two invocations"""
|
||||
atype = type(a)
|
||||
btype = type(b)
|
||||
|
||||
aoutputtype = atype.get_output_type()
|
||||
|
||||
afields = get_type_hints(aoutputtype)
|
||||
bfields = get_type_hints(btype)
|
||||
afields = get_node_outputs(a, context)
|
||||
bfields = get_node_inputs(b, context)
|
||||
|
||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||
|
||||
@@ -94,12 +152,15 @@ def generate_matching_edges(
|
||||
invalid_fields = set(["type", "id"])
|
||||
matching_fields = matching_fields.difference(invalid_fields)
|
||||
|
||||
# Validate types
|
||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||
|
||||
edges = [
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=a.id, field=field),
|
||||
destination=EdgeConnection(node_id=b.id, field=field)
|
||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||
)
|
||||
for field in matching_fields
|
||||
for alias in matching_fields
|
||||
]
|
||||
return edges
|
||||
|
||||
@@ -149,8 +210,12 @@ def invoke_cli():
|
||||
services = InvocationServices(
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
images=DiskImageStorage(output_folder),
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||
images=DiskImageStorage(f'{output_folder}/images'),
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||
filename=db_location, table_name="graph_executions"
|
||||
),
|
||||
@@ -158,9 +223,14 @@ def invoke_cli():
|
||||
restoration=RestorationServices(config),
|
||||
)
|
||||
|
||||
system_graphs = create_system_graphs(services.graph_library)
|
||||
system_graph_names = set([g.name for g in system_graphs])
|
||||
|
||||
invoker = Invoker(services)
|
||||
session: GraphExecutionState = invoker.create_execution_state()
|
||||
parser = get_command_parser()
|
||||
parser = get_command_parser(services)
|
||||
|
||||
re_negid = re.compile('^-[0-9]+$')
|
||||
|
||||
# Uncomment to print out previous sessions at startup
|
||||
# print(services.session_manager.list())
|
||||
@@ -176,11 +246,12 @@ def invoke_cli():
|
||||
|
||||
try:
|
||||
# Refresh the state of the session
|
||||
history = list(get_graph_execution_history(context.session))
|
||||
#history = list(get_graph_execution_history(context.session))
|
||||
history = list(reversed(context.nodes_added))
|
||||
|
||||
# Split the command for piping
|
||||
cmds = cmd_input.split("|")
|
||||
start_id = len(history)
|
||||
start_id = len(context.nodes_added)
|
||||
current_id = start_id
|
||||
new_invocations = list()
|
||||
for cmd in cmds:
|
||||
@@ -196,8 +267,24 @@ def invoke_cli():
|
||||
args[field_name] = field_default
|
||||
|
||||
# Parse invocation
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
command: CliCommand = None # type:ignore
|
||||
system_graph: LibraryGraph|None = None
|
||||
if args['type'] in system_graph_names:
|
||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||
for exposed_input in system_graph.exposed_inputs:
|
||||
if exposed_input.alias in args:
|
||||
node = invocation.graph.get_node(exposed_input.node_path)
|
||||
field = exposed_input.field
|
||||
setattr(node, field, args[exposed_input.alias])
|
||||
command = CliCommand(command = invocation)
|
||||
context.graph_nodes[invocation.id] = system_graph.id
|
||||
else:
|
||||
args["id"] = current_id
|
||||
command = CliCommand(command=args)
|
||||
|
||||
if command is None:
|
||||
continue
|
||||
|
||||
# Run any CLI commands immediately
|
||||
if isinstance(command.command, BaseCommand):
|
||||
@@ -208,6 +295,7 @@ def invoke_cli():
|
||||
command.command.run(context)
|
||||
continue
|
||||
|
||||
# TODO: handle linking with library graphs
|
||||
# Pipe previous command output (if there was a previous command)
|
||||
edges: list[Edge] = list()
|
||||
if len(history) > 0 or current_id != start_id:
|
||||
@@ -220,16 +308,20 @@ def invoke_cli():
|
||||
else context.session.graph.get_node(from_id)
|
||||
)
|
||||
matching_edges = generate_matching_edges(
|
||||
from_node, command.command
|
||||
from_node, command.command, context
|
||||
)
|
||||
edges.extend(matching_edges)
|
||||
|
||||
# Parse provided links
|
||||
if "link_node" in args and args["link_node"]:
|
||||
for link in args["link_node"]:
|
||||
link_node = context.session.graph.get_node(link)
|
||||
node_id = link
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
link_node = context.session.graph.get_node(node_id)
|
||||
matching_edges = generate_matching_edges(
|
||||
link_node, command.command
|
||||
link_node, command.command, context
|
||||
)
|
||||
matching_destinations = [e.destination for e in matching_edges]
|
||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||
@@ -237,13 +329,20 @@ def invoke_cli():
|
||||
|
||||
if "link" in args and args["link"]:
|
||||
for link in args["link"]:
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
||||
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||
|
||||
node_id = link[0]
|
||||
if re_negid.match(node_id):
|
||||
node_id = str(current_id + int(node_id))
|
||||
|
||||
# TODO: handle missing input/output
|
||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
||||
|
||||
edges.append(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=link[1], field=link[0]),
|
||||
destination=EdgeConnection(
|
||||
node_id=command.command.id, field=link[2]
|
||||
)
|
||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -252,10 +351,10 @@ def invoke_cli():
|
||||
current_id = current_id + 1
|
||||
|
||||
# Add the node to the session
|
||||
context.session.add_node(command.command)
|
||||
context.add_node(command.command)
|
||||
for edge in edges:
|
||||
print(edge)
|
||||
context.session.add_edge(edge)
|
||||
context.add_edge(edge)
|
||||
|
||||
# Execute all remaining nodes
|
||||
invoke_all(context)
|
||||
@@ -267,7 +366,7 @@ def invoke_cli():
|
||||
except SessionError:
|
||||
# Start a new session
|
||||
print("Session error: creating a new session")
|
||||
context.session = context.invoker.create_execution_state()
|
||||
context.reset()
|
||||
|
||||
except ExitCli:
|
||||
break
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
from typing import get_args, get_type_hints
|
||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -76,3 +76,56 @@ class BaseInvocation(ABC, BaseModel):
|
||||
#fmt: off
|
||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: figure out a better way to provide these hints
|
||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
||||
class UIConfig(TypedDict, total=False):
|
||||
type_hints: Dict[
|
||||
str,
|
||||
Literal[
|
||||
"integer",
|
||||
"float",
|
||||
"boolean",
|
||||
"string",
|
||||
"enum",
|
||||
"image",
|
||||
"latents",
|
||||
"model",
|
||||
],
|
||||
]
|
||||
tags: List[str]
|
||||
|
||||
|
||||
class CustomisedSchemaExtra(TypedDict):
|
||||
ui: UIConfig
|
||||
|
||||
|
||||
class InvocationConfig(BaseModel.Config):
|
||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
||||
|
||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
||||
|
||||
`tags`
|
||||
- A list of strings, used to categorise invocations.
|
||||
|
||||
`type_hints`
|
||||
- A dict of field types which override the types in the invocation definition.
|
||||
- Each key should be the name of one of the invocation's fields.
|
||||
- Each value should be one of the valid types:
|
||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
||||
|
||||
```python
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"initial_image": "image",
|
||||
},
|
||||
},
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
schema_extra: CustomisedSchemaExtra
|
||||
|
||||
50
invokeai/app/invocations/collections.py
Normal file
50
invokeai/app/invocations/collections.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import numpy.random
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
||||
from .image import ImageField, ImageOutput
|
||||
|
||||
|
||||
class IntCollectionOutput(BaseInvocationOutput):
|
||||
"""A collection of integers"""
|
||||
|
||||
type: Literal["int_collection"] = "int_collection"
|
||||
|
||||
# Outputs
|
||||
collection: list[int] = Field(default=[], description="The int collection")
|
||||
|
||||
|
||||
class RangeInvocation(BaseInvocation):
|
||||
"""Creates a range"""
|
||||
|
||||
type: Literal["range"] = "range"
|
||||
|
||||
# Inputs
|
||||
start: int = Field(default=0, description="The start of the range")
|
||||
stop: int = Field(default=10, description="The stop of the range")
|
||||
step: int = Field(default=1, description="The step of the range")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||
|
||||
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
"""Creates a collection of random numbers"""
|
||||
|
||||
type: Literal["random_range"] = "random_range"
|
||||
|
||||
# Inputs
|
||||
low: int = Field(default=0, description="The inclusive low value")
|
||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = Field(default=1, description="The number of values to generate")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
|
||||
@@ -5,14 +5,26 @@ from typing import Literal
|
||||
import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation):
|
||||
class CvInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["cv", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||
"""Simple inpaint using opencv."""
|
||||
#fmt: off
|
||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||
|
||||
@@ -6,21 +6,37 @@ from typing import Literal, Optional, Union
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.invocations.util.get_model import choose_model
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ..util.util import diffusers_step_callback_adapter, CanceledException
|
||||
from ..models.exceptions import CanceledException
|
||||
from ..util.step_callback import diffusers_step_callback_adapter
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||
|
||||
|
||||
class SDImageInvocation(BaseModel):
|
||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["stable-diffusion", "image"],
|
||||
"type_hints": {
|
||||
"model": "model",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
]
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation):
|
||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
|
||||
type: Literal["txt2img"] = "txt2img"
|
||||
@@ -34,7 +50,7 @@ class TextToImageInvocation(BaseInvocation):
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
@@ -58,16 +74,9 @@ class TextToImageInvocation(BaseInvocation):
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# def step_callback(state: PipelineIntermediateState):
|
||||
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
# raise CanceledException
|
||||
# self.dispatch_progress(context, state.latents, state.step)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
# (right now uses whatever current model is set in model manager)
|
||||
model= context.services.model_manager.get_model()
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=partial(self.dispatch_progress, context),
|
||||
@@ -134,9 +143,8 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
mask = None
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
model = context.services.model_manager.get_model()
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
outputs = Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_image=image,
|
||||
@@ -210,9 +218,8 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
)
|
||||
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
model = context.services.model_manager.get_model()
|
||||
model = choose_model(context.services.model_manager, self.model)
|
||||
|
||||
outputs = Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
|
||||
@@ -7,19 +7,20 @@ import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from ..models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: str = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
@@ -92,7 +93,7 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class CropImageInvocation(BaseInvocation):
|
||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
#fmt: off
|
||||
type: Literal["crop"] = "crop"
|
||||
@@ -125,7 +126,7 @@ class CropImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class PasteImageInvocation(BaseInvocation):
|
||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Pastes an image into another image."""
|
||||
#fmt: off
|
||||
type: Literal["paste"] = "paste"
|
||||
@@ -149,7 +150,7 @@ class PasteImageInvocation(BaseInvocation):
|
||||
None
|
||||
if self.mask is None
|
||||
else ImageOps.invert(
|
||||
services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
context.services.images.get(self.mask.image_type, self.mask.image_name)
|
||||
)
|
||||
)
|
||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||
@@ -175,7 +176,7 @@ class PasteImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class MaskFromAlphaInvocation(BaseInvocation):
|
||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
#fmt: off
|
||||
type: Literal["tomask"] = "tomask"
|
||||
@@ -202,7 +203,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
|
||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||
|
||||
|
||||
class BlurInvocation(BaseInvocation):
|
||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Blurs an image"""
|
||||
|
||||
#fmt: off
|
||||
@@ -236,7 +237,7 @@ class BlurInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class LerpInvocation(BaseInvocation):
|
||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
#fmt: off
|
||||
type: Literal["lerp"] = "lerp"
|
||||
@@ -267,7 +268,7 @@ class LerpInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
class InverseLerpInvocation(BaseInvocation):
|
||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
#fmt: off
|
||||
type: Literal["ilerp"] = "ilerp"
|
||||
|
||||
361
invokeai/app/invocations/latent.py
Normal file
361
invokeai/app/invocations/latent.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import random
|
||||
from typing import Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import torch
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from invokeai.app.invocations.util.get_model import choose_model
|
||||
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
|
||||
|
||||
from ...backend.model_management.model_manager import ModelManager
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents field used for passing latents between invocations"""
|
||||
|
||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||
|
||||
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output latents"""
|
||||
#fmt: off
|
||||
type: Literal["latent_output"] = "latent_output"
|
||||
latents: LatentsField = Field(default=None, description="The output latents")
|
||||
#fmt: on
|
||||
|
||||
class NoiseOutput(BaseInvocationOutput):
|
||||
"""Invocation noise output"""
|
||||
#fmt: off
|
||||
type: Literal["noise_output"] = "noise_output"
|
||||
noise: LatentsField = Field(default=None, description="The output noise")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: this seems like a hack
|
||||
scheduler_map = dict(
|
||||
ddim=diffusers.DDIMScheduler,
|
||||
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||
k_euler=diffusers.EulerDiscreteScheduler,
|
||||
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||
k_heun=diffusers.HeunDiscreteScheduler,
|
||||
k_lms=diffusers.LMSDiscreteScheduler,
|
||||
plms=diffusers.PNDMScheduler,
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(scheduler_map.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
|
||||
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
|
||||
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
|
||||
# limit noise to only the diffusion image channels, not the mask channels
|
||||
input_channels = min(latent_channels, 4)
|
||||
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
|
||||
generator = torch.Generator(device=use_device).manual_seed(seed)
|
||||
x = torch.randn(
|
||||
[
|
||||
1,
|
||||
input_channels,
|
||||
height // downsampling_factor,
|
||||
width // downsampling_factor,
|
||||
],
|
||||
dtype=torch_dtype(device),
|
||||
device=use_device,
|
||||
generator=generator,
|
||||
).to(device)
|
||||
# if self.perlin > 0.0:
|
||||
# perlin_noise = self.get_perlin_noise(
|
||||
# width // self.downsampling_factor, height // self.downsampling_factor
|
||||
# )
|
||||
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||
return x
|
||||
|
||||
|
||||
def random_seed():
|
||||
return random.randint(0, np.iinfo(np.uint32).max)
|
||||
|
||||
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
type: Literal["noise"] = "noise"
|
||||
|
||||
# Inputs
|
||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "noise"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
device = torch.device(choose_torch_device())
|
||||
noise = get_noise(self.width, self.height, device, self.seed)
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, noise)
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=name)
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from a prompt."""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
|
||||
# Inputs
|
||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||
# fmt: off
|
||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||
# fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||
raise CanceledException
|
||||
|
||||
step = intermediate_state.step
|
||||
if intermediate_state.predicted_original is not None:
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
# but also their estimate so far of what the de-noised latents will be.
|
||||
sample = intermediate_state.predicted_original
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||
|
||||
|
||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||
model_info = choose_model(model_manager, self.model)
|
||||
model_name = model_info['model_name']
|
||||
model_hash = model_info['hash']
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model.scheduler = get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=self.scheduler
|
||||
)
|
||||
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
for component in [model.unet, model.vae]:
|
||||
configure_model_padding(component,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
else:
|
||||
configure_model_padding(model,
|
||||
self.seamless,
|
||||
self.seamless_axes
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
self.cfg_scale,
|
||||
extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0,#threshold,
|
||||
warmup=0.2,#warmup,
|
||||
h_symmetry_time_pct=None,#h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None#v_symmetry_time_pct,
|
||||
),
|
||||
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
|
||||
return conditioning_data
|
||||
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
|
||||
|
||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||
"""Generates latents using latents as base image."""
|
||||
|
||||
type: Literal["l2l"] = "l2l"
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
latent = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, state)
|
||||
|
||||
model = self.get_model(context.services.model_manager)
|
||||
conditioning_data = self.get_conditioning_data(model)
|
||||
|
||||
# TODO: Verify the noise is the right size
|
||||
|
||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||
latent, device=model.device, dtype=latent.dtype
|
||||
)
|
||||
|
||||
timesteps, _ = model.get_img2img_timesteps(
|
||||
self.steps,
|
||||
self.strength,
|
||||
device=model.device,
|
||||
)
|
||||
|
||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||
latents=initial_latents,
|
||||
timesteps=timesteps,
|
||||
noise=noise,
|
||||
num_inference_steps=self.steps,
|
||||
conditioning_data=conditioning_data,
|
||||
callback=step_callback
|
||||
)
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||
context.services.latents.set(name, result_latents)
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=name)
|
||||
)
|
||||
|
||||
|
||||
# Latent to image
|
||||
class LatentsToImageInvocation(BaseInvocation):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
type: Literal["l2i"] = "l2i"
|
||||
|
||||
# Inputs
|
||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||
model: str = Field(default="", description="The model to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "image"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
75
invokeai/app/invocations/math.py
Normal file
75
invokeai/app/invocations/math.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
|
||||
|
||||
class MathInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all math invocations with additional config"""
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["math"],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class IntOutput(BaseInvocationOutput):
|
||||
"""An integer output"""
|
||||
#fmt: off
|
||||
type: Literal["int_output"] = "int_output"
|
||||
a: int = Field(default=None, description="The output integer")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Adds two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["add"] = "add"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a + self.b)
|
||||
|
||||
|
||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Subtracts two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["sub"] = "sub"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a - self.b)
|
||||
|
||||
|
||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Multiplies two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["mul"] = "mul"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a * self.b)
|
||||
|
||||
|
||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
||||
"""Divides two numbers"""
|
||||
#fmt: off
|
||||
type: Literal["div"] = "div"
|
||||
a: int = Field(default=0, description="The first number")
|
||||
b: int = Field(default=0, description="The second number")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=int(self.a / self.b))
|
||||
18
invokeai/app/invocations/params.py
Normal file
18
invokeai/app/invocations/params.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from .math import IntOutput
|
||||
|
||||
# Pass-through parameter nodes - used by subgraphs
|
||||
|
||||
class ParamIntInvocation(BaseInvocation):
|
||||
"""An integer parameter"""
|
||||
#fmt: off
|
||||
type: Literal["param_int"] = "param_int"
|
||||
a: int = Field(default=0, description="The integer value")
|
||||
#fmt: on
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||
return IntOutput(a=self.a)
|
||||
@@ -3,10 +3,10 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
|
||||
class RestoreFaceInvocation(BaseInvocation):
|
||||
"""Restores faces in an image."""
|
||||
@@ -18,6 +18,14 @@ class RestoreFaceInvocation(BaseInvocation):
|
||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||
#fmt: on
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["restoration", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
|
||||
@@ -5,10 +5,10 @@ from typing import Literal, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from ..services.image_storage import ImageType
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||
from .image import ImageOutput
|
||||
|
||||
|
||||
class UpscaleInvocation(BaseInvocation):
|
||||
@@ -22,6 +22,15 @@ class UpscaleInvocation(BaseInvocation):
|
||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||
#fmt: on
|
||||
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["upscaling", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get(
|
||||
self.image.image_type, self.image.image_name
|
||||
|
||||
11
invokeai/app/invocations/util/get_model.py
Normal file
11
invokeai/app/invocations/util/get_model.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.backend.model_management.model_manager import ModelManager
|
||||
|
||||
|
||||
def choose_model(model_manager: ModelManager, model_name: str):
|
||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
||||
if model_manager.valid_model(model_name):
|
||||
return model_manager.get_model(model_name)
|
||||
else:
|
||||
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
|
||||
return model_manager.get_model()
|
||||
0
invokeai/app/models/__init__.py
Normal file
0
invokeai/app/models/__init__.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
3
invokeai/app/models/exceptions.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
pass
|
||||
26
invokeai/app/models/image.py
Normal file
26
invokeai/app/models/image.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_type: ImageType = Field(
|
||||
default=ImageType.RESULT, description="The type of the image"
|
||||
)
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"image_type",
|
||||
"image_name",
|
||||
]
|
||||
}
|
||||
11
invokeai/app/models/metadata.py
Normal file
11
invokeai/app/models/metadata.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class ImageMetadata(BaseModel):
|
||||
"""An image's metadata"""
|
||||
|
||||
timestamp: float = Field(description="The creation timestamp of the image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# TODO: figure out metadata
|
||||
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
|
||||
56
invokeai/app/services/default_graphs.py
Normal file
56
invokeai/app/services/default_graphs.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||
from ..invocations.params import ParamIntInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
from .item_storage import ItemStorageABC
|
||||
|
||||
|
||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name='t2i',
|
||||
description='Converts text to an image',
|
||||
graph=Graph(
|
||||
nodes={
|
||||
'width': ParamIntInvocation(id='width', a=512),
|
||||
'height': ParamIntInvocation(id='height', a=512),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': TextToLatentsInvocation(id='4'),
|
||||
'5': LatentsToImageInvocation(id='5')
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
graphs: list[LibraryGraph] = list()
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
#if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import traceback
|
||||
import uuid
|
||||
from types import NoneType
|
||||
from typing import (
|
||||
@@ -17,7 +16,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, root_validator, validator
|
||||
from pydantic.fields import Field
|
||||
|
||||
from ..invocations import *
|
||||
@@ -26,7 +25,6 @@ from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
)
|
||||
from .invocation_services import InvocationServices
|
||||
|
||||
|
||||
class EdgeConnection(BaseModel):
|
||||
@@ -215,7 +213,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
@@ -283,7 +281,8 @@ class Graph(BaseModel):
|
||||
:raises InvalidEdgeError: the provided edge is invalid.
|
||||
"""
|
||||
|
||||
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||
self._validate_edge(edge)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
else:
|
||||
raise InvalidEdgeError()
|
||||
@@ -354,7 +353,7 @@ class Graph(BaseModel):
|
||||
|
||||
return True
|
||||
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
@@ -362,54 +361,53 @@ class Graph(BaseModel):
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
raise InvalidEdgeError("One or both nodes don't exist")
|
||||
|
||||
# Validate that an edge to this node+field doesn't already exist
|
||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||
|
||||
# Validate that no cycles would be created
|
||||
g = self.nx_graph_flat()
|
||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||
if not nx.is_directed_acyclic_graph(g):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
||||
|
||||
# Validate that the field types are compatible
|
||||
if not are_connections_compatible(
|
||||
from_node, edge.source.field, to_node, edge.destination.field
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Fields are incompatible')
|
||||
|
||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
||||
|
||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||
if not self._is_iterator_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
||||
|
||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.destination.node_id, new_input=edge.source
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
||||
|
||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||
if not self._is_collector_connection_valid(
|
||||
edge.source.node_id, new_output=edge.destination
|
||||
):
|
||||
return False
|
||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
||||
|
||||
return True
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
@@ -733,7 +731,7 @@ class Graph(BaseModel):
|
||||
for sgn in (
|
||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||
):
|
||||
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
@@ -750,9 +748,7 @@ class Graph(BaseModel):
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(
|
||||
description="The id of the execution state", default_factory=uuid.uuid4
|
||||
)
|
||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
@@ -794,9 +790,6 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
# Declare all fields as required; necessary for OpenAPI schema generation build.
|
||||
# Technically only fields without a `default_factory` need to be listed here.
|
||||
# See: https://github.com/pydantic/pydantic/discussions/4577
|
||||
class Config:
|
||||
schema_extra = {
|
||||
'required': [
|
||||
@@ -861,7 +854,8 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Returns true if the graph is complete"""
|
||||
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
||||
return self.has_error() or all((k in self.executed for k in node_ids))
|
||||
|
||||
def has_error(self) -> bool:
|
||||
"""Returns true if the graph has any errors"""
|
||||
@@ -949,11 +943,11 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _iterator_graph(self) -> nx.DiGraph:
|
||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||
g = self.graph.nx_graph()
|
||||
g = self.graph.nx_graph_flat()
|
||||
collectors = (
|
||||
n
|
||||
for n in self.graph.nodes
|
||||
if isinstance(self.graph.nodes[n], CollectInvocation)
|
||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||
)
|
||||
for c in collectors:
|
||||
g.remove_edges_from(list(g.in_edges(c)))
|
||||
@@ -965,7 +959,7 @@ class GraphExecutionState(BaseModel):
|
||||
iterators = [
|
||||
n
|
||||
for n in nx.ancestors(g, node_id)
|
||||
if isinstance(self.graph.nodes[n], IterateInvocation)
|
||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||
]
|
||||
return iterators
|
||||
|
||||
@@ -1069,9 +1063,8 @@ class GraphExecutionState(BaseModel):
|
||||
n
|
||||
for n in prepared_nodes
|
||||
if all(
|
||||
pit
|
||||
nx.has_path(execution_graph, pit[0], n)
|
||||
for pit in parent_iterators
|
||||
if nx.has_path(execution_graph, pit[0], n)
|
||||
)
|
||||
),
|
||||
None,
|
||||
@@ -1102,7 +1095,9 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
if not self._is_edge_valid(edge):
|
||||
try:
|
||||
self.graph._validate_edge(edge)
|
||||
except InvalidEdgeError:
|
||||
return False
|
||||
|
||||
# Invalid if destination has already been prepared or executed
|
||||
@@ -1148,4 +1143,52 @@ class GraphExecutionState(BaseModel):
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
||||
class ExposedNodeInput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the input")
|
||||
field: str = Field(description="The field name of the input")
|
||||
alias: str = Field(description="The alias of the input")
|
||||
|
||||
|
||||
class ExposedNodeOutput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the output")
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||
|
||||
@validator('exposed_inputs', 'exposed_outputs')
|
||||
def validate_exposed_aliases(cls, v):
|
||||
if len(v) != len(set(i.alias for i in v)):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@root_validator
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values['graph']
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values['exposed_inputs']:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values['exposed_outputs']:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.update_forward_refs()
|
||||
|
||||
@@ -2,23 +2,25 @@
|
||||
|
||||
import datetime
|
||||
import os
|
||||
from glob import glob
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from PIL.Image import Image
|
||||
import PIL.Image as PILImage
|
||||
from pydantic import BaseModel
|
||||
from invokeai.app.api.models.images import ImageResponse
|
||||
from invokeai.app.models.image import ImageField, ImageType
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||
|
||||
from invokeai.backend.image_util import PngWriter
|
||||
|
||||
|
||||
class ImageType(str, Enum):
|
||||
RESULT = "results"
|
||||
INTERMEDIATE = "intermediates"
|
||||
UPLOAD = "uploads"
|
||||
|
||||
|
||||
class ImageStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving images."""
|
||||
|
||||
@@ -26,9 +28,17 @@ class ImageStorageBase(ABC):
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
pass
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
@abstractmethod
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -66,6 +76,57 @@ class DiskImageStorage(ImageStorageBase):
|
||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
|
||||
def list(
|
||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[ImageResponse]:
|
||||
dir_path = os.path.join(self.__output_folder, image_type)
|
||||
image_paths = glob(f"{dir_path}/*.png")
|
||||
count = len(image_paths)
|
||||
|
||||
sorted_image_paths = sorted(
|
||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
||||
)
|
||||
|
||||
page_of_image_paths = sorted_image_paths[
|
||||
page * per_page : (page + 1) * per_page
|
||||
]
|
||||
|
||||
page_of_images: List[ImageResponse] = []
|
||||
|
||||
for path in page_of_image_paths:
|
||||
filename = os.path.basename(path)
|
||||
img = PILImage.open(path)
|
||||
page_of_images.append(
|
||||
ImageResponse(
|
||||
image_type=image_type.value,
|
||||
image_name=filename,
|
||||
# TODO: DiskImageStorage should not be building URLs...?
|
||||
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||
metadata=ImageMetadata(
|
||||
timestamp=os.path.getctime(path),
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
page_count_trunc = int(count / per_page)
|
||||
page_count_mod = count % per_page
|
||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||
|
||||
return PaginatedResults[ImageResponse](
|
||||
items=page_of_images,
|
||||
page=page,
|
||||
pages=page_count,
|
||||
per_page=per_page,
|
||||
total=count,
|
||||
)
|
||||
|
||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
@@ -73,13 +134,20 @@ class DiskImageStorage(ImageStorageBase):
|
||||
if cache_item:
|
||||
return cache_item
|
||||
|
||||
image = Image.open(image_path)
|
||||
image = PILImage.open(image_path)
|
||||
self.__set_cache(image_path, image)
|
||||
return image
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
def get_path(
|
||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||
) -> str:
|
||||
if is_thumbnail:
|
||||
path = os.path.join(
|
||||
self.__output_folder, image_type, "thumbnails", image_name
|
||||
)
|
||||
else:
|
||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||
return path
|
||||
|
||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||
@@ -87,18 +155,29 @@ class DiskImageStorage(ImageStorageBase):
|
||||
self.__pngWriter.save_image_and_prompt_to_png(
|
||||
image, "", image_subpath, None
|
||||
) # TODO: just pass full path to png writer
|
||||
|
||||
save_thumbnail(
|
||||
image=image,
|
||||
filename=image_name,
|
||||
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||
)
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
self.__set_cache(image_path, image)
|
||||
|
||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||
image_path = self.get_path(image_type, image_name)
|
||||
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||
if os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
if image_path in self.__cache:
|
||||
del self.__cache[image_path]
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
os.remove(thumbnail_path)
|
||||
|
||||
if thumbnail_path in self.__cache:
|
||||
del self.__cache[thumbnail_path]
|
||||
|
||||
def __get_cache(self, image_name: str) -> Image:
|
||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||
|
||||
|
||||
@@ -1,30 +1,17 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# TODO: make this serializable
|
||||
class InvocationQueueItem:
|
||||
# session_id: str
|
||||
graph_execution_state_id: str
|
||||
invocation_id: str
|
||||
invoke_all: bool
|
||||
timestamp: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# session_id: str,
|
||||
graph_execution_state_id: str,
|
||||
invocation_id: str,
|
||||
invoke_all: bool = False,
|
||||
):
|
||||
# self.session_id = session_id
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.invocation_id = invocation_id
|
||||
self.invoke_all = invoke_all
|
||||
self.timestamp = time.time()
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .latent_storage import LatentsStorageBase
|
||||
from .image_storage import ImageStorageBase
|
||||
from .restoration_services import RestorationServices
|
||||
from .invocation_queue import InvocationQueueABC
|
||||
@@ -11,12 +12,14 @@ class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
events: EventServiceBase
|
||||
latents: LatentsStorageBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
model_manager: ModelManager
|
||||
restoration: RestorationServices
|
||||
|
||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||
graph_library: ItemStorageABC["LibraryGraph"]
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
@@ -24,16 +27,20 @@ class InvocationServices:
|
||||
self,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
latents: LatentsStorageBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_library: ItemStorageABC["LibraryGraph"],
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
restoration: RestorationServices,
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.latents = latents
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
self.graph_library = graph_library
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.processor = processor
|
||||
self.restoration = restoration
|
||||
|
||||
@@ -33,7 +33,6 @@ class Invoker:
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
print(f"queueing item {invocation.id}")
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
|
||||
93
invokeai/app/services/latent_storage.py
Normal file
93
invokeai/app/services/latent_storage.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = dict()
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
self.__underlying_storage.set(name, data)
|
||||
self.__set_cache(name, data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
|
||||
def __get_cache(self, name: str) -> torch.Tensor|None:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if not name in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: str
|
||||
|
||||
def __init__(self, output_folder: str):
|
||||
self.__output_folder = output_folder
|
||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def set(self, name: str, data: torch.Tensor) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
os.remove(latent_path)
|
||||
|
||||
def get_path(self, name: str) -> str:
|
||||
return os.path.join(self.__output_folder, name)
|
||||
|
||||
@@ -4,7 +4,7 @@ from threading import Event, Thread
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from .invocation_queue import InvocationQueueItem
|
||||
from .invoker import InvocationProcessorABC, Invoker
|
||||
from ..util.util import CanceledException
|
||||
from ..models.exceptions import CanceledException
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
|
||||
@@ -35,8 +35,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
@@ -45,33 +44,27 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._conn.commit()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
item_type = get_args(self.__orig_class__)[0]
|
||||
return parse_raw_as(item_type, item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.json(),),
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._conn.commit()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Union[T, None]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
@@ -79,18 +72,15 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._conn.commit()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||
(per_page, page * per_page),
|
||||
@@ -101,8 +91,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
@@ -113,8 +101,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
def search(
|
||||
self, query: str, page: int = 0, per_page: int = 10
|
||||
) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
with self._lock:
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
@@ -128,8 +115,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
|
||||
0
invokeai/app/util/__init__.py
Normal file
0
invokeai/app/util/__init__.py
Normal file
25
invokeai/app/util/save_thumbnail.py
Normal file
25
invokeai/app/util/save_thumbnail.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def save_thumbnail(
|
||||
image: Image.Image,
|
||||
filename: str,
|
||||
path: str,
|
||||
size: int = 256,
|
||||
) -> str:
|
||||
"""
|
||||
Saves a thumbnail of an image, returning its path.
|
||||
"""
|
||||
base_filename = os.path.splitext(filename)[0]
|
||||
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
||||
|
||||
if os.path.exists(thumbnail_path):
|
||||
return thumbnail_path
|
||||
|
||||
image_copy = image.copy()
|
||||
image_copy.thumbnail(size=(size, size))
|
||||
|
||||
image_copy.save(thumbnail_path, "WEBP")
|
||||
|
||||
return thumbnail_path
|
||||
@@ -1,14 +1,16 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ...backend.generator.base import Generator
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
|
||||
class CanceledException(Exception):
|
||||
pass
|
||||
|
||||
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
||||
def fast_latents_step_callback(
|
||||
sample: torch.Tensor,
|
||||
step: int,
|
||||
steps: int,
|
||||
id: str,
|
||||
context: InvocationContext,
|
||||
):
|
||||
# TODO: only output a preview image when requested
|
||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||
|
||||
@@ -21,15 +23,12 @@ def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id:
|
||||
context.services.events.emit_generator_progress(
|
||||
context.graph_execution_state_id,
|
||||
id,
|
||||
{
|
||||
"width": width,
|
||||
"height": height,
|
||||
"dataURL": dataURL
|
||||
},
|
||||
{"width": width, "height": height, "dataURL": dataURL},
|
||||
step,
|
||||
steps,
|
||||
)
|
||||
|
||||
|
||||
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
"""
|
||||
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||
@@ -37,6 +36,8 @@ def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||
"""
|
||||
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||
progress_state: PipelineIntermediateState = cb_args[0]
|
||||
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
||||
return fast_latents_step_callback(
|
||||
progress_state.latents, progress_state.step, **kwargs
|
||||
)
|
||||
else:
|
||||
return fast_latents_step_callback(*cb_args, **kwargs)
|
||||
@@ -561,7 +561,7 @@ class Args(object):
|
||||
"--autoimport",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autoconvert",
|
||||
|
||||
@@ -67,7 +67,6 @@ def install_requested_models(
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
@@ -113,7 +112,6 @@ def install_requested_models(
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
commit_to_conf=config_file_path,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
@@ -122,7 +120,7 @@ def install_requested_models(
|
||||
pass
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = "--autoconvert" if convert_to_diffusers else "--autoimport"
|
||||
argument = "--autoconvert"
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
||||
directory = str(scan_directory).replace("\\", "/")
|
||||
|
||||
@@ -7,3 +7,4 @@ from .convert_ckpt_to_diffusers import (
|
||||
)
|
||||
from .model_manager import ModelManager
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""
|
||||
"""enum
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
@@ -15,7 +15,7 @@ import sys
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable
|
||||
@@ -24,8 +24,12 @@ import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import logging as dlogging
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
logging as dlogging,
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@@ -33,37 +37,58 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
V1_INPAINT = 2
|
||||
V2 = 3
|
||||
V2_e = 4
|
||||
V2_v = 5
|
||||
UNKNOWN = 99
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = auto()
|
||||
V1_INPAINT = auto()
|
||||
V2 = auto()
|
||||
V2_e = auto()
|
||||
V2_v = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
class SDModelComponent(Enum):
|
||||
vae="vae"
|
||||
text_encoder="text_encoder"
|
||||
tokenizer="tokenizer"
|
||||
unet="unet"
|
||||
scheduler="scheduler"
|
||||
safety_checker="safety_checker"
|
||||
feature_extractor="feature_extractor"
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
|
||||
class ModelManager(object):
|
||||
'''
|
||||
"""
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf|Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path=None,
|
||||
self,
|
||||
config: OmegaConf | Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path = None,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file or
|
||||
an initialized OmegaConf dictionary. Optional parameters
|
||||
are the torch device type, precision, max_loaded_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
# prevent nasty-looking CLIP log message
|
||||
@@ -87,15 +112,25 @@ class ModelManager(object):
|
||||
"""
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name: str=None)->dict:
|
||||
"""
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
def get_model(self, model_name: str = None) -> dict:
|
||||
"""Given a model named identified in models.yaml, return a dict
|
||||
containing the model object and some of its key features. If
|
||||
in RAM will load into GPU VRAM. If on disk, will load from
|
||||
there.
|
||||
The dict has the following keys:
|
||||
'model': The StableDiffusionGeneratorPipeline object
|
||||
'model_name': The name of the model in models.yaml
|
||||
'width': The width of images trained by this model
|
||||
'height': The height of images trained by this model
|
||||
'hash': A unique hash of this model's files on disk.
|
||||
"""
|
||||
if not model_name:
|
||||
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
|
||||
|
||||
return (
|
||||
self.get_model(self.current_model)
|
||||
if self.current_model
|
||||
else self.get_model(self.default_model())
|
||||
)
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
print(
|
||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
@@ -135,6 +170,81 @@ class ModelManager(object):
|
||||
"hash": hash,
|
||||
}
|
||||
|
||||
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned VAE as an
|
||||
AutoencoderKL object. If no model name is provided, return the
|
||||
vae from the model currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||
|
||||
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||
model name is provided, return the tokenizer from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||
|
||||
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||
name is provided, return the UNet from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||
|
||||
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTextModel. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
|
||||
|
||||
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
|
||||
|
||||
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned scheduler. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||
|
||||
def _get_sub_model(
|
||||
self,
|
||||
model_name: str=None,
|
||||
model_part: SDModelComponent=SDModelComponent.vae,
|
||||
) -> Union[
|
||||
AutoencoderKL,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
UNet2DConditionModel,
|
||||
CLIPTextModel,
|
||||
StableDiffusionSafetyChecker,
|
||||
]:
|
||||
"""Given a model name identified in models.yaml, and the part of the
|
||||
model you wish to retrieve, return that part. Parts are in an Enum
|
||||
class named SDModelComponent, and consist of:
|
||||
SDModelComponent.vae
|
||||
SDModelComponent.text_encoder
|
||||
SDModelComponent.tokenizer
|
||||
SDModelComponent.unet
|
||||
SDModelComponent.scheduler
|
||||
SDModelComponent.safety_checker
|
||||
SDModelComponent.feature_extractor
|
||||
"""
|
||||
model_dict = self.get_model(model_name)
|
||||
model = model_dict["model"]
|
||||
return getattr(model, model_part.value)
|
||||
|
||||
def default_model(self) -> str | None:
|
||||
"""
|
||||
Returns the name of the default model, or None
|
||||
@@ -360,7 +470,7 @@ class ModelManager(object):
|
||||
f"Unknown model format {model_name}: {model_format}"
|
||||
)
|
||||
self._add_embeddings_to_model(model)
|
||||
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||
@@ -433,7 +543,7 @@ class ModelManager(object):
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
print(f" | Default image dimensions = {width} x {height}")
|
||||
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
@@ -454,14 +564,18 @@ class ModelManager(object):
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
try:
|
||||
if self.list_models()[self.current_model]['status'] == 'active':
|
||||
if self.list_models()[self.current_model]["status"] == "active":
|
||||
self.offload_model(self.current_model)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
vae_path = None
|
||||
if vae:
|
||||
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
vae_path = (
|
||||
vae
|
||||
if os.path.isabs(vae)
|
||||
else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
@@ -571,9 +685,7 @@ class ModelManager(object):
|
||||
models.yaml file.
|
||||
"""
|
||||
model_name = model_name or Path(repo_or_path).stem
|
||||
model_description = (
|
||||
description or f"Imported diffusers model {model_name}"
|
||||
)
|
||||
model_description = description or f"Imported diffusers model {model_name}"
|
||||
new_config = dict(
|
||||
description=model_description,
|
||||
vae=vae,
|
||||
@@ -602,7 +714,7 @@ class ModelManager(object):
|
||||
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
||||
SDLegacyType.UNKNOWN
|
||||
"""
|
||||
global_step = checkpoint.get('global_step')
|
||||
global_step = checkpoint.get("global_step")
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
|
||||
try:
|
||||
@@ -628,13 +740,13 @@ class ModelManager(object):
|
||||
return SDLegacyType.UNKNOWN
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
) -> str:
|
||||
"""Accept a string which could be:
|
||||
- a HF diffusers repo_id
|
||||
@@ -738,8 +850,8 @@ class ModelManager(object):
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = None
|
||||
if model_path.suffix in [".ckpt",".pt"]:
|
||||
self.scan_model(model_path,model_path)
|
||||
if model_path.suffix in [".ckpt", ".pt"]:
|
||||
self.scan_model(model_path, model_path)
|
||||
checkpoint = torch.load(model_path)
|
||||
else:
|
||||
checkpoint = safetensors.torch.load_file(model_path)
|
||||
@@ -761,19 +873,16 @@ class ModelManager(object):
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(
|
||||
" | SD-v2-v model detected"
|
||||
)
|
||||
print(" | SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(
|
||||
" | SD-v2-e model detected"
|
||||
)
|
||||
print(" | SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
@@ -820,16 +929,16 @@ class ModelManager(object):
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae:dict=None,
|
||||
vae_path:Path=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool=True,
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae: dict = None,
|
||||
vae_path: Path = None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
@@ -857,10 +966,10 @@ class ModelManager(object):
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model=None
|
||||
vae_model = None
|
||||
if vae:
|
||||
vae_model=self._load_vae(vae)
|
||||
vae_path=None
|
||||
vae_model = self._load_vae(vae)
|
||||
vae_path = None
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
@@ -976,16 +1085,16 @@ class ModelManager(object):
|
||||
legacy_locations = [
|
||||
Path(
|
||||
models_dir,
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
|
||||
),
|
||||
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
||||
Path(
|
||||
models_dir,
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
|
||||
),
|
||||
]
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
|
||||
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
||||
|
||||
legacy_layout = False
|
||||
for model in legacy_locations:
|
||||
legacy_layout = legacy_layout or model.exists()
|
||||
@@ -1003,7 +1112,7 @@ class ModelManager(object):
|
||||
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
||||
>> Otherwise press <enter> to continue."""
|
||||
)
|
||||
input('continue> ')
|
||||
input("continue> ")
|
||||
|
||||
# transformer files get moved into the hub directory
|
||||
if cls._is_huggingface_hub_directory_present():
|
||||
@@ -1090,12 +1199,12 @@ class ModelManager(object):
|
||||
print(
|
||||
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.device.type == "cuda"
|
||||
|
||||
def _diffuser_sha256(
|
||||
self, name_or_path: Union[str, Path], chunksize=4096
|
||||
self, name_or_path: Union[str, Path], chunksize=16777216
|
||||
) -> Union[str, bytes]:
|
||||
path = None
|
||||
if isinstance(name_or_path, Path):
|
||||
|
||||
@@ -57,7 +57,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
else:
|
||||
elif Globals.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
@@ -73,6 +73,8 @@ class HuggingFaceConceptsLibrary(object):
|
||||
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
import os
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@dataclass
|
||||
class EmbeddingInfo:
|
||||
name: str
|
||||
embedding: torch.Tensor
|
||||
num_vectors_per_token: int
|
||||
token_dim: int
|
||||
trained_steps: int = None
|
||||
trained_model_name: str = None
|
||||
trained_model_checksum: str = None
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
@@ -72,66 +82,46 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
try:
|
||||
scan_result = scan_file_path(str(ckpt_path))
|
||||
if scan_result.infected_files == 1:
|
||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||
for embedding_info in embedding_list:
|
||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||
print(
|
||||
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
|
||||
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
)
|
||||
print("### For your safety, InvokeAI will not load this embed.")
|
||||
return
|
||||
except Exception:
|
||||
print(
|
||||
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
|
||||
)
|
||||
return
|
||||
continue
|
||||
|
||||
embedding_info = self._parse_embedding(str(ckpt_path))
|
||||
|
||||
if embedding_info is None:
|
||||
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
|
||||
return
|
||||
elif (
|
||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
!= embedding_info["token_dim"]
|
||||
):
|
||||
print(
|
||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
||||
)
|
||||
return
|
||||
|
||||
# Resolve the situation in which an earlier embedding has claimed the same
|
||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||
trigger_str = embedding_info["name"]
|
||||
sourcefile = (
|
||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else ckpt_path.name
|
||||
)
|
||||
|
||||
if trigger_str in self.trigger_to_sourcefile:
|
||||
replacement_trigger_str = (
|
||||
f"<{ckpt_path.parent.name}>"
|
||||
# Resolve the situation in which an earlier embedding has claimed the same
|
||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||
trigger_str = embedding_info.name
|
||||
sourcefile = (
|
||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
else ckpt_path.name
|
||||
)
|
||||
print(
|
||||
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
trigger_str,
|
||||
embedding_info["embedding"],
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
# remember which source file claims this trigger
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
if trigger_str in self.trigger_to_sourcefile:
|
||||
replacement_trigger_str = (
|
||||
f"<{ckpt_path.parent.name}>"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
)
|
||||
print(
|
||||
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f" | The error was {str(e)}")
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
trigger_str,
|
||||
embedding_info.embedding,
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
# remember which source file claims this trigger
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
|
||||
except ValueError as e:
|
||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
print(f" | The error was {str(e)}")
|
||||
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
@@ -309,111 +299,130 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
return token_id
|
||||
|
||||
def _parse_embedding(self, embedding_file: str):
|
||||
file_type = embedding_file.split(".")[-1]
|
||||
if file_type == "pt":
|
||||
return self._parse_embedding_pt(embedding_file)
|
||||
elif file_type == "bin":
|
||||
return self._parse_embedding_bin(embedding_file)
|
||||
|
||||
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
||||
suffix = Path(embedding_file).suffix
|
||||
try:
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files > 0:
|
||||
print(
|
||||
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||
return list()
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return list()
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
keys = list(ckpt.keys())
|
||||
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
||||
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
||||
|
||||
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
||||
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
||||
|
||||
elif 'emb_params' in keys:
|
||||
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
||||
|
||||
else:
|
||||
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
|
||||
return None
|
||||
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||
|
||||
def _parse_embedding_pt(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v1 embedding file: {basename}')
|
||||
|
||||
# Check if valid embedding file
|
||||
if "string_to_token" and "string_to_param" in embedding_ckpt:
|
||||
# Catch variants that do not have the expected keys or values.
|
||||
try:
|
||||
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
|
||||
os.path.splitext(embedding_file)[0]
|
||||
)
|
||||
embeddings = list()
|
||||
token_counter = -1
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
if token_counter < 0:
|
||||
trigger = embedding_ckpt["name"]
|
||||
elif token_counter == 0:
|
||||
trigger = f'<basename>'
|
||||
else:
|
||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||
token_counter += 1
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
trained_steps = embedding_ckpt["step"],
|
||||
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
||||
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
||||
|
||||
# Check num of embeddings and warn user only the first will be used
|
||||
embedding_info["num_of_embeddings"] = len(
|
||||
embedding_ckpt["string_to_token"]
|
||||
)
|
||||
if embedding_info["num_of_embeddings"] > 1:
|
||||
print(">> More than 1 embedding found. Will use the first one")
|
||||
|
||||
embedding = list(embedding_ckpt["string_to_param"].values())[0]
|
||||
except (AttributeError, KeyError):
|
||||
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
|
||||
|
||||
embedding_info["embedding"] = embedding
|
||||
embedding_info["num_vectors_per_token"] = embedding.size()[0]
|
||||
embedding_info["token_dim"] = embedding.size()[1]
|
||||
|
||||
try:
|
||||
embedding_info["trained_steps"] = embedding_ckpt["step"]
|
||||
embedding_info["trained_model_name"] = embedding_ckpt[
|
||||
"sd_checkpoint_name"
|
||||
]
|
||||
embedding_info["trained_model_checksum"] = embedding_ckpt[
|
||||
"sd_checkpoint"
|
||||
]
|
||||
except AttributeError:
|
||||
print(">> No Training Details Found. Passing ...")
|
||||
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys()) == 1:
|
||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||
|
||||
else:
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _parse_embedding_bin(self, embedding_file):
|
||||
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
|
||||
embedding_info = {}
|
||||
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(">> Invalid concepts file")
|
||||
embedding_info = None
|
||||
else:
|
||||
for token in list(embedding_ckpt.keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[token]
|
||||
embedding_info[
|
||||
"num_vectors_per_token"
|
||||
] = 1 # All Concepts seem to default to 1
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
|
||||
return embedding_info
|
||||
|
||||
def _handle_broken_pt_variants(
|
||||
self, embedding_ckpt: dict, embedding_file: str
|
||||
) -> dict:
|
||||
def _parse_embedding_v2 (
|
||||
self, embedding_ckpt: dict, file_path: str
|
||||
) -> List[EmbeddingInfo]:
|
||||
"""
|
||||
This handles the broken .pt file variants. We only know of one at present.
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
embedding_info = {}
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v2 embedding file: {basename}')
|
||||
embeddings = list()
|
||||
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
if token != "*"
|
||||
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
||||
token_counter = 0
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
trigger = token if token != '*' \
|
||||
else f'<{basename}>' if token_counter == 0 \
|
||||
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
embedding_info["embedding"] = embedding_ckpt[
|
||||
"string_to_param"
|
||||
].state_dict()[token]
|
||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
||||
"embedding"
|
||||
].shape[0]
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
||||
embeddings.append(embedding_info)
|
||||
else:
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
print(f" ** {basename}: Unrecognized embedding format")
|
||||
|
||||
return embedding_info
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
print(f' | Loading v3 embedding file: {basename}')
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = f'<{basename}>',
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
return [embedding_info]
|
||||
|
||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 4' of the textual inversion embedding files. This one
|
||||
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||
"""
|
||||
basename = Path(filepath).stem
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
print(f' | Loading v4 embedding file: {short_path}')
|
||||
|
||||
embeddings = list()
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
print(f" ** Invalid embeddings file: {short_path}")
|
||||
else:
|
||||
for token,embedding in embedding_ckpt.items():
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = token or f"<{basename}>",
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
||||
token_dim = embedding.size()[0],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
||||
|
||||
@@ -158,14 +158,9 @@ def main():
|
||||
report_model_error(opt, e)
|
||||
|
||||
# try to autoconvert new models
|
||||
if path := opt.autoimport:
|
||||
gen.model_manager.heuristic_import(
|
||||
str(path), convert=False, commit_to_conf=opt.conf
|
||||
)
|
||||
|
||||
if path := opt.autoconvert:
|
||||
gen.model_manager.heuristic_import(
|
||||
str(path), convert=True, commit_to_conf=opt.conf
|
||||
str(path), commit_to_conf=opt.conf
|
||||
)
|
||||
|
||||
# web server loops forever
|
||||
@@ -581,6 +576,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
||||
|
||||
elif command.startswith("!replay"):
|
||||
file_path = command.replace("!replay", "", 1).strip()
|
||||
file_path = os.path.join(opt.outdir, file_path)
|
||||
if infile is None and os.path.isfile(file_path):
|
||||
infile = open(file_path, "r", encoding="utf-8")
|
||||
completer.add_history(command)
|
||||
|
||||
@@ -199,17 +199,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
relx=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.convert_models = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||
values=["Keep original format", "Convert to diffusers"],
|
||||
value=0,
|
||||
begin_entry_at=4,
|
||||
max_height=4,
|
||||
hidden=True, # will appear when imported models box is edited
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.cancel = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name="CANCEL",
|
||||
@@ -244,8 +233,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||
|
||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
@@ -256,13 +243,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
if not self.show_directory_fields.value:
|
||||
self.autoload_directory.value = ""
|
||||
|
||||
def _show_hide_convert(self):
|
||||
model_paths = self.import_model_paths.value or ""
|
||||
autoload_directory = self.autoload_directory.value or ""
|
||||
self.convert_models.hidden = (
|
||||
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||
)
|
||||
|
||||
def _get_starter_model_labels(self) -> List[str]:
|
||||
window_width, window_height = get_terminal_size()
|
||||
label_width = 25
|
||||
@@ -322,7 +302,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||
"""
|
||||
# we're using a global here rather than storing the result in the parentapp
|
||||
# due to some bug in npyscreen that is causing attributes to be lost
|
||||
@@ -359,7 +338,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
|
||||
# URLs and the like
|
||||
selections.import_model_paths = self.import_model_paths.value.split()
|
||||
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
@@ -372,7 +350,6 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
def onStart(self):
|
||||
@@ -393,7 +370,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
convert_to_diffusers = selections.convert_to_diffusers
|
||||
|
||||
install_requested_models(
|
||||
install_initial_models=models_to_install,
|
||||
@@ -401,7 +377,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||
external_models=potential_models_to_install,
|
||||
scan_at_startup=scan_at_startup,
|
||||
convert_to_diffusers=convert_to_diffusers,
|
||||
precision="float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device())),
|
||||
|
||||
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
File diff suppressed because one or more lines are too long
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
Normal file
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
|
||||
import{j as y,cO as Ie,r as _,cP as bt,q as Lr,cQ as o,cR as b,cS as v,cT as S,cU as Vr,cV as ut,cW as vt,cN as ft,cX as mt,n as gt,cY as ht,E as pt}from"./index-e53e8108.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-5cde7d31.js";var Or=`
|
||||
:root {
|
||||
--chakra-vh: 100vh;
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-e53e8108.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
||||
</head>
|
||||
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/ar.json
vendored
1
invokeai/frontend/web/dist/locales/ar.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "داكن",
|
||||
"lightTheme": "فاتح",
|
||||
"greenTheme": "أخضر",
|
||||
"text2img": "نص إلى صورة",
|
||||
"img2img": "صورة إلى صورة",
|
||||
"unifiedCanvas": "لوحة موحدة",
|
||||
"nodes": "عقد",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/de.json
vendored
1
invokeai/frontend/web/dist/locales/de.json
vendored
@@ -7,7 +7,6 @@
|
||||
"darkTheme": "Dunkel",
|
||||
"lightTheme": "Hell",
|
||||
"greenTheme": "Grün",
|
||||
"text2img": "Text zu Bild",
|
||||
"img2img": "Bild zu Bild",
|
||||
"nodes": "Knoten",
|
||||
"langGerman": "Deutsch",
|
||||
|
||||
4
invokeai/frontend/web/dist/locales/en.json
vendored
4
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -505,7 +505,9 @@
|
||||
"info": "Info",
|
||||
"deleteImage": "Delete Image",
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel"
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
|
||||
12
invokeai/frontend/web/dist/locales/es.json
vendored
12
invokeai/frontend/web/dist/locales/es.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Oscuro",
|
||||
"lightTheme": "Claro",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Texto a Imagen",
|
||||
"img2img": "Imagen a Imagen",
|
||||
"unifiedCanvas": "Lienzo Unificado",
|
||||
"nodes": "Nodos",
|
||||
@@ -70,7 +69,11 @@
|
||||
"langHebrew": "Hebreo",
|
||||
"pinOptionsPanel": "Pin del panel de opciones",
|
||||
"loading": "Cargando",
|
||||
"loadingInvokeAI": "Cargando invocar a la IA"
|
||||
"loadingInvokeAI": "Cargando invocar a la IA",
|
||||
"postprocessing": "Tratamiento posterior",
|
||||
"txt2img": "De texto a imagen",
|
||||
"accept": "Aceptar",
|
||||
"cancel": "Cancelar"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generaciones",
|
||||
@@ -404,7 +407,8 @@
|
||||
"none": "ninguno",
|
||||
"pickModelType": "Elige el tipo de modelo",
|
||||
"v2_768": "v2 (768px)",
|
||||
"addDifference": "Añadir una diferencia"
|
||||
"addDifference": "Añadir una diferencia",
|
||||
"scanForModels": "Buscar modelos"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Imágenes",
|
||||
@@ -574,7 +578,7 @@
|
||||
"autoSaveToGallery": "Guardar automáticamente en galería",
|
||||
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
|
||||
"limitStrokesToBox": "Limitar trazos a la caja",
|
||||
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
|
||||
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
|
||||
"clearCanvasHistory": "Limpiar historial de lienzo",
|
||||
"clearHistory": "Limpiar historial",
|
||||
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",
|
||||
|
||||
25
invokeai/frontend/web/dist/locales/fr.json
vendored
25
invokeai/frontend/web/dist/locales/fr.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Sombre",
|
||||
"lightTheme": "Clair",
|
||||
"greenTheme": "Vert",
|
||||
"text2img": "Texte en image",
|
||||
"img2img": "Image en image",
|
||||
"unifiedCanvas": "Canvas unifié",
|
||||
"nodes": "Nœuds",
|
||||
@@ -47,7 +46,19 @@
|
||||
"statusLoadingModel": "Chargement du modèle",
|
||||
"statusModelChanged": "Modèle changé",
|
||||
"discordLabel": "Discord",
|
||||
"githubLabel": "Github"
|
||||
"githubLabel": "Github",
|
||||
"accept": "Accepter",
|
||||
"statusMergingModels": "Mélange des modèles",
|
||||
"loadingInvokeAI": "Chargement de Invoke AI",
|
||||
"cancel": "Annuler",
|
||||
"langEnglish": "Anglais",
|
||||
"statusConvertingModel": "Conversion du modèle",
|
||||
"statusModelConverted": "Modèle converti",
|
||||
"loading": "Chargement",
|
||||
"pinOptionsPanel": "Épingler la page d'options",
|
||||
"statusMergedModels": "Modèles mélangés",
|
||||
"txt2img": "Texte vers image",
|
||||
"postprocessing": "Post-Traitement"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Générations",
|
||||
@@ -518,5 +529,15 @@
|
||||
"betaDarkenOutside": "Assombrir à l'extérieur",
|
||||
"betaLimitToBox": "Limiter à la boîte",
|
||||
"betaPreserveMasked": "Conserver masqué"
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Charger une image",
|
||||
"reset": "Réinitialiser",
|
||||
"nextImage": "Image suivante",
|
||||
"previousImage": "Image précédente",
|
||||
"useThisParameter": "Utiliser ce paramètre",
|
||||
"zoomIn": "Zoom avant",
|
||||
"zoomOut": "Zoom arrière",
|
||||
"showOptionsPanel": "Montrer la page d'options"
|
||||
}
|
||||
}
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/he.json
vendored
1
invokeai/frontend/web/dist/locales/he.json
vendored
@@ -125,7 +125,6 @@
|
||||
"langSimplifiedChinese": "סינית",
|
||||
"langUkranian": "אוקראינית",
|
||||
"langSpanish": "ספרדית",
|
||||
"text2img": "טקסט לתמונה",
|
||||
"img2img": "תמונה לתמונה",
|
||||
"unifiedCanvas": "קנבס מאוחד",
|
||||
"nodes": "צמתים",
|
||||
|
||||
14
invokeai/frontend/web/dist/locales/it.json
vendored
14
invokeai/frontend/web/dist/locales/it.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Scuro",
|
||||
"lightTheme": "Chiaro",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Testo a Immagine",
|
||||
"img2img": "Immagine a Immagine",
|
||||
"unifiedCanvas": "Tela unificata",
|
||||
"nodes": "Nodi",
|
||||
@@ -70,7 +69,11 @@
|
||||
"loading": "Caricamento in corso",
|
||||
"oceanTheme": "Oceano",
|
||||
"langHebrew": "Ebraico",
|
||||
"loadingInvokeAI": "Caricamento Invoke AI"
|
||||
"loadingInvokeAI": "Caricamento Invoke AI",
|
||||
"postprocessing": "Post Elaborazione",
|
||||
"txt2img": "Testo a Immagine",
|
||||
"accept": "Accetta",
|
||||
"cancel": "Annulla"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generazioni",
|
||||
@@ -404,7 +407,8 @@
|
||||
"v2_768": "v2 (768px)",
|
||||
"none": "niente",
|
||||
"addDifference": "Aggiungi differenza",
|
||||
"pickModelType": "Scegli il tipo di modello"
|
||||
"pickModelType": "Scegli il tipo di modello",
|
||||
"scanForModels": "Cerca modelli"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Immagini",
|
||||
@@ -574,7 +578,7 @@
|
||||
"autoSaveToGallery": "Salvataggio automatico nella Galleria",
|
||||
"saveBoxRegionOnly": "Salva solo l'area di selezione",
|
||||
"limitStrokesToBox": "Limita i tratti all'area di selezione",
|
||||
"showCanvasDebugInfo": "Mostra informazioni di debug della Tela",
|
||||
"showCanvasDebugInfo": "Mostra ulteriori informazioni sulla Tela",
|
||||
"clearCanvasHistory": "Cancella cronologia Tela",
|
||||
"clearHistory": "Cancella la cronologia",
|
||||
"clearCanvasHistoryMessage": "La cancellazione della cronologia della tela lascia intatta la tela corrente, ma cancella in modo irreversibile la cronologia degli annullamenti e dei ripristini.",
|
||||
@@ -612,7 +616,7 @@
|
||||
"copyMetadataJson": "Copia i metadati JSON",
|
||||
"exitViewer": "Esci dal visualizzatore",
|
||||
"zoomIn": "Zoom avanti",
|
||||
"zoomOut": "Zoom Indietro",
|
||||
"zoomOut": "Zoom indietro",
|
||||
"rotateCounterClockwise": "Ruotare in senso antiorario",
|
||||
"rotateClockwise": "Ruotare in senso orario",
|
||||
"flipHorizontally": "Capovolgi orizzontalmente",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/ko.json
vendored
1
invokeai/frontend/web/dist/locales/ko.json
vendored
@@ -11,7 +11,6 @@
|
||||
"langArabic": "العربية",
|
||||
"langEnglish": "English",
|
||||
"langDutch": "Nederlands",
|
||||
"text2img": "텍스트->이미지",
|
||||
"unifiedCanvas": "통합 캔버스",
|
||||
"langFrench": "Français",
|
||||
"langGerman": "Deutsch",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/nl.json
vendored
1
invokeai/frontend/web/dist/locales/nl.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Donker",
|
||||
"lightTheme": "Licht",
|
||||
"greenTheme": "Groen",
|
||||
"text2img": "Tekst naar afbeelding",
|
||||
"img2img": "Afbeelding naar afbeelding",
|
||||
"unifiedCanvas": "Centraal canvas",
|
||||
"nodes": "Knooppunten",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/pl.json
vendored
1
invokeai/frontend/web/dist/locales/pl.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Ciemny",
|
||||
"lightTheme": "Jasny",
|
||||
"greenTheme": "Zielony",
|
||||
"text2img": "Tekst na obraz",
|
||||
"img2img": "Obraz na obraz",
|
||||
"unifiedCanvas": "Tryb uniwersalny",
|
||||
"nodes": "Węzły",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/pt.json
vendored
1
invokeai/frontend/web/dist/locales/pt.json
vendored
@@ -20,7 +20,6 @@
|
||||
"langSpanish": "Espanhol",
|
||||
"langRussian": "Русский",
|
||||
"langUkranian": "Украї́нська",
|
||||
"text2img": "Texto para Imagem",
|
||||
"img2img": "Imagem para Imagem",
|
||||
"unifiedCanvas": "Tela Unificada",
|
||||
"nodes": "Nós",
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Noite",
|
||||
"lightTheme": "Dia",
|
||||
"greenTheme": "Verde",
|
||||
"text2img": "Texto Para Imagem",
|
||||
"img2img": "Imagem Para Imagem",
|
||||
"unifiedCanvas": "Tela Unificada",
|
||||
"nodes": "Nódulos",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/ru.json
vendored
1
invokeai/frontend/web/dist/locales/ru.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Темная",
|
||||
"lightTheme": "Светлая",
|
||||
"greenTheme": "Зеленая",
|
||||
"text2img": "Изображение из текста (text2img)",
|
||||
"img2img": "Изображение в изображение (img2img)",
|
||||
"unifiedCanvas": "Универсальный холст",
|
||||
"nodes": "Ноды",
|
||||
|
||||
1
invokeai/frontend/web/dist/locales/uk.json
vendored
1
invokeai/frontend/web/dist/locales/uk.json
vendored
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "Темна",
|
||||
"lightTheme": "Світла",
|
||||
"greenTheme": "Зелена",
|
||||
"text2img": "Зображення із тексту (text2img)",
|
||||
"img2img": "Зображення із зображення (img2img)",
|
||||
"unifiedCanvas": "Універсальне полотно",
|
||||
"nodes": "Вузли",
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
"darkTheme": "暗色",
|
||||
"lightTheme": "亮色",
|
||||
"greenTheme": "绿色",
|
||||
"text2img": "文字到图像",
|
||||
"img2img": "图像到图像",
|
||||
"unifiedCanvas": "统一画布",
|
||||
"nodes": "节点",
|
||||
|
||||
@@ -33,7 +33,6 @@
|
||||
"langBrPortuguese": "巴西葡萄牙語",
|
||||
"langRussian": "俄語",
|
||||
"langSpanish": "西班牙語",
|
||||
"text2img": "文字到圖像",
|
||||
"unifiedCanvas": "統一畫布"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -505,7 +505,9 @@
|
||||
"info": "Info",
|
||||
"deleteImage": "Delete Image",
|
||||
"initialImage": "Initial Image",
|
||||
"showOptionsPanel": "Show Options Panel"
|
||||
"showOptionsPanel": "Show Options Panel",
|
||||
"hidePreview": "Hide Preview",
|
||||
"showPreview": "Show Preview"
|
||||
},
|
||||
"settings": {
|
||||
"models": "Models",
|
||||
|
||||
@@ -27,6 +27,7 @@ import {
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldHidePreview,
|
||||
setShouldShowImageDetails,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
@@ -38,6 +39,8 @@ import {
|
||||
FaDownload,
|
||||
FaExpand,
|
||||
FaExpandArrowsAlt,
|
||||
FaEye,
|
||||
FaEyeSlash,
|
||||
FaGrinStars,
|
||||
FaQuoteRight,
|
||||
FaSeedling,
|
||||
@@ -75,7 +78,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
|
||||
const { isLightboxOpen } = lightbox;
|
||||
|
||||
const { shouldShowImageDetails } = ui;
|
||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
||||
|
||||
const { intermediateImage, currentImage } = gallery;
|
||||
|
||||
@@ -91,6 +94,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
shouldShowImageDetails,
|
||||
activeTabName,
|
||||
isLightboxOpen,
|
||||
shouldHidePreview,
|
||||
};
|
||||
},
|
||||
{
|
||||
@@ -120,6 +124,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
currentImage,
|
||||
isLightboxOpen,
|
||||
activeTabName,
|
||||
shouldHidePreview,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
|
||||
const toast = useToast();
|
||||
@@ -188,6 +193,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
[currentImage]
|
||||
);
|
||||
|
||||
const handlePreviewVisibility = () => {
|
||||
dispatch(setShouldHidePreview(!shouldHidePreview));
|
||||
};
|
||||
|
||||
const handleClickUseAllParameters = () => {
|
||||
if (!currentImage) return;
|
||||
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
|
||||
@@ -455,6 +464,21 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
</Link>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
<IAIIconButton
|
||||
icon={shouldHidePreview ? <FaEyeSlash /> : <FaEye />}
|
||||
tooltip={
|
||||
!shouldHidePreview
|
||||
? t('parameters.hidePreview')
|
||||
: t('parameters.showPreview')
|
||||
}
|
||||
aria-label={
|
||||
!shouldHidePreview
|
||||
? t('parameters.hidePreview')
|
||||
: t('parameters.showPreview')
|
||||
}
|
||||
isChecked={shouldHidePreview}
|
||||
onClick={handlePreviewVisibility}
|
||||
/>
|
||||
<IAIIconButton
|
||||
icon={<FaExpand />}
|
||||
tooltip={
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { FaEyeSlash } from 'react-icons/fa';
|
||||
|
||||
const CurrentImageHidden = () => {
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
position: 'absolute',
|
||||
color: 'base.400',
|
||||
}}
|
||||
>
|
||||
<FaEyeSlash size={'30vh'} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default CurrentImageHidden;
|
||||
@@ -10,17 +10,19 @@ import { gallerySelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
import CurrentImageHidden from './CurrentImageHidden';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[gallerySelector, uiSelector],
|
||||
(gallery: GalleryState, ui) => {
|
||||
const { currentImage, intermediateImage } = gallery;
|
||||
const { shouldShowImageDetails } = ui;
|
||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
||||
|
||||
return {
|
||||
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
|
||||
isIntermediate: Boolean(intermediateImage),
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
};
|
||||
},
|
||||
{
|
||||
@@ -31,8 +33,12 @@ export const imagesSelector = createSelector(
|
||||
);
|
||||
|
||||
export default function CurrentImagePreview() {
|
||||
const { shouldShowImageDetails, imageToDisplay, isIntermediate } =
|
||||
useAppSelector(imagesSelector);
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
imageToDisplay,
|
||||
isIntermediate,
|
||||
shouldHidePreview,
|
||||
} = useAppSelector(imagesSelector);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -46,10 +52,16 @@ export default function CurrentImagePreview() {
|
||||
>
|
||||
{imageToDisplay && (
|
||||
<Image
|
||||
src={imageToDisplay.url}
|
||||
src={shouldHidePreview ? undefined : imageToDisplay.url}
|
||||
width={imageToDisplay.width}
|
||||
height={imageToDisplay.height}
|
||||
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
|
||||
fallback={
|
||||
shouldHidePreview ? (
|
||||
<CurrentImageHidden />
|
||||
) : !isIntermediate ? (
|
||||
<CurrentImageFallback />
|
||||
) : undefined
|
||||
}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxWidth: '100%',
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Flex, Image, Text, useToast } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import ImageUploaderIconButton from 'common/components/ImageUploaderIconButton';
|
||||
import CurrentImageHidden from 'features/gallery/components/CurrentImageHidden';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -10,6 +11,8 @@ export default function InitImagePreview() {
|
||||
(state: RootState) => state.generation.initialImage
|
||||
);
|
||||
|
||||
const { shouldHidePreview } = useAppSelector((state: RootState) => state.ui);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -66,8 +69,13 @@ export default function InitImagePreview() {
|
||||
position: 'absolute',
|
||||
}}
|
||||
src={
|
||||
typeof initialImage === 'string' ? initialImage : initialImage.url
|
||||
shouldHidePreview
|
||||
? undefined
|
||||
: typeof initialImage === 'string'
|
||||
? initialImage
|
||||
: initialImage.url
|
||||
}
|
||||
fallback={<CurrentImageHidden />}
|
||||
onError={alertMissingInitImage}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
@@ -16,6 +16,7 @@ const initialtabsState: UIState = {
|
||||
addNewModelUIOption: null,
|
||||
shouldPinGallery: true,
|
||||
shouldShowGallery: true,
|
||||
shouldHidePreview: false,
|
||||
};
|
||||
|
||||
const initialState: UIState = initialtabsState;
|
||||
@@ -53,6 +54,9 @@ export const uiSlice = createSlice({
|
||||
setShouldUseCanvasBetaLayout: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldUseCanvasBetaLayout = action.payload;
|
||||
},
|
||||
setShouldHidePreview: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldHidePreview = action.payload;
|
||||
},
|
||||
setShouldShowExistingModelsInSearch: (
|
||||
state,
|
||||
action: PayloadAction<boolean>
|
||||
@@ -106,6 +110,7 @@ export const {
|
||||
setShouldShowExistingModelsInSearch,
|
||||
setShouldUseSliders,
|
||||
setAddNewModelUIOption,
|
||||
setShouldHidePreview,
|
||||
setShouldPinGallery,
|
||||
setShouldShowGallery,
|
||||
togglePanels,
|
||||
|
||||
@@ -11,6 +11,7 @@ export interface UIState {
|
||||
shouldShowExistingModelsInSearch: boolean;
|
||||
shouldUseSliders: boolean;
|
||||
addNewModelUIOption: AddNewModelType;
|
||||
shouldHidePreview: boolean;
|
||||
shouldPinGallery: boolean;
|
||||
shouldShowGallery: boolean;
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -38,9 +38,9 @@ dependencies = [
|
||||
"albumentations",
|
||||
"click",
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==1.0.4",
|
||||
"compel==1.0.5",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.14",
|
||||
"diffusers[torch]==0.14",
|
||||
"dnspython==2.2.1",
|
||||
"einops",
|
||||
"eventlet",
|
||||
@@ -63,6 +63,7 @@ dependencies = [
|
||||
"prompt-toolkit",
|
||||
"pypatchmatch",
|
||||
"pyreadline3",
|
||||
"python-multipart==0.0.6",
|
||||
"pytorch-lightning==1.7.7",
|
||||
"realesrgan",
|
||||
"requests==2.28.2",
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from .test_invoker import create_edge
|
||||
from .test_nodes import ImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation, PromptCollectionTestInvocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from invokeai.app.invocations.collections import RangeInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -21,13 +23,17 @@ def simple_graph():
|
||||
def mock_services():
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(
|
||||
model_manager = None,
|
||||
events = None,
|
||||
images = None,
|
||||
model_manager = None, # type: ignore
|
||||
events = None, # type: ignore
|
||||
images = None, # type: ignore
|
||||
latents = None, # type: ignore
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=sqlite_memory, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor(),
|
||||
restoration = None,
|
||||
restoration = None, # type: ignore
|
||||
)
|
||||
|
||||
def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]:
|
||||
@@ -73,31 +79,23 @@ def test_graph_is_not_complete(simple_graph, mock_services):
|
||||
|
||||
def test_graph_state_expands_iterator(mock_services):
|
||||
graph = Graph()
|
||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||
graph.add_node(PromptCollectionTestInvocation(id = "1", collection = list(test_prompts)))
|
||||
graph.add_node(IterateInvocation(id = "2"))
|
||||
graph.add_node(ImageTestInvocation(id = "3"))
|
||||
graph.add_edge(create_edge("1", "collection", "2", "collection"))
|
||||
graph.add_edge(create_edge("2", "item", "3", "prompt"))
|
||||
graph.add_node(RangeInvocation(id = "0", start = 0, stop = 3, step = 1))
|
||||
graph.add_node(IterateInvocation(id = "1"))
|
||||
graph.add_node(MultiplyInvocation(id = "2", b = 10))
|
||||
graph.add_node(AddInvocation(id = "3", b = 1))
|
||||
graph.add_edge(create_edge("0", "collection", "1", "collection"))
|
||||
graph.add_edge(create_edge("1", "item", "2", "a"))
|
||||
graph.add_edge(create_edge("2", "a", "3", "a"))
|
||||
|
||||
g = GraphExecutionState(graph = graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n3 = invoke_next(g, mock_services)
|
||||
n4 = invoke_next(g, mock_services)
|
||||
n5 = invoke_next(g, mock_services)
|
||||
while not g.is_complete():
|
||||
invoke_next(g, mock_services)
|
||||
|
||||
prepared_add_nodes = g.source_prepared_mapping['3']
|
||||
results = set([g.results[n].a for n in prepared_add_nodes])
|
||||
expected = set([1, 11, 21])
|
||||
assert results == expected
|
||||
|
||||
assert g.prepared_source_mapping[n1[0].id] == "1"
|
||||
assert g.prepared_source_mapping[n2[0].id] == "2"
|
||||
assert g.prepared_source_mapping[n3[0].id] == "2"
|
||||
assert g.prepared_source_mapping[n4[0].id] == "3"
|
||||
assert g.prepared_source_mapping[n5[0].id] == "3"
|
||||
|
||||
assert isinstance(n4[0], ImageTestInvocation)
|
||||
assert isinstance(n5[0], ImageTestInvocation)
|
||||
|
||||
prompts = [n4[0].prompt, n5[0].prompt]
|
||||
assert sorted(prompts) == sorted(test_prompts)
|
||||
|
||||
def test_graph_state_collects(mock_services):
|
||||
graph = Graph()
|
||||
|
||||
@@ -5,7 +5,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, LibraryGraph, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation, GraphExecutionState
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -24,10 +24,14 @@ def mock_services() -> InvocationServices:
|
||||
model_manager = None, # type: ignore
|
||||
events = TestEventService(),
|
||||
images = None, # type: ignore
|
||||
latents = None, # type: ignore
|
||||
queue = MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=sqlite_memory, table_name="graphs"
|
||||
),
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'),
|
||||
processor = DefaultInvocationProcessor(),
|
||||
restoration = None,
|
||||
restoration = None, # type: ignore
|
||||
)
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from invokeai.app.invocations.image import *
|
||||
|
||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
||||
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||
from invokeai.app.invocations.image import *
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.params import ParamIntInvocation
|
||||
from invokeai.app.services.default_graphs import create_text_to_image
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -417,6 +419,66 @@ def test_graph_gets_subgraph_node():
|
||||
assert result.id == '1'
|
||||
assert result == n1_1
|
||||
|
||||
|
||||
def test_graph_expands_subgraph():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id = "1")
|
||||
n1.graph = Graph()
|
||||
|
||||
n1_1 = AddInvocation(id = "1", a = 1, b = 2)
|
||||
n1_2 = SubtractInvocation(id = "2", b = 3)
|
||||
n1.graph.add_node(n1_1)
|
||||
n1.graph.add_node(n1_2)
|
||||
n1.graph.add_edge(create_edge("1","a","2","a"))
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
n2 = AddInvocation(id = "2", b = 5)
|
||||
g.add_node(n2)
|
||||
g.add_edge(create_edge("1.2","a","2","a"))
|
||||
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == set(['1.1', '1.2', '2'])
|
||||
assert set(dg.edges) == set([('1.1', '1.2'), ('1.2', '2')])
|
||||
|
||||
|
||||
def test_graph_subgraph_t2i():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id = "1")
|
||||
|
||||
# Get text to image default graph
|
||||
lg = create_text_to_image()
|
||||
n1.graph = lg.graph
|
||||
|
||||
g.add_node(n1)
|
||||
|
||||
n2 = ParamIntInvocation(id = "2", a = 512)
|
||||
n3 = ParamIntInvocation(id = "3", a = 256)
|
||||
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
|
||||
g.add_edge(create_edge("2","a","1.width","a"))
|
||||
g.add_edge(create_edge("3","a","1.height","a"))
|
||||
|
||||
n4 = ShowImageInvocation(id = "4")
|
||||
g.add_node(n4)
|
||||
g.add_edge(create_edge("1.5","image","4","image"))
|
||||
|
||||
# Validate
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4'])
|
||||
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
|
||||
expected_edges.extend([
|
||||
('2','1.width'),
|
||||
('3','1.height'),
|
||||
('1.5','4')
|
||||
])
|
||||
print(expected_edges)
|
||||
print(list(dg.edges))
|
||||
assert set(dg.edges) == set(expected_edges)
|
||||
|
||||
|
||||
def test_graph_fails_to_get_missing_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id = "1")
|
||||
|
||||
Reference in New Issue
Block a user