mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 16:28:01 -05:00
Compare commits
3 Commits
dev/ci/upd
...
dev/pytorc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3c50448ccf | ||
|
|
5dec5b6f51 | ||
|
|
e158ad8534 |
14
.github/CODEOWNERS
vendored
14
.github/CODEOWNERS
vendored
@@ -1,16 +1,16 @@
|
|||||||
# continuous integration
|
# continuous integration
|
||||||
/.github/workflows/ @lstein @blessedcoolant
|
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||||
|
|
||||||
# documentation
|
# documentation
|
||||||
/docs/ @lstein @tildebyte @blessedcoolant
|
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||||
/mkdocs.yml @lstein @blessedcoolant
|
/mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||||
|
|
||||||
# nodes
|
# nodes
|
||||||
/invokeai/app/ @Kyle0654 @blessedcoolant
|
/invokeai/app/ @Kyle0654 @blessedcoolant
|
||||||
|
|
||||||
# installation and configuration
|
# installation and configuration
|
||||||
/pyproject.toml @lstein @blessedcoolant
|
/pyproject.toml @mauwii @lstein @blessedcoolant
|
||||||
/docker/ @lstein @blessedcoolant
|
/docker/ @mauwii @lstein @blessedcoolant
|
||||||
/scripts/ @ebr @lstein
|
/scripts/ @ebr @lstein
|
||||||
/installer/ @lstein @ebr
|
/installer/ @lstein @ebr
|
||||||
/invokeai/assets @lstein @ebr
|
/invokeai/assets @lstein @ebr
|
||||||
@@ -22,11 +22,11 @@
|
|||||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||||
|
|
||||||
# generation, model management, postprocessing
|
# generation, model management, postprocessing
|
||||||
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
|
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
|
||||||
|
|
||||||
# front ends
|
# front ends
|
||||||
/invokeai/frontend/CLI @lstein
|
/invokeai/frontend/CLI @lstein
|
||||||
/invokeai/frontend/install @lstein @ebr
|
/invokeai/frontend/install @lstein @ebr @mauwii
|
||||||
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
|
||||||
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
|
||||||
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
/invokeai/frontend/web @psychedelicious @blessedcoolant
|
||||||
|
|||||||
19
.github/stale.yaml
vendored
19
.github/stale.yaml
vendored
@@ -1,19 +0,0 @@
|
|||||||
# Number of days of inactivity before an issue becomes stale
|
|
||||||
daysUntilStale: 28
|
|
||||||
# Number of days of inactivity before a stale issue is closed
|
|
||||||
daysUntilClose: 14
|
|
||||||
# Issues with these labels will never be considered stale
|
|
||||||
exemptLabels:
|
|
||||||
- pinned
|
|
||||||
- security
|
|
||||||
# Label to use when marking an issue as stale
|
|
||||||
staleLabel: stale
|
|
||||||
# Comment to post when marking an issue as stale. Set to `false` to disable
|
|
||||||
markComment: >
|
|
||||||
This issue has been automatically marked as stale because it has not had
|
|
||||||
recent activity. It will be closed if no further activity occurs. Please
|
|
||||||
update the ticket if this is still a problem on the latest release.
|
|
||||||
# Comment to post when closing a stale issue. Set to `false` to disable
|
|
||||||
closeComment: >
|
|
||||||
Due to inactivity, this issue has been automatically closed. If this is
|
|
||||||
still a problem on the latest release, please recreate the issue.
|
|
||||||
15
.github/workflows/mkdocs-material.yml
vendored
15
.github/workflows/mkdocs-material.yml
vendored
@@ -2,7 +2,8 @@ name: mkdocs-material
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'refs/heads/v2.3'
|
- 'main'
|
||||||
|
- 'development'
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
@@ -11,10 +12,6 @@ jobs:
|
|||||||
mkdocs-material:
|
mkdocs-material:
|
||||||
if: github.event.pull_request.draft == false
|
if: github.event.pull_request.draft == false
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
|
||||||
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
|
|
||||||
REPO_NAME: '${{ github.repository }}'
|
|
||||||
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
|
|
||||||
steps:
|
steps:
|
||||||
- name: checkout sources
|
- name: checkout sources
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
@@ -25,15 +22,11 @@ jobs:
|
|||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
cache: pip
|
|
||||||
cache-dependency-path: pyproject.toml
|
|
||||||
|
|
||||||
- name: install requirements
|
- name: install requirements
|
||||||
env:
|
|
||||||
PIP_USE_PEP517: 1
|
|
||||||
run: |
|
run: |
|
||||||
python -m \
|
python -m \
|
||||||
pip install ".[docs]"
|
pip install -r docs/requirements-mkdocs.txt
|
||||||
|
|
||||||
- name: confirm buildability
|
- name: confirm buildability
|
||||||
run: |
|
run: |
|
||||||
@@ -43,7 +36,7 @@ jobs:
|
|||||||
--verbose
|
--verbose
|
||||||
|
|
||||||
- name: deploy to gh-pages
|
- name: deploy to gh-pages
|
||||||
if: ${{ github.ref == 'refs/heads/v2.3' }}
|
if: ${{ github.ref == 'refs/heads/main' }}
|
||||||
run: |
|
run: |
|
||||||
python -m \
|
python -m \
|
||||||
mkdocs gh-deploy \
|
mkdocs gh-deploy \
|
||||||
|
|||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -9,8 +9,6 @@ models/ldm/stable-diffusion-v1/model.ckpt
|
|||||||
configs/models.user.yaml
|
configs/models.user.yaml
|
||||||
config/models.user.yml
|
config/models.user.yml
|
||||||
invokeai.init
|
invokeai.init
|
||||||
.version
|
|
||||||
.last_model
|
|
||||||
|
|
||||||
# ignore the Anaconda/Miniconda installer used while building Docker image
|
# ignore the Anaconda/Miniconda installer used while building Docker image
|
||||||
anaconda.sh
|
anaconda.sh
|
||||||
|
|||||||
@@ -33,8 +33,6 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
_**Note: The UI is not fully functional on `main`. If you need a stable UI based on `main`, use the `pre-nodes` tag while we [migrate to a new backend](https://github.com/invoke-ai/InvokeAI/discussions/3246).**_
|
|
||||||
|
|
||||||
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
|
||||||
|
|
||||||
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
|
||||||
@@ -86,7 +84,7 @@ installing lots of models.
|
|||||||
|
|
||||||
6. Wait while the installer does its thing. After installing the software,
|
6. Wait while the installer does its thing. After installing the software,
|
||||||
the installer will launch a script that lets you configure InvokeAI and
|
the installer will launch a script that lets you configure InvokeAI and
|
||||||
select a set of starting image generation models.
|
select a set of starting image generaiton models.
|
||||||
|
|
||||||
7. Find the folder that InvokeAI was installed into (it is not the
|
7. Find the folder that InvokeAI was installed into (it is not the
|
||||||
same as the unpacked zip file directory!) The default location of this
|
same as the unpacked zip file directory!) The default location of this
|
||||||
@@ -150,11 +148,6 @@ not supported.
|
|||||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||||
```
|
```
|
||||||
|
|
||||||
_For non-GPU systems:_
|
|
||||||
```terminal
|
|
||||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
|
|
||||||
```
|
|
||||||
|
|
||||||
_For Macintoshes, either Intel or M1/M2:_
|
_For Macintoshes, either Intel or M1/M2:_
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
|
|||||||
@@ -1,18 +1,10 @@
|
|||||||
# Invocations
|
# Invocations
|
||||||
|
|
||||||
Invocations represent a single operation, its inputs, and its outputs. These
|
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
|
||||||
operations and their outputs can be chained together to generate and modify
|
|
||||||
images.
|
|
||||||
|
|
||||||
## Creating a new invocation
|
## Creating a new invocation
|
||||||
|
|
||||||
To create a new invocation, either find the appropriate module file in
|
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
|
||||||
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
|
|
||||||
that folder. All invocations in that folder will be discovered and made
|
|
||||||
available to the CLI and API automatically. Invocations make use of
|
|
||||||
[typing](https://docs.python.org/3/library/typing.html) and
|
|
||||||
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
|
|
||||||
into the CLI and API.
|
|
||||||
|
|
||||||
An invocation looks like this:
|
An invocation looks like this:
|
||||||
|
|
||||||
@@ -49,54 +41,34 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
Each portion is important to implement correctly.
|
Each portion is important to implement correctly.
|
||||||
|
|
||||||
### Class definition and type
|
### Class definition and type
|
||||||
|
|
||||||
```py
|
```py
|
||||||
class UpscaleInvocation(BaseInvocation):
|
class UpscaleInvocation(BaseInvocation):
|
||||||
"""Upscales an image."""
|
"""Upscales an image."""
|
||||||
type: Literal['upscale'] = 'upscale'
|
type: Literal['upscale'] = 'upscale'
|
||||||
```
|
```
|
||||||
|
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
|
||||||
All invocations must derive from `BaseInvocation`. They should have a docstring
|
|
||||||
that declares what they do in a single, short line. They should also have a
|
|
||||||
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
|
|
||||||
is what the user will type on the CLI or use in the API to create this
|
|
||||||
invocation. The `command_name` must be unique. The `type` must be assigned to
|
|
||||||
the value of the literal in the type hint.
|
|
||||||
|
|
||||||
### Inputs
|
### Inputs
|
||||||
|
|
||||||
```py
|
```py
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField,None] = Field(description="The input image")
|
image: Union[ImageField,None] = Field(description="The input image")
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||||
level: Literal[2,4] = Field(default=2, description="The upscale level")
|
level: Literal[2,4] = Field(default=2, description="The upscale level")
|
||||||
```
|
```
|
||||||
|
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
|
||||||
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
|
|
||||||
description, and validation information. For example:
|
|
||||||
|
|
||||||
| Part | Value | Description |
|
| Part | Value | Description |
|
||||||
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
|
| ---- | ----- | ----------- |
|
||||||
| Name | `strength` | This field is referred to as `strength` |
|
| Name | `strength` | This field is referred to as `strength` |
|
||||||
| Type Hint | `float` | This field must be of type `float` |
|
| Type Hint | `float` | This field must be of type `float` |
|
||||||
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
|
||||||
|
|
||||||
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
|
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
|
||||||
field to be parsed with `None` as a value, which enables linking to previous
|
|
||||||
invocations. All fields should either provide a default value or allow `None` as
|
|
||||||
a value, so that they can be overwritten with a linked output from another
|
|
||||||
invocation.
|
|
||||||
|
|
||||||
The special type `ImageField` is also used here. All images are passed as
|
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
|
||||||
`ImageField`, which protects them from pydantic validation errors (since images
|
|
||||||
only ever come from links).
|
|
||||||
|
|
||||||
Finally, note that for all linking, the `type` of the linked fields must match.
|
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
|
||||||
If the `name` also matches, then the field can be **automatically linked** to a
|
|
||||||
previous invocation by name and matching.
|
|
||||||
|
|
||||||
### Invoke Function
|
### Invoke Function
|
||||||
|
|
||||||
```py
|
```py
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
image = context.services.images.get(self.image.image_type, self.image.image_name)
|
||||||
@@ -116,22 +88,13 @@ previous invocation by name and matching.
|
|||||||
image = ImageField(image_type = image_type, image_name = image_name)
|
image = ImageField(image_type = image_type, image_name = image_name)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
|
||||||
|
|
||||||
The `invoke` function is the last portion of an invocation. It is provided an
|
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
|
||||||
`InvocationContext` which contains services to perform work as well as a
|
|
||||||
`session_id` for use as needed. It should return a class with output values that
|
|
||||||
derives from `BaseInvocationOutput`.
|
|
||||||
|
|
||||||
Before being called, the invocation will have all of its fields set from
|
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
|
||||||
defaults, inputs, and finally links (overriding in that order).
|
|
||||||
|
|
||||||
Assume that this invocation may be running simultaneously with other
|
|
||||||
invocations, may be running on another machine, or in other interesting
|
|
||||||
scenarios. If you need functionality, please provide it as a service in the
|
|
||||||
`InvocationServices` class, and make sure it can be overridden.
|
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
```py
|
```py
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
@@ -139,64 +102,4 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
image: ImageField = Field(default=None, description="The output image")
|
||||||
```
|
```
|
||||||
|
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.
|
||||||
Output classes look like an invocation class without the invoke method. Prefer
|
|
||||||
to use an existing output class if available, and prefer to name inputs the same
|
|
||||||
as outputs when possible, to promote automatic invocation linking.
|
|
||||||
|
|
||||||
## Schema Generation
|
|
||||||
|
|
||||||
Invocation, output and related classes are used to generate an OpenAPI schema.
|
|
||||||
|
|
||||||
### Required Properties
|
|
||||||
|
|
||||||
The schema generation treat all properties with default values as optional. This
|
|
||||||
makes sense internally, but when when using these classes via the generated
|
|
||||||
schema, we end up with e.g. the `ImageOutput` class having its `image` property
|
|
||||||
marked as optional.
|
|
||||||
|
|
||||||
We know that this property will always be present, so the additional logic
|
|
||||||
needed to always check if the property exists adds a lot of extraneous cruft.
|
|
||||||
|
|
||||||
To fix this, we can leverage `pydantic`'s
|
|
||||||
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
|
|
||||||
to mark properties that we know will always be present as required.
|
|
||||||
|
|
||||||
Here's that `ImageOutput` class, without the needed schema customisation:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output an image"""
|
|
||||||
|
|
||||||
type: Literal["image"] = "image"
|
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
|
||||||
```
|
|
||||||
|
|
||||||
The generated OpenAPI schema, and all clients/types generated from it, will have
|
|
||||||
the `type` and `image` properties marked as optional, even though we know they
|
|
||||||
will always have a value by the time we can interact with them via the API.
|
|
||||||
|
|
||||||
Here's the same class, but with the schema customisation added:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output an image"""
|
|
||||||
|
|
||||||
type: Literal["image"] = "image"
|
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {
|
|
||||||
'required': [
|
|
||||||
'type',
|
|
||||||
'image',
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
The resultant schema (and any API client or types generated from it) will now
|
|
||||||
have see `type` as string literal `"image"` and `image` as an `ImageField`
|
|
||||||
object.
|
|
||||||
|
|
||||||
See this `pydantic` issue for discussion on this solution:
|
|
||||||
<https://github.com/pydantic/pydantic/discussions/4577>
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
|
|||||||
At installation time, InvokeAI will ask whether the checker should be
|
At installation time, InvokeAI will ask whether the checker should be
|
||||||
activated by default (neither argument given on the command line). The
|
activated by default (neither argument given on the command line). The
|
||||||
response is stored in the InvokeAI initialization file (usually
|
response is stored in the InvokeAI initialization file (usually
|
||||||
`invokeai.init` in your home directory). You can change the default at any
|
`.invokeai` in your home directory). You can change the default at any
|
||||||
time by opening this file in a text editor and commenting or
|
time by opening this file in a text editor and commenting or
|
||||||
uncommenting the line `--nsfw_checker`.
|
uncommenting the line `--nsfw_checker`.
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ experimental versions later.
|
|||||||
sudo apt update
|
sudo apt update
|
||||||
sudo apt install -y software-properties-common
|
sudo apt install -y software-properties-common
|
||||||
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
||||||
sudo apt install -y python3.10 python3-pip python3.10-venv
|
sudo apt install python3.10 python3-pip python3.10-venv
|
||||||
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ subset that are currently installed are found in
|
|||||||
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|
||||||
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|
||||||
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||||
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|
||||||
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|
||||||
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|
||||||
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |
|
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |
|
||||||
|
|||||||
@@ -461,8 +461,7 @@ def get_torch_source() -> (Union[str, None],str):
|
|||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
url = 'https://download.pytorch.org/whl/cu117'
|
url = 'https://download.pytorch.org/whl/cu118'
|
||||||
optional_modules = '[xformers]'
|
|
||||||
|
|
||||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +1,20 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from typing import types
|
|
||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.model_manager_initializer import get_model_manager
|
from ..services.model_manager_initializer import get_model_manager
|
||||||
from ..services.restoration_services import RestorationServices
|
from ..services.restoration_services import RestorationServices
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_storage import DiskImageStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.metadata import PngMetadataService
|
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@@ -43,16 +40,15 @@ class ApiDependencies:
|
|||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
def initialize(config, event_handler_id: int):
|
||||||
Globals.try_patchmatch = config.patchmatch
|
Globals.try_patchmatch = config.patchmatch
|
||||||
Globals.always_use_cpu = config.always_use_cpu
|
Globals.always_use_cpu = config.always_use_cpu
|
||||||
Globals.internet_available = config.internet_available and check_internet()
|
Globals.internet_available = config.internet_available and check_internet()
|
||||||
Globals.disable_xformers = not config.xformers
|
Globals.disable_xformers = not config.xformers
|
||||||
Globals.ckpt_convert = config.ckpt_convert
|
Globals.ckpt_convert = config.ckpt_convert
|
||||||
|
|
||||||
# TO DO: Use the config to select the logger rather than use the default
|
# TODO: Use a logger
|
||||||
# invokeai logging module
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
@@ -62,33 +58,24 @@ class ApiDependencies:
|
|||||||
|
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
||||||
|
|
||||||
metadata = PngMetadataService()
|
images = DiskImageStorage(f'{output_folder}/images')
|
||||||
|
|
||||||
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
|
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=get_model_manager(config,logger),
|
model_manager=get_model_manager(config),
|
||||||
events=events,
|
events=events,
|
||||||
logger=logger,
|
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
metadata=metadata,
|
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
|
||||||
filename=db_location, table_name="graphs"
|
|
||||||
),
|
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger),
|
restoration=RestorationServices(config),
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class FastAPIEventService(EventServiceBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except Empty:
|
except Empty:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.001)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except asyncio.CancelledError as e:
|
except asyncio.CancelledError as e:
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
|
||||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResponseMetadata(BaseModel):
|
|
||||||
"""An image's metadata. Used only in HTTP responses."""
|
|
||||||
|
|
||||||
created: int = Field(description="The creation timestamp of the image")
|
|
||||||
width: int = Field(description="The width of the image in pixels")
|
|
||||||
height: int = Field(description="The height of the image in pixels")
|
|
||||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
|
||||||
description="The image's InvokeAI-specific metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(BaseModel):
|
|
||||||
"""The response type for images"""
|
|
||||||
|
|
||||||
image_type: ImageType = Field(description="The type of the image")
|
|
||||||
image_name: str = Field(description="The name of the image")
|
|
||||||
image_url: str = Field(description="The url of the image")
|
|
||||||
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
|
||||||
metadata: ImageResponseMetadata = Field(description="The image's metadata")
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressImage(BaseModel):
|
|
||||||
"""The progress image sent intermittently during processing"""
|
|
||||||
|
|
||||||
width: int = Field(description="The effective width of the image in pixels")
|
|
||||||
height: int = Field(description="The effective height of the image in pixels")
|
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
|
||||||
|
|
||||||
|
|
||||||
class SavedImage(BaseModel):
|
|
||||||
image_name: str = Field(description="The name of the saved image")
|
|
||||||
thumbnail_name: str = Field(description="The name of the saved thumbnail")
|
|
||||||
created: int = Field(description="The created timestamp of the saved image")
|
|
||||||
@@ -1,20 +1,11 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
import io
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import Path, Request, UploadFile
|
||||||
from fastapi.responses import FileResponse, Response
|
from fastapi.responses import FileResponse, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from invokeai.app.api.models.images import (
|
|
||||||
ImageResponse,
|
|
||||||
ImageResponseMetadata,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
|
||||||
|
|
||||||
from ...services.image_storage import ImageType
|
from ...services.image_storage import ImageType
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@@ -26,123 +17,50 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
|||||||
async def get_image(
|
async def get_image(
|
||||||
image_type: ImageType = Path(description="The type of image to get"),
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
image_name: str = Path(description="The name of the image to get"),
|
image_name: str = Path(description="The name of the image to get"),
|
||||||
) -> FileResponse:
|
):
|
||||||
"""Gets an image"""
|
"""Gets a result"""
|
||||||
|
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||||
|
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
|
||||||
|
return FileResponse(filename)
|
||||||
|
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
|
||||||
image_type=image_type, image_name=image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
|
||||||
return FileResponse(path)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=404)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
|
||||||
async def delete_image(
|
|
||||||
image_type: ImageType = Path(description="The type of image to delete"),
|
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
|
||||||
) -> None:
|
|
||||||
"""Deletes an image and its thumbnail"""
|
|
||||||
|
|
||||||
ApiDependencies.invoker.services.images.delete(
|
|
||||||
image_type=image_type, image_name=image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
|
||||||
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
|
|
||||||
)
|
|
||||||
async def get_thumbnail(
|
async def get_thumbnail(
|
||||||
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
|
image_name: str = Path(description="The name of the image to get"),
|
||||||
) -> FileResponse | Response:
|
):
|
||||||
"""Gets a thumbnail"""
|
"""Gets a thumbnail"""
|
||||||
|
# TODO: This is not really secure at all. At least make sure only output results are served
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
|
||||||
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
|
return FileResponse(filename)
|
||||||
)
|
|
||||||
|
|
||||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
|
||||||
return FileResponse(path)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=404)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.post(
|
@images_router.post(
|
||||||
"/uploads/",
|
"/uploads/",
|
||||||
operation_id="upload_image",
|
operation_id="upload_image",
|
||||||
responses={
|
responses={
|
||||||
201: {
|
201: {"description": "The image was uploaded successfully"},
|
||||||
"description": "The image was uploaded successfully",
|
404: {"description": "Session not found"},
|
||||||
"model": ImageResponse,
|
|
||||||
},
|
},
|
||||||
415: {"description": "Image upload failed"},
|
|
||||||
},
|
|
||||||
status_code=201,
|
|
||||||
)
|
)
|
||||||
async def upload_image(
|
async def upload_image(file: UploadFile, request: Request):
|
||||||
file: UploadFile, request: Request, response: Response
|
|
||||||
) -> ImageResponse:
|
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
return Response(status_code=415)
|
||||||
|
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img = Image.open(io.BytesIO(contents))
|
im = Image.open(contents)
|
||||||
except:
|
except:
|
||||||
# Error opening the image
|
# Error opening the image
|
||||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
return Response(status_code=415)
|
||||||
|
|
||||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
||||||
|
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
|
||||||
|
|
||||||
saved_image = ApiDependencies.invoker.services.images.save(
|
return Response(
|
||||||
ImageType.UPLOAD, filename, img
|
status_code=201,
|
||||||
|
headers={
|
||||||
|
"Location": request.url_for(
|
||||||
|
"get_image", image_type=ImageType.UPLOAD, image_name=filename
|
||||||
)
|
)
|
||||||
|
},
|
||||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
|
||||||
|
|
||||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
|
||||||
ImageType.UPLOAD, saved_image.image_name
|
|
||||||
)
|
)
|
||||||
|
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
|
||||||
ImageType.UPLOAD, saved_image.image_name, True
|
|
||||||
)
|
|
||||||
|
|
||||||
res = ImageResponse(
|
|
||||||
image_type=ImageType.UPLOAD,
|
|
||||||
image_name=saved_image.image_name,
|
|
||||||
image_url=image_url,
|
|
||||||
thumbnail_url=thumbnail_url,
|
|
||||||
metadata=ImageResponseMetadata(
|
|
||||||
created=saved_image.created,
|
|
||||||
width=img.width,
|
|
||||||
height=img.height,
|
|
||||||
invokeai=invokeai_metadata,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
response.status_code = 201
|
|
||||||
response.headers["Location"] = image_url
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
|
||||||
"/",
|
|
||||||
operation_id="list_images",
|
|
||||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
|
||||||
)
|
|
||||||
async def list_images(
|
|
||||||
image_type: ImageType = Query(
|
|
||||||
default=ImageType.RESULT, description="The type of images to get"
|
|
||||||
),
|
|
||||||
page: int = Query(default=0, description="The page of images to get"),
|
|
||||||
per_page: int = Query(default=10, description="The number of images per page"),
|
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
"""Gets a list of images"""
|
|
||||||
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import shutil
|
|
||||||
import asyncio
|
|
||||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from pathlib import Path
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
@@ -17,9 +15,11 @@ class VaeRepo(BaseModel):
|
|||||||
path: Optional[str] = Field(description="The path to the VAE")
|
path: Optional[str] = Field(description="The path to the VAE")
|
||||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
description: Optional[str] = Field(description="A description of the model")
|
description: Optional[str] = Field(description="A description of the model")
|
||||||
|
|
||||||
|
|
||||||
class CkptModelInfo(ModelInfo):
|
class CkptModelInfo(ModelInfo):
|
||||||
format: Literal['ckpt'] = 'ckpt'
|
format: Literal['ckpt'] = 'ckpt'
|
||||||
|
|
||||||
@@ -29,6 +29,7 @@ class CkptModelInfo(ModelInfo):
|
|||||||
width: Optional[int] = Field(description="The width of the model")
|
width: Optional[int] = Field(description="The width of the model")
|
||||||
height: Optional[int] = Field(description="The height of the model")
|
height: Optional[int] = Field(description="The height of the model")
|
||||||
|
|
||||||
|
|
||||||
class DiffusersModelInfo(ModelInfo):
|
class DiffusersModelInfo(ModelInfo):
|
||||||
format: Literal['diffusers'] = 'diffusers'
|
format: Literal['diffusers'] = 'diffusers'
|
||||||
|
|
||||||
@@ -36,29 +37,12 @@ class DiffusersModelInfo(ModelInfo):
|
|||||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
||||||
path: Optional[str] = Field(description="The path to the model")
|
path: Optional[str] = Field(description="The path to the model")
|
||||||
|
|
||||||
class CreateModelRequest(BaseModel):
|
|
||||||
name: str = Field(description="The name of the model")
|
|
||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
|
||||||
|
|
||||||
class CreateModelResponse(BaseModel):
|
|
||||||
name: str = Field(description="The name of the new model")
|
|
||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
|
||||||
status: str = Field(description="The status of the API response")
|
|
||||||
|
|
||||||
class ConversionRequest(BaseModel):
|
|
||||||
name: str = Field(description="The name of the new model")
|
|
||||||
info: CkptModelInfo = Field(description="The converted model info")
|
|
||||||
save_location: str = Field(description="The path to save the converted model weights")
|
|
||||||
|
|
||||||
|
|
||||||
class ConvertedModelResponse(BaseModel):
|
|
||||||
name: str = Field(description="The name of the new model")
|
|
||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_models",
|
operation_id="list_models",
|
||||||
@@ -70,60 +54,106 @@ async def list_models() -> ModelsList:
|
|||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
# @socketio.on("requestSystemConfig")
|
||||||
|
# def handle_request_capabilities():
|
||||||
|
# print(">> System config requested")
|
||||||
|
# config = self.get_system_config()
|
||||||
|
# config["model_list"] = self.generate.model_manager.list_models()
|
||||||
|
# config["infill_methods"] = infill_methods()
|
||||||
|
# socketio.emit("systemConfig", config)
|
||||||
|
|
||||||
@models_router.post(
|
# @socketio.on("searchForModels")
|
||||||
"/",
|
# def handle_search_models(search_folder: str):
|
||||||
operation_id="update_model",
|
# try:
|
||||||
responses={200: {"status": "success"}},
|
# if not search_folder:
|
||||||
)
|
# socketio.emit(
|
||||||
async def update_model(
|
# "foundModels",
|
||||||
model_request: CreateModelRequest
|
# {"search_folder": None, "found_models": None},
|
||||||
) -> CreateModelResponse:
|
# )
|
||||||
""" Add Model """
|
# else:
|
||||||
model_request_info = model_request.info
|
# (
|
||||||
info_dict = model_request_info.dict()
|
# search_folder,
|
||||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
# found_models,
|
||||||
|
# ) = self.generate.model_manager.search_models(search_folder)
|
||||||
|
# socketio.emit(
|
||||||
|
# "foundModels",
|
||||||
|
# {"search_folder": search_folder, "found_models": found_models},
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
# print("\n")
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
# @socketio.on("addNewModel")
|
||||||
model_name=model_request.name,
|
# def handle_add_model(new_model_config: dict):
|
||||||
model_attributes=info_dict,
|
# try:
|
||||||
clobber=True,
|
# model_name = new_model_config["name"]
|
||||||
)
|
# del new_model_config["name"]
|
||||||
|
# model_attributes = new_model_config
|
||||||
|
# if len(model_attributes["vae"]) == 0:
|
||||||
|
# del model_attributes["vae"]
|
||||||
|
# update = False
|
||||||
|
# current_model_list = self.generate.model_manager.list_models()
|
||||||
|
# if model_name in current_model_list:
|
||||||
|
# update = True
|
||||||
|
|
||||||
return model_response
|
# print(f">> Adding New Model: {model_name}")
|
||||||
|
|
||||||
|
# self.generate.model_manager.add_model(
|
||||||
|
# model_name=model_name,
|
||||||
|
# model_attributes=model_attributes,
|
||||||
|
# clobber=True,
|
||||||
|
# )
|
||||||
|
# self.generate.model_manager.commit(opt.conf)
|
||||||
|
|
||||||
@models_router.delete(
|
# new_model_list = self.generate.model_manager.list_models()
|
||||||
"/{model_name}",
|
# socketio.emit(
|
||||||
operation_id="del_model",
|
# "newModelAdded",
|
||||||
responses={
|
# {
|
||||||
204: {
|
# "new_model_name": model_name,
|
||||||
"description": "Model deleted successfully"
|
# "model_list": new_model_list,
|
||||||
},
|
# "update": update,
|
||||||
404: {
|
# },
|
||||||
"description": "Model not found"
|
# )
|
||||||
}
|
# print(f">> New Model Added: {model_name}")
|
||||||
},
|
# except Exception as e:
|
||||||
)
|
# self.handle_exceptions(e)
|
||||||
async def delete_model(model_name: str) -> None:
|
|
||||||
"""Delete Model"""
|
|
||||||
model_names = ApiDependencies.invoker.services.model_manager.model_names()
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
|
||||||
model_exists = model_name in model_names
|
|
||||||
|
|
||||||
# check if model exists
|
# @socketio.on("deleteModel")
|
||||||
logger.info(f"Checking for model {model_name}...")
|
# def handle_delete_model(model_name: str):
|
||||||
|
# try:
|
||||||
if model_exists:
|
# print(f">> Deleting Model: {model_name}")
|
||||||
logger.info(f"Deleting Model: {model_name}")
|
# self.generate.model_manager.del_model(model_name)
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
|
# self.generate.model_manager.commit(opt.conf)
|
||||||
logger.info(f"Model Deleted: {model_name}")
|
# updated_model_list = self.generate.model_manager.list_models()
|
||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
# socketio.emit(
|
||||||
|
# "modelDeleted",
|
||||||
else:
|
# {
|
||||||
logger.error(f"Model not found")
|
# "deleted_model_name": model_name,
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
# "model_list": updated_model_list,
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# print(f">> Model Deleted: {model_name}")
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
|
# @socketio.on("requestModelChange")
|
||||||
|
# def handle_set_model(model_name: str):
|
||||||
|
# try:
|
||||||
|
# print(f">> Model change requested: {model_name}")
|
||||||
|
# model = self.generate.set_model(model_name)
|
||||||
|
# model_list = self.generate.model_manager.list_models()
|
||||||
|
# if model is None:
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelChangeFailed",
|
||||||
|
# {"model_name": model_name, "model_list": model_list},
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# socketio.emit(
|
||||||
|
# "modelChanged",
|
||||||
|
# {"model_name": model_name, "model_list": model_list},
|
||||||
|
# )
|
||||||
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
|
|
||||||
# @socketio.on("convertToDiffusers")
|
# @socketio.on("convertToDiffusers")
|
||||||
# def convert_to_diffusers(model_to_convert: dict):
|
# def convert_to_diffusers(model_to_convert: dict):
|
||||||
@@ -246,3 +276,4 @@ async def delete_model(model_name: str) -> None:
|
|||||||
# print(f">> Models Merged: {models_to_merge}")
|
# print(f">> Models Merged: {models_to_merge}")
|
||||||
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
|
# self.handle_exceptions(e)
|
||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
from typing import Annotated, List, Optional, Union
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Response
|
from fastapi import Body, Path, Query
|
||||||
|
from fastapi.responses import Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
@@ -75,7 +76,7 @@ async def get_session(
|
|||||||
"""Gets a session"""
|
"""Gets a session"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
else:
|
else:
|
||||||
return session
|
return session
|
||||||
|
|
||||||
@@ -98,7 +99,7 @@ async def add_node(
|
|||||||
"""Adds a node to the graph"""
|
"""Adds a node to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.add_node(node)
|
session.add_node(node)
|
||||||
@@ -107,9 +108,9 @@ async def add_node(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session.id
|
return session.id
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.put(
|
@session_router.put(
|
||||||
@@ -131,7 +132,7 @@ async def update_node(
|
|||||||
"""Updates a node in the graph and removes all linked edges"""
|
"""Updates a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.update_node(node_path, node)
|
session.update_node(node_path, node)
|
||||||
@@ -140,9 +141,9 @@ async def update_node(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.delete(
|
@session_router.delete(
|
||||||
@@ -161,7 +162,7 @@ async def delete_node(
|
|||||||
"""Deletes a node in the graph and removes all linked edges"""
|
"""Deletes a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.delete_node(node_path)
|
session.delete_node(node_path)
|
||||||
@@ -170,9 +171,9 @@ async def delete_node(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.post(
|
@session_router.post(
|
||||||
@@ -191,7 +192,7 @@ async def add_edge(
|
|||||||
"""Adds an edge to the graph"""
|
"""Adds an edge to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session.add_edge(edge)
|
session.add_edge(edge)
|
||||||
@@ -200,9 +201,9 @@ async def add_edge(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
# TODO: the edge being in the path here is really ugly, find a better solution
|
# TODO: the edge being in the path here is really ugly, find a better solution
|
||||||
@@ -225,7 +226,7 @@ async def delete_edge(
|
|||||||
"""Deletes an edge from the graph"""
|
"""Deletes an edge from the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
edge = Edge(
|
edge = Edge(
|
||||||
@@ -238,9 +239,9 @@ async def delete_edge(
|
|||||||
) # TODO: can this be done automatically, or add node through an API?
|
) # TODO: can this be done automatically, or add node through an API?
|
||||||
return session
|
return session
|
||||||
except NodeAlreadyExecutedError:
|
except NodeAlreadyExecutedError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
|
|
||||||
@session_router.put(
|
@session_router.put(
|
||||||
@@ -258,14 +259,14 @@ async def invoke_session(
|
|||||||
all: bool = Query(
|
all: bool = Query(
|
||||||
default=False, description="Whether or not to invoke all remaining invocations"
|
default=False, description="Whether or not to invoke all remaining invocations"
|
||||||
),
|
),
|
||||||
) -> Response:
|
) -> None:
|
||||||
"""Invokes a session"""
|
"""Invokes a session"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
if session is None:
|
if session is None:
|
||||||
raise HTTPException(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
if session.is_complete():
|
if session.is_complete():
|
||||||
raise HTTPException(status_code=400)
|
return Response(status_code=400)
|
||||||
|
|
||||||
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
ApiDependencies.invoker.invoke(session, invoke_all=all)
|
||||||
return Response(status_code=202)
|
return Response(status_code=202)
|
||||||
@@ -280,7 +281,7 @@ async def invoke_session(
|
|||||||
)
|
)
|
||||||
async def cancel_session_invoke(
|
async def cancel_session_invoke(
|
||||||
session_id: str = Path(description="The id of the session to cancel"),
|
session_id: str = Path(description="The id of the session to cancel"),
|
||||||
) -> Response:
|
) -> None:
|
||||||
"""Invokes a session"""
|
"""Invokes a session"""
|
||||||
ApiDependencies.invoker.cancel(session_id)
|
ApiDependencies.invoker.cancel(session_id)
|
||||||
return Response(status_code=202)
|
return Response(status_code=202)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import asyncio
|
|||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
@@ -17,6 +16,7 @@ from ..backend import Args
|
|||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import images, sessions, models
|
from .api.routers import images, sessions, models
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
@@ -56,7 +56,7 @@ async def startup_event():
|
|||||||
config.parse_args()
|
config.parse_args()
|
||||||
|
|
||||||
ApiDependencies.initialize(
|
ApiDependencies.initialize(
|
||||||
config=config, event_handler_id=event_handler_id, logger=logger
|
config=config, event_handler_id=event_handler_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,46 +2,15 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
|
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
|
||||||
from ..invocations.image import ImageField
|
from ..invocations.image import ImageField
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
def add_field_argument(command_parser, name: str, field, default_override = None):
|
|
||||||
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
|
||||||
if get_origin(field.type_) == Literal:
|
|
||||||
allowed_values = get_args(field.type_)
|
|
||||||
allowed_types = set()
|
|
||||||
for val in allowed_values:
|
|
||||||
allowed_types.add(type(val))
|
|
||||||
allowed_types_list = list(allowed_types)
|
|
||||||
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
|
||||||
|
|
||||||
command_parser.add_argument(
|
|
||||||
f"--{name}",
|
|
||||||
dest=name,
|
|
||||||
type=field_type,
|
|
||||||
default=default,
|
|
||||||
choices=allowed_values,
|
|
||||||
help=field.field_info.description,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
command_parser.add_argument(
|
|
||||||
f"--{name}",
|
|
||||||
dest=name,
|
|
||||||
type=field.type_,
|
|
||||||
default=default,
|
|
||||||
help=field.field_info.description,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_parsers(
|
def add_parsers(
|
||||||
subparsers,
|
subparsers,
|
||||||
commands: list[type],
|
commands: list[type],
|
||||||
@@ -66,26 +35,30 @@ def add_parsers(
|
|||||||
if name in exclude_fields:
|
if name in exclude_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
add_field_argument(command_parser, name, field)
|
if get_origin(field.type_) == Literal:
|
||||||
|
allowed_values = get_args(field.type_)
|
||||||
|
allowed_types = set()
|
||||||
|
for val in allowed_values:
|
||||||
|
allowed_types.add(type(val))
|
||||||
|
allowed_types_list = list(allowed_types)
|
||||||
|
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
|
||||||
|
|
||||||
|
command_parser.add_argument(
|
||||||
def add_graph_parsers(
|
f"--{name}",
|
||||||
subparsers,
|
dest=name,
|
||||||
graphs: list[LibraryGraph],
|
type=field_type,
|
||||||
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
|
default=field.default if field.default_factory is None else field.default_factory(),
|
||||||
):
|
choices=allowed_values,
|
||||||
for graph in graphs:
|
help=field.field_info.description,
|
||||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
)
|
||||||
|
else:
|
||||||
if add_arguments is not None:
|
command_parser.add_argument(
|
||||||
add_arguments(command_parser)
|
f"--{name}",
|
||||||
|
dest=name,
|
||||||
# Add arguments for inputs
|
type=field.type_,
|
||||||
for exposed_input in graph.exposed_inputs:
|
default=field.default if field.default_factory is None else field.default_factory(),
|
||||||
node = graph.graph.get_node(exposed_input.node_path)
|
help=field.field_info.description,
|
||||||
field = node.__fields__[exposed_input.field]
|
)
|
||||||
default_override = getattr(node, exposed_input.field)
|
|
||||||
add_field_argument(command_parser, exposed_input.alias, field, default_override)
|
|
||||||
|
|
||||||
|
|
||||||
class CliContext:
|
class CliContext:
|
||||||
@@ -93,38 +66,17 @@ class CliContext:
|
|||||||
session: GraphExecutionState
|
session: GraphExecutionState
|
||||||
parser: argparse.ArgumentParser
|
parser: argparse.ArgumentParser
|
||||||
defaults: dict[str, Any]
|
defaults: dict[str, Any]
|
||||||
graph_nodes: dict[str, str]
|
|
||||||
nodes_added: list[str]
|
|
||||||
|
|
||||||
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
|
||||||
self.invoker = invoker
|
self.invoker = invoker
|
||||||
self.session = session
|
self.session = session
|
||||||
self.parser = parser
|
self.parser = parser
|
||||||
self.defaults = dict()
|
self.defaults = dict()
|
||||||
self.graph_nodes = dict()
|
|
||||||
self.nodes_added = list()
|
|
||||||
|
|
||||||
def get_session(self):
|
def get_session(self):
|
||||||
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
|
||||||
return self.session
|
return self.session
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.session = self.invoker.create_execution_state()
|
|
||||||
self.graph_nodes = dict()
|
|
||||||
self.nodes_added = list()
|
|
||||||
# Leave defaults unchanged
|
|
||||||
|
|
||||||
def add_node(self, node: BaseInvocation):
|
|
||||||
self.get_session()
|
|
||||||
self.session.graph.add_node(node)
|
|
||||||
self.nodes_added.append(node.id)
|
|
||||||
self.invoker.services.graph_execution_manager.set(self.session)
|
|
||||||
|
|
||||||
def add_edge(self, edge: Edge):
|
|
||||||
self.get_session()
|
|
||||||
self.session.add_edge(edge)
|
|
||||||
self.invoker.services.graph_execution_manager.set(self.session)
|
|
||||||
|
|
||||||
|
|
||||||
class ExitCli(Exception):
|
class ExitCli(Exception):
|
||||||
"""Exception to exit the CLI"""
|
"""Exception to exit the CLI"""
|
||||||
@@ -230,7 +182,7 @@ class HistoryCommand(BaseCommand):
|
|||||||
for i in range(min(self.count, len(history))):
|
for i in range(min(self.count, len(history))):
|
||||||
entry_id = history[-1 - i]
|
entry_id = history[-1 - i]
|
||||||
entry = context.get_session().graph.get_node(entry_id)
|
entry = context.get_session().graph.get_node(entry_id)
|
||||||
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
|
print(f"{entry_id}: {get_invocation_command(entry)}")
|
||||||
|
|
||||||
|
|
||||||
class SetDefaultCommand(BaseCommand):
|
class SetDefaultCommand(BaseCommand):
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import shlex
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ...backend import ModelManager, Globals
|
from ...backend import ModelManager, Globals
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from .commands import BaseCommand
|
from .commands import BaseCommand
|
||||||
@@ -161,8 +160,8 @@ def set_autocompleter(model_manager: ModelManager) -> Completer:
|
|||||||
pass
|
pass
|
||||||
except OSError: # file likely corrupted
|
except OSError: # file likely corrupted
|
||||||
newname = f"{histfile}.old"
|
newname = f"{histfile}.old"
|
||||||
logger.error(
|
print(
|
||||||
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
f"## Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||||
)
|
)
|
||||||
histfile.replace(Path(newname))
|
histfile.replace(Path(newname))
|
||||||
atexit.register(readline.write_history_file, histfile)
|
atexit.register(readline.write_history_file, histfile)
|
||||||
|
|||||||
@@ -13,21 +13,17 @@ from typing import (
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.metadata import PngMetadataService
|
|
||||||
from .services.default_graphs import create_system_graphs
|
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from ..backend import Args
|
from ..backend import Args
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
|
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.model_manager_initializer import get_model_manager
|
from .services.model_manager_initializer import get_model_manager
|
||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
from .services.graph import Edge, EdgeConnection, GraphExecutionState, are_connection_types_compatible
|
||||||
from .services.default_graphs import default_text_to_image_graph_id
|
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
@@ -62,7 +58,7 @@ def add_invocation_args(command_parser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
def get_command_parser() -> argparse.ArgumentParser:
|
||||||
# Create invocation parser
|
# Create invocation parser
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@@ -80,72 +76,20 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
|||||||
commands = BaseCommand.get_all_subclasses()
|
commands = BaseCommand.get_all_subclasses()
|
||||||
add_parsers(subparsers, commands, exclude_fields=["type"])
|
add_parsers(subparsers, commands, exclude_fields=["type"])
|
||||||
|
|
||||||
# Create subparsers for exposed CLI graphs
|
|
||||||
# TODO: add a way to identify these graphs
|
|
||||||
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
|
|
||||||
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class NodeField():
|
|
||||||
alias: str
|
|
||||||
node_path: str
|
|
||||||
field: str
|
|
||||||
field_type: type
|
|
||||||
|
|
||||||
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
|
|
||||||
self.alias = alias
|
|
||||||
self.node_path = node_path
|
|
||||||
self.field = field
|
|
||||||
self.field_type = field_type
|
|
||||||
|
|
||||||
|
|
||||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
|
||||||
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
|
||||||
"""Gets the node field for the specified field alias"""
|
|
||||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
|
||||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
|
||||||
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
|
||||||
|
|
||||||
|
|
||||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
|
||||||
"""Gets the node field for the specified field alias"""
|
|
||||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
|
||||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
|
||||||
node_output_type = node_type.get_output_type()
|
|
||||||
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
|
||||||
|
|
||||||
|
|
||||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
|
||||||
"""Gets the inputs for the specified invocation from the context"""
|
|
||||||
node_type = type(invocation)
|
|
||||||
if node_type is not GraphInvocation:
|
|
||||||
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
|
|
||||||
else:
|
|
||||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
|
||||||
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
|
|
||||||
|
|
||||||
|
|
||||||
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
|
||||||
"""Gets the outputs for the specified invocation from the context"""
|
|
||||||
node_type = type(invocation)
|
|
||||||
if node_type is not GraphInvocation:
|
|
||||||
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
|
|
||||||
else:
|
|
||||||
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
|
|
||||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
|
||||||
|
|
||||||
|
|
||||||
def generate_matching_edges(
|
def generate_matching_edges(
|
||||||
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
a: BaseInvocation, b: BaseInvocation
|
||||||
) -> list[Edge]:
|
) -> list[Edge]:
|
||||||
"""Generates all possible edges between two invocations"""
|
"""Generates all possible edges between two invocations"""
|
||||||
afields = get_node_outputs(a, context)
|
atype = type(a)
|
||||||
bfields = get_node_inputs(b, context)
|
btype = type(b)
|
||||||
|
|
||||||
|
aoutputtype = atype.get_output_type()
|
||||||
|
|
||||||
|
afields = get_type_hints(aoutputtype)
|
||||||
|
bfields = get_type_hints(btype)
|
||||||
|
|
||||||
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
matching_fields = set(afields.keys()).intersection(bfields.keys())
|
||||||
|
|
||||||
@@ -154,14 +98,14 @@ def generate_matching_edges(
|
|||||||
matching_fields = matching_fields.difference(invalid_fields)
|
matching_fields = matching_fields.difference(invalid_fields)
|
||||||
|
|
||||||
# Validate types
|
# Validate types
|
||||||
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f], bfields[f])]
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
source=EdgeConnection(node_id=a.id, field=field),
|
||||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
destination=EdgeConnection(node_id=b.id, field=field)
|
||||||
)
|
)
|
||||||
for alias in matching_fields
|
for field in matching_fields
|
||||||
]
|
]
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
@@ -181,7 +125,7 @@ def invoke_all(context: CliContext):
|
|||||||
# Print any errors
|
# Print any errors
|
||||||
if context.session.has_error():
|
if context.session.has_error():
|
||||||
for n in context.session.errors:
|
for n in context.session.errors:
|
||||||
context.invoker.services.logger.error(
|
print(
|
||||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -191,18 +135,16 @@ def invoke_all(context: CliContext):
|
|||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
config = Args()
|
config = Args()
|
||||||
config.parse_args()
|
config.parse_args()
|
||||||
model_manager = get_model_manager(config,logger=logger)
|
model_manager = get_model_manager(config)
|
||||||
|
|
||||||
# This initializes the autocompleter and returns it.
|
# This initializes the autocompleter and returns it.
|
||||||
# Currently nothing is done with the returned Completer
|
# Currently nothing is done with the returned Completer
|
||||||
# object, but the object can be used to change autocompletion
|
# object, but the object can be used to change autocompletion
|
||||||
# behavior on the fly, if desired.
|
# behavior on the fly, if desired.
|
||||||
set_autocompleter(model_manager)
|
completer = set_autocompleter(model_manager)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
metadata = PngMetadataService()
|
|
||||||
|
|
||||||
output_folder = os.path.abspath(
|
output_folder = os.path.abspath(
|
||||||
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
os.path.join(os.path.dirname(__file__), "../../../outputs")
|
||||||
)
|
)
|
||||||
@@ -214,26 +156,18 @@ def invoke_cli():
|
|||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||||
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
images=DiskImageStorage(f'{output_folder}/images'),
|
||||||
metadata=metadata,
|
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
|
||||||
filename=db_location, table_name="graphs"
|
|
||||||
),
|
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
),
|
),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger=logger),
|
restoration=RestorationServices(config),
|
||||||
logger=logger,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
system_graphs = create_system_graphs(services.graph_library)
|
|
||||||
system_graph_names = set([g.name for g in system_graphs])
|
|
||||||
|
|
||||||
invoker = Invoker(services)
|
invoker = Invoker(services)
|
||||||
session: GraphExecutionState = invoker.create_execution_state()
|
session: GraphExecutionState = invoker.create_execution_state()
|
||||||
parser = get_command_parser(services)
|
parser = get_command_parser()
|
||||||
|
|
||||||
re_negid = re.compile('^-[0-9]+$')
|
re_negid = re.compile('^-[0-9]+$')
|
||||||
|
|
||||||
@@ -251,12 +185,11 @@ def invoke_cli():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Refresh the state of the session
|
# Refresh the state of the session
|
||||||
#history = list(get_graph_execution_history(context.session))
|
history = list(get_graph_execution_history(context.session))
|
||||||
history = list(reversed(context.nodes_added))
|
|
||||||
|
|
||||||
# Split the command for piping
|
# Split the command for piping
|
||||||
cmds = cmd_input.split("|")
|
cmds = cmd_input.split("|")
|
||||||
start_id = len(context.nodes_added)
|
start_id = len(history)
|
||||||
current_id = start_id
|
current_id = start_id
|
||||||
new_invocations = list()
|
new_invocations = list()
|
||||||
for cmd in cmds:
|
for cmd in cmds:
|
||||||
@@ -272,25 +205,9 @@ def invoke_cli():
|
|||||||
args[field_name] = field_default
|
args[field_name] = field_default
|
||||||
|
|
||||||
# Parse invocation
|
# Parse invocation
|
||||||
command: CliCommand = None # type:ignore
|
|
||||||
system_graph: LibraryGraph|None = None
|
|
||||||
if args['type'] in system_graph_names:
|
|
||||||
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
|
||||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
|
||||||
for exposed_input in system_graph.exposed_inputs:
|
|
||||||
if exposed_input.alias in args:
|
|
||||||
node = invocation.graph.get_node(exposed_input.node_path)
|
|
||||||
field = exposed_input.field
|
|
||||||
setattr(node, field, args[exposed_input.alias])
|
|
||||||
command = CliCommand(command = invocation)
|
|
||||||
context.graph_nodes[invocation.id] = system_graph.id
|
|
||||||
else:
|
|
||||||
args["id"] = current_id
|
args["id"] = current_id
|
||||||
command = CliCommand(command=args)
|
command = CliCommand(command=args)
|
||||||
|
|
||||||
if command is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Run any CLI commands immediately
|
# Run any CLI commands immediately
|
||||||
if isinstance(command.command, BaseCommand):
|
if isinstance(command.command, BaseCommand):
|
||||||
# Invoke all current nodes to preserve operation order
|
# Invoke all current nodes to preserve operation order
|
||||||
@@ -300,7 +217,6 @@ def invoke_cli():
|
|||||||
command.command.run(context)
|
command.command.run(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# TODO: handle linking with library graphs
|
|
||||||
# Pipe previous command output (if there was a previous command)
|
# Pipe previous command output (if there was a previous command)
|
||||||
edges: list[Edge] = list()
|
edges: list[Edge] = list()
|
||||||
if len(history) > 0 or current_id != start_id:
|
if len(history) > 0 or current_id != start_id:
|
||||||
@@ -313,7 +229,7 @@ def invoke_cli():
|
|||||||
else context.session.graph.get_node(from_id)
|
else context.session.graph.get_node(from_id)
|
||||||
)
|
)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
from_node, command.command, context
|
from_node, command.command
|
||||||
)
|
)
|
||||||
edges.extend(matching_edges)
|
edges.extend(matching_edges)
|
||||||
|
|
||||||
@@ -326,7 +242,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
link_node = context.session.graph.get_node(node_id)
|
link_node = context.session.graph.get_node(node_id)
|
||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
link_node, command.command, context
|
link_node, command.command
|
||||||
)
|
)
|
||||||
matching_destinations = [e.destination for e in matching_edges]
|
matching_destinations = [e.destination for e in matching_edges]
|
||||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||||
@@ -340,14 +256,12 @@ def invoke_cli():
|
|||||||
if re_negid.match(node_id):
|
if re_negid.match(node_id):
|
||||||
node_id = str(current_id + int(node_id))
|
node_id = str(current_id + int(node_id))
|
||||||
|
|
||||||
# TODO: handle missing input/output
|
|
||||||
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
|
|
||||||
node_input = get_node_inputs(command.command, context)[link[2]]
|
|
||||||
|
|
||||||
edges.append(
|
edges.append(
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
source=EdgeConnection(node_id=node_id, field=link[1]),
|
||||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
destination=EdgeConnection(
|
||||||
|
node_id=command.command.id, field=link[2]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -356,22 +270,22 @@ def invoke_cli():
|
|||||||
current_id = current_id + 1
|
current_id = current_id + 1
|
||||||
|
|
||||||
# Add the node to the session
|
# Add the node to the session
|
||||||
context.add_node(command.command)
|
context.session.add_node(command.command)
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
print(edge)
|
print(edge)
|
||||||
context.add_edge(edge)
|
context.session.add_edge(edge)
|
||||||
|
|
||||||
# Execute all remaining nodes
|
# Execute all remaining nodes
|
||||||
invoke_all(context)
|
invoke_all(context)
|
||||||
|
|
||||||
except InvalidArgs:
|
except InvalidArgs:
|
||||||
invoker.services.logger.warning('Invalid command, use "help" to list commands')
|
print('Invalid command, use "help" to list commands')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except SessionError:
|
except SessionError:
|
||||||
# Start a new session
|
# Start a new session
|
||||||
invoker.services.logger.warning("Session error: creating a new session")
|
print("Session error: creating a new session")
|
||||||
context.reset()
|
context.session = context.invoker.create_execution_state()
|
||||||
|
|
||||||
except ExitCli:
|
except ExitCli:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
from typing import get_args, get_type_hints
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -76,56 +76,3 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
#fmt: off
|
#fmt: off
|
||||||
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
id: str = Field(description="The id of this node. Must be unique among all nodes.")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
# TODO: figure out a better way to provide these hints
|
|
||||||
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
|
|
||||||
class UIConfig(TypedDict, total=False):
|
|
||||||
type_hints: Dict[
|
|
||||||
str,
|
|
||||||
Literal[
|
|
||||||
"integer",
|
|
||||||
"float",
|
|
||||||
"boolean",
|
|
||||||
"string",
|
|
||||||
"enum",
|
|
||||||
"image",
|
|
||||||
"latents",
|
|
||||||
"model",
|
|
||||||
],
|
|
||||||
]
|
|
||||||
tags: List[str]
|
|
||||||
title: str
|
|
||||||
|
|
||||||
class CustomisedSchemaExtra(TypedDict):
|
|
||||||
ui: UIConfig
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationConfig(BaseModel.Config):
|
|
||||||
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
|
|
||||||
|
|
||||||
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
|
|
||||||
|
|
||||||
`tags`
|
|
||||||
- A list of strings, used to categorise invocations.
|
|
||||||
|
|
||||||
`type_hints`
|
|
||||||
- A dict of field types which override the types in the invocation definition.
|
|
||||||
- Each key should be the name of one of the invocation's fields.
|
|
||||||
- Each value should be one of the valid types:
|
|
||||||
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
|
|
||||||
|
|
||||||
```python
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["stable-diffusion", "image"],
|
|
||||||
"type_hints": {
|
|
||||||
"initial_image": "image",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
schema_extra: CustomisedSchemaExtra
|
|
||||||
|
|||||||
@@ -1,17 +1,16 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
|
import cv2 as cv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
from PIL import Image, ImageOps
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from .baseinvocation import (
|
from ..services.image_storage import ImageType
|
||||||
BaseInvocation,
|
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
|
||||||
InvocationConfig,
|
from .image import ImageField, ImageOutput
|
||||||
InvocationContext,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IntCollectionOutput(BaseInvocationOutput):
|
class IntCollectionOutput(BaseInvocationOutput):
|
||||||
@@ -34,9 +33,7 @@ class RangeInvocation(BaseInvocation):
|
|||||||
step: int = Field(default=1, description="The step of the range")
|
step: int = Field(default=1, description="The step of the range")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(
|
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||||
collection=list(range(self.start, self.stop, self.step))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
@@ -46,19 +43,8 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
low: int = Field(default=0, description="The inclusive low value")
|
low: int = Field(default=0, description="The inclusive low value")
|
||||||
high: int = Field(
|
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||||
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
|
||||||
)
|
|
||||||
size: int = Field(default=1, description="The number of values to generate")
|
size: int = Field(default=1, description="The number of values to generate")
|
||||||
seed: Optional[int] = Field(
|
|
||||||
ge=0,
|
|
||||||
le=np.iinfo(np.int32).max,
|
|
||||||
description="The seed for the RNG",
|
|
||||||
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
|
|
||||||
)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
rng = np.random.default_rng(self.seed)
|
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
|
||||||
return IntCollectionOutput(
|
|
||||||
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,245 +0,0 @@
|
|||||||
from typing import Literal, Optional, Union
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
|
|
||||||
|
|
||||||
from compel import Compel
|
|
||||||
from compel.prompt_parser import (
|
|
||||||
Blend,
|
|
||||||
CrossAttentionControlSubstitute,
|
|
||||||
FlattenedPrompt,
|
|
||||||
Fragment,
|
|
||||||
)
|
|
||||||
|
|
||||||
from invokeai.backend.globals import Globals
|
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
|
||||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["conditioning_name"]}
|
|
||||||
|
|
||||||
|
|
||||||
class CompelOutput(BaseInvocationOutput):
|
|
||||||
"""Compel parser output"""
|
|
||||||
|
|
||||||
#fmt: off
|
|
||||||
type: Literal["compel_output"] = "compel_output"
|
|
||||||
|
|
||||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
|
||||||
#fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class CompelInvocation(BaseInvocation):
|
|
||||||
"""Parse prompt using compel package to conditioning."""
|
|
||||||
|
|
||||||
type: Literal["compel"] = "compel"
|
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
|
||||||
model: str = Field(default="", description="Model to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Prompt (Compel)",
|
|
||||||
"tags": ["prompt", "compel"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
|
||||||
|
|
||||||
# TODO: load without model
|
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
|
||||||
pipeline = model["model"]
|
|
||||||
tokenizer = pipeline.tokenizer
|
|
||||||
text_encoder = pipeline.text_encoder
|
|
||||||
|
|
||||||
# TODO: global? input?
|
|
||||||
#use_full_precision = precision == "float32" or precision == "autocast"
|
|
||||||
#use_full_precision = False
|
|
||||||
|
|
||||||
# TODO: redo TI when separate model loding implemented
|
|
||||||
#textual_inversion_manager = TextualInversionManager(
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# text_encoder=text_encoder,
|
|
||||||
# full_precision=use_full_precision,
|
|
||||||
#)
|
|
||||||
|
|
||||||
def load_huggingface_concepts(concepts: list[str]):
|
|
||||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
|
||||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
|
||||||
self.prompt,
|
|
||||||
lambda concepts: load_huggingface_concepts(concepts),
|
|
||||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# lazy-load any deferred textual inversions.
|
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
|
||||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
|
||||||
prompt_str
|
|
||||||
)
|
|
||||||
|
|
||||||
compel = Compel(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
|
||||||
dtype_for_device_getter=torch_dtype,
|
|
||||||
truncate_long_prompts=True, # TODO:
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: support legacy blend?
|
|
||||||
|
|
||||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
|
||||||
|
|
||||||
if getattr(Globals, "log_tokenization", False):
|
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
|
||||||
|
|
||||||
# TODO: long prompt support
|
|
||||||
#if not self.truncate_long_prompts:
|
|
||||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
|
||||||
context.services.latents.set(conditioning_name, (c, ec))
|
|
||||||
|
|
||||||
return CompelOutput(
|
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
|
|
||||||
) -> int:
|
|
||||||
if type(prompt) is Blend:
|
|
||||||
blend: Blend = prompt
|
|
||||||
return max(
|
|
||||||
[
|
|
||||||
get_max_token_count(tokenizer, c, truncate_if_too_long)
|
|
||||||
for c in blend.prompts
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return len(
|
|
||||||
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_for_prompt_object(
|
|
||||||
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
|
||||||
) -> [str]:
|
|
||||||
if type(parsed_prompt) is Blend:
|
|
||||||
raise ValueError(
|
|
||||||
"Blend is not supported here - you need to get tokens for each of its .children"
|
|
||||||
)
|
|
||||||
|
|
||||||
text_fragments = [
|
|
||||||
x.text
|
|
||||||
if type(x) is Fragment
|
|
||||||
else (
|
|
||||||
" ".join([f.text for f in x.original])
|
|
||||||
if type(x) is CrossAttentionControlSubstitute
|
|
||||||
else str(x)
|
|
||||||
)
|
|
||||||
for x in parsed_prompt.children
|
|
||||||
]
|
|
||||||
text = " ".join(text_fragments)
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
if truncate_if_too_long:
|
|
||||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
|
||||||
tokens = tokens[0:max_tokens_length]
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_prompt_object(
|
|
||||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
|
||||||
):
|
|
||||||
display_label_prefix = display_label_prefix or ""
|
|
||||||
if type(p) is Blend:
|
|
||||||
blend: Blend = p
|
|
||||||
for i, c in enumerate(blend.prompts):
|
|
||||||
log_tokenization_for_prompt_object(
|
|
||||||
c,
|
|
||||||
tokenizer,
|
|
||||||
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
|
|
||||||
)
|
|
||||||
elif type(p) is FlattenedPrompt:
|
|
||||||
flattened_prompt: FlattenedPrompt = p
|
|
||||||
if flattened_prompt.wants_cross_attention_control:
|
|
||||||
original_fragments = []
|
|
||||||
edited_fragments = []
|
|
||||||
for f in flattened_prompt.children:
|
|
||||||
if type(f) is CrossAttentionControlSubstitute:
|
|
||||||
original_fragments += f.original
|
|
||||||
edited_fragments += f.edited
|
|
||||||
else:
|
|
||||||
original_fragments.append(f)
|
|
||||||
edited_fragments.append(f)
|
|
||||||
|
|
||||||
original_text = " ".join([x.text for x in original_fragments])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
original_text,
|
|
||||||
tokenizer,
|
|
||||||
display_label=f"{display_label_prefix}(.swap originals)",
|
|
||||||
)
|
|
||||||
edited_text = " ".join([x.text for x in edited_fragments])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
edited_text,
|
|
||||||
tokenizer,
|
|
||||||
display_label=f"{display_label_prefix}(.swap replacements)",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text = " ".join([x.text for x in flattened_prompt.children])
|
|
||||||
log_tokenization_for_text(
|
|
||||||
text, tokenizer, display_label=display_label_prefix
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
|
||||||
"""shows how the prompt is tokenized
|
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
|
||||||
# but for readability it has been replaced with ' '
|
|
||||||
"""
|
|
||||||
tokens = tokenizer.tokenize(text)
|
|
||||||
tokenized = ""
|
|
||||||
discarded = ""
|
|
||||||
usedTokens = 0
|
|
||||||
totalTokens = len(tokens)
|
|
||||||
|
|
||||||
for i in range(0, totalTokens):
|
|
||||||
token = tokens[i].replace("</w>", " ")
|
|
||||||
# alternate color
|
|
||||||
s = (usedTokens % 6) + 1
|
|
||||||
if truncate_if_too_long and i >= tokenizer.model_max_length:
|
|
||||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
|
||||||
else:
|
|
||||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
|
||||||
usedTokens += 1
|
|
||||||
|
|
||||||
if usedTokens > 0:
|
|
||||||
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
|
||||||
print(f"{tokenized}\x1b[0m")
|
|
||||||
|
|
||||||
if discarded != "":
|
|
||||||
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
|
||||||
print(f"{discarded}\x1b[0m")
|
|
||||||
@@ -5,26 +5,14 @@ from typing import Literal
|
|||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageField, ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class CvInvocationConfig(BaseModel):
|
class CvInpaintInvocation(BaseInvocation):
|
||||||
"""Helper class to provide all OpenCV invocations with additional config"""
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["cv", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||||
@@ -56,14 +44,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, image_inpainted)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=image_inpainted,
|
|
||||||
)
|
)
|
||||||
@@ -6,36 +6,21 @@ from typing import Literal, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from ..services.image_storage import ImageType
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .image import ImageField, ImageOutput
|
||||||
from .image import ImageOutput, build_image_output
|
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ..util.step_callback import stable_diffusion_step_callback
|
from ..util.util import diffusers_step_callback_adapter, CanceledException
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
|
||||||
|
|
||||||
|
|
||||||
class SDImageInvocation(BaseModel):
|
|
||||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["stable-diffusion", "image"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
|
tuple(InvokeAIGenerator.schedulers())
|
||||||
|
]
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
class TextToImageInvocation(BaseInvocation):
|
||||||
"""Generates an image using text2img."""
|
"""Generates an image using text2img."""
|
||||||
|
|
||||||
type: Literal["txt2img"] = "txt2img"
|
type: Literal["txt2img"] = "txt2img"
|
||||||
@@ -46,10 +31,10 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
@@ -57,31 +42,35 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self,
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
stable_diffusion_step_callback(
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
context=context,
|
raise CanceledException
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
step = intermediate_state.step
|
||||||
source_node_id=source_node_id,
|
if intermediate_state.predicted_original is not None:
|
||||||
)
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
# def step_callback(state: PipelineIntermediateState):
|
||||||
|
# if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
|
# raise CanceledException
|
||||||
|
# self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
|
# TODO: How to get the default model name now?
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# (right now uses whatever current model is set in model manager)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
model= context.services.model_manager.get_model()
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
outputs = Txt2Img(model).generate(
|
outputs = Txt2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
@@ -97,18 +86,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, generate_output.image)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(
|
|
||||||
image_type, image_name, generate_output.image, metadata
|
|
||||||
)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=generate_output.image,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -128,17 +108,20 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self,
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
stable_diffusion_step_callback(
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
context=context,
|
raise CanceledException
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
step = intermediate_state.step
|
||||||
source_node_id=source_node_id,
|
if intermediate_state.predicted_original is not None:
|
||||||
)
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
@@ -150,23 +133,15 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
)
|
)
|
||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
if self.fit:
|
|
||||||
image = image.resize((self.width, self.height))
|
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
|
# TODO: How to get the default model name now?
|
||||||
# Get the source node id (we are invoking the prepared node)
|
model = context.services.model_manager.get_model()
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
outputs = Img2Img(model).generate(
|
outputs = Img2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_image=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
@@ -185,19 +160,11 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, result_image)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=result_image,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
class InpaintInvocation(ImageToImageInvocation):
|
||||||
"""Generates an image using inpaint."""
|
"""Generates an image using inpaint."""
|
||||||
|
|
||||||
@@ -213,17 +180,20 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self,
|
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
stable_diffusion_step_callback(
|
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
|
||||||
context=context,
|
raise CanceledException
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
step = intermediate_state.step
|
||||||
source_node_id=source_node_id,
|
if intermediate_state.predicted_original is not None:
|
||||||
)
|
# Some schedulers report not only the noisy latents at the current timestep,
|
||||||
|
# but also their estimate so far of what the de-noised latents will be.
|
||||||
|
sample = intermediate_state.predicted_original
|
||||||
|
else:
|
||||||
|
sample = intermediate_state.latents
|
||||||
|
|
||||||
|
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = (
|
image = (
|
||||||
@@ -240,19 +210,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
model = choose_model(context.services.model_manager, self.model)
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
|
# TODO: How to get the default model name now?
|
||||||
# Get the source node id (we are invoking the prepared node)
|
model = context.services.model_manager.get_model()
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
outputs = Inpaint(model).generate(
|
outputs = Inpaint(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_image=image,
|
init_img=image,
|
||||||
mask_image=mask,
|
init_mask=mask,
|
||||||
step_callback=partial(self.dispatch_progress, context, source_node_id),
|
step_callback=partial(self.dispatch_progress, context),
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt", "image", "mask"}
|
exclude={"prompt", "image", "mask"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
@@ -271,14 +236,7 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, result_image)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=result_image,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,97 +1,70 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageFilter, ImageOps
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..models.image import ImageField, ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import (
|
from ..services.invocation_services import InvocationServices
|
||||||
BaseInvocation,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
BaseInvocationOutput,
|
|
||||||
InvocationContext,
|
|
||||||
InvocationConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PILInvocationConfig(BaseModel):
|
class ImageField(BaseModel):
|
||||||
"""Helper class to provide all PIL invocations with additional config"""
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
image_type: str = Field(
|
||||||
schema_extra = {
|
default=ImageType.RESULT, description="The type of the image"
|
||||||
"ui": {
|
)
|
||||||
"tags": ["PIL", "image"],
|
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
image: ImageField = Field(default=None, description="The output image")
|
||||||
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
|
#fmt: on
|
||||||
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": ["type", "image", "width", "height", "mode"]
|
'required': [
|
||||||
}
|
'type',
|
||||||
|
'image',
|
||||||
|
|
||||||
def build_image_output(
|
|
||||||
image_type: ImageType, image_name: str, image: Image.Image
|
|
||||||
) -> ImageOutput:
|
|
||||||
"""Builds an ImageOutput and its ImageField"""
|
|
||||||
image_field = ImageField(
|
|
||||||
image_name=image_name,
|
|
||||||
image_type=image_type,
|
|
||||||
)
|
|
||||||
return ImageOutput(
|
|
||||||
image=image_field,
|
|
||||||
width=image.width,
|
|
||||||
height=image.height,
|
|
||||||
mode=image.mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MaskOutput(BaseInvocationOutput):
|
|
||||||
"""Base class for invocations that output a mask"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["mask"] = "mask"
|
|
||||||
mask: ImageField = Field(default=None, description="The output mask")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {
|
|
||||||
"required": [
|
|
||||||
"type",
|
|
||||||
"mask",
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class MaskOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for invocations that output a mask"""
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["mask"] = "mask"
|
||||||
|
mask: ImageField = Field(default=None, description="The output mask")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
schema_extra = {
|
||||||
|
'required': [
|
||||||
|
'type',
|
||||||
|
'mask',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: this isn't really necessary anymore
|
||||||
class LoadImageInvocation(BaseInvocation):
|
class LoadImageInvocation(BaseInvocation):
|
||||||
"""Load an image and provide it as output."""
|
"""Load an image from a filename and provide it as output."""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["load_image"] = "load_image"
|
type: Literal["load_image"] = "load_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image_type: ImageType = Field(description="The type of the image")
|
image_type: ImageType = Field(description="The type of the image")
|
||||||
image_name: str = Field(description="The name of the image")
|
image_name: str = Field(description="The name of the image")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
image = context.services.images.get(self.image_type, self.image_name)
|
|
||||||
|
|
||||||
return build_image_output(
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image_type=self.image_type,
|
return ImageOutput(
|
||||||
image_name=self.image_name,
|
image=ImageField(image_type=self.image_type, image_name=self.image_name)
|
||||||
image=image,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -112,17 +85,16 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: how to handle failure?
|
# TODO: how to handle failure?
|
||||||
|
|
||||||
return build_image_output(
|
return ImageOutput(
|
||||||
image_type=self.image.image_type,
|
image=ImageField(
|
||||||
image_name=self.image.image_name,
|
image_type=self.image.image_type, image_name=self.image.image_name
|
||||||
image=image,
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
class CropImageInvocation(BaseInvocation):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["crop"] = "crop"
|
type: Literal["crop"] = "crop"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -131,7 +103,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||||
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
@@ -147,23 +119,15 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, image_crop)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=image_crop,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
class PasteImageInvocation(BaseInvocation):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["paste"] = "paste"
|
type: Literal["paste"] = "paste"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -172,7 +136,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get(
|
base_image = context.services.images.get(
|
||||||
@@ -185,7 +149,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else ImageOps.invert(
|
else ImageOps.invert(
|
||||||
context.services.images.get(self.mask.image_type, self.mask.image_name)
|
services.images.get(self.mask.image_type, self.mask.image_name)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||||
@@ -205,29 +169,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, new_image)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=new_image,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
class MaskFromAlphaInvocation(BaseInvocation):
|
||||||
"""Extracts the alpha channel of an image as a mask."""
|
"""Extracts the alpha channel of an image as a mask."""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["tomask"] = "tomask"
|
type: Literal["tomask"] = "tomask"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
image: ImageField = Field(default=None, description="The image to create the mask from")
|
||||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
@@ -242,26 +198,21 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, image_mask)
|
||||||
metadata = context.services.metadata.build_metadata(
|
|
||||||
session_id=context.graph_execution_state_id, node=self
|
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image_mask, metadata)
|
|
||||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
||||||
|
|
||||||
|
|
||||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
class BlurInvocation(BaseInvocation):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
# fmt: off
|
#fmt: off
|
||||||
type: Literal["blur"] = "blur"
|
type: Literal["blur"] = "blur"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to blur")
|
image: ImageField = Field(default=None, description="The image to blur")
|
||||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
@@ -279,28 +230,22 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, blur_image)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, blur_image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type, image_name=image_name, image=blur_image
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
class LerpInvocation(BaseInvocation):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["lerp"] = "lerp"
|
type: Literal["lerp"] = "lerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to lerp")
|
image: ImageField = Field(default=None, description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
@@ -316,28 +261,22 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, lerp_image)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type, image_name=image_name, image=lerp_image
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
class InverseLerpInvocation(BaseInvocation):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["ilerp"] = "ilerp"
|
type: Literal["ilerp"] = "ilerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to lerp")
|
image: ImageField = Field(default=None, description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
@@ -358,12 +297,7 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, ilerp_image)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,26 +1,25 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import random
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from torch import Tensor
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
|
||||||
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
||||||
|
|
||||||
from ...backend.model_management.model_manager import ModelManager
|
from ...backend.model_management.model_manager import ModelManager
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import CUDA_DEVICE, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from accelerate.utils import set_seed
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput, build_image_output
|
from .image import ImageField, ImageOutput
|
||||||
from .compel import ConditioningField
|
from ...backend.generator import Generator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
from ...backend.util.util import image_to_dataURL
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
import diffusers
|
import diffusers
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
@@ -31,8 +30,6 @@ class LatentsField(BaseModel):
|
|||||||
|
|
||||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["latents_name"]}
|
|
||||||
|
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
@@ -102,31 +99,18 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def random_seed():
|
|
||||||
return random.randint(0, np.iinfo(np.uint32).max)
|
|
||||||
|
|
||||||
|
|
||||||
class NoiseInvocation(BaseInvocation):
|
class NoiseInvocation(BaseInvocation):
|
||||||
"""Generates latent noise."""
|
"""Generates latent noise."""
|
||||||
|
|
||||||
type: Literal["noise"] = "noise"
|
type: Literal["noise"] = "noise"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
|
seed: int = Field(default=0, ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", )
|
||||||
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting noise", )
|
||||||
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting noise", )
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents", "noise"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(CUDA_DEVICE)
|
||||||
noise = get_noise(self.width, self.height, device, self.seed)
|
noise = get_noise(self.width, self.height, device, self.seed)
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
@@ -138,54 +122,60 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class TextToLatentsInvocation(BaseInvocation):
|
class TextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from a prompt."""
|
||||||
|
|
||||||
type: Literal["t2l"] = "t2l"
|
type: Literal["t2l"] = "t2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
|
# TODO: consider making prompt optional to enable providing prompt through a link
|
||||||
# fmt: off
|
# fmt: off
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
|
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
|
||||||
|
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
|
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents", "image"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
self, context: InvocationContext, sample: Tensor, step: int
|
||||||
) -> None:
|
) -> None:
|
||||||
stable_diffusion_step_callback(
|
# TODO: only output a preview image when requested
|
||||||
context=context,
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
(width, height) = image.size
|
||||||
source_node_id=source_node_id,
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
context.services.events.emit_generator_progress(
|
||||||
|
context.graph_execution_state_id,
|
||||||
|
self.id,
|
||||||
|
{
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"dataURL": dataURL
|
||||||
|
},
|
||||||
|
step,
|
||||||
|
self.steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||||
model_info = choose_model(model_manager, self.model)
|
model_info = model_manager.get_model(self.model)
|
||||||
model_name = model_info['model_name']
|
model_name = model_info['model_name']
|
||||||
model_hash = model_info['hash']
|
model_hash = model_info['hash']
|
||||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
model.scheduler = get_scheduler(
|
model.scheduler = get_scheduler(
|
||||||
model=model,
|
model=model,
|
||||||
scheduler_name=self.scheduler
|
scheduler_name=self.sampler_name
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
@@ -203,10 +193,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
uc,
|
uc,
|
||||||
c,
|
c,
|
||||||
@@ -225,15 +213,11 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(context, model)
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
@@ -260,17 +244,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
type: Literal["l2l"] = "l2l"
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
strength: float = Field(default=0.5, description="The strength of the latents to use")
|
||||||
@@ -279,12 +252,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, state.latents, state.step)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(model)
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
@@ -330,23 +299,12 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||||
model: str = Field(default="", description="The model to use")
|
model: str = Field(default="", description="The model to use")
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents", "image"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO: this only really needs the vae
|
# TODO: this only really needs the vae
|
||||||
model_info = choose_model(context.services.model_manager, self.model)
|
model_info = context.services.model_manager.get_model(self.model)
|
||||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
@@ -357,79 +315,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, image)
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type, image_name=image_name, image=image
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal[
|
|
||||||
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
|
||||||
|
|
||||||
type: Literal["lresize"] = "lresize"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
|
||||||
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
|
||||||
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
|
||||||
latents,
|
|
||||||
size=(self.height // 8, self.width // 8),
|
|
||||||
mode=self.mode,
|
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.set(name, resized_latents)
|
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
|
||||||
"""Scales latents by a given factor."""
|
|
||||||
|
|
||||||
type: Literal["lscale"] = "lscale"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
|
||||||
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
|
||||||
|
|
||||||
# resizing
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
|
||||||
latents,
|
|
||||||
scale_factor=self.scale_factor,
|
|
||||||
mode=self.mode,
|
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.set(name, resized_latents)
|
|
||||||
return LatentsOutput(latents=LatentsField(latents_name=name))
|
|
||||||
|
|||||||
@@ -1,22 +1,15 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Literal
|
from datetime import datetime, timezone
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from PIL import Image, ImageFilter, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from ..services.image_storage import ImageType
|
||||||
|
from ..services.invocation_services import InvocationServices
|
||||||
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
class MathInvocationConfig(BaseModel):
|
|
||||||
"""Helper class to provide all math invocations with additional config"""
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["math"],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class IntOutput(BaseInvocationOutput):
|
class IntOutput(BaseInvocationOutput):
|
||||||
@@ -27,7 +20,7 @@ class IntOutput(BaseInvocationOutput):
|
|||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class AddInvocation(BaseInvocation, MathInvocationConfig):
|
class AddInvocation(BaseInvocation):
|
||||||
"""Adds two numbers"""
|
"""Adds two numbers"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["add"] = "add"
|
type: Literal["add"] = "add"
|
||||||
@@ -39,7 +32,7 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
return IntOutput(a=self.a + self.b)
|
return IntOutput(a=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
class SubtractInvocation(BaseInvocation):
|
||||||
"""Subtracts two numbers"""
|
"""Subtracts two numbers"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["sub"] = "sub"
|
type: Literal["sub"] = "sub"
|
||||||
@@ -51,7 +44,7 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
return IntOutput(a=self.a - self.b)
|
return IntOutput(a=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
class MultiplyInvocation(BaseInvocation):
|
||||||
"""Multiplies two numbers"""
|
"""Multiplies two numbers"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["mul"] = "mul"
|
type: Literal["mul"] = "mul"
|
||||||
@@ -63,7 +56,7 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
return IntOutput(a=self.a * self.b)
|
return IntOutput(a=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
class DivideInvocation(BaseInvocation):
|
||||||
"""Divides two numbers"""
|
"""Divides two numbers"""
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["div"] = "div"
|
type: Literal["div"] = "div"
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
from pydantic import Field
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
|
||||||
from .math import IntOutput
|
|
||||||
|
|
||||||
# Pass-through parameter nodes - used by subgraphs
|
|
||||||
|
|
||||||
class ParamIntInvocation(BaseInvocation):
|
|
||||||
"""An integer parameter"""
|
|
||||||
#fmt: off
|
|
||||||
type: Literal["param_int"] = "param_int"
|
|
||||||
a: int = Field(default=0, description="The integer value")
|
|
||||||
#fmt: on
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
|
||||||
return IntOutput(a=self.a)
|
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from ..services.image_storage import ImageType
|
||||||
|
from ..services.invocation_services import InvocationServices
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageField, ImageOutput
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
class RestoreFaceInvocation(BaseInvocation):
|
||||||
"""Restores faces in an image."""
|
"""Restores faces in an image."""
|
||||||
@@ -17,14 +18,6 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["restoration", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
@@ -43,14 +36,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, results[0][0])
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=results[0][0]
|
|
||||||
)
|
)
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Literal, Union
|
from typing import Literal, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from ..services.invocation_services import InvocationServices
|
||||||
from .image import ImageOutput, build_image_output
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
|
from .image import ImageField, ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
class UpscaleInvocation(BaseInvocation):
|
||||||
@@ -20,15 +22,6 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["upscaling", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
@@ -47,14 +40,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
|
context.services.images.save(image_type, image_name, results[0][0])
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=results[0][0]
|
|
||||||
)
|
)
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
from invokeai.backend.model_management.model_manager import ModelManager
|
|
||||||
|
|
||||||
|
|
||||||
def choose_model(model_manager: ModelManager, model_name: str):
|
|
||||||
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
|
|
||||||
logger = model_manager.logger
|
|
||||||
if model_manager.valid_model(model_name):
|
|
||||||
model = model_manager.get_model(model_name)
|
|
||||||
else:
|
|
||||||
model = model_manager.get_model()
|
|
||||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
|
||||||
|
|
||||||
return model
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
class CanceledException(Exception):
|
|
||||||
"""Execution canceled by user."""
|
|
||||||
pass
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class ImageType(str, Enum):
|
|
||||||
RESULT = "results"
|
|
||||||
INTERMEDIATE = "intermediates"
|
|
||||||
UPLOAD = "uploads"
|
|
||||||
|
|
||||||
|
|
||||||
def is_image_type(obj):
|
|
||||||
try:
|
|
||||||
ImageType(obj)
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
|
||||||
"""An image field used for passing image objects between invocations"""
|
|
||||||
|
|
||||||
image_type: ImageType = Field(
|
|
||||||
default=ImageType.RESULT, description="The type of the image"
|
|
||||||
)
|
|
||||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
schema_extra = {"required": ["image_type", "image_name"]}
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
|
||||||
from ..invocations.compel import CompelInvocation
|
|
||||||
from ..invocations.params import ParamIntInvocation
|
|
||||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
|
||||||
from .item_storage import ItemStorageABC
|
|
||||||
|
|
||||||
|
|
||||||
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
|
||||||
|
|
||||||
|
|
||||||
def create_text_to_image() -> LibraryGraph:
|
|
||||||
return LibraryGraph(
|
|
||||||
id=default_text_to_image_graph_id,
|
|
||||||
name='t2i',
|
|
||||||
description='Converts text to an image',
|
|
||||||
graph=Graph(
|
|
||||||
nodes={
|
|
||||||
'width': ParamIntInvocation(id='width', a=512),
|
|
||||||
'height': ParamIntInvocation(id='height', a=512),
|
|
||||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
|
||||||
'3': NoiseInvocation(id='3'),
|
|
||||||
'4': CompelInvocation(id='4'),
|
|
||||||
'5': CompelInvocation(id='5'),
|
|
||||||
'6': TextToLatentsInvocation(id='6'),
|
|
||||||
'7': LatentsToImageInvocation(id='7'),
|
|
||||||
},
|
|
||||||
edges=[
|
|
||||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
|
||||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
|
||||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
|
||||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
|
||||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
|
||||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
|
||||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
exposed_inputs=[
|
|
||||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
|
||||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
|
||||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
|
||||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
|
||||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
|
||||||
],
|
|
||||||
exposed_outputs=[
|
|
||||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
|
||||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
|
||||||
|
|
||||||
graphs: list[LibraryGraph] = list()
|
|
||||||
|
|
||||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
|
||||||
|
|
||||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
|
||||||
#if text_to_image is None:
|
|
||||||
text_to_image = create_text_to_image()
|
|
||||||
graph_library.set(text_to_image)
|
|
||||||
|
|
||||||
graphs.append(text_to_image)
|
|
||||||
|
|
||||||
return graphs
|
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Dict, TypedDict
|
||||||
from invokeai.app.api.models.images import ProgressImage
|
|
||||||
from invokeai.app.util.misc import get_timestamp
|
|
||||||
|
|
||||||
|
ProgressImage = TypedDict(
|
||||||
|
"ProgressImage", {"dataURL": str, "width": int, "height": int}
|
||||||
|
)
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
session_event: str = "session_event"
|
session_event: str = "session_event"
|
||||||
@@ -13,8 +14,7 @@ class EventServiceBase:
|
|||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
|
||||||
payload["timestamp"] = get_timestamp()
|
|
||||||
self.dispatch(
|
self.dispatch(
|
||||||
event_name=EventServiceBase.session_event,
|
event_name=EventServiceBase.session_event,
|
||||||
payload=dict(event=event_name, data=payload),
|
payload=dict(event=event_name, data=payload),
|
||||||
@@ -25,8 +25,7 @@ class EventServiceBase:
|
|||||||
def emit_generator_progress(
|
def emit_generator_progress(
|
||||||
self,
|
self,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
invocation_id: str,
|
||||||
source_node_id: str,
|
|
||||||
progress_image: ProgressImage | None,
|
progress_image: ProgressImage | None,
|
||||||
step: int,
|
step: int,
|
||||||
total_steps: int,
|
total_steps: int,
|
||||||
@@ -36,60 +35,48 @@ class EventServiceBase:
|
|||||||
event_name="generator_progress",
|
event_name="generator_progress",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
invocation_id=invocation_id,
|
||||||
source_node_id=source_node_id,
|
progress_image=progress_image,
|
||||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
|
||||||
step=step,
|
step=step,
|
||||||
total_steps=total_steps,
|
total_steps=total_steps,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_complete(
|
def emit_invocation_complete(
|
||||||
self,
|
self, graph_execution_state_id: str, invocation_id: str, result: Dict
|
||||||
graph_execution_state_id: str,
|
|
||||||
result: dict,
|
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="invocation_complete",
|
event_name="invocation_complete",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
invocation_id=invocation_id,
|
||||||
source_node_id=source_node_id,
|
|
||||||
result=result,
|
result=result,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_error(
|
def emit_invocation_error(
|
||||||
self,
|
self, graph_execution_state_id: str, invocation_id: str, error: str
|
||||||
graph_execution_state_id: str,
|
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
error: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has completed"""
|
"""Emitted when an invocation has completed"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="invocation_error",
|
event_name="invocation_error",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
invocation_id=invocation_id,
|
||||||
source_node_id=source_node_id,
|
|
||||||
error=error,
|
error=error,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(
|
def emit_invocation_started(
|
||||||
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
self, graph_execution_state_id: str, invocation_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="invocation_started",
|
event_name="invocation_started",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
invocation_id=invocation_id,
|
||||||
source_node_id=source_node_id,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -97,7 +84,5 @@ class EventServiceBase:
|
|||||||
"""Emitted when a session has completed all invocations"""
|
"""Emitted when a session has completed all invocations"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="graph_execution_state_complete",
|
event_name="graph_execution_state_complete",
|
||||||
payload=dict(
|
payload=dict(graph_execution_state_id=graph_execution_state_id),
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -16,7 +17,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, validator
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from ..invocations import *
|
from ..invocations import *
|
||||||
@@ -25,6 +26,7 @@ from ..invocations.baseinvocation import (
|
|||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
)
|
)
|
||||||
|
from .invocation_services import InvocationServices
|
||||||
|
|
||||||
|
|
||||||
class EdgeConnection(BaseModel):
|
class EdgeConnection(BaseModel):
|
||||||
@@ -213,7 +215,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
|
|||||||
|
|
||||||
|
|
||||||
class Graph(BaseModel):
|
class Graph(BaseModel):
|
||||||
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
|
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
|
||||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||||
description="The nodes in this graph", default_factory=dict
|
description="The nodes in this graph", default_factory=dict
|
||||||
@@ -281,8 +283,7 @@ class Graph(BaseModel):
|
|||||||
:raises InvalidEdgeError: the provided edge is invalid.
|
:raises InvalidEdgeError: the provided edge is invalid.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self._validate_edge(edge)
|
if self._is_edge_valid(edge) and edge not in self.edges:
|
||||||
if edge not in self.edges:
|
|
||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
else:
|
else:
|
||||||
raise InvalidEdgeError()
|
raise InvalidEdgeError()
|
||||||
@@ -353,7 +354,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _validate_edge(self, edge: Edge):
|
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||||
|
|
||||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||||
@@ -361,53 +362,54 @@ class Graph(BaseModel):
|
|||||||
from_node = self.get_node(edge.source.node_id)
|
from_node = self.get_node(edge.source.node_id)
|
||||||
to_node = self.get_node(edge.destination.node_id)
|
to_node = self.get_node(edge.destination.node_id)
|
||||||
except NodeNotFoundError:
|
except NodeNotFoundError:
|
||||||
raise InvalidEdgeError("One or both nodes don't exist")
|
return False
|
||||||
|
|
||||||
# Validate that an edge to this node+field doesn't already exist
|
# Validate that an edge to this node+field doesn't already exist
|
||||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||||
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
return False
|
||||||
|
|
||||||
# Validate that no cycles would be created
|
# Validate that no cycles would be created
|
||||||
g = self.nx_graph_flat()
|
g = self.nx_graph_flat()
|
||||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||||
if not nx.is_directed_acyclic_graph(g):
|
if not nx.is_directed_acyclic_graph(g):
|
||||||
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
|
return False
|
||||||
|
|
||||||
# Validate that the field types are compatible
|
# Validate that the field types are compatible
|
||||||
if not are_connections_compatible(
|
if not are_connections_compatible(
|
||||||
from_node, edge.source.field, to_node, edge.destination.field
|
from_node, edge.source.field, to_node, edge.destination.field
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Fields are incompatible')
|
return False
|
||||||
|
|
||||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(
|
||||||
edge.destination.node_id, new_input=edge.source
|
edge.destination.node_id, new_input=edge.source
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
|
return False
|
||||||
|
|
||||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(
|
||||||
edge.source.node_id, new_output=edge.destination
|
edge.source.node_id, new_output=edge.destination
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
|
return False
|
||||||
|
|
||||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(
|
||||||
edge.destination.node_id, new_input=edge.source
|
edge.destination.node_id, new_input=edge.source
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Collector output type does not match collector input type')
|
return False
|
||||||
|
|
||||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(
|
||||||
edge.source.node_id, new_output=edge.destination
|
edge.source.node_id, new_output=edge.destination
|
||||||
):
|
):
|
||||||
raise InvalidEdgeError(f'Collector input type does not match collector output type')
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def has_node(self, node_path: str) -> bool:
|
def has_node(self, node_path: str) -> bool:
|
||||||
"""Determines whether or not a node exists in the graph."""
|
"""Determines whether or not a node exists in the graph."""
|
||||||
@@ -731,7 +733,7 @@ class Graph(BaseModel):
|
|||||||
for sgn in (
|
for sgn in (
|
||||||
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||||
):
|
):
|
||||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||||
|
|
||||||
# TODO: figure out if iteration nodes need to be expanded
|
# TODO: figure out if iteration nodes need to be expanded
|
||||||
|
|
||||||
@@ -748,7 +750,9 @@ class Graph(BaseModel):
|
|||||||
class GraphExecutionState(BaseModel):
|
class GraphExecutionState(BaseModel):
|
||||||
"""Tracks the state of a graph execution"""
|
"""Tracks the state of a graph execution"""
|
||||||
|
|
||||||
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
|
id: str = Field(
|
||||||
|
description="The id of the execution state", default_factory=uuid.uuid4
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Store a reference to the graph instead of the actual graph?
|
# TODO: Store a reference to the graph instead of the actual graph?
|
||||||
graph: Graph = Field(description="The graph being executed")
|
graph: Graph = Field(description="The graph being executed")
|
||||||
@@ -790,6 +794,9 @@ class GraphExecutionState(BaseModel):
|
|||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Declare all fields as required; necessary for OpenAPI schema generation build.
|
||||||
|
# Technically only fields without a `default_factory` need to be listed here.
|
||||||
|
# See: https://github.com/pydantic/pydantic/discussions/4577
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
'required': [
|
'required': [
|
||||||
@@ -854,8 +861,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
def is_complete(self) -> bool:
|
def is_complete(self) -> bool:
|
||||||
"""Returns true if the graph is complete"""
|
"""Returns true if the graph is complete"""
|
||||||
node_ids = set(self.graph.nx_graph_flat().nodes)
|
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
|
||||||
return self.has_error() or all((k in self.executed for k in node_ids))
|
|
||||||
|
|
||||||
def has_error(self) -> bool:
|
def has_error(self) -> bool:
|
||||||
"""Returns true if the graph has any errors"""
|
"""Returns true if the graph has any errors"""
|
||||||
@@ -943,11 +949,11 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
def _iterator_graph(self) -> nx.DiGraph:
|
def _iterator_graph(self) -> nx.DiGraph:
|
||||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||||
g = self.graph.nx_graph_flat()
|
g = self.graph.nx_graph()
|
||||||
collectors = (
|
collectors = (
|
||||||
n
|
n
|
||||||
for n in self.graph.nodes
|
for n in self.graph.nodes
|
||||||
if isinstance(self.graph.get_node(n), CollectInvocation)
|
if isinstance(self.graph.nodes[n], CollectInvocation)
|
||||||
)
|
)
|
||||||
for c in collectors:
|
for c in collectors:
|
||||||
g.remove_edges_from(list(g.in_edges(c)))
|
g.remove_edges_from(list(g.in_edges(c)))
|
||||||
@@ -959,7 +965,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
iterators = [
|
iterators = [
|
||||||
n
|
n
|
||||||
for n in nx.ancestors(g, node_id)
|
for n in nx.ancestors(g, node_id)
|
||||||
if isinstance(self.graph.get_node(n), IterateInvocation)
|
if isinstance(self.graph.nodes[n], IterateInvocation)
|
||||||
]
|
]
|
||||||
return iterators
|
return iterators
|
||||||
|
|
||||||
@@ -1095,9 +1101,7 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||||
try:
|
if not self._is_edge_valid(edge):
|
||||||
self.graph._validate_edge(edge)
|
|
||||||
except InvalidEdgeError:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Invalid if destination has already been prepared or executed
|
# Invalid if destination has already been prepared or executed
|
||||||
@@ -1143,52 +1147,4 @@ class GraphExecutionState(BaseModel):
|
|||||||
self.graph.delete_edge(edge)
|
self.graph.delete_edge(edge)
|
||||||
|
|
||||||
|
|
||||||
class ExposedNodeInput(BaseModel):
|
|
||||||
node_path: str = Field(description="The node path to the node with the input")
|
|
||||||
field: str = Field(description="The field name of the input")
|
|
||||||
alias: str = Field(description="The alias of the input")
|
|
||||||
|
|
||||||
|
|
||||||
class ExposedNodeOutput(BaseModel):
|
|
||||||
node_path: str = Field(description="The node path to the node with the output")
|
|
||||||
field: str = Field(description="The field name of the output")
|
|
||||||
alias: str = Field(description="The alias of the output")
|
|
||||||
|
|
||||||
class LibraryGraph(BaseModel):
|
|
||||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
|
||||||
graph: Graph = Field(description="The graph")
|
|
||||||
name: str = Field(description="The name of the graph")
|
|
||||||
description: str = Field(description="The description of the graph")
|
|
||||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
|
||||||
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
|
||||||
|
|
||||||
@validator('exposed_inputs', 'exposed_outputs')
|
|
||||||
def validate_exposed_aliases(cls, v):
|
|
||||||
if len(v) != len(set(i.alias for i in v)):
|
|
||||||
raise ValueError("Duplicate exposed alias")
|
|
||||||
return v
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_exposed_nodes(cls, values):
|
|
||||||
graph = values['graph']
|
|
||||||
|
|
||||||
# Validate exposed inputs
|
|
||||||
for exposed_input in values['exposed_inputs']:
|
|
||||||
if not graph.has_node(exposed_input.node_path):
|
|
||||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
|
||||||
node = graph.get_node(exposed_input.node_path)
|
|
||||||
if get_input_field(node, exposed_input.field) is None:
|
|
||||||
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
|
||||||
|
|
||||||
# Validate exposed outputs
|
|
||||||
for exposed_output in values['exposed_outputs']:
|
|
||||||
if not graph.has_node(exposed_output.node_path):
|
|
||||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
|
||||||
node = graph.get_node(exposed_output.node_path)
|
|
||||||
if get_output_field(node, exposed_output.field) is None:
|
|
||||||
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
GraphInvocation.update_forward_refs()
|
GraphInvocation.update_forward_refs()
|
||||||
|
|||||||
@@ -1,29 +1,23 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import datetime
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, List
|
from typing import Dict
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
import PIL.Image as PILImage
|
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||||
from send2trash import send2trash
|
|
||||||
from invokeai.app.api.models.images import (
|
from invokeai.backend.image_util import PngWriter
|
||||||
ImageResponse,
|
|
||||||
ImageResponseMetadata,
|
|
||||||
SavedImage,
|
class ImageType(str, Enum):
|
||||||
)
|
RESULT = "results"
|
||||||
from invokeai.app.models.image import ImageType
|
INTERMEDIATE = "intermediates"
|
||||||
from invokeai.app.services.metadata import (
|
UPLOAD = "uploads"
|
||||||
InvokeAIMetadata,
|
|
||||||
MetadataServiceBase,
|
|
||||||
build_invokeai_metadata_pnginfo,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
|
||||||
from invokeai.app.util.misc import get_timestamp
|
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
|
||||||
|
|
||||||
|
|
||||||
class ImageStorageBase(ABC):
|
class ImageStorageBase(ABC):
|
||||||
@@ -31,74 +25,40 @@ class ImageStorageBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||||
"""Retrieves an image as PIL Image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list(
|
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
"""Gets a paginated list of images."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(
|
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets the internal path to an image or its thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
|
||||||
def get_uri(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets the external URI to an image or its thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates an image path."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(
|
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||||
self,
|
|
||||||
image_type: ImageType,
|
|
||||||
image_name: str,
|
|
||||||
image: Image,
|
|
||||||
metadata: InvokeAIMetadata | None = None,
|
|
||||||
) -> SavedImage:
|
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_name(self, context_id: str, node_id: str) -> str:
|
def create_name(self, context_id: str, node_id: str) -> str:
|
||||||
"""Creates a unique contextual image filename."""
|
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
|
||||||
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
|
||||||
|
|
||||||
|
|
||||||
class DiskImageStorage(ImageStorageBase):
|
class DiskImageStorage(ImageStorageBase):
|
||||||
"""Stores images on disk"""
|
"""Stores images on disk"""
|
||||||
|
|
||||||
__output_folder: str
|
__output_folder: str
|
||||||
|
__pngWriter: PngWriter
|
||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
__cache: Dict[str, Image]
|
__cache: Dict[str, Image]
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
__metadata_service: MetadataServiceBase
|
|
||||||
|
|
||||||
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
def __init__(self, output_folder: str):
|
||||||
self.__output_folder = output_folder
|
self.__output_folder = output_folder
|
||||||
|
self.__pngWriter = PngWriter(output_folder)
|
||||||
self.__cache = dict()
|
self.__cache = dict()
|
||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
self.__metadata_service = metadata_service
|
|
||||||
|
|
||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@@ -111,155 +71,43 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
parents=True, exist_ok=True
|
parents=True, exist_ok=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def list(
|
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
dir_path = os.path.join(self.__output_folder, image_type)
|
|
||||||
image_paths = glob(f"{dir_path}/*.png")
|
|
||||||
count = len(image_paths)
|
|
||||||
|
|
||||||
sorted_image_paths = sorted(
|
|
||||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
page_of_image_paths = sorted_image_paths[
|
|
||||||
page * per_page : (page + 1) * per_page
|
|
||||||
]
|
|
||||||
|
|
||||||
page_of_images: List[ImageResponse] = []
|
|
||||||
|
|
||||||
for path in page_of_image_paths:
|
|
||||||
filename = os.path.basename(path)
|
|
||||||
img = PILImage.open(path)
|
|
||||||
|
|
||||||
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
|
||||||
|
|
||||||
page_of_images.append(
|
|
||||||
ImageResponse(
|
|
||||||
image_type=image_type.value,
|
|
||||||
image_name=filename,
|
|
||||||
# TODO: DiskImageStorage should not be building URLs...?
|
|
||||||
image_url=self.get_uri(image_type, filename),
|
|
||||||
thumbnail_url=self.get_uri(image_type, filename, True),
|
|
||||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
|
||||||
metadata=ImageResponseMetadata(
|
|
||||||
created=int(os.path.getctime(path)),
|
|
||||||
width=img.width,
|
|
||||||
height=img.height,
|
|
||||||
invokeai=invokeai_metadata,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
page_count_trunc = int(count / per_page)
|
|
||||||
page_count_mod = count % per_page
|
|
||||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
|
||||||
|
|
||||||
return PaginatedResults[ImageResponse](
|
|
||||||
items=page_of_images,
|
|
||||||
page=page,
|
|
||||||
pages=page_count,
|
|
||||||
per_page=per_page,
|
|
||||||
total=count,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
def get(self, image_type: ImageType, image_name: str) -> Image:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
cache_item = self.__get_cache(image_path)
|
cache_item = self.__get_cache(image_path)
|
||||||
if cache_item:
|
if cache_item:
|
||||||
return cache_item
|
return cache_item
|
||||||
|
|
||||||
image = PILImage.open(image_path)
|
image = Image.open(image_path)
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(
|
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||||
) -> str:
|
return path
|
||||||
# strip out any relative path shenanigans
|
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
|
|
||||||
if is_thumbnail:
|
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||||
path = os.path.join(
|
image_subpath = os.path.join(image_type, image_name)
|
||||||
self.__output_folder, image_type, "thumbnails", basename
|
self.__pngWriter.save_image_and_prompt_to_png(
|
||||||
|
image, "", image_subpath, None
|
||||||
|
) # TODO: just pass full path to png writer
|
||||||
|
save_thumbnail(
|
||||||
|
image=image,
|
||||||
|
filename=image_name,
|
||||||
|
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
path = os.path.join(self.__output_folder, image_type, basename)
|
|
||||||
|
|
||||||
abspath = os.path.abspath(path)
|
|
||||||
|
|
||||||
return abspath
|
|
||||||
|
|
||||||
def get_uri(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
# strip out any relative path shenanigans
|
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
|
|
||||||
if is_thumbnail:
|
|
||||||
thumbnail_basename = get_thumbnail_name(basename)
|
|
||||||
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
|
|
||||||
else:
|
|
||||||
uri = f"api/v1/images/{image_type.value}/{basename}"
|
|
||||||
|
|
||||||
return uri
|
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
try:
|
|
||||||
os.stat(path)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image_type: ImageType,
|
|
||||||
image_name: str,
|
|
||||||
image: Image,
|
|
||||||
metadata: InvokeAIMetadata | None = None,
|
|
||||||
) -> SavedImage:
|
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
|
||||||
# TODO: Reading the image and then saving it strips the metadata...
|
|
||||||
if metadata:
|
|
||||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
|
||||||
else:
|
|
||||||
image.save(image_path) # this saved image has an empty info
|
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
|
||||||
thumbnail_image = make_thumbnail(image)
|
|
||||||
thumbnail_image.save(thumbnail_path)
|
|
||||||
|
|
||||||
self.__set_cache(image_path, image)
|
self.__set_cache(image_path, image)
|
||||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
|
||||||
|
|
||||||
return SavedImage(
|
|
||||||
image_name=image_name,
|
|
||||||
thumbnail_name=thumbnail_name,
|
|
||||||
created=int(os.path.getctime(image_path)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
basename = os.path.basename(image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
image_path = self.get_path(image_type, basename)
|
|
||||||
|
|
||||||
if os.path.exists(image_path):
|
if os.path.exists(image_path):
|
||||||
send2trash(image_path)
|
os.remove(image_path)
|
||||||
|
|
||||||
if image_path in self.__cache:
|
if image_path in self.__cache:
|
||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
def __get_cache(self, image_name: str) -> Image:
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
|
||||||
|
|
||||||
if os.path.exists(thumbnail_path):
|
|
||||||
send2trash(thumbnail_path)
|
|
||||||
if thumbnail_path in self.__cache:
|
|
||||||
del self.__cache[thumbnail_path]
|
|
||||||
|
|
||||||
def __get_cache(self, image_name: str) -> Image | None:
|
|
||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
def __set_cache(self, image_name: str, image: Image):
|
def __set_cache(self, image_name: str, image: Image):
|
||||||
|
|||||||
@@ -1,17 +1,30 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
import time
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueItem(BaseModel):
|
# TODO: make this serializable
|
||||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
class InvocationQueueItem:
|
||||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
# session_id: str
|
||||||
invoke_all: bool = Field(default=False)
|
graph_execution_state_id: str
|
||||||
timestamp: float = Field(default_factory=time.time)
|
invocation_id: str
|
||||||
|
invoke_all: bool
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
# session_id: str,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
invocation_id: str,
|
||||||
|
invoke_all: bool = False,
|
||||||
|
):
|
||||||
|
# self.session_id = session_id
|
||||||
|
self.graph_execution_state_id = graph_execution_state_id
|
||||||
|
self.invocation_id = invocation_id
|
||||||
|
self.invoke_all = invoke_all
|
||||||
|
self.timestamp = time.time()
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueABC(ABC):
|
class InvocationQueueABC(ABC):
|
||||||
|
|||||||
@@ -1,7 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import types
|
|
||||||
from invokeai.app.services.metadata import MetadataServiceBase
|
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.backend import ModelManager
|
||||||
|
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
@@ -17,13 +14,11 @@ class InvocationServices:
|
|||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
latents: LatentsStorageBase
|
latents: LatentsStorageBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
metadata: MetadataServiceBase
|
|
||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
model_manager: ModelManager
|
model_manager: ModelManager
|
||||||
restoration: RestorationServices
|
restoration: RestorationServices
|
||||||
|
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
graph_library: ItemStorageABC["LibraryGraph"]
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
|
||||||
@@ -31,24 +26,18 @@ class InvocationServices:
|
|||||||
self,
|
self,
|
||||||
model_manager: ModelManager,
|
model_manager: ModelManager,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
logger: types.ModuleType,
|
|
||||||
latents: LatentsStorageBase,
|
latents: LatentsStorageBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
metadata: MetadataServiceBase,
|
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
graph_library: ItemStorageABC["LibraryGraph"],
|
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
restoration: RestorationServices,
|
restoration: RestorationServices,
|
||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
self.logger = logger
|
|
||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.images = images
|
self.images = images
|
||||||
self.metadata = metadata
|
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.graph_library = graph_library
|
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
self.restoration = restoration
|
self.restoration = restoration
|
||||||
|
|||||||
@@ -71,12 +71,18 @@ class Invoker:
|
|||||||
for service in vars(self.services):
|
for service in vars(self.services):
|
||||||
self.__start_service(getattr(self.services, service))
|
self.__start_service(getattr(self.services, service))
|
||||||
|
|
||||||
|
for service in vars(self.services):
|
||||||
|
self.__start_service(getattr(self.services, service))
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||||
# First stop all services
|
# First stop all services
|
||||||
for service in vars(self.services):
|
for service in vars(self.services):
|
||||||
self.__stop_service(getattr(self.services, service))
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
|
for service in vars(self.services):
|
||||||
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
self.services.queue.put(None)
|
self.services.queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
import json
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Dict, Optional, TypedDict
|
|
||||||
from PIL import Image, PngImagePlugin
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType, is_image_type
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataImageField(TypedDict):
|
|
||||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
|
||||||
|
|
||||||
image_type: ImageType
|
|
||||||
image_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataLatentsField(TypedDict):
|
|
||||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
|
||||||
|
|
||||||
latents_name: str
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
|
||||||
NodeMetadata = Dict[
|
|
||||||
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIMetadata(TypedDict, total=False):
|
|
||||||
"""InvokeAI-specific metadata format."""
|
|
||||||
|
|
||||||
session_id: Optional[str]
|
|
||||||
node: Optional[NodeMetadata]
|
|
||||||
|
|
||||||
|
|
||||||
def build_invokeai_metadata_pnginfo(
|
|
||||||
metadata: InvokeAIMetadata | None,
|
|
||||||
) -> PngImagePlugin.PngInfo:
|
|
||||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
|
||||||
|
|
||||||
if metadata is not None:
|
|
||||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
|
||||||
|
|
||||||
return pnginfo
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataServiceBase(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
|
||||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def build_metadata(
|
|
||||||
self, session_id: str, node: BaseModel
|
|
||||||
) -> InvokeAIMetadata | None:
|
|
||||||
"""Builds an InvokeAIMetadata object"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PngMetadataService(MetadataServiceBase):
|
|
||||||
"""Handles loading and building metadata for images."""
|
|
||||||
|
|
||||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
|
||||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
|
||||||
"""Loads a specific info entry from a PIL Image."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
info = image.info.get("invokeai")
|
|
||||||
|
|
||||||
if type(info) is not str:
|
|
||||||
return None
|
|
||||||
|
|
||||||
loaded_metadata = json.loads(info)
|
|
||||||
|
|
||||||
if type(loaded_metadata) is not dict:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if len(loaded_metadata.items()) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return loaded_metadata
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
|
||||||
"""Retrieves an image's metadata as a dict"""
|
|
||||||
loaded_metadata = self._load_metadata(image)
|
|
||||||
|
|
||||||
return loaded_metadata
|
|
||||||
|
|
||||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
|
||||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
|
||||||
|
|
||||||
return metadata
|
|
||||||
@@ -5,7 +5,6 @@ from argparse import Namespace
|
|||||||
from invokeai.backend import Args
|
from invokeai.backend import Args
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import types
|
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
from ...backend import ModelManager
|
from ...backend import ModelManager
|
||||||
@@ -13,16 +12,16 @@ from ...backend.util import choose_precision, choose_torch_device
|
|||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
|
|
||||||
# TODO: Replace with an abstract class base ModelManagerBase
|
# TODO: Replace with an abstract class base ModelManagerBase
|
||||||
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
def get_model_manager(config: Args) -> ModelManager:
|
||||||
if not config.conf:
|
if not config.conf:
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
if not os.path.exists(config_file):
|
if not os.path.exists(config_file):
|
||||||
report_model_error(
|
report_model_error(
|
||||||
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
|
config, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
|
||||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
# when the frozen CLIP tokenizer is imported
|
# when the frozen CLIP tokenizer is imported
|
||||||
@@ -63,12 +62,11 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
|||||||
device_type=device,
|
device_type=device,
|
||||||
max_loaded_models=config.max_loaded_models,
|
max_loaded_models=config.max_loaded_models,
|
||||||
embedding_path = Path(embedding_path),
|
embedding_path = Path(embedding_path),
|
||||||
logger = logger,
|
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(config, e, logger)
|
report_model_error(config, e)
|
||||||
except (IOError, KeyError) as e:
|
except (IOError, KeyError) as e:
|
||||||
logger.error(f"{e}. Aborting.")
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# try to autoconvert new models
|
# try to autoconvert new models
|
||||||
@@ -78,18 +76,18 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
|||||||
conf_path=config.conf,
|
conf_path=config.conf,
|
||||||
weights_directory=path,
|
weights_directory=path,
|
||||||
)
|
)
|
||||||
logger.info('Model manager initialized')
|
|
||||||
return model_manager
|
return model_manager
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
def report_model_error(opt: Namespace, e: Exception):
|
||||||
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||||
logger.error(
|
print(
|
||||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||||
)
|
)
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||||
if yes_to_all:
|
if yes_to_all:
|
||||||
logger.warning(
|
print(
|
||||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = input(
|
response = input(
|
||||||
@@ -98,12 +96,13 @@ def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
|
|||||||
if response.startswith(("n", "N")):
|
if response.startswith(("n", "N")):
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("invokeai-configure is launching....\n")
|
print("invokeai-configure is launching....\n")
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
# Match arguments that were set on the CLI
|
||||||
# only the arguments accepted by the configuration script are parsed
|
# only the arguments accepted by the configuration script are parsed
|
||||||
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
|
||||||
config = ["--config", opt.conf] if opt.conf is not None else []
|
config = ["--config", opt.conf] if opt.conf is not None else []
|
||||||
|
previous_config = sys.argv
|
||||||
sys.argv = ["invokeai-configure"]
|
sys.argv = ["invokeai-configure"]
|
||||||
sys.argv.extend(root_dir)
|
sys.argv.extend(root_dir)
|
||||||
sys.argv.extend(config.to_dict())
|
sys.argv.extend(config.to_dict())
|
||||||
|
|||||||
@@ -1,20 +1,17 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from threading import Event, Thread, BoundedSemaphore
|
from threading import Event, Thread
|
||||||
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
from ..models.exceptions import CanceledException
|
from ..util.util import CanceledException
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
__stop_event: Event
|
__stop_event: Event
|
||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
__threadLimit: BoundedSemaphore
|
|
||||||
|
|
||||||
def start(self, invoker) -> None:
|
def start(self, invoker) -> None:
|
||||||
# if we do want multithreading at some point, we could make this configurable
|
|
||||||
self.__threadLimit = BoundedSemaphore(1)
|
|
||||||
self.__invoker = invoker
|
self.__invoker = invoker
|
||||||
self.__stop_event = Event()
|
self.__stop_event = Event()
|
||||||
self.__invoker_thread = Thread(
|
self.__invoker_thread = Thread(
|
||||||
@@ -23,7 +20,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
kwargs=dict(stop_event=self.__stop_event),
|
kwargs=dict(stop_event=self.__stop_event),
|
||||||
)
|
)
|
||||||
self.__invoker_thread.daemon = (
|
self.__invoker_thread.daemon = (
|
||||||
True # TODO: make async and do not use threads
|
True # TODO: probably better to just not use threads?
|
||||||
)
|
)
|
||||||
self.__invoker_thread.start()
|
self.__invoker_thread.start()
|
||||||
|
|
||||||
@@ -32,7 +29,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
|
|
||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||||
if not queue_item: # Probably stopping
|
if not queue_item: # Probably stopping
|
||||||
@@ -47,14 +43,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
queue_item.invocation_id
|
queue_item.invocation_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
|
||||||
|
|
||||||
# Send starting event
|
# Send starting event
|
||||||
self.__invoker.services.events.emit_invocation_started(
|
self.__invoker.services.events.emit_invocation_started(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
invocation_id=invocation.id,
|
||||||
source_node_id=source_node_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
@@ -83,8 +75,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
invocation_id=invocation.id,
|
||||||
source_node_id=source_node_id,
|
|
||||||
result=outputs.dict(),
|
result=outputs.dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -108,8 +99,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Send error event
|
# Send error event
|
||||||
self.__invoker.services.events.emit_invocation_error(
|
self.__invoker.services.events.emit_invocation_error(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
invocation_id=invocation.id,
|
||||||
source_node_id=source_node_id,
|
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,6 +121,4 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
... # Log something?
|
||||||
finally:
|
|
||||||
self.__threadLimit.release()
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import torch
|
import torch
|
||||||
from typing import types
|
|
||||||
from ...backend.restoration import Restoration
|
from ...backend.restoration import Restoration
|
||||||
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||||||
|
|
||||||
@@ -11,7 +10,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
|||||||
class RestorationServices:
|
class RestorationServices:
|
||||||
'''Face restoration and upscaling'''
|
'''Face restoration and upscaling'''
|
||||||
|
|
||||||
def __init__(self,args,logger:types.ModuleType):
|
def __init__(self,args):
|
||||||
try:
|
try:
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
gfpgan, codeformer, esrgan = None, None, None
|
||||||
if args.restore or args.esrgan:
|
if args.restore or args.esrgan:
|
||||||
@@ -21,22 +20,20 @@ class RestorationServices:
|
|||||||
args.gfpgan_model_path
|
args.gfpgan_model_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Face restoration disabled")
|
print(">> Face restoration disabled")
|
||||||
if args.esrgan:
|
if args.esrgan:
|
||||||
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||||||
else:
|
else:
|
||||||
logger.info("Upscaling disabled")
|
print(">> Upscaling disabled")
|
||||||
else:
|
else:
|
||||||
logger.info("Face restoration and upscaling disabled")
|
print(">> Face restoration and upscaling disabled")
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
self.gfpgan = gfpgan
|
self.gfpgan = gfpgan
|
||||||
self.codeformer = codeformer
|
self.codeformer = codeformer
|
||||||
self.esrgan = esrgan
|
self.esrgan = esrgan
|
||||||
self.logger = logger
|
|
||||||
self.logger.info('Face restoration initialized')
|
|
||||||
|
|
||||||
# note that this one method does gfpgan and codepath reconstruction, as well as
|
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||||||
# esrgan upscaling
|
# esrgan upscaling
|
||||||
@@ -61,15 +58,15 @@ class RestorationServices:
|
|||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
if facetool == "gfpgan":
|
if facetool == "gfpgan":
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
self.logger.info(
|
print(
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
">> GFPGAN not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == "codeformer":
|
if facetool == "codeformer":
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
self.logger.info(
|
print(
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
">> CodeFormer not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cf_device = (
|
cf_device = (
|
||||||
@@ -83,7 +80,7 @@ class RestorationServices:
|
|||||||
fidelity=codeformer_fidelity,
|
fidelity=codeformer_fidelity,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.info("Face Restoration is disabled.")
|
print(">> Face Restoration is disabled.")
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
if self.esrgan is not None:
|
if self.esrgan is not None:
|
||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
@@ -96,10 +93,10 @@ class RestorationServices:
|
|||||||
denoise_str=upscale_denoise_str,
|
denoise_str=upscale_denoise_str,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.info(
|
print(
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp():
|
|
||||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
|
||||||
25
invokeai/app/util/save_thumbnail.py
Normal file
25
invokeai/app/util/save_thumbnail.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def save_thumbnail(
|
||||||
|
image: Image.Image,
|
||||||
|
filename: str,
|
||||||
|
path: str,
|
||||||
|
size: int = 256,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Saves a thumbnail of an image, returning its path.
|
||||||
|
"""
|
||||||
|
base_filename = os.path.splitext(filename)[0]
|
||||||
|
thumbnail_path = os.path.join(path, base_filename + ".webp")
|
||||||
|
|
||||||
|
if os.path.exists(thumbnail_path):
|
||||||
|
return thumbnail_path
|
||||||
|
|
||||||
|
image_copy = image.copy()
|
||||||
|
image_copy.thumbnail(size=(size, size))
|
||||||
|
|
||||||
|
image_copy.save(thumbnail_path, "WEBP")
|
||||||
|
|
||||||
|
return thumbnail_path
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
from invokeai.app.api.models.images import ProgressImage
|
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
|
||||||
from ...backend.util.util import image_to_dataURL
|
|
||||||
from ...backend.generator.base import Generator
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
||||||
|
|
||||||
|
|
||||||
def stable_diffusion_step_callback(
|
|
||||||
context: InvocationContext,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
node: dict,
|
|
||||||
source_node_id: str,
|
|
||||||
):
|
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
|
||||||
raise CanceledException
|
|
||||||
|
|
||||||
# Some schedulers report not only the noisy latents at the current timestep,
|
|
||||||
# but also their estimate so far of what the de-noised latents will be. Use
|
|
||||||
# that estimate if it is available.
|
|
||||||
if intermediate_state.predicted_original is not None:
|
|
||||||
sample = intermediate_state.predicted_original
|
|
||||||
else:
|
|
||||||
sample = intermediate_state.latents
|
|
||||||
|
|
||||||
# TODO: This does not seem to be needed any more?
|
|
||||||
# # txt2img provides a Tensor in the step_callback
|
|
||||||
# # img2img provides a PipelineIntermediateState
|
|
||||||
# if isinstance(sample, PipelineIntermediateState):
|
|
||||||
# # this was an img2img
|
|
||||||
# print('img2img')
|
|
||||||
# latents = sample.latents
|
|
||||||
# step = sample.step
|
|
||||||
# else:
|
|
||||||
# print('txt2img')
|
|
||||||
# latents = sample
|
|
||||||
# step = intermediate_state.step
|
|
||||||
|
|
||||||
# TODO: only output a preview image when requested
|
|
||||||
image = Generator.sample_to_lowres_estimated_image(sample)
|
|
||||||
|
|
||||||
(width, height) = image.size
|
|
||||||
width *= 8
|
|
||||||
height *= 8
|
|
||||||
|
|
||||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
|
||||||
|
|
||||||
context.services.events.emit_generator_progress(
|
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
|
||||||
node=node,
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
|
||||||
step=intermediate_state.step,
|
|
||||||
total_steps=node["steps"],
|
|
||||||
)
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
import os
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def get_thumbnail_name(image_name: str) -> str:
|
|
||||||
"""Formats given an image name, returns the appropriate thumbnail image name"""
|
|
||||||
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
|
|
||||||
return thumbnail_name
|
|
||||||
|
|
||||||
|
|
||||||
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
|
|
||||||
"""Makes a thumbnail from a PIL Image"""
|
|
||||||
thumbnail = image.copy()
|
|
||||||
thumbnail.thumbnail(size=(size, size))
|
|
||||||
return thumbnail
|
|
||||||
42
invokeai/app/util/util.py
Normal file
42
invokeai/app/util/util.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
|
from ...backend.util.util import image_to_dataURL
|
||||||
|
from ...backend.generator.base import Generator
|
||||||
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
|
||||||
|
class CanceledException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def fast_latents_step_callback(sample: torch.Tensor, step: int, steps: int, id: str, context: InvocationContext, ):
|
||||||
|
# TODO: only output a preview image when requested
|
||||||
|
image = Generator.sample_to_lowres_estimated_image(sample)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
width *= 8
|
||||||
|
height *= 8
|
||||||
|
|
||||||
|
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||||
|
|
||||||
|
context.services.events.emit_generator_progress(
|
||||||
|
context.graph_execution_state_id,
|
||||||
|
id,
|
||||||
|
{
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"dataURL": dataURL
|
||||||
|
},
|
||||||
|
step,
|
||||||
|
steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def diffusers_step_callback_adapter(*cb_args, **kwargs):
|
||||||
|
"""
|
||||||
|
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
|
||||||
|
This adapter grabs the needed data and passes it along to the callback function.
|
||||||
|
"""
|
||||||
|
if isinstance(cb_args[0], PipelineIntermediateState):
|
||||||
|
progress_state: PipelineIntermediateState = cb_args[0]
|
||||||
|
return fast_latents_step_callback(progress_state.latents, progress_state.step, **kwargs)
|
||||||
|
else:
|
||||||
|
return fast_latents_step_callback(*cb_args, **kwargs)
|
||||||
@@ -10,7 +10,7 @@ from .generator import (
|
|||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
from .model_management import ModelManager, SDModelComponent
|
from .model_management import ModelManager
|
||||||
from .safety_checker import SafetyChecker
|
from .safety_checker import SafetyChecker
|
||||||
from .args import Args
|
from .args import Args
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
|
|||||||
@@ -96,7 +96,6 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.image_util import retrieve_metadata
|
from invokeai.backend.image_util import retrieve_metadata
|
||||||
|
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
@@ -190,7 +189,7 @@ class Args(object):
|
|||||||
print(f"{APP_NAME} {APP_VERSION}")
|
print(f"{APP_NAME} {APP_VERSION}")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
logger.info("Initializing, be patient...")
|
print("* Initializing, be patient...")
|
||||||
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
|
||||||
Globals.try_patchmatch = switches.patchmatch
|
Globals.try_patchmatch = switches.patchmatch
|
||||||
|
|
||||||
@@ -198,13 +197,14 @@ class Args(object):
|
|||||||
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
|
||||||
legacyinit = os.path.expanduser("~/.invokeai")
|
legacyinit = os.path.expanduser("~/.invokeai")
|
||||||
if os.path.exists(initfile):
|
if os.path.exists(initfile):
|
||||||
logger.info(
|
print(
|
||||||
f"Initialization file {initfile} found. Loading...",
|
f">> Initialization file {initfile} found. Loading...",
|
||||||
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
sysargs.insert(0, f"@{initfile}")
|
sysargs.insert(0, f"@{initfile}")
|
||||||
elif os.path.exists(legacyinit):
|
elif os.path.exists(legacyinit):
|
||||||
logger.warning(
|
print(
|
||||||
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
f">> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
|
||||||
)
|
)
|
||||||
sysargs.insert(0, f"@{legacyinit}")
|
sysargs.insert(0, f"@{legacyinit}")
|
||||||
Globals.log_tokenization = self._arg_parser.parse_args(
|
Globals.log_tokenization = self._arg_parser.parse_args(
|
||||||
@@ -214,7 +214,7 @@ class Args(object):
|
|||||||
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
self._arg_switches = self._arg_parser.parse_args(sysargs)
|
||||||
return self._arg_switches
|
return self._arg_switches
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An exception has occurred: {e}")
|
print(f"An exception has occurred: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_cmd(self, cmd_string):
|
def parse_cmd(self, cmd_string):
|
||||||
@@ -561,7 +561,7 @@ class Args(object):
|
|||||||
"--autoimport",
|
"--autoimport",
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--autoconvert",
|
"--autoconvert",
|
||||||
@@ -1154,7 +1154,7 @@ class Args(object):
|
|||||||
|
|
||||||
|
|
||||||
def format_metadata(**kwargs):
|
def format_metadata(**kwargs):
|
||||||
logger.warning("format_metadata() is deprecated. Please use metadata_dumps()")
|
print("format_metadata() is deprecated. Please use metadata_dumps()")
|
||||||
return metadata_dumps(kwargs)
|
return metadata_dumps(kwargs)
|
||||||
|
|
||||||
|
|
||||||
@@ -1326,7 +1326,7 @@ def metadata_loads(metadata) -> list:
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error("Could not read metadata")
|
print(">> could not read metadata", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ def install_requested_models(
|
|||||||
scan_directory: Path = None,
|
scan_directory: Path = None,
|
||||||
external_models: List[str] = None,
|
external_models: List[str] = None,
|
||||||
scan_at_startup: bool = False,
|
scan_at_startup: bool = False,
|
||||||
|
convert_to_diffusers: bool = False,
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
purge_deleted: bool = False,
|
purge_deleted: bool = False,
|
||||||
config_file_path: Path = None,
|
config_file_path: Path = None,
|
||||||
@@ -112,6 +113,7 @@ def install_requested_models(
|
|||||||
try:
|
try:
|
||||||
model_manager.heuristic_import(
|
model_manager.heuristic_import(
|
||||||
path_url_or_repo,
|
path_url_or_repo,
|
||||||
|
convert=convert_to_diffusers,
|
||||||
commit_to_conf=config_file_path,
|
commit_to_conf=config_file_path,
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -120,7 +122,7 @@ def install_requested_models(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if scan_at_startup and scan_directory.is_dir():
|
if scan_at_startup and scan_directory.is_dir():
|
||||||
argument = "--autoconvert"
|
argument = "--autoconvert" if convert_to_diffusers else "--autoimport"
|
||||||
initfile = Path(Globals.root, Globals.initfile)
|
initfile = Path(Globals.root, Globals.initfile)
|
||||||
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
replacement = Path(Globals.root, f"{Globals.initfile}.new")
|
||||||
directory = str(scan_directory).replace("\\", "/")
|
directory = str(scan_directory).replace("\\", "/")
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .args import metadata_from_png
|
from .args import metadata_from_png
|
||||||
from .generator import infill_methods
|
from .generator import infill_methods
|
||||||
from .globals import Globals, global_cache_dir
|
from .globals import Globals, global_cache_dir
|
||||||
@@ -196,12 +195,12 @@ class Generate:
|
|||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
# it wasn't actually doing anything. This logic could be reinstated.
|
# it wasn't actually doing anything. This logic could be reinstated.
|
||||||
self.device = torch.device(choose_torch_device())
|
self.device = torch.device(choose_torch_device())
|
||||||
logger.info(f"Using device_type {self.device.type}")
|
print(f">> Using device_type {self.device.type}")
|
||||||
if full_precision:
|
if full_precision:
|
||||||
if self.precision != "auto":
|
if self.precision != "auto":
|
||||||
raise ValueError("Remove --full_precision / -F if using --precision")
|
raise ValueError("Remove --full_precision / -F if using --precision")
|
||||||
logger.warning("Please remove deprecated --full_precision / -F")
|
print("Please remove deprecated --full_precision / -F")
|
||||||
logger.warning("If auto config does not work you can use --precision=float32")
|
print("If auto config does not work you can use --precision=float32")
|
||||||
self.precision = "float32"
|
self.precision = "float32"
|
||||||
if self.precision == "auto":
|
if self.precision == "auto":
|
||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
@@ -209,13 +208,13 @@ class Generate:
|
|||||||
|
|
||||||
if is_xformers_available():
|
if is_xformers_available():
|
||||||
if torch.cuda.is_available() and not Globals.disable_xformers:
|
if torch.cuda.is_available() and not Globals.disable_xformers:
|
||||||
logger.info("xformers memory-efficient attention is available and enabled")
|
print(">> xformers memory-efficient attention is available and enabled")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
print(
|
||||||
"xformers memory-efficient attention is available but disabled"
|
">> xformers memory-efficient attention is available but disabled"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("xformers not installed")
|
print(">> xformers not installed")
|
||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_manager = ModelManager(
|
self.model_manager = ModelManager(
|
||||||
@@ -230,8 +229,8 @@ class Generate:
|
|||||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
model = model or fallback
|
model = model or fallback
|
||||||
if not self.model_manager.valid_model(model):
|
if not self.model_manager.valid_model(model):
|
||||||
logger.warning(
|
print(
|
||||||
f'"{model}" is not a known model name; falling back to {fallback}.'
|
f'** "{model}" is not a known model name; falling back to {fallback}.'
|
||||||
)
|
)
|
||||||
model = None
|
model = None
|
||||||
self.model_name = model or fallback
|
self.model_name = model or fallback
|
||||||
@@ -247,10 +246,10 @@ class Generate:
|
|||||||
|
|
||||||
# load safety checker if requested
|
# load safety checker if requested
|
||||||
if safety_checker:
|
if safety_checker:
|
||||||
logger.info("Initializing NSFW checker")
|
print(">> Initializing NSFW checker")
|
||||||
self.safety_checker = SafetyChecker(self.device)
|
self.safety_checker = SafetyChecker(self.device)
|
||||||
else:
|
else:
|
||||||
logger.info("NSFW checker is disabled")
|
print(">> NSFW checker is disabled")
|
||||||
|
|
||||||
def prompt2png(self, prompt, outdir, **kwargs):
|
def prompt2png(self, prompt, outdir, **kwargs):
|
||||||
"""
|
"""
|
||||||
@@ -568,7 +567,7 @@ class Generate:
|
|||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
|
|
||||||
if catch_interrupts:
|
if catch_interrupts:
|
||||||
logger.warning("Interrupted** Partial results will be returned.")
|
print("**Interrupted** Partial results will be returned.")
|
||||||
else:
|
else:
|
||||||
raise KeyboardInterrupt
|
raise KeyboardInterrupt
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
@@ -576,11 +575,11 @@ class Generate:
|
|||||||
self.clear_cuda_cache()
|
self.clear_cuda_cache()
|
||||||
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
logger.info("Could not generate image.")
|
print(">> Could not generate image.")
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
logger.info("Usage stats:")
|
print("\n>> Usage stats:")
|
||||||
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
|
print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic))
|
||||||
self.print_cuda_stats()
|
self.print_cuda_stats()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -610,16 +609,16 @@ class Generate:
|
|||||||
def print_cuda_stats(self):
|
def print_cuda_stats(self):
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
self.gather_cuda_stats()
|
self.gather_cuda_stats()
|
||||||
logger.info(
|
print(
|
||||||
"Max VRAM used for this generation: "+
|
">> Max VRAM used for this generation:",
|
||||||
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
|
"%4.2fG." % (self.max_memory_allocated / 1e9),
|
||||||
"Current VRAM utilization: "+
|
"Current VRAM utilization:",
|
||||||
"%4.2fG" % (self.memory_allocated / 1e9)
|
"%4.2fG" % (self.memory_allocated / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
print(
|
||||||
"Max VRAM used since script start: " +
|
">> Max VRAM used since script start: ",
|
||||||
"%4.2fG" % (self.session_peakmem / 1e9)
|
"%4.2fG" % (self.session_peakmem / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
|
||||||
@@ -648,7 +647,7 @@ class Generate:
|
|||||||
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
|
|
||||||
prompt = opt.prompt or args.prompt or ""
|
prompt = opt.prompt or args.prompt or ""
|
||||||
logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
|
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
|
||||||
|
|
||||||
# try to reuse the same filename prefix as the original file.
|
# try to reuse the same filename prefix as the original file.
|
||||||
# we take everything up to the first period
|
# we take everything up to the first period
|
||||||
@@ -697,8 +696,8 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
extend_instructions[direction] = int(pixels)
|
extend_instructions[direction] = int(pixels)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
print(
|
||||||
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
|
||||||
)
|
)
|
||||||
|
|
||||||
opt.seed = seed
|
opt.seed = seed
|
||||||
@@ -721,8 +720,8 @@ class Generate:
|
|||||||
# fetch the metadata from the image
|
# fetch the metadata from the image
|
||||||
generator = self.select_generator(embiggen=True)
|
generator = self.select_generator(embiggen=True)
|
||||||
opt.strength = opt.embiggen_strength or 0.40
|
opt.strength = opt.embiggen_strength or 0.40
|
||||||
logger.info(
|
print(
|
||||||
f"Setting img2img strength to {opt.strength} for happy embiggening"
|
f">> Setting img2img strength to {opt.strength} for happy embiggening"
|
||||||
)
|
)
|
||||||
generator.generate(
|
generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
@@ -749,12 +748,12 @@ class Generate:
|
|||||||
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
logger.warning(
|
print(
|
||||||
"please provide at least one postprocessing option, such as -G or -U"
|
"* please provide at least one postprocessing option, such as -G or -U"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
logger.warning(f"postprocessing tool {tool} is not yet supported")
|
print(f"* postprocessing tool {tool} is not yet supported")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def select_generator(
|
def select_generator(
|
||||||
@@ -798,8 +797,8 @@ class Generate:
|
|||||||
image = self._load_img(img)
|
image = self._load_img(img)
|
||||||
|
|
||||||
if image.width < self.width and image.height < self.height:
|
if image.width < self.width and image.height < self.height:
|
||||||
logger.warning(
|
print(
|
||||||
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
|
||||||
)
|
)
|
||||||
|
|
||||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||||
@@ -810,8 +809,8 @@ class Generate:
|
|||||||
if (image.width * image.height) > (
|
if (image.width * image.height) > (
|
||||||
self.width * self.height
|
self.width * self.height
|
||||||
) and self.size_matters:
|
) and self.size_matters:
|
||||||
logger.info(
|
print(
|
||||||
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
">> This input is larger than your defaults. If you run out of memory, please use a smaller image."
|
||||||
)
|
)
|
||||||
self.size_matters = False
|
self.size_matters = False
|
||||||
|
|
||||||
@@ -892,11 +891,11 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
model_data = cache.get_model(model_name)
|
model_data = cache.get_model(model_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"model {model_name} could not be loaded: {str(e)}")
|
print(f"** model {model_name} could not be loaded: {str(e)}")
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
if previous_model_name is None:
|
if previous_model_name is None:
|
||||||
raise e
|
raise e
|
||||||
logger.warning("trying to reload previous model")
|
print("** trying to reload previous model")
|
||||||
model_data = cache.get_model(previous_model_name) # load previous
|
model_data = cache.get_model(previous_model_name) # load previous
|
||||||
if model_data is None:
|
if model_data is None:
|
||||||
raise e
|
raise e
|
||||||
@@ -963,15 +962,15 @@ class Generate:
|
|||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
if facetool == "gfpgan":
|
if facetool == "gfpgan":
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
logger.info(
|
print(
|
||||||
"GFPGAN not found. Face restoration is disabled."
|
">> GFPGAN not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = self.gfpgan.process(image, strength, seed)
|
image = self.gfpgan.process(image, strength, seed)
|
||||||
if facetool == "codeformer":
|
if facetool == "codeformer":
|
||||||
if self.codeformer is None:
|
if self.codeformer is None:
|
||||||
logger.info(
|
print(
|
||||||
"CodeFormer not found. Face restoration is disabled."
|
">> CodeFormer not found. Face restoration is disabled."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cf_device = (
|
cf_device = (
|
||||||
@@ -985,7 +984,7 @@ class Generate:
|
|||||||
fidelity=codeformer_fidelity,
|
fidelity=codeformer_fidelity,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Face Restoration is disabled.")
|
print(">> Face Restoration is disabled.")
|
||||||
if upscale is not None:
|
if upscale is not None:
|
||||||
if self.esrgan is not None:
|
if self.esrgan is not None:
|
||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
@@ -998,10 +997,10 @@ class Generate:
|
|||||||
denoise_str=upscale_denoise_str,
|
denoise_str=upscale_denoise_str,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("ESRGAN is disabled. Image not upscaled.")
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(
|
print(
|
||||||
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
@@ -1067,17 +1066,17 @@ class Generate:
|
|||||||
if self.sampler_name in scheduler_map:
|
if self.sampler_name in scheduler_map:
|
||||||
sampler_class = scheduler_map[self.sampler_name]
|
sampler_class = scheduler_map[self.sampler_name]
|
||||||
msg = (
|
msg = (
|
||||||
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||||
)
|
)
|
||||||
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f" Unsupported Sampler: {self.sampler_name} "+
|
f">> Unsupported Sampler: {self.sampler_name} "
|
||||||
f"Defaulting to {default}"
|
f"Defaulting to {default}"
|
||||||
)
|
)
|
||||||
self.sampler = default
|
self.sampler = default
|
||||||
|
|
||||||
logger.info(msg)
|
print(msg)
|
||||||
|
|
||||||
if not hasattr(self.sampler, "uses_inpainting_model"):
|
if not hasattr(self.sampler, "uses_inpainting_model"):
|
||||||
# FIXME: terrible kludge!
|
# FIXME: terrible kludge!
|
||||||
@@ -1086,17 +1085,17 @@ class Generate:
|
|||||||
def _load_img(self, img) -> Image:
|
def _load_img(self, img) -> Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
logger.info(f"using provided input image of size {image.width}x{image.height}")
|
print(f">> using provided input image of size {image.width}x{image.height}")
|
||||||
elif isinstance(img, str):
|
elif isinstance(img, str):
|
||||||
assert os.path.exists(img), f"{img}: File not found"
|
assert os.path.exists(img), f">> {img}: File not found"
|
||||||
|
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
logger.info(
|
print(
|
||||||
f"loaded input image of size {image.width}x{image.height} from {img}"
|
f">> loaded input image of size {image.width}x{image.height} from {img}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
image = Image.open(img)
|
image = Image.open(img)
|
||||||
logger.info(f"loaded input image of size {image.width}x{image.height}")
|
print(f">> loaded input image of size {image.width}x{image.height}")
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@@ -1184,14 +1183,14 @@ class Generate:
|
|||||||
|
|
||||||
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
|
||||||
if not mask:
|
if not mask:
|
||||||
logger.info(
|
print(
|
||||||
"Initial image has transparent areas. Will inpaint in these regions."
|
">> Initial image has transparent areas. Will inpaint in these regions."
|
||||||
)
|
)
|
||||||
if (not force_outpaint) and self._check_for_erasure(image):
|
if (not force_outpaint) and self._check_for_erasure(image):
|
||||||
logger.info(
|
print(
|
||||||
"Colors underneath the transparent region seem to have been erased.\n" +
|
">> WARNING: Colors underneath the transparent region seem to have been erased.\n",
|
||||||
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
|
">> Inpainting will be suboptimal. Please preserve the colors when making\n",
|
||||||
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
|
">> a transparency mask, or provide mask explicitly using --init_mask (-M).",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _squeeze_image(self, image):
|
def _squeeze_image(self, image):
|
||||||
@@ -1202,11 +1201,11 @@ class Generate:
|
|||||||
|
|
||||||
def _fit_image(self, image, max_dimensions):
|
def _fit_image(self, image, max_dimensions):
|
||||||
w, h = max_dimensions
|
w, h = max_dimensions
|
||||||
logger.info(f"image will be resized to fit inside a box {w}x{h} in size.")
|
print(f">> image will be resized to fit inside a box {w}x{h} in size.")
|
||||||
# note that InitImageResizer does the multiple of 64 truncation internally
|
# note that InitImageResizer does the multiple of 64 truncation internally
|
||||||
image = InitImageResizer(image).resize(width=w, height=h)
|
image = InitImageResizer(image).resize(width=w, height=h)
|
||||||
logger.info(
|
print(
|
||||||
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
|
||||||
)
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@@ -1217,8 +1216,8 @@ class Generate:
|
|||||||
) # resize to integer multiple of 64
|
) # resize to integer multiple of 64
|
||||||
if h != height or w != width:
|
if h != height or w != width:
|
||||||
if log:
|
if log:
|
||||||
logger.info(
|
print(
|
||||||
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
|
||||||
)
|
)
|
||||||
height = h
|
height = h
|
||||||
width = w
|
width = w
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from typing import Callable, List, Iterator, Optional, Type
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ..image_util import configure_model_padding
|
from ..image_util import configure_model_padding
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
from ..safety_checker import SafetyChecker
|
from ..safety_checker import SafetyChecker
|
||||||
@@ -373,7 +372,7 @@ class Generator:
|
|||||||
try:
|
try:
|
||||||
x_T = self.get_noise(width, height)
|
x_T = self.get_noise(width, height)
|
||||||
except:
|
except:
|
||||||
logger.error("An error occurred while getting initial noise")
|
print("** An error occurred while getting initial noise **")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||||
@@ -608,7 +607,7 @@ class Generator:
|
|||||||
image = self.sample_to_image(sample)
|
image = self.sample_to_image(sample)
|
||||||
dirname = os.path.dirname(filepath) or "."
|
dirname = os.path.dirname(filepath) or "."
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
logger.info(f"creating directory {dirname}")
|
print(f"** creating directory {dirname}")
|
||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
image.save(filepath, "PNG")
|
image.save(filepath, "PNG")
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,10 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
from .base import Generator
|
from .base import Generator
|
||||||
from .img2img import Img2Img
|
from .img2img import Img2Img
|
||||||
|
|
||||||
|
|
||||||
class Embiggen(Generator):
|
class Embiggen(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
@@ -73,22 +72,22 @@ class Embiggen(Generator):
|
|||||||
embiggen = [1.0] # If not specified, assume no scaling
|
embiggen = [1.0] # If not specified, assume no scaling
|
||||||
elif embiggen[0] < 0:
|
elif embiggen[0] < 0:
|
||||||
embiggen[0] = 1.0
|
embiggen[0] = 1.0
|
||||||
logger.warning(
|
print(
|
||||||
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
|
||||||
)
|
)
|
||||||
if len(embiggen) < 2:
|
if len(embiggen) < 2:
|
||||||
embiggen.append(0.75)
|
embiggen.append(0.75)
|
||||||
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
elif embiggen[1] > 1.0 or embiggen[1] < 0:
|
||||||
embiggen[1] = 0.75
|
embiggen[1] = 0.75
|
||||||
logger.warning(
|
print(
|
||||||
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
|
||||||
)
|
)
|
||||||
if len(embiggen) < 3:
|
if len(embiggen) < 3:
|
||||||
embiggen.append(0.25)
|
embiggen.append(0.25)
|
||||||
elif embiggen[2] < 0:
|
elif embiggen[2] < 0:
|
||||||
embiggen[2] = 0.25
|
embiggen[2] = 0.25
|
||||||
logger.warning(
|
print(
|
||||||
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
|
||||||
@@ -98,8 +97,8 @@ class Embiggen(Generator):
|
|||||||
embiggen_tiles.sort()
|
embiggen_tiles.sort()
|
||||||
|
|
||||||
if strength >= 0.5:
|
if strength >= 0.5:
|
||||||
logger.warning(
|
print(
|
||||||
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prep img2img generator, since we wrap over it
|
# Prep img2img generator, since we wrap over it
|
||||||
@@ -122,8 +121,8 @@ class Embiggen(Generator):
|
|||||||
from ..restoration.realesrgan import ESRGAN
|
from ..restoration.realesrgan import ESRGAN
|
||||||
|
|
||||||
esrgan = ESRGAN()
|
esrgan = ESRGAN()
|
||||||
logger.info(
|
print(
|
||||||
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
|
||||||
)
|
)
|
||||||
if embiggen[0] > 2:
|
if embiggen[0] > 2:
|
||||||
initsuperimage = esrgan.process(
|
initsuperimage = esrgan.process(
|
||||||
@@ -313,10 +312,10 @@ class Embiggen(Generator):
|
|||||||
def make_image():
|
def make_image():
|
||||||
# Make main tiles -------------------------------------------------
|
# Make main tiles -------------------------------------------------
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
|
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
|
||||||
else:
|
else:
|
||||||
logger.info(
|
print(
|
||||||
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
|
||||||
)
|
)
|
||||||
|
|
||||||
emb_tile_store = []
|
emb_tile_store = []
|
||||||
@@ -362,11 +361,11 @@ class Embiggen(Generator):
|
|||||||
# newinitimage.save(newinitimagepath)
|
# newinitimage.save(newinitimagepath)
|
||||||
|
|
||||||
if embiggen_tiles:
|
if embiggen_tiles:
|
||||||
logger.debug(
|
print(
|
||||||
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
|
||||||
|
|
||||||
# create a torch tensor from an Image
|
# create a torch tensor from an Image
|
||||||
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
|
||||||
@@ -548,8 +547,8 @@ class Embiggen(Generator):
|
|||||||
# Layer tile onto final image
|
# Layer tile onto final image
|
||||||
outputsuperimage.alpha_composite(intileimage, (left, top))
|
outputsuperimage.alpha_composite(intileimage, (left, top))
|
||||||
else:
|
else:
|
||||||
logger.error(
|
print(
|
||||||
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
|
||||||
)
|
)
|
||||||
|
|
||||||
# after internal loops and patching up return Embiggen image
|
# after internal loops and patching up return Embiggen image
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
|
|||||||
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
from ..stable_diffusion.diffusers_pipeline import ConditioningData
|
||||||
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Txt2Img2Img(Generator):
|
class Txt2Img2Img(Generator):
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
super().__init__(model, precision)
|
super().__init__(model, precision)
|
||||||
@@ -79,8 +77,8 @@ class Txt2Img2Img(Generator):
|
|||||||
# the message below is accurate.
|
# the message below is accurate.
|
||||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||||
logger.info(
|
print(
|
||||||
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||||
)
|
)
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
|
|||||||
@@ -5,9 +5,10 @@ wraps the actual patchmatch object. It respects the global
|
|||||||
be suppressed or deferred
|
be suppressed or deferred
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
Thin class wrapper around the patchmatch function.
|
Thin class wrapper around the patchmatch function.
|
||||||
@@ -27,12 +28,12 @@ class PatchMatch:
|
|||||||
from patchmatch import patch_match as pm
|
from patchmatch import patch_match as pm
|
||||||
|
|
||||||
if pm.patchmatch_available:
|
if pm.patchmatch_available:
|
||||||
logger.info("Patchmatch initialized")
|
print(">> Patchmatch initialized")
|
||||||
else:
|
else:
|
||||||
logger.info("Patchmatch not loaded (nonfatal)")
|
print(">> Patchmatch not loaded (nonfatal)")
|
||||||
self.patch_match = pm
|
self.patch_match = pm
|
||||||
else:
|
else:
|
||||||
logger.info("Patchmatch loading disabled")
|
print(">> Patchmatch loading disabled")
|
||||||
self.tried_load = True
|
self.tried_load = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -30,9 +30,9 @@ work fine.
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
from torchvision import transforms
|
||||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import global_cache_dir
|
from invokeai.backend.globals import global_cache_dir
|
||||||
|
|
||||||
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
||||||
@@ -83,7 +83,7 @@ class Txt2Mask(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device="cpu", refined=False):
|
def __init__(self, device="cpu", refined=False):
|
||||||
logger.info("Initializing clipseg model for text to mask inference")
|
print(">> Initializing clipseg model for text to mask inference")
|
||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
@@ -101,6 +101,18 @@ class Txt2Mask(object):
|
|||||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
pixels indicate where the object is inferred to be.
|
pixels indicate where the object is inferred to be.
|
||||||
"""
|
"""
|
||||||
|
transform = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
),
|
||||||
|
transforms.Resize(
|
||||||
|
(CLIPSEG_SIZE, CLIPSEG_SIZE)
|
||||||
|
), # must be multiple of 64...
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if type(image) is str:
|
if type(image) is str:
|
||||||
image = Image.open(image).convert("RGB")
|
image = Image.open(image).convert("RGB")
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,5 @@ from .convert_ckpt_to_diffusers import (
|
|||||||
convert_ckpt_to_diffusers,
|
convert_ckpt_to_diffusers,
|
||||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||||
)
|
)
|
||||||
from .model_manager import ModelManager,SDModelComponent
|
from .model_manager import ModelManager
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from typing import Union
|
|||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||||
|
|
||||||
from .model_manager import ModelManager, SDLegacyType
|
from .model_manager import ModelManager, SDLegacyType
|
||||||
@@ -373,9 +372,9 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
unet_key = "model.diffusion_model."
|
unet_key = "model.diffusion_model."
|
||||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||||
logger.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
|
||||||
if extract_ema:
|
if extract_ema:
|
||||||
logger.debug("Extracting EMA weights (usually better for inference)")
|
print(" | Extracting EMA weights (usually better for inference)")
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith("model.diffusion_model"):
|
if key.startswith("model.diffusion_model"):
|
||||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||||
@@ -393,8 +392,8 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
key
|
key
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
print(
|
||||||
"Extracting only the non-EMA weights (usually better for fine-tuning)"
|
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
|
||||||
)
|
)
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
@@ -1116,7 +1115,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
if "global_step" in checkpoint:
|
if "global_step" in checkpoint:
|
||||||
global_step = checkpoint["global_step"]
|
global_step = checkpoint["global_step"]
|
||||||
else:
|
else:
|
||||||
logger.debug("global_step key not found in model")
|
print(" | global_step key not found in model")
|
||||||
global_step = None
|
global_step = None
|
||||||
|
|
||||||
# sometimes there is a state_dict key and sometimes not
|
# sometimes there is a state_dict key and sometimes not
|
||||||
@@ -1230,15 +1229,15 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
# If a replacement VAE path was specified, we'll incorporate that into
|
# If a replacement VAE path was specified, we'll incorporate that into
|
||||||
# the checkpoint model and then convert it
|
# the checkpoint model and then convert it
|
||||||
if vae_path:
|
if vae_path:
|
||||||
logger.debug(f"Converting VAE {vae_path}")
|
print(f" | Converting VAE {vae_path}")
|
||||||
replace_checkpoint_vae(checkpoint,vae_path)
|
replace_checkpoint_vae(checkpoint,vae_path)
|
||||||
# otherwise we use the original VAE, provided that
|
# otherwise we use the original VAE, provided that
|
||||||
# an externally loaded diffusers VAE was not passed
|
# an externally loaded diffusers VAE was not passed
|
||||||
elif not vae:
|
elif not vae:
|
||||||
logger.debug("Using checkpoint model's original VAE")
|
print(" | Using checkpoint model's original VAE")
|
||||||
|
|
||||||
if vae:
|
if vae:
|
||||||
logger.debug("Using replacement diffusers VAE")
|
print(" | Using replacement diffusers VAE")
|
||||||
else: # convert the original or replacement VAE
|
else: # convert the original or replacement VAE
|
||||||
vae_config = create_vae_diffusers_config(
|
vae_config = create_vae_diffusers_config(
|
||||||
original_config, image_size=image_size
|
original_config, image_size=image_size
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""enum
|
"""
|
||||||
Manage a cache of Stable Diffusion model files for fast switching.
|
Manage a cache of Stable Diffusion model files for fast switching.
|
||||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||||
below a preset minimum, the least recently used model will be
|
below a preset minimum, the least recently used model will be
|
||||||
@@ -15,22 +15,17 @@ import sys
|
|||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum, auto
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import move, rmtree
|
from shutil import move, rmtree
|
||||||
from typing import Any, Optional, Union, Callable, types
|
from typing import Any, Optional, Union, Callable
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
import invokeai.backend.util.logging as logger
|
from diffusers import AutoencoderKL
|
||||||
from diffusers import (
|
from diffusers import logging as dlogging
|
||||||
AutoencoderKL,
|
|
||||||
UNet2DConditionModel,
|
|
||||||
SchedulerMixin,
|
|
||||||
logging as dlogging,
|
|
||||||
)
|
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
@@ -38,55 +33,31 @@ from picklescan.scanner import scan_file_path
|
|||||||
|
|
||||||
from invokeai.backend.globals import Globals, global_cache_dir
|
from invokeai.backend.globals import Globals, global_cache_dir
|
||||||
|
|
||||||
from transformers import (
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
CLIPTextModel,
|
|
||||||
CLIPTokenizer,
|
|
||||||
CLIPFeatureExtractor,
|
|
||||||
)
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
)
|
|
||||||
from ..stable_diffusion import (
|
|
||||||
StableDiffusionGeneratorPipeline,
|
|
||||||
)
|
|
||||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||||
|
|
||||||
|
|
||||||
class SDLegacyType(Enum):
|
class SDLegacyType(Enum):
|
||||||
V1 = auto()
|
V1 = 1
|
||||||
V1_INPAINT = auto()
|
V1_INPAINT = 2
|
||||||
V2 = auto()
|
V2 = 3
|
||||||
V2_e = auto()
|
V2_e = 4
|
||||||
V2_v = auto()
|
V2_v = 5
|
||||||
UNKNOWN = auto()
|
UNKNOWN = 99
|
||||||
|
|
||||||
class SDModelComponent(Enum):
|
|
||||||
vae="vae"
|
|
||||||
text_encoder="text_encoder"
|
|
||||||
tokenizer="tokenizer"
|
|
||||||
unet="unet"
|
|
||||||
scheduler="scheduler"
|
|
||||||
safety_checker="safety_checker"
|
|
||||||
feature_extractor="feature_extractor"
|
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
DEFAULT_MAX_MODELS = 2
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""
|
'''
|
||||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||||
"""
|
'''
|
||||||
|
|
||||||
logger: types.ModuleType = logger
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: OmegaConf | Path,
|
config: OmegaConf|Path,
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: str = "float16",
|
precision: str = "float16",
|
||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
embedding_path: Path = None,
|
embedding_path: Path=None,
|
||||||
logger: types.ModuleType = logger,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize with the path to the models.yaml config file or
|
Initialize with the path to the models.yaml config file or
|
||||||
@@ -108,7 +79,6 @@ class ModelManager(object):
|
|||||||
self.current_model = None
|
self.current_model = None
|
||||||
self.sequential_offload = sequential_offload
|
self.sequential_offload = sequential_offload
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
def valid_model(self, model_name: str) -> bool:
|
def valid_model(self, model_name: str) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -117,28 +87,18 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
return model_name in self.config
|
return model_name in self.config
|
||||||
|
|
||||||
def get_model(self, model_name: str = None) -> dict:
|
def get_model(self, model_name: str=None)->dict:
|
||||||
"""Given a model named identified in models.yaml, return a dict
|
"""
|
||||||
containing the model object and some of its key features. If
|
Given a model named identified in models.yaml, return
|
||||||
in RAM will load into GPU VRAM. If on disk, will load from
|
the model object. If in RAM will load into GPU VRAM.
|
||||||
there.
|
If on disk, will load from there.
|
||||||
The dict has the following keys:
|
|
||||||
'model': The StableDiffusionGeneratorPipeline object
|
|
||||||
'model_name': The name of the model in models.yaml
|
|
||||||
'width': The width of images trained by this model
|
|
||||||
'height': The height of images trained by this model
|
|
||||||
'hash': A unique hash of this model's files on disk.
|
|
||||||
"""
|
"""
|
||||||
if not model_name:
|
if not model_name:
|
||||||
return (
|
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
|
||||||
self.get_model(self.current_model)
|
|
||||||
if self.current_model
|
|
||||||
else self.get_model(self.default_model())
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
if not self.valid_model(model_name):
|
||||||
self.logger.error(
|
print(
|
||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
return self.current_model
|
return self.current_model
|
||||||
|
|
||||||
@@ -149,7 +109,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if model_name in self.models:
|
if model_name in self.models:
|
||||||
requested_model = self.models[model_name]["model"]
|
requested_model = self.models[model_name]["model"]
|
||||||
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
|
print(f">> Retrieving model {model_name} from system RAM cache")
|
||||||
requested_model.ready()
|
requested_model.ready()
|
||||||
width = self.models[model_name]["width"]
|
width = self.models[model_name]["width"]
|
||||||
height = self.models[model_name]["height"]
|
height = self.models[model_name]["height"]
|
||||||
@@ -175,81 +135,6 @@ class ModelManager(object):
|
|||||||
"hash": hash,
|
"hash": hash,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
|
|
||||||
"""Given a model name identified in models.yaml, load the model into
|
|
||||||
GPU if necessary and return its assigned VAE as an
|
|
||||||
AutoencoderKL object. If no model name is provided, return the
|
|
||||||
vae from the model currently in the GPU.
|
|
||||||
"""
|
|
||||||
return self._get_sub_model(model_name, SDModelComponent.vae)
|
|
||||||
|
|
||||||
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
|
||||||
"""Given a model name identified in models.yaml, load the model into
|
|
||||||
GPU if necessary and return its assigned CLIPTokenizer. If no
|
|
||||||
model name is provided, return the tokenizer from the model
|
|
||||||
currently in the GPU.
|
|
||||||
"""
|
|
||||||
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
|
||||||
|
|
||||||
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
|
||||||
"""Given a model name identified in models.yaml, load the model into
|
|
||||||
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
|
||||||
name is provided, return the UNet from the model
|
|
||||||
currently in the GPU.
|
|
||||||
"""
|
|
||||||
return self._get_sub_model(model_name, SDModelComponent.unet)
|
|
||||||
|
|
||||||
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
|
|
||||||
"""Given a model name identified in models.yaml, load the model into
|
|
||||||
GPU if necessary and return its assigned CLIPTextModel. If no
|
|
||||||
model name is provided, return the text encoder from the model
|
|
||||||
currently in the GPU.
|
|
||||||
"""
|
|
||||||
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
|
|
||||||
|
|
||||||
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
|
|
||||||
"""Given a model name identified in models.yaml, load the model into
|
|
||||||
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
|
|
||||||
model name is provided, return the text encoder from the model
|
|
||||||
currently in the GPU.
|
|
||||||
"""
|
|
||||||
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
|
|
||||||
|
|
||||||
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
|
|
||||||
"""Given a model name identified in models.yaml, load the model into
|
|
||||||
GPU if necessary and return its assigned scheduler. If no
|
|
||||||
model name is provided, return the text encoder from the model
|
|
||||||
currently in the GPU.
|
|
||||||
"""
|
|
||||||
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
|
||||||
|
|
||||||
def _get_sub_model(
|
|
||||||
self,
|
|
||||||
model_name: str=None,
|
|
||||||
model_part: SDModelComponent=SDModelComponent.vae,
|
|
||||||
) -> Union[
|
|
||||||
AutoencoderKL,
|
|
||||||
CLIPTokenizer,
|
|
||||||
CLIPFeatureExtractor,
|
|
||||||
UNet2DConditionModel,
|
|
||||||
CLIPTextModel,
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
]:
|
|
||||||
"""Given a model name identified in models.yaml, and the part of the
|
|
||||||
model you wish to retrieve, return that part. Parts are in an Enum
|
|
||||||
class named SDModelComponent, and consist of:
|
|
||||||
SDModelComponent.vae
|
|
||||||
SDModelComponent.text_encoder
|
|
||||||
SDModelComponent.tokenizer
|
|
||||||
SDModelComponent.unet
|
|
||||||
SDModelComponent.scheduler
|
|
||||||
SDModelComponent.safety_checker
|
|
||||||
SDModelComponent.feature_extractor
|
|
||||||
"""
|
|
||||||
model_dict = self.get_model(model_name)
|
|
||||||
model = model_dict["model"]
|
|
||||||
return getattr(model, model_part.value)
|
|
||||||
|
|
||||||
def default_model(self) -> str | None:
|
def default_model(self) -> str | None:
|
||||||
"""
|
"""
|
||||||
Returns the name of the default model, or None
|
Returns the name of the default model, or None
|
||||||
@@ -384,7 +269,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
omega = self.config
|
omega = self.config
|
||||||
if model_name not in omega:
|
if model_name not in omega:
|
||||||
self.logger.error(f"Unknown model {model_name}")
|
print(f"** Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
# save these for use in deletion later
|
# save these for use in deletion later
|
||||||
conf = omega[model_name]
|
conf = omega[model_name]
|
||||||
@@ -397,13 +282,13 @@ class ModelManager(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
if delete_files:
|
if delete_files:
|
||||||
if weights:
|
if weights:
|
||||||
self.logger.info(f"Deleting file {weights}")
|
print(f"** Deleting file {weights}")
|
||||||
Path(weights).unlink(missing_ok=True)
|
Path(weights).unlink(missing_ok=True)
|
||||||
elif path:
|
elif path:
|
||||||
self.logger.info(f"Deleting directory {path}")
|
print(f"** Deleting directory {path}")
|
||||||
rmtree(path, ignore_errors=True)
|
rmtree(path, ignore_errors=True)
|
||||||
elif repo_id:
|
elif repo_id:
|
||||||
self.logger.info(f"Deleting the cached model directory for {repo_id}")
|
print(f"** Deleting the cached model directory for {repo_id}")
|
||||||
self._delete_model_from_cache(repo_id)
|
self._delete_model_from_cache(repo_id)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
@@ -444,7 +329,7 @@ class ModelManager(object):
|
|||||||
def _load_model(self, model_name: str):
|
def _load_model(self, model_name: str):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
if model_name not in self.config:
|
if model_name not in self.config:
|
||||||
self.logger.error(
|
print(
|
||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
@@ -462,7 +347,7 @@ class ModelManager(object):
|
|||||||
model_format = mconfig.get("format", "ckpt")
|
model_format = mconfig.get("format", "ckpt")
|
||||||
if model_format == "ckpt":
|
if model_format == "ckpt":
|
||||||
weights = mconfig.weights
|
weights = mconfig.weights
|
||||||
self.logger.info(f"Loading {model_name} from {weights}")
|
print(f">> Loading {model_name} from {weights}")
|
||||||
model, width, height, model_hash = self._load_ckpt_model(
|
model, width, height, model_hash = self._load_ckpt_model(
|
||||||
model_name, mconfig
|
model_name, mconfig
|
||||||
)
|
)
|
||||||
@@ -478,15 +363,13 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# usage statistics
|
# usage statistics
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
print(">> Model loaded in", "%4.2fs" % (toc - tic))
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
self.logger.info(
|
print(
|
||||||
"Max VRAM used to load the model: "+
|
">> Max VRAM used to load the model:",
|
||||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
|
||||||
)
|
"\n>> Current VRAM usage:"
|
||||||
self.logger.info(
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||||
"Current VRAM usage: "+
|
|
||||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
|
||||||
)
|
)
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
@@ -494,11 +377,11 @@ class ModelManager(object):
|
|||||||
name_or_path = self.model_name_or_path(mconfig)
|
name_or_path = self.model_name_or_path(mconfig)
|
||||||
using_fp16 = self.precision == "float16"
|
using_fp16 = self.precision == "float16"
|
||||||
|
|
||||||
self.logger.info(f"Loading diffusers model from {name_or_path}")
|
print(f">> Loading diffusers model from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
self.logger.debug("Using faster float16 precision")
|
print(" | Using faster float16 precision")
|
||||||
else:
|
else:
|
||||||
self.logger.debug("Using more accurate float32 precision")
|
print(" | Using more accurate float32 precision")
|
||||||
|
|
||||||
# TODO: scan weights maybe?
|
# TODO: scan weights maybe?
|
||||||
pipeline_args: dict[str, Any] = dict(
|
pipeline_args: dict[str, Any] = dict(
|
||||||
@@ -530,8 +413,8 @@ class ModelManager(object):
|
|||||||
if str(e).startswith("fp16 is not a valid"):
|
if str(e).startswith("fp16 is not a valid"):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.logger.error(
|
print(
|
||||||
f"An unexpected error occurred while downloading the model: {e})"
|
f"** An unexpected error occurred while downloading the model: {e})"
|
||||||
)
|
)
|
||||||
if pipeline:
|
if pipeline:
|
||||||
break
|
break
|
||||||
@@ -549,7 +432,7 @@ class ModelManager(object):
|
|||||||
# square images???
|
# square images???
|
||||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||||
height = width
|
height = width
|
||||||
self.logger.debug(f"Default image dimensions = {width} x {height}")
|
print(f" | Default image dimensions = {width} x {height}")
|
||||||
|
|
||||||
return pipeline, width, height, model_hash
|
return pipeline, width, height, model_hash
|
||||||
|
|
||||||
@@ -566,23 +449,19 @@ class ModelManager(object):
|
|||||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||||
|
|
||||||
# Convert to diffusers and return a diffusers pipeline
|
# Convert to diffusers and return a diffusers pipeline
|
||||||
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||||
|
|
||||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.list_models()[self.current_model]["status"] == "active":
|
if self.list_models()[self.current_model]['status'] == 'active':
|
||||||
self.offload_model(self.current_model)
|
self.offload_model(self.current_model)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
vae_path = None
|
vae_path = None
|
||||||
if vae:
|
if vae:
|
||||||
vae_path = (
|
vae_path = vae if os.path.isabs(vae) else os.path.normpath(os.path.join(Globals.root, vae))
|
||||||
vae
|
|
||||||
if os.path.isabs(vae)
|
|
||||||
else os.path.normpath(os.path.join(Globals.root, vae))
|
|
||||||
)
|
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
@@ -631,7 +510,7 @@ class ModelManager(object):
|
|||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.logger.info(f"Offloading {model_name} to CPU")
|
print(f">> Offloading {model_name} to CPU")
|
||||||
model = self.models[model_name]["model"]
|
model = self.models[model_name]["model"]
|
||||||
model.offload_all()
|
model.offload_all()
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
@@ -647,26 +526,30 @@ class ModelManager(object):
|
|||||||
and option to exit if an infected file is identified.
|
and option to exit if an infected file is identified.
|
||||||
"""
|
"""
|
||||||
# scan model
|
# scan model
|
||||||
self.logger.debug(f"Scanning Model: {model_name}")
|
print(f" | Scanning Model: {model_name}")
|
||||||
scan_result = scan_file_path(checkpoint)
|
scan_result = scan_file_path(checkpoint)
|
||||||
if scan_result.infected_files != 0:
|
if scan_result.infected_files != 0:
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
print(f"\n### Issues Found In Model: {scan_result.issues_count}")
|
||||||
self.logger.critical("The model you are trying to load seems to be infected.")
|
print(
|
||||||
self.logger.critical("For your safety, InvokeAI will not load this model.")
|
"### WARNING: The model you are trying to load seems to be infected."
|
||||||
self.logger.critical("Please use checkpoints from trusted sources.")
|
)
|
||||||
self.logger.critical("Exiting InvokeAI")
|
print("### For your safety, InvokeAI will not load this model.")
|
||||||
|
print("### Please use checkpoints from trusted sources.")
|
||||||
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
self.logger.warning("InvokeAI was unable to scan the model you are using.")
|
print(
|
||||||
|
"\n### WARNING: InvokeAI was unable to scan the model you are using."
|
||||||
|
)
|
||||||
model_safe_check_fail = ask_user(
|
model_safe_check_fail = ask_user(
|
||||||
"Do you want to to continue loading the model?", ["y", "n"]
|
"Do you want to to continue loading the model?", ["y", "n"]
|
||||||
)
|
)
|
||||||
if model_safe_check_fail.lower() != "y":
|
if model_safe_check_fail.lower() != "y":
|
||||||
self.logger.critical("Exiting InvokeAI")
|
print("### Exiting InvokeAI")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
else:
|
else:
|
||||||
self.logger.debug("Model scanned ok")
|
print(" | Model scanned ok")
|
||||||
|
|
||||||
def import_diffuser_model(
|
def import_diffuser_model(
|
||||||
self,
|
self,
|
||||||
@@ -688,7 +571,9 @@ class ModelManager(object):
|
|||||||
models.yaml file.
|
models.yaml file.
|
||||||
"""
|
"""
|
||||||
model_name = model_name or Path(repo_or_path).stem
|
model_name = model_name or Path(repo_or_path).stem
|
||||||
model_description = description or f"Imported diffusers model {model_name}"
|
model_description = (
|
||||||
|
description or f"Imported diffusers model {model_name}"
|
||||||
|
)
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
description=model_description,
|
description=model_description,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
@@ -717,7 +602,7 @@ class ModelManager(object):
|
|||||||
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
||||||
SDLegacyType.UNKNOWN
|
SDLegacyType.UNKNOWN
|
||||||
"""
|
"""
|
||||||
global_step = checkpoint.get("global_step")
|
global_step = checkpoint.get('global_step')
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -783,24 +668,26 @@ class ModelManager(object):
|
|||||||
model_path: Path = None
|
model_path: Path = None
|
||||||
thing = path_url_or_repo # to save typing
|
thing = path_url_or_repo # to save typing
|
||||||
|
|
||||||
self.logger.info(f"Probing {thing} for import")
|
print(f">> Probing {thing} for import")
|
||||||
|
|
||||||
if thing.startswith(("http:", "https:", "ftp:")):
|
if thing.startswith(("http:", "https:", "ftp:")):
|
||||||
self.logger.info(f"{thing} appears to be a URL")
|
print(f" | {thing} appears to be a URL")
|
||||||
model_path = self._resolve_path(
|
model_path = self._resolve_path(
|
||||||
thing, "models/ldm/stable-diffusion-v1"
|
thing, "models/ldm/stable-diffusion-v1"
|
||||||
) # _resolve_path does a download if needed
|
) # _resolve_path does a download if needed
|
||||||
|
|
||||||
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
|
||||||
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
|
||||||
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
|
print(
|
||||||
|
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
|
print(f" | {thing} appears to be a checkpoint file on disk")
|
||||||
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
|
||||||
|
|
||||||
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
|
||||||
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
|
print(f" | {thing} appears to be a diffusers file on disk")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing,
|
thing,
|
||||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||||
@@ -811,30 +698,34 @@ class ModelManager(object):
|
|||||||
|
|
||||||
elif Path(thing).is_dir():
|
elif Path(thing).is_dir():
|
||||||
if (Path(thing) / "model_index.json").exists():
|
if (Path(thing) / "model_index.json").exists():
|
||||||
self.logger.debug(f"{thing} appears to be a diffusers model.")
|
print(f" | {thing} appears to be a diffusers model.")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
|
print(
|
||||||
|
f" |{thing} appears to be a directory. Will scan for models to import"
|
||||||
|
)
|
||||||
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
for m in list(Path(thing).rglob("*.ckpt")) + list(
|
||||||
Path(thing).rglob("*.safetensors")
|
Path(thing).rglob("*.safetensors")
|
||||||
):
|
):
|
||||||
if model_name := self.heuristic_import(
|
if model_name := self.heuristic_import(
|
||||||
str(m), commit_to_conf=commit_to_conf
|
str(m), commit_to_conf=commit_to_conf
|
||||||
):
|
):
|
||||||
self.logger.info(f"{model_name} successfully imported")
|
print(f" >> {model_name} successfully imported")
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
|
||||||
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
|
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
|
||||||
model_name = self.import_diffuser_model(
|
model_name = self.import_diffuser_model(
|
||||||
thing, commit_to_conf=commit_to_conf
|
thing, commit_to_conf=commit_to_conf
|
||||||
)
|
)
|
||||||
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
|
||||||
return model_name
|
return model_name
|
||||||
else:
|
else:
|
||||||
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
|
print(
|
||||||
|
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
|
||||||
|
)
|
||||||
|
|
||||||
# Model_path is set in the event of a legacy checkpoint file.
|
# Model_path is set in the event of a legacy checkpoint file.
|
||||||
# If not set, we're all done
|
# If not set, we're all done
|
||||||
@@ -842,13 +733,13 @@ class ModelManager(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if model_path.stem in self.config: # already imported
|
if model_path.stem in self.config: # already imported
|
||||||
self.logger.debug("Already imported. Skipping")
|
print(" | Already imported. Skipping")
|
||||||
return model_path.stem
|
return model_path.stem
|
||||||
|
|
||||||
# another round of heuristics to guess the correct config file.
|
# another round of heuristics to guess the correct config file.
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
if model_path.suffix in [".ckpt", ".pt"]:
|
if model_path.suffix in [".ckpt",".pt"]:
|
||||||
self.scan_model(model_path, model_path)
|
self.scan_model(model_path,model_path)
|
||||||
checkpoint = torch.load(model_path)
|
checkpoint = torch.load(model_path)
|
||||||
else:
|
else:
|
||||||
checkpoint = safetensors.torch.load_file(model_path)
|
checkpoint = safetensors.torch.load_file(model_path)
|
||||||
@@ -858,39 +749,42 @@ class ModelManager(object):
|
|||||||
# look for a like-named .yaml file in same directory
|
# look for a like-named .yaml file in same directory
|
||||||
if model_path.with_suffix(".yaml").exists():
|
if model_path.with_suffix(".yaml").exists():
|
||||||
model_config_file = model_path.with_suffix(".yaml")
|
model_config_file = model_path.with_suffix(".yaml")
|
||||||
self.logger.debug(f"Using config file {model_config_file.name}")
|
print(f" | Using config file {model_config_file.name}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model_type = self.probe_model_type(checkpoint)
|
model_type = self.probe_model_type(checkpoint)
|
||||||
if model_type == SDLegacyType.V1:
|
if model_type == SDLegacyType.V1:
|
||||||
self.logger.debug("SD-v1 model detected")
|
print(" | SD-v1 model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V1_INPAINT:
|
elif model_type == SDLegacyType.V1_INPAINT:
|
||||||
self.logger.debug("SD-v1 inpainting model detected")
|
print(" | SD-v1 inpainting model detected")
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root,
|
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_v:
|
elif model_type == SDLegacyType.V2_v:
|
||||||
self.logger.debug("SD-v2-v model detected")
|
print(
|
||||||
|
" | SD-v2-v model detected"
|
||||||
|
)
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2_e:
|
elif model_type == SDLegacyType.V2_e:
|
||||||
self.logger.debug("SD-v2-e model detected")
|
print(
|
||||||
|
" | SD-v2-e model detected"
|
||||||
|
)
|
||||||
model_config_file = Path(
|
model_config_file = Path(
|
||||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||||
)
|
)
|
||||||
elif model_type == SDLegacyType.V2:
|
elif model_type == SDLegacyType.V2:
|
||||||
self.logger.warning(
|
print(
|
||||||
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
self.logger.warning(
|
print(
|
||||||
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -906,7 +800,7 @@ class ModelManager(object):
|
|||||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||||
self.logger.debug(f"Using VAE file {vae_path.name}")
|
print(f" | Using VAE file {vae_path.name}")
|
||||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||||
|
|
||||||
diffuser_path = Path(
|
diffuser_path = Path(
|
||||||
@@ -931,11 +825,11 @@ class ModelManager(object):
|
|||||||
diffusers_path: Path,
|
diffusers_path: Path,
|
||||||
model_name=None,
|
model_name=None,
|
||||||
model_description=None,
|
model_description=None,
|
||||||
vae: dict = None,
|
vae:dict=None,
|
||||||
vae_path: Path = None,
|
vae_path:Path=None,
|
||||||
original_config_file: Path = None,
|
original_config_file: Path = None,
|
||||||
commit_to_conf: Path = None,
|
commit_to_conf: Path = None,
|
||||||
scan_needed: bool = True,
|
scan_needed: bool=True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Convert a legacy ckpt weights file to diffuser model and import
|
Convert a legacy ckpt weights file to diffuser model and import
|
||||||
@@ -952,21 +846,21 @@ class ModelManager(object):
|
|||||||
from . import convert_ckpt_to_diffusers
|
from . import convert_ckpt_to_diffusers
|
||||||
|
|
||||||
if diffusers_path.exists():
|
if diffusers_path.exists():
|
||||||
self.logger.error(
|
print(
|
||||||
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
model_name = model_name or diffusers_path.name
|
model_name = model_name or diffusers_path.name
|
||||||
model_description = model_description or f"Converted version of {model_name}"
|
model_description = model_description or f"Converted version of {model_name}"
|
||||||
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
print(f" | Converting {model_name} to diffusers (30-60s)")
|
||||||
try:
|
try:
|
||||||
# By passing the specified VAE to the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
vae_model = None
|
vae_model=None
|
||||||
if vae:
|
if vae:
|
||||||
vae_model = self._load_vae(vae)
|
vae_model=self._load_vae(vae)
|
||||||
vae_path = None
|
vae_path=None
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
diffusers_path,
|
diffusers_path,
|
||||||
@@ -976,10 +870,10 @@ class ModelManager(object):
|
|||||||
vae_path=vae_path,
|
vae_path=vae_path,
|
||||||
scan_needed=scan_needed,
|
scan_needed=scan_needed,
|
||||||
)
|
)
|
||||||
self.logger.debug(
|
print(
|
||||||
f"Success. Converted model is now located at {str(diffusers_path)}"
|
f" | Success. Converted model is now located at {str(diffusers_path)}"
|
||||||
)
|
)
|
||||||
self.logger.debug(f"Writing new config file entry for {model_name}")
|
print(f" | Writing new config file entry for {model_name}")
|
||||||
new_config = dict(
|
new_config = dict(
|
||||||
path=str(diffusers_path),
|
path=str(diffusers_path),
|
||||||
description=model_description,
|
description=model_description,
|
||||||
@@ -990,17 +884,17 @@ class ModelManager(object):
|
|||||||
self.add_model(model_name, new_config, True)
|
self.add_model(model_name, new_config, True)
|
||||||
if commit_to_conf:
|
if commit_to_conf:
|
||||||
self.commit(commit_to_conf)
|
self.commit(commit_to_conf)
|
||||||
self.logger.debug("Conversion succeeded")
|
print(" | Conversion succeeded")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f"Conversion failed: {str(e)}")
|
print(f"** Conversion failed: {str(e)}")
|
||||||
self.logger.warning(
|
print(
|
||||||
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
print(f">> Finding Models In: {search_folder}")
|
||||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||||
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
|
||||||
|
|
||||||
@@ -1024,8 +918,8 @@ class ModelManager(object):
|
|||||||
num_loaded_models = len(self.models)
|
num_loaded_models = len(self.models)
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
least_recent_model = self._pop_oldest_model()
|
least_recent_model = self._pop_oldest_model()
|
||||||
self.logger.info(
|
print(
|
||||||
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||||
)
|
)
|
||||||
if least_recent_model is not None:
|
if least_recent_model is not None:
|
||||||
del self.models[least_recent_model]
|
del self.models[least_recent_model]
|
||||||
@@ -1033,8 +927,8 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def print_vram_usage(self) -> None:
|
def print_vram_usage(self) -> None:
|
||||||
if self._has_cuda:
|
if self._has_cuda:
|
||||||
self.logger.info(
|
print(
|
||||||
"Current VRAM usage:"+
|
">> Current VRAM usage: ",
|
||||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1082,15 +976,15 @@ class ModelManager(object):
|
|||||||
legacy_locations = [
|
legacy_locations = [
|
||||||
Path(
|
Path(
|
||||||
models_dir,
|
models_dir,
|
||||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
|
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
|
||||||
),
|
),
|
||||||
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
||||||
Path(
|
Path(
|
||||||
models_dir,
|
models_dir,
|
||||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
|
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
|
||||||
|
|
||||||
legacy_layout = False
|
legacy_layout = False
|
||||||
for model in legacy_locations:
|
for model in legacy_locations:
|
||||||
@@ -1109,7 +1003,7 @@ class ModelManager(object):
|
|||||||
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
|
||||||
>> Otherwise press <enter> to continue."""
|
>> Otherwise press <enter> to continue."""
|
||||||
)
|
)
|
||||||
input("continue> ")
|
input('continue> ')
|
||||||
|
|
||||||
# transformer files get moved into the hub directory
|
# transformer files get moved into the hub directory
|
||||||
if cls._is_huggingface_hub_directory_present():
|
if cls._is_huggingface_hub_directory_present():
|
||||||
@@ -1123,10 +1017,10 @@ class ModelManager(object):
|
|||||||
dest = hub / model.stem
|
dest = hub / model.stem
|
||||||
if dest.exists() and not source.exists():
|
if dest.exists() and not source.exists():
|
||||||
continue
|
continue
|
||||||
cls.logger.info(f"{source} => {dest}")
|
print(f"** {source} => {dest}")
|
||||||
if source.exists():
|
if source.exists():
|
||||||
if dest.is_symlink():
|
if dest.is_symlink():
|
||||||
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
|
print(f"** Found symlink at {dest.name}. Not migrating.")
|
||||||
elif dest.exists():
|
elif dest.exists():
|
||||||
if source.is_dir():
|
if source.is_dir():
|
||||||
rmtree(source)
|
rmtree(source)
|
||||||
@@ -1143,7 +1037,7 @@ class ModelManager(object):
|
|||||||
]
|
]
|
||||||
for d in empty:
|
for d in empty:
|
||||||
os.rmdir(d)
|
os.rmdir(d)
|
||||||
cls.logger.info("Migration is done. Continuing...")
|
print("** Migration is done. Continuing...")
|
||||||
|
|
||||||
def _resolve_path(
|
def _resolve_path(
|
||||||
self, source: Union[str, Path], dest_directory: str
|
self, source: Union[str, Path], dest_directory: str
|
||||||
@@ -1186,22 +1080,22 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
self.logger.info(f"Loading embeddings from {self.embedding_path}")
|
print(f">> Loading embeddings from {self.embedding_path}")
|
||||||
for root, _, files in os.walk(self.embedding_path):
|
for root, _, files in os.walk(self.embedding_path):
|
||||||
for name in files:
|
for name in files:
|
||||||
ti_path = os.path.join(root, name)
|
ti_path = os.path.join(root, name)
|
||||||
model.textual_inversion_manager.load_textual_inversion(
|
model.textual_inversion_manager.load_textual_inversion(
|
||||||
ti_path, defer_injecting_tokens=True
|
ti_path, defer_injecting_tokens=True
|
||||||
)
|
)
|
||||||
self.logger.info(
|
print(
|
||||||
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||||
)
|
)
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
def _has_cuda(self) -> bool:
|
||||||
return self.device.type == "cuda"
|
return self.device.type == "cuda"
|
||||||
|
|
||||||
def _diffuser_sha256(
|
def _diffuser_sha256(
|
||||||
self, name_or_path: Union[str, Path], chunksize=16777216
|
self, name_or_path: Union[str, Path], chunksize=4096
|
||||||
) -> Union[str, bytes]:
|
) -> Union[str, bytes]:
|
||||||
path = None
|
path = None
|
||||||
if isinstance(name_or_path, Path):
|
if isinstance(name_or_path, Path):
|
||||||
@@ -1216,7 +1110,7 @@ class ModelManager(object):
|
|||||||
with open(hashpath) as f:
|
with open(hashpath) as f:
|
||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
self.logger.debug("Calculating sha256 hash of model files")
|
print(" | Calculating sha256 hash of model files")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
count = 0
|
count = 0
|
||||||
@@ -1228,7 +1122,7 @@ class ModelManager(object):
|
|||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
@@ -1246,13 +1140,13 @@ class ModelManager(object):
|
|||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
self.logger.debug("Calculating sha256 hash of weights file")
|
print(" | Calculating sha256 hash of weights file")
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
sha.update(data)
|
sha.update(data)
|
||||||
hash = sha.hexdigest()
|
hash = sha.hexdigest()
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic))
|
||||||
|
|
||||||
with open(hashpath, "w") as f:
|
with open(hashpath, "w") as f:
|
||||||
f.write(hash)
|
f.write(hash)
|
||||||
@@ -1273,12 +1167,12 @@ class ModelManager(object):
|
|||||||
local_files_only=not Globals.internet_available,
|
local_files_only=not Globals.internet_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
print(f" | Loading diffusers VAE from {name_or_path}")
|
||||||
if using_fp16:
|
if using_fp16:
|
||||||
vae_args.update(torch_dtype=torch.float16)
|
vae_args.update(torch_dtype=torch.float16)
|
||||||
fp_args_list = [{"revision": "fp16"}, {}]
|
fp_args_list = [{"revision": "fp16"}, {}]
|
||||||
else:
|
else:
|
||||||
self.logger.debug("Using more accurate float32 precision")
|
print(" | Using more accurate float32 precision")
|
||||||
fp_args_list = [{}]
|
fp_args_list = [{}]
|
||||||
|
|
||||||
vae = None
|
vae = None
|
||||||
@@ -1302,12 +1196,12 @@ class ModelManager(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not vae and deferred_error:
|
if not vae and deferred_error:
|
||||||
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||||
|
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
@classmethod
|
@staticmethod
|
||||||
def _delete_model_from_cache(cls,repo_id):
|
def _delete_model_from_cache(repo_id):
|
||||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||||
|
|
||||||
# I'm sure there is a way to do this with comprehensions
|
# I'm sure there is a way to do this with comprehensions
|
||||||
@@ -1318,8 +1212,8 @@ class ModelManager(object):
|
|||||||
for revision in repo.revisions:
|
for revision in repo.revisions:
|
||||||
hashes_to_delete.add(revision.commit_hash)
|
hashes_to_delete.add(revision.commit_hash)
|
||||||
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
strategy = cache_info.delete_revisions(*hashes_to_delete)
|
||||||
cls.logger.warning(
|
print(
|
||||||
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
f"** Deletion of this model is expected to free {strategy.expected_freed_size_str}"
|
||||||
)
|
)
|
||||||
strategy.execute()
|
strategy.execute()
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from compel.prompt_parser import (
|
|||||||
PromptParser,
|
PromptParser,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
from ..stable_diffusion import InvokeAIDiffuserComponent
|
from ..stable_diffusion import InvokeAIDiffuserComponent
|
||||||
@@ -163,8 +162,8 @@ def log_tokenization(
|
|||||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
):
|
):
|
||||||
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
|
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||||
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||||
|
|
||||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||||
log_tokenization_for_prompt_object(
|
log_tokenization_for_prompt_object(
|
||||||
@@ -238,12 +237,12 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
|
|||||||
usedTokens += 1
|
usedTokens += 1
|
||||||
|
|
||||||
if usedTokens > 0:
|
if usedTokens > 0:
|
||||||
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
|
||||||
logger.debug(f"{tokenized}\x1b[0m")
|
print(f"{tokenized}\x1b[0m")
|
||||||
|
|
||||||
if discarded != "":
|
if discarded != "":
|
||||||
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
|
||||||
logger.debug(f"{discarded}\x1b[0m")
|
print(f"{discarded}\x1b[0m")
|
||||||
|
|
||||||
|
|
||||||
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
||||||
@@ -296,8 +295,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
|
|||||||
return parsed_prompts
|
return parsed_prompts
|
||||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||||
if weight_sum == 0:
|
if weight_sum == 0:
|
||||||
logger.warning(
|
print(
|
||||||
"Subprompt weights add up to zero. Discarding and using even weights instead."
|
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
|
||||||
)
|
)
|
||||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Restoration:
|
class Restoration:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
@@ -10,17 +8,17 @@ class Restoration:
|
|||||||
# Load GFPGAN
|
# Load GFPGAN
|
||||||
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
gfpgan = self.load_gfpgan(gfpgan_model_path)
|
||||||
if gfpgan.gfpgan_model_exists:
|
if gfpgan.gfpgan_model_exists:
|
||||||
logger.info("GFPGAN Initialized")
|
print(">> GFPGAN Initialized")
|
||||||
else:
|
else:
|
||||||
logger.info("GFPGAN Disabled")
|
print(">> GFPGAN Disabled")
|
||||||
gfpgan = None
|
gfpgan = None
|
||||||
|
|
||||||
# Load CodeFormer
|
# Load CodeFormer
|
||||||
codeformer = self.load_codeformer()
|
codeformer = self.load_codeformer()
|
||||||
if codeformer.codeformer_model_exists:
|
if codeformer.codeformer_model_exists:
|
||||||
logger.info("CodeFormer Initialized")
|
print(">> CodeFormer Initialized")
|
||||||
else:
|
else:
|
||||||
logger.info("CodeFormer Disabled")
|
print(">> CodeFormer Disabled")
|
||||||
codeformer = None
|
codeformer = None
|
||||||
|
|
||||||
return gfpgan, codeformer
|
return gfpgan, codeformer
|
||||||
@@ -41,5 +39,5 @@ class Restoration:
|
|||||||
from .realesrgan import ESRGAN
|
from .realesrgan import ESRGAN
|
||||||
|
|
||||||
esrgan = ESRGAN(esrgan_bg_tile)
|
esrgan = ESRGAN(esrgan_bg_tile)
|
||||||
logger.info("ESRGAN Initialized")
|
print(">> ESRGAN Initialized")
|
||||||
return esrgan
|
return esrgan
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ..globals import Globals
|
from ..globals import Globals
|
||||||
|
|
||||||
pretrained_model_url = (
|
pretrained_model_url = (
|
||||||
@@ -24,12 +23,12 @@ class CodeFormerRestoration:
|
|||||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
if not self.codeformer_model_exists:
|
||||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
|
||||||
sys.path.append(os.path.abspath(codeformer_dir))
|
sys.path.append(os.path.abspath(codeformer_dir))
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
|
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
warnings.filterwarnings("ignore", category=UserWarning)
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
@@ -98,7 +97,7 @@ class CodeFormerRestoration:
|
|||||||
del output
|
del output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
logger.error(f"Failed inference for CodeFormer: {error}.")
|
print(f"\tFailed inference for CodeFormer: {error}.")
|
||||||
restored_face = cropped_face
|
restored_face = cropped_face
|
||||||
|
|
||||||
restored_face = restored_face.astype("uint8")
|
restored_face = restored_face.astype("uint8")
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class GFPGAN:
|
class GFPGAN:
|
||||||
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
|
||||||
if not os.path.isabs(gfpgan_model_path):
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
@@ -19,7 +19,7 @@ class GFPGAN:
|
|||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
if not self.gfpgan_model_exists:
|
||||||
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_exists(self):
|
def model_exists(self):
|
||||||
@@ -27,7 +27,7 @@ class GFPGAN:
|
|||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None):
|
def process(self, image, strength: float, seed: str = None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
|
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
@@ -47,14 +47,14 @@ class GFPGAN:
|
|||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error("Error loading GFPGAN:", file=sys.stderr)
|
print(">> Error loading GFPGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
os.chdir(cwd)
|
os.chdir(cwd)
|
||||||
|
|
||||||
if self.gfpgan is None:
|
if self.gfpgan is None:
|
||||||
logger.warning("WARNING: GFPGAN not initialized.")
|
print(f">> WARNING: GFPGAN not initialized.")
|
||||||
logger.warning(
|
print(
|
||||||
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
class Outcrop(object):
|
class Outcrop(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -82,7 +82,7 @@ class Outcrop(object):
|
|||||||
pixels = extents[direction]
|
pixels = extents[direction]
|
||||||
# round pixels up to the nearest 64
|
# round pixels up to the nearest 64
|
||||||
pixels = math.ceil(pixels / 64) * 64
|
pixels = math.ceil(pixels / 64) * 64
|
||||||
logger.info(f"extending image {direction}ward by {pixels} pixels")
|
print(f">> extending image {direction}ward by {pixels} pixels")
|
||||||
image = self._rotate(image, direction)
|
image = self._rotate(image, direction)
|
||||||
image = self._extend(image, pixels)
|
image = self._extend(image, pixels)
|
||||||
image = self._rotate(image, direction, reverse=True)
|
image = self._rotate(image, direction, reverse=True)
|
||||||
|
|||||||
@@ -6,13 +6,18 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
class ESRGAN:
|
class ESRGAN:
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
self.bg_tile_size = bg_tile_size
|
self.bg_tile_size = bg_tile_size
|
||||||
|
|
||||||
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
|
use_half_precision = False
|
||||||
|
else:
|
||||||
|
use_half_precision = True
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self, denoise_str):
|
def load_esrgan_bg_upsampler(self, denoise_str):
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
use_half_precision = False
|
use_half_precision = False
|
||||||
@@ -69,16 +74,16 @@ class ESRGAN:
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error("Error loading Real-ESRGAN:")
|
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
if upsampler_scale == 0:
|
if upsampler_scale == 0:
|
||||||
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
||||||
return image
|
return image
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
logger.info(
|
print(
|
||||||
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
||||||
)
|
)
|
||||||
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from PIL import Image, ImageFilter
|
|||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .globals import global_cache_dir
|
from .globals import global_cache_dir
|
||||||
from .util import CPU_DEVICE
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
@@ -41,8 +40,8 @@ class SafetyChecker(object):
|
|||||||
cache_dir=safety_model_path,
|
cache_dir=safety_model_path,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
print(
|
||||||
"An error was encountered while installing the safety checker:"
|
"** An error was encountered while installing the safety checker:"
|
||||||
)
|
)
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
@@ -66,8 +65,8 @@ class SafetyChecker(object):
|
|||||||
)
|
)
|
||||||
self.safety_checker.to(CPU_DEVICE) # offload
|
self.safety_checker.to(CPU_DEVICE) # offload
|
||||||
if has_nsfw_concept[0]:
|
if has_nsfw_concept[0]:
|
||||||
logger.warning(
|
print(
|
||||||
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||||
)
|
)
|
||||||
return self.blur(image)
|
return self.blur(image)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from huggingface_hub import (
|
|||||||
hf_hub_url,
|
hf_hub_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
@@ -58,7 +57,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
self.concept_list.extend(list(local_concepts_to_add))
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
elif Globals.internet_available is True:
|
else:
|
||||||
try:
|
try:
|
||||||
models = self.hf_api.list_models(
|
models = self.hf_api.list_models(
|
||||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||||
@@ -67,15 +66,13 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||||
self.concept_list.extend(list(local_concepts_to_add))
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
print(
|
||||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||||
)
|
)
|
||||||
logger.warning(
|
print(
|
||||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||||
)
|
)
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
else:
|
|
||||||
return self.concept_list
|
|
||||||
|
|
||||||
def get_concept_model_path(self, concept_name: str) -> str:
|
def get_concept_model_path(self, concept_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -84,7 +81,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
be downloaded.
|
be downloaded.
|
||||||
"""
|
"""
|
||||||
if not concept_name in self.list_concepts():
|
if not concept_name in self.list_concepts():
|
||||||
logger.warning(
|
print(
|
||||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@@ -222,7 +219,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
if chunk == 0:
|
if chunk == 0:
|
||||||
bytes += total
|
bytes += total
|
||||||
|
|
||||||
logger.info(f"Downloading {repo_id}...", end="")
|
print(f">> Downloading {repo_id}...", end="")
|
||||||
try:
|
try:
|
||||||
for file in (
|
for file in (
|
||||||
"README.md",
|
"README.md",
|
||||||
@@ -236,22 +233,22 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
)
|
)
|
||||||
except ul_error.HTTPError as e:
|
except ul_error.HTTPError as e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
logger.warning(
|
print(
|
||||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
print(
|
||||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||||
)
|
)
|
||||||
os.rmdir(dest)
|
os.rmdir(dest)
|
||||||
return False
|
return False
|
||||||
except ul_error.URLError as e:
|
except ul_error.URLError as e:
|
||||||
logger.error(
|
print(
|
||||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||||
)
|
)
|
||||||
os.rmdir(dest)
|
os.rmdir(dest)
|
||||||
return False
|
return False
|
||||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
print("...{:.2f}Kb".format(bytes / 1024))
|
||||||
return succeeded
|
return succeeded
|
||||||
|
|
||||||
def _concept_id(self, concept_name: str) -> str:
|
def _concept_id(self, concept_name: str) -> str:
|
||||||
|
|||||||
@@ -445,15 +445,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
@property
|
@property
|
||||||
def _submodels(self) -> Sequence[torch.nn.Module]:
|
def _submodels(self) -> Sequence[torch.nn.Module]:
|
||||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||||
submodels = []
|
values = [getattr(self, name) for name in module_names.keys()]
|
||||||
for name in module_names.keys():
|
return [m for m in values if isinstance(m, torch.nn.Module)]
|
||||||
if hasattr(self, name):
|
|
||||||
value = getattr(self, name)
|
|
||||||
else:
|
|
||||||
value = getattr(self.config, name)
|
|
||||||
if isinstance(value, torch.nn.Module):
|
|
||||||
submodels.append(value)
|
|
||||||
return submodels
|
|
||||||
|
|
||||||
def image_from_embeddings(
|
def image_from_embeddings(
|
||||||
self,
|
self,
|
||||||
@@ -538,7 +531,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
run_id: str = None,
|
run_id: str = None,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
# FIXME: do we still use any slicing now that PyTorch 2.0 has scaled dot-product attention on all platforms?
|
||||||
|
# self._adjust_memory_efficient_attention(latents)
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
@@ -551,7 +545,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
yield PipelineIntermediateState(
|
yield PipelineIntermediateState(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
step=-1,
|
step=-1,
|
||||||
timestep=self.scheduler.config.num_train_timesteps,
|
timestep=self.scheduler.num_train_timesteps,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -922,7 +916,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
@property
|
@property
|
||||||
def channels(self) -> int:
|
def channels(self) -> int:
|
||||||
"""Compatible with DiffusionWrapper"""
|
"""Compatible with DiffusionWrapper"""
|
||||||
return self.unet.config.in_channels
|
return self.unet.in_channels
|
||||||
|
|
||||||
def decode_latents(self, latents):
|
def decode_latents(self, latents):
|
||||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||||
|
|||||||
@@ -10,12 +10,13 @@ import diffusers
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
class CrossAttentionType(enum.Enum):
|
||||||
SELF = 1
|
SELF = 1
|
||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
@@ -187,7 +188,7 @@ class Context:
|
|||||||
|
|
||||||
class InvokeAICrossAttentionMixin:
|
class InvokeAICrossAttentionMixin:
|
||||||
"""
|
"""
|
||||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
|
||||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||||
and dymamic slicing strategy selection.
|
and dymamic slicing strategy selection.
|
||||||
"""
|
"""
|
||||||
@@ -208,7 +209,7 @@ class InvokeAICrossAttentionMixin:
|
|||||||
Set custom attention calculator to be called when attention is calculated
|
Set custom attention calculator to be called when attention is calculated
|
||||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||||
`module` is the current Attention module for which the callback is being invoked.
|
`module` is the current CrossAttention module for which the callback is being invoked.
|
||||||
`suggested_attention_slice` is the default-calculated attention slice
|
`suggested_attention_slice` is the default-calculated attention slice
|
||||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||||
@@ -344,11 +345,11 @@ class InvokeAICrossAttentionMixin:
|
|||||||
def restore_default_cross_attention(
|
def restore_default_cross_attention(
|
||||||
model,
|
model,
|
||||||
is_running_diffusers: bool,
|
is_running_diffusers: bool,
|
||||||
restore_attention_processor: Optional[AttentionProcessor] = None,
|
restore_attention_processor: Optional[AttnProcessor] = None,
|
||||||
):
|
):
|
||||||
if is_running_diffusers:
|
if is_running_diffusers:
|
||||||
unet = model
|
unet = model
|
||||||
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
|
||||||
else:
|
else:
|
||||||
remove_attention_function(model)
|
remove_attention_function(model)
|
||||||
|
|
||||||
@@ -407,9 +408,12 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
|
|||||||
def get_cross_attention_modules(
|
def get_cross_attention_modules(
|
||||||
model, which: CrossAttentionType
|
model, which: CrossAttentionType
|
||||||
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||||
|
from ldm.modules.attention import CrossAttention # avoid circular import
|
||||||
|
|
||||||
cross_attention_class: type = (
|
cross_attention_class: type = (
|
||||||
InvokeAIDiffusersCrossAttention
|
InvokeAIDiffusersCrossAttention
|
||||||
|
if isinstance(model, UNet2DConditionModel)
|
||||||
|
else CrossAttention
|
||||||
)
|
)
|
||||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||||
attention_module_tuples = [
|
attention_module_tuples = [
|
||||||
@@ -421,13 +425,13 @@ def get_cross_attention_modules(
|
|||||||
expected_count = 16
|
expected_count = 16
|
||||||
if cross_attention_modules_in_model_count != expected_count:
|
if cross_attention_modules_in_model_count != expected_count:
|
||||||
# non-fatal error but .swap() won't work.
|
# non-fatal error but .swap() won't work.
|
||||||
logger.error(
|
print(
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
||||||
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
||||||
+ "work properly until it is fixed."
|
+ f"work properly until it is fixed."
|
||||||
)
|
)
|
||||||
return attention_module_tuples
|
return attention_module_tuples
|
||||||
|
|
||||||
@@ -546,7 +550,7 @@ def get_mem_free_total(device):
|
|||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffusersCrossAttention(
|
class InvokeAIDiffusersCrossAttention(
|
||||||
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
|
diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin
|
||||||
):
|
):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -568,8 +572,8 @@ class InvokeAIDiffusersCrossAttention(
|
|||||||
"""
|
"""
|
||||||
# base implementation
|
# base implementation
|
||||||
|
|
||||||
class AttnProcessor:
|
class CrossAttnProcessor:
|
||||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||||
|
|
||||||
@@ -597,9 +601,9 @@ class AttnProcessor:
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.cross_attention import (
|
||||||
Attention,
|
CrossAttention,
|
||||||
AttnProcessor,
|
CrossAttnProcessor,
|
||||||
SlicedAttnProcessor,
|
SlicedAttnProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -649,7 +653,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: CrossAttention,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
|||||||
@@ -5,10 +5,9 @@ from typing import Any, Callable, Dict, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals
|
from invokeai.backend.globals import Globals
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
@@ -102,7 +101,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
def override_cross_attention(
|
def override_cross_attention(
|
||||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||||
) -> Dict[str, AttentionProcessor]:
|
) -> Dict[str, AttnProcessor]:
|
||||||
"""
|
"""
|
||||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||||
the previous attention processor is returned so that the caller can restore it later.
|
the previous attention processor is returned so that the caller can restore it later.
|
||||||
@@ -119,7 +118,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def restore_default_cross_attention(
|
def restore_default_cross_attention(
|
||||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
self, restore_attention_processor: Optional["AttnProcessor"] = None
|
||||||
):
|
):
|
||||||
self.conditioning = None
|
self.conditioning = None
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
@@ -263,7 +262,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
# TODO remove when compvis codepath support is dropped
|
# TODO remove when compvis codepath support is dropped
|
||||||
if step_index is None and sigma is None:
|
if step_index is None and sigma is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
f"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||||
)
|
)
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
return percent_through
|
return percent_through
|
||||||
@@ -467,14 +466,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
outside = torch.count_nonzero(
|
outside = torch.count_nonzero(
|
||||||
(latents < -current_threshold) | (latents > current_threshold)
|
(latents < -current_threshold) | (latents > current_threshold)
|
||||||
)
|
)
|
||||||
logger.info(
|
print(
|
||||||
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
|
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||||
)
|
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||||
logger.debug(
|
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
|
||||||
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if maxval < current_threshold and minval > -current_threshold:
|
if maxval < current_threshold and minval > -current_threshold:
|
||||||
@@ -501,11 +496,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.debug_thresholding:
|
if self.debug_thresholding:
|
||||||
logger.debug(
|
print(
|
||||||
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
|
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||||
)
|
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
|
||||||
logger.debug(
|
|
||||||
f"{num_altered / latents.numel() * 100:.2f}% values altered"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return latents
|
return latents
|
||||||
@@ -606,6 +599,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# below is fugly omg
|
# below is fugly omg
|
||||||
|
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||||
chunk_count = ceil(len(conditionings) / 2)
|
chunk_count = ceil(len(conditionings) / 2)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from torchvision.utils import make_grid
|
|||||||
|
|
||||||
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
|
|
||||||
|
|
||||||
@@ -191,7 +191,7 @@ def mkdirs(paths):
|
|||||||
def mkdir_and_rename(path):
|
def mkdir_and_rename(path):
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
new_name = path + "_archived_" + get_timestamp()
|
new_name = path + "_archived_" + get_timestamp()
|
||||||
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
|
print("Path already exists. Rename it to [{:s}]".format(new_name))
|
||||||
os.replace(path, new_name)
|
os.replace(path, new_name)
|
||||||
os.makedirs(path)
|
os.makedirs(path)
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from compel.embeddings_provider import BaseTextualInversionManager
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -60,12 +59,12 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||||
): # in case a token with literal angle brackets encountered
|
): # in case a token with literal angle brackets encountered
|
||||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
print(f">> Loaded local embedding for trigger {concept_name}")
|
||||||
continue
|
continue
|
||||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
if not bin_file:
|
if not bin_file:
|
||||||
continue
|
continue
|
||||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
print(f">> Loaded remote embedding for trigger {concept_name}")
|
||||||
self.load_textual_inversion(bin_file)
|
self.load_textual_inversion(bin_file)
|
||||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||||
|
|
||||||
@@ -86,8 +85,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||||
for embedding_info in embedding_list:
|
for embedding_info in embedding_list:
|
||||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||||
logger.warning(
|
print(
|
||||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -106,8 +105,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if ckpt_path.name == "learned_embeds.bin"
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
else f"<{ckpt_path.stem}>"
|
else f"<{ckpt_path.stem}>"
|
||||||
)
|
)
|
||||||
logger.info(
|
print(
|
||||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||||
)
|
)
|
||||||
trigger_str = replacement_trigger_str
|
trigger_str = replacement_trigger_str
|
||||||
|
|
||||||
@@ -121,8 +120,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||||
logger.debug(f"The error was {str(e)}")
|
print(f" | The error was {str(e)}")
|
||||||
|
|
||||||
def _add_textual_inversion(
|
def _add_textual_inversion(
|
||||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||||
@@ -134,8 +133,8 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
:return: The token id for the added embedding, either existing or newly-added.
|
:return: The token id for the added embedding, either existing or newly-added.
|
||||||
"""
|
"""
|
||||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||||
logger.warning(
|
print(
|
||||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if not self.full_precision:
|
if not self.full_precision:
|
||||||
@@ -156,11 +155,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if str(e).startswith("Warning"):
|
if str(e).startswith("Warning"):
|
||||||
logger.warning(f"{str(e)}")
|
print(f">> {str(e)}")
|
||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
logger.error(
|
print(
|
||||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -220,16 +219,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
for ti in self.textual_inversions:
|
for ti in self.textual_inversions:
|
||||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||||
if ti.embedding_vector_length > 1:
|
if ti.embedding_vector_length > 1:
|
||||||
logger.info(
|
print(
|
||||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self._inject_tokens_and_assign_embeddings(ti)
|
self._inject_tokens_and_assign_embeddings(ti)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.debug(
|
print(
|
||||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||||
)
|
)
|
||||||
logger.debug(f"The error was {str(e)}")
|
print(f" | The error was {str(e)}")
|
||||||
continue
|
continue
|
||||||
injected_token_ids.append(ti.trigger_token_id)
|
injected_token_ids.append(ti.trigger_token_id)
|
||||||
injected_token_ids.extend(ti.pad_token_ids)
|
injected_token_ids.extend(ti.pad_token_ids)
|
||||||
@@ -307,16 +306,16 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if suffix in [".pt",".ckpt",".bin"]:
|
if suffix in [".pt",".ckpt",".bin"]:
|
||||||
scan_result = scan_file_path(embedding_file)
|
scan_result = scan_file_path(embedding_file)
|
||||||
if scan_result.infected_files > 0:
|
if scan_result.infected_files > 0:
|
||||||
logger.critical(
|
print(
|
||||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
f" ** Security Issues Found in Model: {scan_result.issues_count}"
|
||||||
)
|
)
|
||||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
print(" ** For your safety, InvokeAI will not load this embed.")
|
||||||
return list()
|
return list()
|
||||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||||
else:
|
else:
|
||||||
ckpt = safetensors.torch.load_file(embedding_file)
|
ckpt = safetensors.torch.load_file(embedding_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||||
return list()
|
return list()
|
||||||
|
|
||||||
# try to figure out what kind of embedding file it is and parse accordingly
|
# try to figure out what kind of embedding file it is and parse accordingly
|
||||||
@@ -335,7 +334,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
print(f' | Loading v1 embedding file: {basename}')
|
||||||
|
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
token_counter = -1
|
token_counter = -1
|
||||||
@@ -343,7 +342,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
if token_counter < 0:
|
if token_counter < 0:
|
||||||
trigger = embedding_ckpt["name"]
|
trigger = embedding_ckpt["name"]
|
||||||
elif token_counter == 0:
|
elif token_counter == 0:
|
||||||
trigger = '<basename>'
|
trigger = f'<basename>'
|
||||||
else:
|
else:
|
||||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||||
token_counter += 1
|
token_counter += 1
|
||||||
@@ -366,7 +365,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
This handles embedding .pt file variant #2.
|
This handles embedding .pt file variant #2.
|
||||||
"""
|
"""
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
print(f' | Loading v2 embedding file: {basename}')
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
@@ -385,7 +384,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
)
|
)
|
||||||
embeddings.append(embedding_info)
|
embeddings.append(embedding_info)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
print(f" ** {basename}: Unrecognized embedding format")
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@@ -394,7 +393,7 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||||
"""
|
"""
|
||||||
basename = Path(file_path).stem
|
basename = Path(file_path).stem
|
||||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
print(f' | Loading v3 embedding file: {basename}')
|
||||||
embedding = embedding_ckpt['emb_params']
|
embedding = embedding_ckpt['emb_params']
|
||||||
embedding_info = EmbeddingInfo(
|
embedding_info = EmbeddingInfo(
|
||||||
name = f'<{basename}>',
|
name = f'<{basename}>',
|
||||||
@@ -412,11 +411,11 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
basename = Path(filepath).stem
|
basename = Path(filepath).stem
|
||||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||||
|
|
||||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
print(f' | Loading v4 embedding file: {short_path}')
|
||||||
|
|
||||||
embeddings = list()
|
embeddings = list()
|
||||||
if list(embedding_ckpt.keys()) == 0:
|
if list(embedding_ckpt.keys()) == 0:
|
||||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
print(f" ** Invalid embeddings file: {short_path}")
|
||||||
else:
|
else:
|
||||||
for token,embedding in embedding_ckpt.items():
|
for token,embedding in embedding_ckpt.items():
|
||||||
embedding_info = EmbeddingInfo(
|
embedding_info = EmbeddingInfo(
|
||||||
|
|||||||
@@ -1,109 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
|
||||||
|
|
||||||
"""invokeai.util.logging
|
|
||||||
|
|
||||||
Logging class for InvokeAI that produces console messages that follow
|
|
||||||
the conventions established in InvokeAI 1.X through 2.X.
|
|
||||||
|
|
||||||
|
|
||||||
One way to use it:
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
logger = InvokeAILogger.getLogger(__name__)
|
|
||||||
logger.critical('this is critical')
|
|
||||||
logger.error('this is an error')
|
|
||||||
logger.warning('this is a warning')
|
|
||||||
logger.info('this is info')
|
|
||||||
logger.debug('this is debugging')
|
|
||||||
|
|
||||||
Console messages:
|
|
||||||
### this is critical
|
|
||||||
*** this is an error ***
|
|
||||||
** this is a warning
|
|
||||||
>> this is info
|
|
||||||
| this is debugging
|
|
||||||
|
|
||||||
Another way:
|
|
||||||
import invokeai.backend.util.logging as ialog
|
|
||||||
ialogger.debug('this is a debugging message')
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# module level functions
|
|
||||||
def debug(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def info(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def warning(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def error(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def critical(msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def log(level, msg, *args, **kwargs):
|
|
||||||
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
|
|
||||||
|
|
||||||
def disable(level=logging.CRITICAL):
|
|
||||||
InvokeAILogger.getLogger().disable(level)
|
|
||||||
|
|
||||||
def basicConfig(**kwargs):
|
|
||||||
InvokeAILogger.getLogger().basicConfig(**kwargs)
|
|
||||||
|
|
||||||
def getLogger(name: str=None)->logging.Logger:
|
|
||||||
return InvokeAILogger.getLogger(name)
|
|
||||||
|
|
||||||
class InvokeAILogFormatter(logging.Formatter):
|
|
||||||
'''
|
|
||||||
Repurposed from:
|
|
||||||
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
|
|
||||||
'''
|
|
||||||
crit_fmt = "### %(msg)s"
|
|
||||||
err_fmt = "*** %(msg)s"
|
|
||||||
warn_fmt = "** %(msg)s"
|
|
||||||
info_fmt = ">> %(msg)s"
|
|
||||||
dbg_fmt = " | %(msg)s"
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
|
|
||||||
|
|
||||||
def format(self, record):
|
|
||||||
# Remember the format used when the logging module
|
|
||||||
# was installed (in the event that this formatter is
|
|
||||||
# used with the vanilla logging module.
|
|
||||||
format_orig = self._style._fmt
|
|
||||||
if record.levelno == logging.DEBUG:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.dbg_fmt
|
|
||||||
if record.levelno == logging.INFO:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.info_fmt
|
|
||||||
if record.levelno == logging.WARNING:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.warn_fmt
|
|
||||||
if record.levelno == logging.ERROR:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.err_fmt
|
|
||||||
if record.levelno == logging.CRITICAL:
|
|
||||||
self._style._fmt = InvokeAILogFormatter.crit_fmt
|
|
||||||
|
|
||||||
# parent class does the work
|
|
||||||
result = super().format(record)
|
|
||||||
self._style._fmt = format_orig
|
|
||||||
return result
|
|
||||||
|
|
||||||
class InvokeAILogger(object):
|
|
||||||
loggers = dict()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def getLogger(self, name:str='invokeai')->logging.Logger:
|
|
||||||
if name not in self.loggers:
|
|
||||||
logger = logging.getLogger(name)
|
|
||||||
logger.setLevel(logging.DEBUG)
|
|
||||||
ch = logging.StreamHandler()
|
|
||||||
fmt = InvokeAILogFormatter()
|
|
||||||
ch.setFormatter(fmt)
|
|
||||||
logger.addHandler(ch)
|
|
||||||
self.loggers[name] = logger
|
|
||||||
return self.loggers[name]
|
|
||||||
@@ -18,7 +18,6 @@ import torch
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from .devices import torch_dtype
|
from .devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +38,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill="black", font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
logger.warning("Cant encode string for logging. Skipping.")
|
print("Cant encode string for logging. Skipping.")
|
||||||
|
|
||||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
@@ -81,8 +80,8 @@ def mean_flat(tensor):
|
|||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.debug(
|
print(
|
||||||
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
|
||||||
)
|
)
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
@@ -133,8 +132,8 @@ def parallel_data_prefetch(
|
|||||||
raise ValueError("list expected but function got ndarray.")
|
raise ValueError("list expected but function got ndarray.")
|
||||||
elif isinstance(data, abc.Iterable):
|
elif isinstance(data, abc.Iterable):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
logger.warning(
|
print(
|
||||||
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||||
)
|
)
|
||||||
data = list(data.values())
|
data = list(data.values())
|
||||||
if target_data_type == "ndarray":
|
if target_data_type == "ndarray":
|
||||||
@@ -176,7 +175,7 @@ def parallel_data_prefetch(
|
|||||||
processes += [p]
|
processes += [p]
|
||||||
|
|
||||||
# start processes
|
# start processes
|
||||||
logger.info("Start prefetching...")
|
print("Start prefetching...")
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -195,7 +194,7 @@ def parallel_data_prefetch(
|
|||||||
gather_res[res[0]] = res[1]
|
gather_res[res[0]] = res[1]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Exception: ", e)
|
print("Exception: ", e)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.terminate()
|
p.terminate()
|
||||||
|
|
||||||
@@ -203,7 +202,7 @@ def parallel_data_prefetch(
|
|||||||
finally:
|
finally:
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||||
|
|
||||||
if target_data_type == "ndarray":
|
if target_data_type == "ndarray":
|
||||||
if not isinstance(gather_res[0], np.ndarray):
|
if not isinstance(gather_res[0], np.ndarray):
|
||||||
@@ -319,23 +318,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
resp = requests.get(url, headers=header, stream=True) # new request with range
|
resp = requests.get(url, headers=header, stream=True) # new request with range
|
||||||
|
|
||||||
if exist_size > content_length:
|
if exist_size > content_length:
|
||||||
logger.warning("corrupt existing file found. re-downloading")
|
print("* corrupt existing file found. re-downloading")
|
||||||
os.remove(dest)
|
os.remove(dest)
|
||||||
exist_size = 0
|
exist_size = 0
|
||||||
|
|
||||||
if resp.status_code == 416 or exist_size == content_length:
|
if resp.status_code == 416 or exist_size == content_length:
|
||||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
print(f"* {dest}: complete file found. Skipping.")
|
||||||
return dest
|
return dest
|
||||||
elif resp.status_code == 206 or exist_size > 0:
|
elif resp.status_code == 206 or exist_size > 0:
|
||||||
logger.warning(f"{dest}: partial file found. Resuming...")
|
print(f"* {dest}: partial file found. Resuming...")
|
||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
print(f"** An error occurred during downloading {dest}: {resp.reason}")
|
||||||
else:
|
else:
|
||||||
logger.error(f"{dest}: Downloading...")
|
print(f"* {dest}: Downloading...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if content_length < 2000:
|
if content_length < 2000:
|
||||||
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with open(dest, open_mode) as file, tqdm(
|
with open(dest, open_mode) as file, tqdm(
|
||||||
@@ -350,7 +349,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
print(f"An error occurred while downloading {dest}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return dest
|
return dest
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from PIL import Image
|
|||||||
from PIL.Image import Image as ImageType
|
from PIL.Image import Image as ImageType
|
||||||
from werkzeug.utils import secure_filename
|
from werkzeug.utils import secure_filename
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
import invokeai.frontend.web.dist as frontend
|
import invokeai.frontend.web.dist as frontend
|
||||||
|
|
||||||
from .. import Generate
|
from .. import Generate
|
||||||
@@ -78,6 +77,7 @@ class InvokeAIWebServer:
|
|||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type("text/css", ".css")
|
||||||
# Socket IO
|
# Socket IO
|
||||||
|
logger = True if args.web_verbose else False
|
||||||
engineio_logger = True if args.web_verbose else False
|
engineio_logger = True if args.web_verbose else False
|
||||||
max_http_buffer_size = 10000000
|
max_http_buffer_size = 10000000
|
||||||
|
|
||||||
@@ -213,7 +213,7 @@ class InvokeAIWebServer:
|
|||||||
self.load_socketio_listeners(self.socketio)
|
self.load_socketio_listeners(self.socketio)
|
||||||
|
|
||||||
if args.gui:
|
if args.gui:
|
||||||
logger.info("Launching Invoke AI GUI")
|
print(">> Launching Invoke AI GUI")
|
||||||
try:
|
try:
|
||||||
from flaskwebgui import FlaskUI
|
from flaskwebgui import FlaskUI
|
||||||
|
|
||||||
@@ -231,17 +231,17 @@ class InvokeAIWebServer:
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
useSSL = args.certfile or args.keyfile
|
useSSL = args.certfile or args.keyfile
|
||||||
logger.info("Started Invoke AI Web Server")
|
print(">> Started Invoke AI Web Server")
|
||||||
if self.host == "0.0.0.0":
|
if self.host == "0.0.0.0":
|
||||||
logger.info(
|
print(
|
||||||
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
print(
|
||||||
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
|
||||||
)
|
)
|
||||||
logger.info(
|
print(
|
||||||
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
|
||||||
)
|
)
|
||||||
if not useSSL:
|
if not useSSL:
|
||||||
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
self.socketio.run(app=self.app, host=self.host, port=self.port)
|
||||||
@@ -273,7 +273,7 @@ class InvokeAIWebServer:
|
|||||||
# path for thumbnail images
|
# path for thumbnail images
|
||||||
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
|
||||||
# txt log
|
# txt log
|
||||||
self.log_path = os.path.join(self.result_path, "invoke_logger.txt")
|
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
|
||||||
# make all output paths
|
# make all output paths
|
||||||
[
|
[
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
@@ -290,7 +290,7 @@ class InvokeAIWebServer:
|
|||||||
def load_socketio_listeners(self, socketio):
|
def load_socketio_listeners(self, socketio):
|
||||||
@socketio.on("requestSystemConfig")
|
@socketio.on("requestSystemConfig")
|
||||||
def handle_request_capabilities():
|
def handle_request_capabilities():
|
||||||
logger.info("System config requested")
|
print(">> System config requested")
|
||||||
config = self.get_system_config()
|
config = self.get_system_config()
|
||||||
config["model_list"] = self.generate.model_manager.list_models()
|
config["model_list"] = self.generate.model_manager.list_models()
|
||||||
config["infill_methods"] = infill_methods()
|
config["infill_methods"] = infill_methods()
|
||||||
@@ -330,7 +330,7 @@ class InvokeAIWebServer:
|
|||||||
if model_name in current_model_list:
|
if model_name in current_model_list:
|
||||||
update = True
|
update = True
|
||||||
|
|
||||||
logger.info(f"Adding New Model: {model_name}")
|
print(f">> Adding New Model: {model_name}")
|
||||||
|
|
||||||
self.generate.model_manager.add_model(
|
self.generate.model_manager.add_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -348,14 +348,14 @@ class InvokeAIWebServer:
|
|||||||
"update": update,
|
"update": update,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"New Model Added: {model_name}")
|
print(f">> New Model Added: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@socketio.on("deleteModel")
|
@socketio.on("deleteModel")
|
||||||
def handle_delete_model(model_name: str):
|
def handle_delete_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Deleting Model: {model_name}")
|
print(f">> Deleting Model: {model_name}")
|
||||||
self.generate.model_manager.del_model(model_name)
|
self.generate.model_manager.del_model(model_name)
|
||||||
self.generate.model_manager.commit(opt.conf)
|
self.generate.model_manager.commit(opt.conf)
|
||||||
updated_model_list = self.generate.model_manager.list_models()
|
updated_model_list = self.generate.model_manager.list_models()
|
||||||
@@ -366,14 +366,14 @@ class InvokeAIWebServer:
|
|||||||
"model_list": updated_model_list,
|
"model_list": updated_model_list,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Model Deleted: {model_name}")
|
print(f">> Model Deleted: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@socketio.on("requestModelChange")
|
@socketio.on("requestModelChange")
|
||||||
def handle_set_model(model_name: str):
|
def handle_set_model(model_name: str):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Model change requested: {model_name}")
|
print(f">> Model change requested: {model_name}")
|
||||||
model = self.generate.set_model(model_name)
|
model = self.generate.set_model(model_name)
|
||||||
model_list = self.generate.model_manager.list_models()
|
model_list = self.generate.model_manager.list_models()
|
||||||
if model is None:
|
if model is None:
|
||||||
@@ -454,7 +454,7 @@ class InvokeAIWebServer:
|
|||||||
"update": True,
|
"update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Model Converted: {model_name}")
|
print(f">> Model Converted: {model_name}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@@ -490,7 +490,7 @@ class InvokeAIWebServer:
|
|||||||
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
|
||||||
"vae", None
|
"vae", None
|
||||||
):
|
):
|
||||||
logger.info(f"Using configured VAE assigned to {models_to_merge[0]}")
|
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
|
||||||
merged_model_config.update(vae=vae)
|
merged_model_config.update(vae=vae)
|
||||||
|
|
||||||
self.generate.model_manager.import_diffuser_model(
|
self.generate.model_manager.import_diffuser_model(
|
||||||
@@ -507,8 +507,8 @@ class InvokeAIWebServer:
|
|||||||
"update": True,
|
"update": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Models Merged: {models_to_merge}")
|
print(f">> Models Merged: {models_to_merge}")
|
||||||
logger.info(f"New Model Added: {model_merge_info['merged_model_name']}")
|
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
@@ -698,7 +698,7 @@ class InvokeAIWebServer:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Unable to load {path}")
|
print(f">> Unable to load {path}")
|
||||||
socketio.emit(
|
socketio.emit(
|
||||||
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
"error", {"message": f"Unable to load {path}: {str(e)}"}
|
||||||
)
|
)
|
||||||
@@ -735,9 +735,9 @@ class InvokeAIWebServer:
|
|||||||
printable_parameters["init_mask"][:64] + "..."
|
printable_parameters["init_mask"][:64] + "..."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
|
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
|
||||||
logger.info(f"ESRGAN Parameters: {esrgan_parameters}")
|
print(f">> ESRGAN Parameters: {esrgan_parameters}")
|
||||||
logger.info(f"Facetool Parameters: {facetool_parameters}")
|
print(f">> Facetool Parameters: {facetool_parameters}")
|
||||||
|
|
||||||
self.generate_images(
|
self.generate_images(
|
||||||
generation_parameters,
|
generation_parameters,
|
||||||
@@ -750,8 +750,8 @@ class InvokeAIWebServer:
|
|||||||
@socketio.on("runPostprocessing")
|
@socketio.on("runPostprocessing")
|
||||||
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
def handle_run_postprocessing(original_image, postprocessing_parameters):
|
||||||
try:
|
try:
|
||||||
logger.info(
|
print(
|
||||||
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
|
||||||
)
|
)
|
||||||
|
|
||||||
progress = Progress()
|
progress = Progress()
|
||||||
@@ -861,14 +861,14 @@ class InvokeAIWebServer:
|
|||||||
|
|
||||||
@socketio.on("cancel")
|
@socketio.on("cancel")
|
||||||
def handle_cancel():
|
def handle_cancel():
|
||||||
logger.info("Cancel processing requested")
|
print(">> Cancel processing requested")
|
||||||
self.canceled.set()
|
self.canceled.set()
|
||||||
|
|
||||||
# TODO: I think this needs a safety mechanism.
|
# TODO: I think this needs a safety mechanism.
|
||||||
@socketio.on("deleteImage")
|
@socketio.on("deleteImage")
|
||||||
def handle_delete_image(url, thumbnail, uuid, category):
|
def handle_delete_image(url, thumbnail, uuid, category):
|
||||||
try:
|
try:
|
||||||
logger.info(f'Delete requested "{url}"')
|
print(f'>> Delete requested "{url}"')
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
path = self.get_image_path_from_url(url)
|
path = self.get_image_path_from_url(url)
|
||||||
@@ -1263,7 +1263,7 @@ class InvokeAIWebServer:
|
|||||||
image, os.path.basename(path), self.thumbnail_image_path
|
image, os.path.basename(path), self.thumbnail_image_path
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f'Image generated: "{path}"\n')
|
print(f'\n\n>> Image generated: "{path}"\n')
|
||||||
self.write_log_message(f'[Generated] "{path}": {command}')
|
self.write_log_message(f'[Generated] "{path}": {command}')
|
||||||
|
|
||||||
if progress.total_iterations > progress.current_iteration:
|
if progress.total_iterations > progress.current_iteration:
|
||||||
@@ -1329,7 +1329,7 @@ class InvokeAIWebServer:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Clear the CUDA cache on an exception
|
# Clear the CUDA cache on an exception
|
||||||
self.empty_cuda_cache()
|
self.empty_cuda_cache()
|
||||||
logger.error(e)
|
print(e)
|
||||||
self.handle_exceptions(e)
|
self.handle_exceptions(e)
|
||||||
|
|
||||||
def empty_cuda_cache(self):
|
def empty_cuda_cache(self):
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ if sys.platform == "darwin":
|
|||||||
import pyparsing # type: ignore
|
import pyparsing # type: ignore
|
||||||
|
|
||||||
import invokeai.version as invokeai
|
import invokeai.version as invokeai
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
|
|
||||||
from ...backend import Generate, ModelManager
|
from ...backend import Generate, ModelManager
|
||||||
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
|
||||||
@@ -70,7 +69,7 @@ def main():
|
|||||||
# run any post-install patches needed
|
# run any post-install patches needed
|
||||||
run_patches()
|
run_patches()
|
||||||
|
|
||||||
logger.info(f"Internet connectivity is {Globals.internet_available}")
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
|
|
||||||
if not args.conf:
|
if not args.conf:
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
@@ -79,8 +78,8 @@ def main():
|
|||||||
opt, FileNotFoundError(f"The file {config_file} could not be found.")
|
opt, FileNotFoundError(f"The file {config_file} could not be found.")
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"{invokeai.__app_name__}, version {invokeai.__version__}")
|
print(f">> {invokeai.__app_name__}, version {invokeai.__version__}")
|
||||||
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
|
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||||
|
|
||||||
# loading here to avoid long delays on startup
|
# loading here to avoid long delays on startup
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
@@ -122,7 +121,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"{opt.infile} not found.")
|
raise FileNotFoundError(f"{opt.infile} not found.")
|
||||||
except (FileNotFoundError, IOError) as e:
|
except (FileNotFoundError, IOError) as e:
|
||||||
logger.critical('Aborted',exc_info=True)
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# creating a Generate object:
|
# creating a Generate object:
|
||||||
@@ -143,12 +142,12 @@ def main():
|
|||||||
)
|
)
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
except (IOError, KeyError):
|
except (IOError, KeyError) as e:
|
||||||
logger.critical("Aborted",exc_info=True)
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
if opt.seamless:
|
if opt.seamless:
|
||||||
logger.info("Changed to seamless tiling mode")
|
print(">> changed to seamless tiling mode")
|
||||||
|
|
||||||
# preload the model
|
# preload the model
|
||||||
try:
|
try:
|
||||||
@@ -159,9 +158,14 @@ def main():
|
|||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
|
|
||||||
# try to autoconvert new models
|
# try to autoconvert new models
|
||||||
|
if path := opt.autoimport:
|
||||||
|
gen.model_manager.heuristic_import(
|
||||||
|
str(path), convert=False, commit_to_conf=opt.conf
|
||||||
|
)
|
||||||
|
|
||||||
if path := opt.autoconvert:
|
if path := opt.autoconvert:
|
||||||
gen.model_manager.heuristic_import(
|
gen.model_manager.heuristic_import(
|
||||||
str(path), commit_to_conf=opt.conf
|
str(path), convert=True, commit_to_conf=opt.conf
|
||||||
)
|
)
|
||||||
|
|
||||||
# web server loops forever
|
# web server loops forever
|
||||||
@@ -181,7 +185,9 @@ def main():
|
|||||||
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("An error occurred",exc_info=True)
|
print(">> An error occurred:")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
|
||||||
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
# TODO: main_loop() has gotten busy. Needs to be refactored.
|
||||||
def main_loop(gen, opt):
|
def main_loop(gen, opt):
|
||||||
@@ -247,7 +253,7 @@ def main_loop(gen, opt):
|
|||||||
if not opt.prompt:
|
if not opt.prompt:
|
||||||
oldargs = metadata_from_png(opt.init_img)
|
oldargs = metadata_from_png(opt.init_img)
|
||||||
opt.prompt = oldargs.prompt
|
opt.prompt = oldargs.prompt
|
||||||
logger.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
|
||||||
except (OSError, AttributeError, KeyError):
|
except (OSError, AttributeError, KeyError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -264,9 +270,9 @@ def main_loop(gen, opt):
|
|||||||
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
|
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
|
||||||
try:
|
try:
|
||||||
opt.init_img = last_results[int(opt.init_img)][0]
|
opt.init_img = last_results[int(opt.init_img)][0]
|
||||||
logger.info(f"Reusing previous image {opt.init_img}")
|
print(f">> Reusing previous image {opt.init_img}")
|
||||||
except IndexError:
|
except IndexError:
|
||||||
logger.info(f"No previous initial image at position {opt.init_img} found")
|
print(f">> No previous initial image at position {opt.init_img} found")
|
||||||
opt.init_img = None
|
opt.init_img = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -287,9 +293,9 @@ def main_loop(gen, opt):
|
|||||||
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
|
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
|
||||||
try:
|
try:
|
||||||
opt.seed = last_results[opt.seed][1]
|
opt.seed = last_results[opt.seed][1]
|
||||||
logger.info(f"Reusing previous seed {opt.seed}")
|
print(f">> Reusing previous seed {opt.seed}")
|
||||||
except IndexError:
|
except IndexError:
|
||||||
logger.info(f"No previous seed at position {opt.seed} found")
|
print(f">> No previous seed at position {opt.seed} found")
|
||||||
opt.seed = None
|
opt.seed = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -308,7 +314,7 @@ def main_loop(gen, opt):
|
|||||||
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
|
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
|
||||||
current_outdir = os.path.join(opt.outdir, subdir)
|
current_outdir = os.path.join(opt.outdir, subdir)
|
||||||
|
|
||||||
logger.info('Writing files to directory: "' + current_outdir + '"')
|
print('Writing files to directory: "' + current_outdir + '"')
|
||||||
|
|
||||||
# make sure the output directory exists
|
# make sure the output directory exists
|
||||||
if not os.path.exists(current_outdir):
|
if not os.path.exists(current_outdir):
|
||||||
@@ -437,14 +443,15 @@ def main_loop(gen, opt):
|
|||||||
catch_interrupts=catch_ctrl_c,
|
catch_interrupts=catch_ctrl_c,
|
||||||
**vars(opt),
|
**vars(opt),
|
||||||
)
|
)
|
||||||
except (PromptParser.ParsingException, pyparsing.ParseException):
|
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
|
||||||
logger.error("An error occurred while processing your prompt",exc_info=True)
|
print("** An error occurred while processing your prompt **")
|
||||||
|
print(f"** {str(e)} **")
|
||||||
elif operation == "postprocess":
|
elif operation == "postprocess":
|
||||||
logger.info(f"fixing {opt.prompt}")
|
print(f">> fixing {opt.prompt}")
|
||||||
opt.last_operation = do_postprocess(gen, opt, image_writer)
|
opt.last_operation = do_postprocess(gen, opt, image_writer)
|
||||||
|
|
||||||
elif operation == "mask":
|
elif operation == "mask":
|
||||||
logger.info(f"generating masks from {opt.prompt}")
|
print(f">> generating masks from {opt.prompt}")
|
||||||
do_textmask(gen, opt, image_writer)
|
do_textmask(gen, opt, image_writer)
|
||||||
|
|
||||||
if opt.grid and len(grid_images) > 0:
|
if opt.grid and len(grid_images) > 0:
|
||||||
@@ -467,12 +474,12 @@ def main_loop(gen, opt):
|
|||||||
)
|
)
|
||||||
results = [[path, formatted_dream_prompt]]
|
results = [[path, formatted_dream_prompt]]
|
||||||
|
|
||||||
except AssertionError:
|
except AssertionError as e:
|
||||||
logger.error(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.error(e)
|
print(e)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
@@ -511,7 +518,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
gen.set_model(model_name)
|
gen.set_model(model_name)
|
||||||
add_embedding_terms(gen, completer)
|
add_embedding_terms(gen, completer)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(e)
|
print(str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
report_model_error(opt, e)
|
report_model_error(opt, e)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@@ -525,8 +532,8 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!import"):
|
elif command.startswith("!import"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
logger.warning(
|
print(
|
||||||
"please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
"** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@@ -539,7 +546,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith(("!convert", "!optimize")):
|
elif command.startswith(("!convert", "!optimize")):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
logger.warning("please provide the path to a .ckpt or .safetensors model")
|
print("** please provide the path to a .ckpt or .safetensors model")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
convert_model(path[1], gen, opt, completer)
|
convert_model(path[1], gen, opt, completer)
|
||||||
@@ -551,7 +558,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!edit"):
|
elif command.startswith("!edit"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
logger.warning("please provide the name of a model")
|
print("** please provide the name of a model")
|
||||||
else:
|
else:
|
||||||
edit_model(path[1], gen, opt, completer)
|
edit_model(path[1], gen, opt, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@@ -560,7 +567,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
elif command.startswith("!del"):
|
elif command.startswith("!del"):
|
||||||
path = shlex.split(command)
|
path = shlex.split(command)
|
||||||
if len(path) < 2:
|
if len(path) < 2:
|
||||||
logger.warning("please provide the name of a model")
|
print("** please provide the name of a model")
|
||||||
else:
|
else:
|
||||||
del_config(path[1], gen, opt, completer)
|
del_config(path[1], gen, opt, completer)
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@@ -574,7 +581,6 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
|
|||||||
|
|
||||||
elif command.startswith("!replay"):
|
elif command.startswith("!replay"):
|
||||||
file_path = command.replace("!replay", "", 1).strip()
|
file_path = command.replace("!replay", "", 1).strip()
|
||||||
file_path = os.path.join(opt.outdir, file_path)
|
|
||||||
if infile is None and os.path.isfile(file_path):
|
if infile is None and os.path.isfile(file_path):
|
||||||
infile = open(file_path, "r", encoding="utf-8")
|
infile = open(file_path, "r", encoding="utf-8")
|
||||||
completer.add_history(command)
|
completer.add_history(command)
|
||||||
@@ -640,8 +646,8 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
try:
|
try:
|
||||||
default_name = url_attachment_name(model_path)
|
default_name = url_attachment_name(model_path)
|
||||||
default_name = Path(default_name).stem
|
default_name = Path(default_name).stem
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True)
|
print(f"** URL: {str(e)}")
|
||||||
model_name, model_desc = _get_model_name_and_desc(
|
model_name, model_desc = _get_model_name_and_desc(
|
||||||
gen.model_manager,
|
gen.model_manager,
|
||||||
completer,
|
completer,
|
||||||
@@ -662,11 +668,11 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
model_config_file=config_file,
|
model_config_file=config_file,
|
||||||
)
|
)
|
||||||
if not imported_name:
|
if not imported_name:
|
||||||
logger.error("Aborting import.")
|
print("** Aborting import.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not _verify_load(imported_name, gen):
|
if not _verify_load(imported_name, gen):
|
||||||
logger.error("model failed to load. Discarding configuration entry")
|
print("** model failed to load. Discarding configuration entry")
|
||||||
gen.model_manager.del_model(imported_name)
|
gen.model_manager.del_model(imported_name)
|
||||||
return
|
return
|
||||||
if click.confirm("Make this the default model?", default=False):
|
if click.confirm("Make this the default model?", default=False):
|
||||||
@@ -674,7 +680,7 @@ def import_model(model_path: str, gen, opt, completer):
|
|||||||
|
|
||||||
gen.model_manager.commit(opt.conf)
|
gen.model_manager.commit(opt.conf)
|
||||||
completer.update_models(gen.model_manager.list_models())
|
completer.update_models(gen.model_manager.list_models())
|
||||||
logger.info(f"{imported_name} successfully installed")
|
print(f">> {imported_name} successfully installed")
|
||||||
|
|
||||||
def _pick_configuration_file(completer)->Path:
|
def _pick_configuration_file(completer)->Path:
|
||||||
print(
|
print(
|
||||||
@@ -718,21 +724,21 @@ Please select the type of this model:
|
|||||||
return choice
|
return choice
|
||||||
|
|
||||||
def _verify_load(model_name: str, gen) -> bool:
|
def _verify_load(model_name: str, gen) -> bool:
|
||||||
logger.info("Verifying that new model loads...")
|
print(">> Verifying that new model loads...")
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
try:
|
try:
|
||||||
if not gen.set_model(model_name):
|
if not gen.set_model(model_name):
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"model failed to load: {str(e)}")
|
print(f"** model failed to load: {str(e)}")
|
||||||
logger.warning(
|
print(
|
||||||
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
if click.confirm("Keep model loaded?", default=True):
|
if click.confirm("Keep model loaded?", default=True):
|
||||||
gen.set_model(model_name)
|
gen.set_model(model_name)
|
||||||
else:
|
else:
|
||||||
logger.info("Restoring previous model")
|
print(">> Restoring previous model")
|
||||||
gen.set_model(current_model)
|
gen.set_model(current_model)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -755,7 +761,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
ckpt_path = None
|
ckpt_path = None
|
||||||
original_config_file = None
|
original_config_file = None
|
||||||
if model_name_or_path == gen.model_name:
|
if model_name_or_path == gen.model_name:
|
||||||
logger.warning("Can't convert the active model. !switch to another model first. **")
|
print("** Can't convert the active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
elif model_info := manager.model_info(model_name_or_path):
|
elif model_info := manager.model_info(model_name_or_path):
|
||||||
if "weights" in model_info:
|
if "weights" in model_info:
|
||||||
@@ -765,7 +771,7 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
model_description = model_info["description"]
|
model_description = model_info["description"]
|
||||||
vae_path = model_info.get("vae")
|
vae_path = model_info.get("vae")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"{model_name_or_path} is not a legacy .ckpt weights file")
|
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
|
||||||
return
|
return
|
||||||
model_name = manager.convert_and_import(
|
model_name = manager.convert_and_import(
|
||||||
ckpt_path,
|
ckpt_path,
|
||||||
@@ -786,16 +792,16 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
|
|||||||
manager.commit(opt.conf)
|
manager.commit(opt.conf)
|
||||||
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
|
||||||
ckpt_path.unlink(missing_ok=True)
|
ckpt_path.unlink(missing_ok=True)
|
||||||
logger.warning(f"{ckpt_path} deleted")
|
print(f"{ckpt_path} deleted")
|
||||||
|
|
||||||
|
|
||||||
def del_config(model_name: str, gen, opt, completer):
|
def del_config(model_name: str, gen, opt, completer):
|
||||||
current_model = gen.model_name
|
current_model = gen.model_name
|
||||||
if model_name == current_model:
|
if model_name == current_model:
|
||||||
logger.warning("Can't delete active model. !switch to another model first. **")
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
if model_name not in gen.model_manager.config:
|
if model_name not in gen.model_manager.config:
|
||||||
logger.warning(f"Unknown model {model_name}")
|
print(f"** Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not click.confirm(
|
if not click.confirm(
|
||||||
@@ -808,17 +814,17 @@ def del_config(model_name: str, gen, opt, completer):
|
|||||||
)
|
)
|
||||||
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
gen.model_manager.del_model(model_name, delete_files=delete_completely)
|
||||||
gen.model_manager.commit(opt.conf)
|
gen.model_manager.commit(opt.conf)
|
||||||
logger.warning(f"{model_name} deleted")
|
print(f"** {model_name} deleted")
|
||||||
completer.update_models(gen.model_manager.list_models())
|
completer.update_models(gen.model_manager.list_models())
|
||||||
|
|
||||||
|
|
||||||
def edit_model(model_name: str, gen, opt, completer):
|
def edit_model(model_name: str, gen, opt, completer):
|
||||||
manager = gen.model_manager
|
manager = gen.model_manager
|
||||||
if not (info := manager.model_info(model_name)):
|
if not (info := manager.model_info(model_name)):
|
||||||
logger.warning(f"** Unknown model {model_name}")
|
print(f"** Unknown model {model_name}")
|
||||||
return
|
return
|
||||||
print()
|
|
||||||
logger.info(f"Editing model {model_name} from configuration file {opt.conf}")
|
print(f"\n>> Editing model {model_name} from configuration file {opt.conf}")
|
||||||
new_name = _get_model_name(manager.list_models(), completer, model_name)
|
new_name = _get_model_name(manager.list_models(), completer, model_name)
|
||||||
|
|
||||||
for attribute in info.keys():
|
for attribute in info.keys():
|
||||||
@@ -856,7 +862,7 @@ def edit_model(model_name: str, gen, opt, completer):
|
|||||||
manager.set_default_model(new_name)
|
manager.set_default_model(new_name)
|
||||||
manager.commit(opt.conf)
|
manager.commit(opt.conf)
|
||||||
completer.update_models(manager.list_models())
|
completer.update_models(manager.list_models())
|
||||||
logger.info("Model successfully updated")
|
print(">> Model successfully updated")
|
||||||
|
|
||||||
|
|
||||||
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
||||||
@@ -867,11 +873,11 @@ def _get_model_name(existing_names, completer, default_name: str = "") -> str:
|
|||||||
if len(model_name) == 0:
|
if len(model_name) == 0:
|
||||||
model_name = default_name
|
model_name = default_name
|
||||||
if not re.match("^[\w._+:/-]+$", model_name):
|
if not re.match("^[\w._+:/-]+$", model_name):
|
||||||
logger.warning(
|
print(
|
||||||
'model name must contain only words, digits and the characters "._+:/-" **'
|
'** model name must contain only words, digits and the characters "._+:/-" **'
|
||||||
)
|
)
|
||||||
elif model_name != default_name and model_name in existing_names:
|
elif model_name != default_name and model_name in existing_names:
|
||||||
logger.warning(f"the name {model_name} is already in use. Pick another.")
|
print(f"** the name {model_name} is already in use. Pick another.")
|
||||||
else:
|
else:
|
||||||
done = True
|
done = True
|
||||||
return model_name
|
return model_name
|
||||||
@@ -938,10 +944,11 @@ def do_postprocess(gen, opt, callback):
|
|||||||
opt=opt,
|
opt=opt,
|
||||||
)
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
logger.error(f"{file_path}: file could not be read",exc_info=True)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
print(f"** {file_path}: file could not be read")
|
||||||
return
|
return
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
logger.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
return
|
return
|
||||||
return opt.last_operation
|
return opt.last_operation
|
||||||
|
|
||||||
@@ -996,13 +1003,13 @@ def prepare_image_metadata(
|
|||||||
try:
|
try:
|
||||||
filename = opt.fnformat.format(**wildcards)
|
filename = opt.fnformat.format(**wildcards)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(
|
print(
|
||||||
f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
f"** The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
|
||||||
)
|
)
|
||||||
filename = f"{prefix}.{seed}.png"
|
filename = f"{prefix}.{seed}.png"
|
||||||
except IndexError:
|
except IndexError:
|
||||||
logger.error(
|
print(
|
||||||
"The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
"** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
|
||||||
)
|
)
|
||||||
filename = f"{prefix}.{seed}.png"
|
filename = f"{prefix}.{seed}.png"
|
||||||
|
|
||||||
@@ -1091,14 +1098,14 @@ def split_variations(variations_string) -> list:
|
|||||||
for part in variations_string.split(","):
|
for part in variations_string.split(","):
|
||||||
seed_and_weight = part.split(":")
|
seed_and_weight = part.split(":")
|
||||||
if len(seed_and_weight) != 2:
|
if len(seed_and_weight) != 2:
|
||||||
logger.warning(f'Could not parse with_variation part "{part}"')
|
print(f'** Could not parse with_variation part "{part}"')
|
||||||
broken = True
|
broken = True
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
seed = int(seed_and_weight[0])
|
seed = int(seed_and_weight[0])
|
||||||
weight = float(seed_and_weight[1])
|
weight = float(seed_and_weight[1])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f'Could not parse with_variation part "{part}"')
|
print(f'** Could not parse with_variation part "{part}"')
|
||||||
broken = True
|
broken = True
|
||||||
break
|
break
|
||||||
parts.append([seed, weight])
|
parts.append([seed, weight])
|
||||||
@@ -1122,23 +1129,23 @@ def load_face_restoration(opt):
|
|||||||
opt.gfpgan_model_path
|
opt.gfpgan_model_path
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Face restoration disabled")
|
print(">> Face restoration disabled")
|
||||||
if opt.esrgan:
|
if opt.esrgan:
|
||||||
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
|
||||||
else:
|
else:
|
||||||
logger.info("Upscaling disabled")
|
print(">> Upscaling disabled")
|
||||||
else:
|
else:
|
||||||
logger.info("Face restoration and upscaling disabled")
|
print(">> Face restoration and upscaling disabled")
|
||||||
except (ModuleNotFoundError, ImportError):
|
except (ModuleNotFoundError, ImportError):
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||||
return gfpgan, codeformer, esrgan
|
return gfpgan, codeformer, esrgan
|
||||||
|
|
||||||
|
|
||||||
def make_step_callback(gen, opt, prefix):
|
def make_step_callback(gen, opt, prefix):
|
||||||
destination = os.path.join(opt.outdir, "intermediates", prefix)
|
destination = os.path.join(opt.outdir, "intermediates", prefix)
|
||||||
os.makedirs(destination, exist_ok=True)
|
os.makedirs(destination, exist_ok=True)
|
||||||
logger.info(f"Intermediate images will be written into {destination}")
|
print(f">> Intermediate images will be written into {destination}")
|
||||||
|
|
||||||
def callback(state: PipelineIntermediateState):
|
def callback(state: PipelineIntermediateState):
|
||||||
latents = state.latents
|
latents = state.latents
|
||||||
@@ -1180,20 +1187,21 @@ def retrieve_dream_command(opt, command, completer):
|
|||||||
try:
|
try:
|
||||||
cmd = dream_cmd_from_png(path)
|
cmd = dream_cmd_from_png(path)
|
||||||
except OSError:
|
except OSError:
|
||||||
logger.error(f"{tokens[0]}: file could not be read")
|
print(f"## {tokens[0]}: file could not be read")
|
||||||
except (KeyError, AttributeError, IndexError):
|
except (KeyError, AttributeError, IndexError):
|
||||||
logger.error(f"{tokens[0]}: file has no metadata")
|
print(f"## {tokens[0]}: file has no metadata")
|
||||||
except:
|
except:
|
||||||
logger.error(f"{tokens[0]}: file could not be processed")
|
print(f"## {tokens[0]}: file could not be processed")
|
||||||
if len(cmd) > 0:
|
if len(cmd) > 0:
|
||||||
completer.set_line(cmd)
|
completer.set_line(cmd)
|
||||||
|
|
||||||
|
|
||||||
def write_commands(opt, file_path: str, outfilepath: str):
|
def write_commands(opt, file_path: str, outfilepath: str):
|
||||||
dir, basename = os.path.split(file_path)
|
dir, basename = os.path.split(file_path)
|
||||||
try:
|
try:
|
||||||
paths = sorted(list(Path(dir).glob(basename)))
|
paths = sorted(list(Path(dir).glob(basename)))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.error(f'"{basename}": unacceptable pattern')
|
print(f'## "{basename}": unacceptable pattern')
|
||||||
return
|
return
|
||||||
|
|
||||||
commands = []
|
commands = []
|
||||||
@@ -1202,9 +1210,9 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
|||||||
try:
|
try:
|
||||||
cmd = dream_cmd_from_png(path)
|
cmd = dream_cmd_from_png(path)
|
||||||
except (KeyError, AttributeError, IndexError):
|
except (KeyError, AttributeError, IndexError):
|
||||||
logger.error(f"{path}: file has no metadata")
|
print(f"## {path}: file has no metadata")
|
||||||
except:
|
except:
|
||||||
logger.error(f"{path}: file could not be processed")
|
print(f"## {path}: file could not be processed")
|
||||||
if cmd:
|
if cmd:
|
||||||
commands.append(f"# {path}")
|
commands.append(f"# {path}")
|
||||||
commands.append(cmd)
|
commands.append(cmd)
|
||||||
@@ -1214,18 +1222,18 @@ def write_commands(opt, file_path: str, outfilepath: str):
|
|||||||
outfilepath = os.path.join(opt.outdir, basename)
|
outfilepath = os.path.join(opt.outdir, basename)
|
||||||
with open(outfilepath, "w", encoding="utf-8") as f:
|
with open(outfilepath, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(commands))
|
f.write("\n".join(commands))
|
||||||
logger.info(f"File {outfilepath} with commands created")
|
print(f">> File {outfilepath} with commands created")
|
||||||
|
|
||||||
|
|
||||||
def report_model_error(opt: Namespace, e: Exception):
|
def report_model_error(opt: Namespace, e: Exception):
|
||||||
logger.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"')
|
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
|
||||||
logger.warning(
|
print(
|
||||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||||
)
|
)
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||||
if yes_to_all:
|
if yes_to_all:
|
||||||
logger.warning(
|
print(
|
||||||
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if not click.confirm(
|
if not click.confirm(
|
||||||
@@ -1234,7 +1242,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("invokeai-configure is launching....\n")
|
print("invokeai-configure is launching....\n")
|
||||||
|
|
||||||
# Match arguments that were set on the CLI
|
# Match arguments that were set on the CLI
|
||||||
# only the arguments accepted by the configuration script are parsed
|
# only the arguments accepted by the configuration script are parsed
|
||||||
@@ -1251,7 +1259,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
from ..install import invokeai_configure
|
from ..install import invokeai_configure
|
||||||
|
|
||||||
invokeai_configure()
|
invokeai_configure()
|
||||||
logger.warning("InvokeAI will now restart")
|
print("** InvokeAI will now restart")
|
||||||
sys.argv = previous_args
|
sys.argv = previous_args
|
||||||
main() # would rather do a os.exec(), but doesn't exist?
|
main() # would rather do a os.exec(), but doesn't exist?
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
'''
|
"""
|
||||||
Minimalist updater script. Prompts user for the tag or branch to update to and runs
|
Minimalist updater script. Prompts user for the tag or branch to update to and runs
|
||||||
pip install <path_to_git_source>.
|
pip install <path_to_git_source>.
|
||||||
'''
|
"""
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from rich import box, print
|
from rich import box, print
|
||||||
from rich.console import Console, Group, group
|
from rich.console import Console, Group, group
|
||||||
@@ -15,10 +16,8 @@ from rich.text import Text
|
|||||||
|
|
||||||
from invokeai.version import __version__
|
from invokeai.version import __version__
|
||||||
|
|
||||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
|
INVOKE_AI_SRC = "https://github.com/invoke-ai/InvokeAI/archive"
|
||||||
INVOKE_AI_TAG="https://github.com/invoke-ai/InvokeAI/archive/refs/tags"
|
INVOKE_AI_REL = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||||
INVOKE_AI_BRANCH="https://github.com/invoke-ai/InvokeAI/archive/refs/heads"
|
|
||||||
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
|
||||||
|
|
||||||
OS = platform.uname().system
|
OS = platform.uname().system
|
||||||
ARCH = platform.uname().machine
|
ARCH = platform.uname().machine
|
||||||
@@ -29,22 +28,22 @@ if OS == "Windows":
|
|||||||
else:
|
else:
|
||||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||||
|
|
||||||
def get_versions()->dict:
|
|
||||||
|
def get_versions() -> dict:
|
||||||
return requests.get(url=INVOKE_AI_REL).json()
|
return requests.get(url=INVOKE_AI_REL).json()
|
||||||
|
|
||||||
def welcome(versions: dict):
|
|
||||||
|
|
||||||
|
def welcome(versions: dict):
|
||||||
@group()
|
@group()
|
||||||
def text():
|
def text():
|
||||||
yield f'InvokeAI Version: [bold yellow]{__version__}'
|
yield f"InvokeAI Version: [bold yellow]{__version__}"
|
||||||
yield ''
|
yield ""
|
||||||
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
|
yield "This script will update InvokeAI to the latest release, or to a development version of your choice."
|
||||||
yield ''
|
yield ""
|
||||||
yield '[bold yellow]Options:'
|
yield "[bold yellow]Options:"
|
||||||
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
yield f"""[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
|
||||||
[2] Update to the bleeding-edge development version ([italic]main[/italic])
|
[2] Update to the bleeding-edge development version ([italic]main[/italic])
|
||||||
[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to
|
[3] Manually enter the tag or branch name you wish to update"""
|
||||||
[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to'''
|
|
||||||
|
|
||||||
console.rule()
|
console.rule()
|
||||||
print(
|
print(
|
||||||
@@ -60,41 +59,33 @@ def welcome(versions: dict):
|
|||||||
)
|
)
|
||||||
console.line()
|
console.line()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
versions = get_versions()
|
versions = get_versions()
|
||||||
welcome(versions)
|
welcome(versions)
|
||||||
|
|
||||||
tag = None
|
tag = None
|
||||||
branch = None
|
choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
|
||||||
release = None
|
|
||||||
choice = Prompt.ask('Choice:',choices=['1','2','3','4'],default='1')
|
|
||||||
|
|
||||||
if choice=='1':
|
if choice == "1":
|
||||||
release = versions[0]['tag_name']
|
tag = versions[0]["tag_name"]
|
||||||
elif choice=='2':
|
elif choice == "2":
|
||||||
release = 'main'
|
tag = "main"
|
||||||
elif choice=='3':
|
elif choice == "3":
|
||||||
tag = Prompt.ask('Enter an InvokeAI tag name')
|
tag = Prompt.ask("Enter an InvokeAI tag or branch name")
|
||||||
elif choice=='4':
|
|
||||||
branch = Prompt.ask('Enter an InvokeAI branch name')
|
|
||||||
|
|
||||||
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
|
print(f":crossed_fingers: Upgrading to [yellow]{tag}[/yellow]")
|
||||||
if release:
|
cmd = f"pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517"
|
||||||
cmd = f'pip install {INVOKE_AI_SRC}/{release}.zip --use-pep517 --upgrade'
|
print("")
|
||||||
elif tag:
|
print("")
|
||||||
cmd = f'pip install {INVOKE_AI_TAG}/{tag}.zip --use-pep517 --upgrade'
|
if os.system(cmd) == 0:
|
||||||
|
print(f":heavy_check_mark: Upgrade successful")
|
||||||
else:
|
else:
|
||||||
cmd = f'pip install {INVOKE_AI_BRANCH}/{branch}.zip --use-pep517 --upgrade'
|
print(f":exclamation: [bold red]Upgrade failed[/red bold]")
|
||||||
print('')
|
|
||||||
print('')
|
|
||||||
if os.system(cmd)==0:
|
|
||||||
print(f':heavy_check_mark: Upgrade successful')
|
|
||||||
else:
|
|
||||||
print(f':exclamation: [bold red]Upgrade failed[/red bold]')
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
main()
|
main()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ import torch
|
|||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals, global_config_dir
|
from invokeai.backend.globals import Globals, global_config_dir
|
||||||
|
|
||||||
from ...backend.config.model_install_backend import (
|
from ...backend.config.model_install_backend import (
|
||||||
@@ -200,6 +199,17 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
relx=4,
|
relx=4,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
|
self.nextrely += 1
|
||||||
|
self.convert_models = self.add_widget_intelligent(
|
||||||
|
npyscreen.TitleSelectOne,
|
||||||
|
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||||
|
values=["Keep original format", "Convert to diffusers"],
|
||||||
|
value=0,
|
||||||
|
begin_entry_at=4,
|
||||||
|
max_height=4,
|
||||||
|
hidden=True, # will appear when imported models box is edited
|
||||||
|
scroll_exit=True,
|
||||||
|
)
|
||||||
self.cancel = self.add_widget_intelligent(
|
self.cancel = self.add_widget_intelligent(
|
||||||
npyscreen.ButtonPress,
|
npyscreen.ButtonPress,
|
||||||
name="CANCEL",
|
name="CANCEL",
|
||||||
@@ -234,6 +244,8 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||||
|
|
||||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||||
|
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||||
|
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||||
|
|
||||||
def resize(self):
|
def resize(self):
|
||||||
super().resize()
|
super().resize()
|
||||||
@@ -244,6 +256,13 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
if not self.show_directory_fields.value:
|
if not self.show_directory_fields.value:
|
||||||
self.autoload_directory.value = ""
|
self.autoload_directory.value = ""
|
||||||
|
|
||||||
|
def _show_hide_convert(self):
|
||||||
|
model_paths = self.import_model_paths.value or ""
|
||||||
|
autoload_directory = self.autoload_directory.value or ""
|
||||||
|
self.convert_models.hidden = (
|
||||||
|
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||||
|
)
|
||||||
|
|
||||||
def _get_starter_model_labels(self) -> List[str]:
|
def _get_starter_model_labels(self) -> List[str]:
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
label_width = 25
|
label_width = 25
|
||||||
@@ -303,6 +322,7 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
.scan_directory: Path to a directory of models to scan and import
|
.scan_directory: Path to a directory of models to scan and import
|
||||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||||
|
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||||
"""
|
"""
|
||||||
# we're using a global here rather than storing the result in the parentapp
|
# we're using a global here rather than storing the result in the parentapp
|
||||||
# due to some bug in npyscreen that is causing attributes to be lost
|
# due to some bug in npyscreen that is causing attributes to be lost
|
||||||
@@ -339,6 +359,7 @@ class addModelsForm(npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
# URLs and the like
|
# URLs and the like
|
||||||
selections.import_model_paths = self.import_model_paths.value.split()
|
selections.import_model_paths = self.import_model_paths.value.split()
|
||||||
|
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||||
|
|
||||||
|
|
||||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||||
@@ -351,6 +372,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
|||||||
scan_directory=None,
|
scan_directory=None,
|
||||||
autoscan_on_startup=None,
|
autoscan_on_startup=None,
|
||||||
import_model_paths=None,
|
import_model_paths=None,
|
||||||
|
convert_to_diffusers=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def onStart(self):
|
def onStart(self):
|
||||||
@@ -371,6 +393,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
|||||||
directory_to_scan = selections.scan_directory
|
directory_to_scan = selections.scan_directory
|
||||||
scan_at_startup = selections.autoscan_on_startup
|
scan_at_startup = selections.autoscan_on_startup
|
||||||
potential_models_to_install = selections.import_model_paths
|
potential_models_to_install = selections.import_model_paths
|
||||||
|
convert_to_diffusers = selections.convert_to_diffusers
|
||||||
|
|
||||||
install_requested_models(
|
install_requested_models(
|
||||||
install_initial_models=models_to_install,
|
install_initial_models=models_to_install,
|
||||||
@@ -378,6 +401,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
|||||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||||
external_models=potential_models_to_install,
|
external_models=potential_models_to_install,
|
||||||
scan_at_startup=scan_at_startup,
|
scan_at_startup=scan_at_startup,
|
||||||
|
convert_to_diffusers=convert_to_diffusers,
|
||||||
precision="float32"
|
precision="float32"
|
||||||
if opt.full_precision
|
if opt.full_precision
|
||||||
else choose_precision(torch.device(choose_torch_device())),
|
else choose_precision(torch.device(choose_torch_device())),
|
||||||
@@ -456,8 +480,8 @@ def main():
|
|||||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||||
|
|
||||||
if not global_config_dir().exists():
|
if not global_config_dir().exists():
|
||||||
logger.info(
|
print(
|
||||||
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
|
||||||
)
|
)
|
||||||
from invokeai.frontend.install import invokeai_configure
|
from invokeai.frontend.install import invokeai_configure
|
||||||
|
|
||||||
@@ -467,18 +491,18 @@ def main():
|
|||||||
try:
|
try:
|
||||||
select_and_download_models(opt)
|
select_and_download_models(opt)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
logger.error(e)
|
print(str(e))
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Goodbye! Come back soon.")
|
print("\nGoodbye! Come back soon.")
|
||||||
except widget.NotEnoughSpaceForWidget as e:
|
except widget.NotEnoughSpaceForWidget as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
logger.error(
|
print(
|
||||||
"Insufficient vertical space for the interface. Please make your window taller and try again"
|
"** Insufficient vertical space for the interface. Please make your window taller and try again"
|
||||||
)
|
)
|
||||||
elif str(e).startswith("addwstr"):
|
elif str(e).startswith("addwstr"):
|
||||||
logger.error(
|
print(
|
||||||
"Insufficient horizontal space for the interface. Please make your window wider and try again."
|
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -27,8 +27,6 @@ from ...backend.globals import (
|
|||||||
global_models_dir,
|
global_models_dir,
|
||||||
global_set_root,
|
global_set_root,
|
||||||
)
|
)
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from ...backend.model_management import ModelManager
|
from ...backend.model_management import ModelManager
|
||||||
from ...frontend.install.widgets import FloatTitleSlider
|
from ...frontend.install.widgets import FloatTitleSlider
|
||||||
|
|
||||||
@@ -115,7 +113,7 @@ def merge_diffusion_models_and_commit(
|
|||||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||||
)
|
)
|
||||||
if vae := model_manager.config[models[0]].get("vae", None):
|
if vae := model_manager.config[models[0]].get("vae", None):
|
||||||
logger.info(f"Using configured VAE assigned to {models[0]}")
|
print(f">> Using configured VAE assigned to {models[0]}")
|
||||||
import_args.update(vae=vae)
|
import_args.update(vae=vae)
|
||||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||||
model_manager.commit(config_file)
|
model_manager.commit(config_file)
|
||||||
@@ -393,8 +391,10 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
for name in self.model_manager.model_names()
|
for name in self.model_manager.model_names()
|
||||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||||
]
|
]
|
||||||
|
print(model_names)
|
||||||
return sorted(model_names)
|
return sorted(model_names)
|
||||||
|
|
||||||
|
|
||||||
class Mergeapp(npyscreen.NPSAppManaged):
|
class Mergeapp(npyscreen.NPSAppManaged):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -414,7 +414,7 @@ def run_gui(args: Namespace):
|
|||||||
|
|
||||||
args = mergeapp.merge_arguments
|
args = mergeapp.merge_arguments
|
||||||
merge_diffusion_models_and_commit(**args)
|
merge_diffusion_models_and_commit(**args)
|
||||||
logger.info(f'Models merged into new model: "{args["merged_model_name"]}".')
|
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||||
|
|
||||||
|
|
||||||
def run_cli(args: Namespace):
|
def run_cli(args: Namespace):
|
||||||
@@ -425,8 +425,8 @@ def run_cli(args: Namespace):
|
|||||||
|
|
||||||
if not args.merged_model_name:
|
if not args.merged_model_name:
|
||||||
args.merged_model_name = "+".join(args.models)
|
args.merged_model_name = "+".join(args.models)
|
||||||
logger.info(
|
print(
|
||||||
f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||||
)
|
)
|
||||||
|
|
||||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||||
@@ -435,7 +435,7 @@ def run_cli(args: Namespace):
|
|||||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||||
|
|
||||||
merge_diffusion_models_and_commit(**vars(args))
|
merge_diffusion_models_and_commit(**vars(args))
|
||||||
logger.info(f'Models merged into new model: "{args.merged_model_name}".')
|
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -455,16 +455,17 @@ def main():
|
|||||||
run_cli(args)
|
run_cli(args)
|
||||||
except widget.NotEnoughSpaceForWidget as e:
|
except widget.NotEnoughSpaceForWidget as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
logger.error(
|
print(
|
||||||
"You need to have at least two diffusers models defined in models.yaml in order to merge"
|
"** You need to have at least two diffusers models defined in models.yaml in order to merge"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(
|
print(
|
||||||
"Not enough room for the user interface. Try making this window larger."
|
"** Not enough room for the user interface. Try making this window larger."
|
||||||
)
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.error(e)
|
print(">> An error occurred:")
|
||||||
|
traceback.print_exc()
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import npyscreen
|
|||||||
from npyscreen import widget
|
from npyscreen import widget
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.backend.globals import Globals, global_set_root
|
from invokeai.backend.globals import Globals, global_set_root
|
||||||
|
|
||||||
from ...backend.training import do_textual_inversion_training, parse_args
|
from ...backend.training import do_textual_inversion_training, parse_args
|
||||||
@@ -369,14 +368,14 @@ def copy_to_embeddings_folder(args: dict):
|
|||||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||||
os.makedirs(destination, exist_ok=True)
|
os.makedirs(destination, exist_ok=True)
|
||||||
logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}")
|
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||||
shutil.copy(source, destination)
|
shutil.copy(source, destination)
|
||||||
if (
|
if (
|
||||||
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
||||||
).startswith(("y", "Y")):
|
).startswith(("y", "Y")):
|
||||||
shutil.rmtree(Path(args["output_dir"]))
|
shutil.rmtree(Path(args["output_dir"]))
|
||||||
else:
|
else:
|
||||||
logger.info(f'Keeping {args["output_dir"]}')
|
print(f'>> Keeping {args["output_dir"]}')
|
||||||
|
|
||||||
|
|
||||||
def save_args(args: dict):
|
def save_args(args: dict):
|
||||||
@@ -423,10 +422,10 @@ def do_front_end(args: Namespace):
|
|||||||
do_textual_inversion_training(**args)
|
do_textual_inversion_training(**args)
|
||||||
copy_to_embeddings_folder(args)
|
copy_to_embeddings_folder(args)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("An exception occurred during training. The exception was:")
|
print("** An exception occurred during training. The exception was:")
|
||||||
logger.error(str(e))
|
print(str(e))
|
||||||
logger.error("DETAILS:")
|
print("** DETAILS:")
|
||||||
logger.error(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -438,21 +437,21 @@ def main():
|
|||||||
else:
|
else:
|
||||||
do_textual_inversion_training(**vars(args))
|
do_textual_inversion_training(**vars(args))
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
logger.error(e)
|
print(str(e))
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
except (widget.NotEnoughSpaceForWidget, Exception) as e:
|
||||||
if str(e).startswith("Height of 1 allocated"):
|
if str(e).startswith("Height of 1 allocated"):
|
||||||
logger.error(
|
print(
|
||||||
"You need to have at least one diffusers models defined in models.yaml in order to train"
|
"** You need to have at least one diffusers models defined in models.yaml in order to train"
|
||||||
)
|
)
|
||||||
elif str(e).startswith("addwstr"):
|
elif str(e).startswith("addwstr"):
|
||||||
logger.error(
|
print(
|
||||||
"Not enough window space for the interface. Please make your window larger and try again."
|
"** Not enough window space for the interface. Please make your window larger and try again."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(e)
|
print(f"** An error has occurred: {str(e)}")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,5 +6,3 @@ stats.html
|
|||||||
index.html
|
index.html
|
||||||
.yarn/
|
.yarn/
|
||||||
*.scss
|
*.scss
|
||||||
src/services/api/
|
|
||||||
src/services/fixtures/*
|
|
||||||
|
|||||||
@@ -3,8 +3,4 @@ dist/
|
|||||||
node_modules/
|
node_modules/
|
||||||
patches/
|
patches/
|
||||||
stats.html
|
stats.html
|
||||||
index.html
|
|
||||||
.yarn/
|
.yarn/
|
||||||
*.scss
|
|
||||||
src/services/api/
|
|
||||||
src/services/fixtures/*
|
|
||||||
|
|||||||
@@ -1,16 +1,10 @@
|
|||||||
# InvokeAI Web UI
|
# InvokeAI Web UI
|
||||||
|
|
||||||
- [InvokeAI Web UI](#invokeai-web-ui)
|
|
||||||
- [Stack](#stack)
|
|
||||||
- [Contributing](#contributing)
|
|
||||||
- [Dev Environment](#dev-environment)
|
|
||||||
- [Production builds](#production-builds)
|
|
||||||
|
|
||||||
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
|
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
|
||||||
|
|
||||||
Code in `invokeai/frontend/web/` if you want to have a look.
|
Code in `invokeai/frontend/web/` if you want to have a look.
|
||||||
|
|
||||||
## Stack
|
## Details
|
||||||
|
|
||||||
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
||||||
|
|
||||||
@@ -38,7 +32,7 @@ Start everything in dev mode:
|
|||||||
|
|
||||||
1. Start the dev server: `yarn dev`
|
1. Start the dev server: `yarn dev`
|
||||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
2. Start the InvokeAI UI per usual: `invokeai --web`
|
||||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
|
||||||
|
|
||||||
### Production builds
|
### Production builds
|
||||||
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
import react from '@vitejs/plugin-react-swc';
|
|
||||||
import { visualizer } from 'rollup-plugin-visualizer';
|
|
||||||
import { PluginOption, UserConfig } from 'vite';
|
|
||||||
import eslint from 'vite-plugin-eslint';
|
|
||||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
|
||||||
|
|
||||||
export const appConfig: UserConfig = {
|
|
||||||
base: './',
|
|
||||||
plugins: [
|
|
||||||
react(),
|
|
||||||
eslint(),
|
|
||||||
tsconfigPaths(),
|
|
||||||
visualizer() as unknown as PluginOption,
|
|
||||||
],
|
|
||||||
build: {
|
|
||||||
chunkSizeWarningLimit: 1500,
|
|
||||||
},
|
|
||||||
server: {
|
|
||||||
// Proxy HTTP requests to the flask server
|
|
||||||
proxy: {
|
|
||||||
// Proxy socket.io to the nodes socketio server
|
|
||||||
'/ws/socket.io': {
|
|
||||||
target: 'ws://127.0.0.1:9090',
|
|
||||||
ws: true,
|
|
||||||
},
|
|
||||||
// Proxy openapi schema definiton
|
|
||||||
'/openapi.json': {
|
|
||||||
target: 'http://127.0.0.1:9090/openapi.json',
|
|
||||||
rewrite: (path) => path.replace(/^\/openapi.json/, ''),
|
|
||||||
changeOrigin: true,
|
|
||||||
},
|
|
||||||
// proxy nodes api
|
|
||||||
'/api/v1': {
|
|
||||||
target: 'http://127.0.0.1:9090/api/v1',
|
|
||||||
rewrite: (path) => path.replace(/^\/api\/v1/, ''),
|
|
||||||
changeOrigin: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
import react from '@vitejs/plugin-react-swc';
|
|
||||||
import path from 'path';
|
|
||||||
import { visualizer } from 'rollup-plugin-visualizer';
|
|
||||||
import { PluginOption, UserConfig } from 'vite';
|
|
||||||
import dts from 'vite-plugin-dts';
|
|
||||||
import eslint from 'vite-plugin-eslint';
|
|
||||||
import tsconfigPaths from 'vite-tsconfig-paths';
|
|
||||||
|
|
||||||
export const packageConfig: UserConfig = {
|
|
||||||
base: './',
|
|
||||||
plugins: [
|
|
||||||
react(),
|
|
||||||
eslint(),
|
|
||||||
tsconfigPaths(),
|
|
||||||
visualizer() as unknown as PluginOption,
|
|
||||||
dts({
|
|
||||||
insertTypesEntry: true,
|
|
||||||
}),
|
|
||||||
],
|
|
||||||
build: {
|
|
||||||
chunkSizeWarningLimit: 1500,
|
|
||||||
lib: {
|
|
||||||
entry: path.resolve(__dirname, '../src/index.ts'),
|
|
||||||
name: 'InvokeAIUI',
|
|
||||||
fileName: (format) => `invoke-ai-ui.${format}.js`,
|
|
||||||
},
|
|
||||||
rollupOptions: {
|
|
||||||
external: ['react', 'react-dom', '@emotion/react'],
|
|
||||||
output: {
|
|
||||||
globals: {
|
|
||||||
react: 'React',
|
|
||||||
'react-dom': 'ReactDOM',
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resolve: {
|
|
||||||
alias: {
|
|
||||||
app: path.resolve(__dirname, '../src/app'),
|
|
||||||
assets: path.resolve(__dirname, '../src/assets'),
|
|
||||||
common: path.resolve(__dirname, '../src/common'),
|
|
||||||
features: path.resolve(__dirname, '../src/features'),
|
|
||||||
services: path.resolve(__dirname, '../src/services'),
|
|
||||||
theme: path.resolve(__dirname, '../src/theme'),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
Normal file
188
invokeai/frontend/web/dist/assets/App-843b023b.js
vendored
Normal file
File diff suppressed because one or more lines are too long
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
188
invokeai/frontend/web/dist/assets/App-af7ef809.js
vendored
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
|||||||
import{j as y,cO as Ie,r as _,cP as bt,q as Lr,cQ as o,cR as b,cS as v,cT as S,cU as Vr,cV as ut,cW as vt,cN as ft,cX as mt,n as gt,cY as ht,E as pt}from"./index-e53e8108.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-5cde7d31.js";var Or=`
|
import{j as y,cN as Ie,r as _,cO as bt,q as Lr,cP as o,cQ as b,cR as v,cS as S,cT as Vr,cU as ut,cV as vt,cM as ft,cW as mt,n as gt,cX as ht,E as pt}from"./index-f7f41e1f.js";import{d as yt,i as St,T as xt,j as $t,h as kt}from"./storeHooks-eaf47ae3.js";var Or=`
|
||||||
:root {
|
:root {
|
||||||
--chakra-vh: 100vh;
|
--chakra-vh: 100vh;
|
||||||
}
|
}
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-e53e8108.js"></script>
|
<script type="module" crossorigin src="./assets/index-f7f41e1f.js"></script>
|
||||||
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
<link rel="stylesheet" href="./assets/index-5483945c.css">
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
|
|||||||
1
invokeai/frontend/web/dist/locales/ar.json
vendored
1
invokeai/frontend/web/dist/locales/ar.json
vendored
@@ -8,6 +8,7 @@
|
|||||||
"darkTheme": "داكن",
|
"darkTheme": "داكن",
|
||||||
"lightTheme": "فاتح",
|
"lightTheme": "فاتح",
|
||||||
"greenTheme": "أخضر",
|
"greenTheme": "أخضر",
|
||||||
|
"text2img": "نص إلى صورة",
|
||||||
"img2img": "صورة إلى صورة",
|
"img2img": "صورة إلى صورة",
|
||||||
"unifiedCanvas": "لوحة موحدة",
|
"unifiedCanvas": "لوحة موحدة",
|
||||||
"nodes": "عقد",
|
"nodes": "عقد",
|
||||||
|
|||||||
1
invokeai/frontend/web/dist/locales/de.json
vendored
1
invokeai/frontend/web/dist/locales/de.json
vendored
@@ -7,6 +7,7 @@
|
|||||||
"darkTheme": "Dunkel",
|
"darkTheme": "Dunkel",
|
||||||
"lightTheme": "Hell",
|
"lightTheme": "Hell",
|
||||||
"greenTheme": "Grün",
|
"greenTheme": "Grün",
|
||||||
|
"text2img": "Text zu Bild",
|
||||||
"img2img": "Bild zu Bild",
|
"img2img": "Bild zu Bild",
|
||||||
"nodes": "Knoten",
|
"nodes": "Knoten",
|
||||||
"langGerman": "Deutsch",
|
"langGerman": "Deutsch",
|
||||||
|
|||||||
4
invokeai/frontend/web/dist/locales/en.json
vendored
4
invokeai/frontend/web/dist/locales/en.json
vendored
@@ -505,9 +505,7 @@
|
|||||||
"info": "Info",
|
"info": "Info",
|
||||||
"deleteImage": "Delete Image",
|
"deleteImage": "Delete Image",
|
||||||
"initialImage": "Initial Image",
|
"initialImage": "Initial Image",
|
||||||
"showOptionsPanel": "Show Options Panel",
|
"showOptionsPanel": "Show Options Panel"
|
||||||
"hidePreview": "Hide Preview",
|
|
||||||
"showPreview": "Show Preview"
|
|
||||||
},
|
},
|
||||||
"settings": {
|
"settings": {
|
||||||
"models": "Models",
|
"models": "Models",
|
||||||
|
|||||||
12
invokeai/frontend/web/dist/locales/es.json
vendored
12
invokeai/frontend/web/dist/locales/es.json
vendored
@@ -8,6 +8,7 @@
|
|||||||
"darkTheme": "Oscuro",
|
"darkTheme": "Oscuro",
|
||||||
"lightTheme": "Claro",
|
"lightTheme": "Claro",
|
||||||
"greenTheme": "Verde",
|
"greenTheme": "Verde",
|
||||||
|
"text2img": "Texto a Imagen",
|
||||||
"img2img": "Imagen a Imagen",
|
"img2img": "Imagen a Imagen",
|
||||||
"unifiedCanvas": "Lienzo Unificado",
|
"unifiedCanvas": "Lienzo Unificado",
|
||||||
"nodes": "Nodos",
|
"nodes": "Nodos",
|
||||||
@@ -69,11 +70,7 @@
|
|||||||
"langHebrew": "Hebreo",
|
"langHebrew": "Hebreo",
|
||||||
"pinOptionsPanel": "Pin del panel de opciones",
|
"pinOptionsPanel": "Pin del panel de opciones",
|
||||||
"loading": "Cargando",
|
"loading": "Cargando",
|
||||||
"loadingInvokeAI": "Cargando invocar a la IA",
|
"loadingInvokeAI": "Cargando invocar a la IA"
|
||||||
"postprocessing": "Tratamiento posterior",
|
|
||||||
"txt2img": "De texto a imagen",
|
|
||||||
"accept": "Aceptar",
|
|
||||||
"cancel": "Cancelar"
|
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Generaciones",
|
"generations": "Generaciones",
|
||||||
@@ -407,8 +404,7 @@
|
|||||||
"none": "ninguno",
|
"none": "ninguno",
|
||||||
"pickModelType": "Elige el tipo de modelo",
|
"pickModelType": "Elige el tipo de modelo",
|
||||||
"v2_768": "v2 (768px)",
|
"v2_768": "v2 (768px)",
|
||||||
"addDifference": "Añadir una diferencia",
|
"addDifference": "Añadir una diferencia"
|
||||||
"scanForModels": "Buscar modelos"
|
|
||||||
},
|
},
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"images": "Imágenes",
|
"images": "Imágenes",
|
||||||
@@ -578,7 +574,7 @@
|
|||||||
"autoSaveToGallery": "Guardar automáticamente en galería",
|
"autoSaveToGallery": "Guardar automáticamente en galería",
|
||||||
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
|
"saveBoxRegionOnly": "Guardar solo región dentro de la caja",
|
||||||
"limitStrokesToBox": "Limitar trazos a la caja",
|
"limitStrokesToBox": "Limitar trazos a la caja",
|
||||||
"showCanvasDebugInfo": "Mostrar la información adicional del lienzo",
|
"showCanvasDebugInfo": "Mostrar información de depuración de lienzo",
|
||||||
"clearCanvasHistory": "Limpiar historial de lienzo",
|
"clearCanvasHistory": "Limpiar historial de lienzo",
|
||||||
"clearHistory": "Limpiar historial",
|
"clearHistory": "Limpiar historial",
|
||||||
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",
|
"clearCanvasHistoryMessage": "Limpiar el historial de lienzo también restablece completamente el lienzo unificado. Esto incluye todo el historial de deshacer/rehacer, las imágenes en el área de preparación y la capa base del lienzo.",
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user