Compare commits

..

1 Commits

Author SHA1 Message Date
Kyle Schouviller
510ae34bff [nodes] Add cancelation to the API 2023-03-16 20:05:36 -07:00
634 changed files with 8026 additions and 26494 deletions

6
.coveragerc Normal file
View File

@@ -0,0 +1,6 @@
[run]
omit='.env/*'
source='.'
[report]
show_missing = true

14
.github/CODEOWNERS vendored
View File

@@ -1,16 +1,16 @@
# continuous integration
/.github/workflows/ @lstein @blessedcoolant
/.github/workflows/ @mauwii @lstein
# documentation
/docs/ @lstein @tildebyte @blessedcoolant
/mkdocs.yml @lstein @blessedcoolant
/docs/ @lstein @mauwii @tildebyte
/mkdocs.yml @lstein @mauwii
# nodes
/invokeai/app/ @Kyle0654 @blessedcoolant
# installation and configuration
/pyproject.toml @lstein @blessedcoolant
/docker/ @lstein @blessedcoolant
/pyproject.toml @mauwii @lstein @blessedcoolant
/docker/ @mauwii @lstein
/scripts/ @ebr @lstein
/installer/ @lstein @ebr
/invokeai/assets @lstein @ebr
@@ -22,11 +22,11 @@
/invokeai/backend @blessedcoolant @psychedelicious @lstein
# generation, model management, postprocessing
/invokeai/backend @damian0815 @lstein @blessedcoolant @jpphoto @gregghelt2
/invokeai/backend @keturn @damian0815 @lstein @blessedcoolant @jpphoto
# front ends
/invokeai/frontend/CLI @lstein
/invokeai/frontend/install @lstein @ebr
/invokeai/frontend/install @lstein @ebr @mauwii
/invokeai/frontend/merge @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/training @lstein @blessedcoolant @hipsterusername
/invokeai/frontend/web @psychedelicious @blessedcoolant

19
.github/stale.yaml vendored
View File

@@ -1,19 +0,0 @@
# Number of days of inactivity before an issue becomes stale
daysUntilStale: 28
# Number of days of inactivity before a stale issue is closed
daysUntilClose: 14
# Issues with these labels will never be considered stale
exemptLabels:
- pinned
- security
# Label to use when marking an issue as stale
staleLabel: stale
# Comment to post when marking an issue as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Please
update the ticket if this is still a problem on the latest release.
# Comment to post when closing a stale issue. Set to `false` to disable
closeComment: >
Due to inactivity, this issue has been automatically closed. If this is
still a problem on the latest release, please recreate the issue.

View File

@@ -16,10 +16,6 @@ on:
- 'v*.*.*'
workflow_dispatch:
permissions:
contents: write
packages: write
jobs:
docker:
if: github.event.pull_request.draft == false

View File

@@ -2,19 +2,13 @@ name: mkdocs-material
on:
push:
branches:
- 'refs/heads/v2.3'
permissions:
contents: write
- 'main'
- 'development'
jobs:
mkdocs-material:
if: github.event.pull_request.draft == false
runs-on: ubuntu-latest
env:
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
REPO_NAME: '${{ github.repository }}'
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
steps:
- name: checkout sources
uses: actions/checkout@v3
@@ -25,15 +19,11 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install requirements
env:
PIP_USE_PEP517: 1
run: |
python -m \
pip install ".[docs]"
pip install -r docs/requirements-mkdocs.txt
- name: confirm buildability
run: |
@@ -43,7 +33,7 @@ jobs:
--verbose
- name: deploy to gh-pages
if: ${{ github.ref == 'refs/heads/v2.3' }}
if: ${{ github.ref == 'refs/heads/main' }}
run: |
python -m \
mkdocs gh-deploy \

View File

@@ -6,6 +6,7 @@ on:
- '!pyproject.toml'
- '!invokeai/**'
- 'invokeai/frontend/web/**'
- '!invokeai/frontend/web/dist/**'
merge_group:
workflow_dispatch:

View File

@@ -7,11 +7,13 @@ on:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'invokeai/frontend/web/dist/**'
pull_request:
paths:
- 'pyproject.toml'
- 'invokeai/**'
- '!invokeai/frontend/web/**'
- 'invokeai/frontend/web/dist/**'
types:
- 'ready_for_review'
- 'opened'

4
.gitignore vendored
View File

@@ -9,8 +9,6 @@ models/ldm/stable-diffusion-v1/model.ckpt
configs/models.user.yaml
config/models.user.yml
invokeai.init
.version
.last_model
# ignore the Anaconda/Miniconda installer used while building Docker image
anaconda.sh
@@ -65,7 +63,6 @@ pip-delete-this-directory.txt
htmlcov/
.tox/
.nox/
.coveragerc
.coverage
.coverage.*
.cache
@@ -76,7 +73,6 @@ cov.xml
*.py,cover
.hypothesis/
.pytest_cache/
.pytest.ini
cover/
junit/

5
.pytest.ini Normal file
View File

@@ -0,0 +1,5 @@
[pytest]
DJANGO_SETTINGS_MODULE = webtas.settings
; python_files = tests.py test_*.py *_tests.py
addopts = --cov=. --cov-config=.coveragerc --cov-report xml:cov.xml

View File

@@ -33,8 +33,6 @@
</div>
_**Note: The UI is not fully functional on `main`. If you need a stable UI based on `main`, use the `pre-nodes` tag while we [migrate to a new backend](https://github.com/invoke-ai/InvokeAI/discussions/3246).**_
InvokeAI is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. InvokeAI offers an industry leading Web Interface, interactive Command Line Interface, and also serves as the foundation for multiple commercial products.
**Quick links**: [[How to Install](https://invoke-ai.github.io/InvokeAI/#installation)] [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>] [<a href="https://invoke-ai.github.io/InvokeAI/">Documentation and Tutorials</a>] [<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas & Q&A</a>]
@@ -86,7 +84,7 @@ installing lots of models.
6. Wait while the installer does its thing. After installing the software,
the installer will launch a script that lets you configure InvokeAI and
select a set of starting image generation models.
select a set of starting image generaiton models.
7. Find the folder that InvokeAI was installed into (it is not the
same as the unpacked zip file directory!) The default location of this
@@ -141,20 +139,15 @@ not supported.
_For Windows/Linux with an NVIDIA GPU:_
```terminal
pip install "InvokeAI[xformers]" --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
pip install InvokeAI[xformers] --use-pep517 --extra-index-url https://download.pytorch.org/whl/cu117
```
_For Linux with an AMD GPU:_
```sh
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
```
_For non-GPU systems:_
```terminal
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/cpu
```
_For Macintoshes, either Intel or M1/M2:_
```sh

4
coverage/.gitignore vendored
View File

@@ -1,4 +0,0 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

Binary file not shown.

Before

Width:  |  Height:  |  Size: 470 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 457 KiB

View File

@@ -1,18 +1,10 @@
# Invocations
Invocations represent a single operation, its inputs, and its outputs. These
operations and their outputs can be chained together to generate and modify
images.
Invocations represent a single operation, its inputs, and its outputs. These operations and their outputs can be chained together to generate and modify images.
## Creating a new invocation
To create a new invocation, either find the appropriate module file in
`/ldm/invoke/app/invocations` to add your invocation to, or create a new one in
that folder. All invocations in that folder will be discovered and made
available to the CLI and API automatically. Invocations make use of
[typing](https://docs.python.org/3/library/typing.html) and
[pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration
into the CLI and API.
To create a new invocation, either find the appropriate module file in `/ldm/invoke/app/invocations` to add your invocation to, or create a new one in that folder. All invocations in that folder will be discovered and made available to the CLI and API automatically. Invocations make use of [typing](https://docs.python.org/3/library/typing.html) and [pydantic](https://pydantic-docs.helpmanual.io/) for validation and integration into the CLI and API.
An invocation looks like this:
@@ -49,54 +41,34 @@ class UpscaleInvocation(BaseInvocation):
Each portion is important to implement correctly.
### Class definition and type
```py
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
type: Literal['upscale'] = 'upscale'
```
All invocations must derive from `BaseInvocation`. They should have a docstring
that declares what they do in a single, short line. They should also have a
`type` with a type hint that's `Literal["command_name"]`, where `command_name`
is what the user will type on the CLI or use in the API to create this
invocation. The `command_name` must be unique. The `type` must be assigned to
the value of the literal in the type hint.
All invocations must derive from `BaseInvocation`. They should have a docstring that declares what they do in a single, short line. They should also have a `type` with a type hint that's `Literal["command_name"]`, where `command_name` is what the user will type on the CLI or use in the API to create this invocation. The `command_name` must be unique. The `type` must be assigned to the value of the literal in the type hint.
### Inputs
```py
# Inputs
image: Union[ImageField,None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2,4] = Field(default=2, description="The upscale level")
```
Inputs consist of three parts: a name, a type hint, and a `Field` with default, description, and validation information. For example:
| Part | Value | Description |
| ---- | ----- | ----------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
Inputs consist of three parts: a name, a type hint, and a `Field` with default,
description, and validation information. For example:
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this field to be parsed with `None` as a value, which enables linking to previous invocations. All fields should either provide a default value or allow `None` as a value, so that they can be overwritten with a linked output from another invocation.
| Part | Value | Description |
| --------- | ------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
| Name | `strength` | This field is referred to as `strength` |
| Type Hint | `float` | This field must be of type `float` |
| Field | `Field(default=0.75, gt=0, le=1, description="The strength")` | The default value is `0.75`, the value must be in the range (0,1], and help text will show "The strength" for this field. |
The special type `ImageField` is also used here. All images are passed as `ImageField`, which protects them from pydantic validation errors (since images only ever come from links).
Notice that `image` has type `Union[ImageField,None]`. The `Union` allows this
field to be parsed with `None` as a value, which enables linking to previous
invocations. All fields should either provide a default value or allow `None` as
a value, so that they can be overwritten with a linked output from another
invocation.
The special type `ImageField` is also used here. All images are passed as
`ImageField`, which protects them from pydantic validation errors (since images
only ever come from links).
Finally, note that for all linking, the `type` of the linked fields must match.
If the `name` also matches, then the field can be **automatically linked** to a
previous invocation by name and matching.
Finally, note that for all linking, the `type` of the linked fields must match. If the `name` also matches, then the field can be **automatically linked** to a previous invocation by name and matching.
### Invoke Function
```py
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image.image_type, self.image.image_name)
@@ -116,22 +88,13 @@ previous invocation by name and matching.
image = ImageField(image_type = image_type, image_name = image_name)
)
```
The `invoke` function is the last portion of an invocation. It is provided an `InvocationContext` which contains services to perform work as well as a `session_id` for use as needed. It should return a class with output values that derives from `BaseInvocationOutput`.
The `invoke` function is the last portion of an invocation. It is provided an
`InvocationContext` which contains services to perform work as well as a
`session_id` for use as needed. It should return a class with output values that
derives from `BaseInvocationOutput`.
Before being called, the invocation will have all of its fields set from defaults, inputs, and finally links (overriding in that order).
Before being called, the invocation will have all of its fields set from
defaults, inputs, and finally links (overriding in that order).
Assume that this invocation may be running simultaneously with other
invocations, may be running on another machine, or in other interesting
scenarios. If you need functionality, please provide it as a service in the
`InvocationServices` class, and make sure it can be overridden.
Assume that this invocation may be running simultaneously with other invocations, may be running on another machine, or in other interesting scenarios. If you need functionality, please provide it as a service in the `InvocationServices` class, and make sure it can be overridden.
### Outputs
```py
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
@@ -139,64 +102,4 @@ class ImageOutput(BaseInvocationOutput):
image: ImageField = Field(default=None, description="The output image")
```
Output classes look like an invocation class without the invoke method. Prefer
to use an existing output class if available, and prefer to name inputs the same
as outputs when possible, to promote automatic invocation linking.
## Schema Generation
Invocation, output and related classes are used to generate an OpenAPI schema.
### Required Properties
The schema generation treat all properties with default values as optional. This
makes sense internally, but when when using these classes via the generated
schema, we end up with e.g. the `ImageOutput` class having its `image` property
marked as optional.
We know that this property will always be present, so the additional logic
needed to always check if the property exists adds a lot of extraneous cruft.
To fix this, we can leverage `pydantic`'s
[schema customisation](https://docs.pydantic.dev/usage/schema/#schema-customization)
to mark properties that we know will always be present as required.
Here's that `ImageOutput` class, without the needed schema customisation:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
```
The generated OpenAPI schema, and all clients/types generated from it, will have
the `type` and `image` properties marked as optional, even though we know they
will always have a value by the time we can interact with them via the API.
Here's the same class, but with the schema customisation added:
```python
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
```
The resultant schema (and any API client or types generated from it) will now
have see `type` as string literal `"image"` and `image` as an `ImageField`
object.
See this `pydantic` issue for discussion on this solution:
<https://github.com/pydantic/pydantic/discussions/4577>
Output classes look like an invocation class without the invoke method. Prefer to use an existing output class if available, and prefer to name inputs the same as outputs when possible, to promote automatic invocation linking.

View File

@@ -1,83 +0,0 @@
# Local Development
If you are looking to contribute you will need to have a local development
environment. See the
[Developer Install](../installation/020_INSTALL_MANUAL.md#developer-install) for
full details.
Broadly this involves cloning the repository, installing the pre-reqs, and
InvokeAI (in editable form). Assuming this is working, choose your area of
focus.
## Documentation
We use [mkdocs](https://www.mkdocs.org) for our documentation with the
[material theme](https://squidfunk.github.io/mkdocs-material/). Documentation is
written in markdown files under the `./docs` folder and then built into a static
website for hosting with GitHub Pages at
[invoke-ai.github.io/InvokeAI](https://invoke-ai.github.io/InvokeAI).
To contribute to the documentation you'll need to install the dependencies. Note
the use of `"`.
```zsh
pip install ".[docs]"
```
Now, to run the documentation locally with hot-reloading for changes made.
```zsh
mkdocs serve
```
You'll then be prompted to connect to `http://127.0.0.1:8080` in order to
access.
## Backend
The backend is contained within the `./invokeai/backend` folder structure. To
get started however please install the development dependencies.
From the root of the repository run the following command. Note the use of `"`.
```zsh
pip install ".[test]"
```
This in an optional group of packages which is defined within the
`pyproject.toml` and will be required for testing the changes you make the the
code.
### Running Tests
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
be found under the `./tests` folder and can be run with a single `pytest`
command. Optionally, to review test coverage you can append `--cov`.
```zsh
pytest --cov
```
Test outcomes and coverage will be reported in the terminal. In addition a more
detailed report is created in both XML and HTML format in the `./coverage`
folder. The HTML one in particular can help identify missing statements
requiring tests to ensure coverage. This can be run by opening
`./coverage/html/index.html`.
For example.
```zsh
pytest --cov; open ./coverage/html/index.html
```
??? info "HTML coverage report output"
![html-overview](../assets/contributing/html-overview.png)
![html-detail](../assets/contributing/html-detail.png)
## Front End
<!--#TODO: get input from blessedcoolant here, for the moment inserted the frontend README via snippets extension.-->
--8<-- "invokeai/frontend/web/README.md"

View File

@@ -168,15 +168,11 @@ used by Stable Diffusion 1.4 and 1.5.
After installation, your `models.yaml` should contain an entry that looks like
this one:
```yml
inpainting-1.5:
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
description: SD inpainting v1.5
config: configs/stable-diffusion/v1-inpainting-inference.yaml
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
```
inpainting-1.5: weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
description: SD inpainting v1.5 config:
configs/stable-diffusion/v1-inpainting-inference.yaml vae:
models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt width: 512
height: 512
As shown in the example, you may include a VAE fine-tuning weights file as well.
This is strongly recommended.

View File

@@ -32,7 +32,7 @@ turned on and off on the command line using `--nsfw_checker` and
At installation time, InvokeAI will ask whether the checker should be
activated by default (neither argument given on the command line). The
response is stored in the InvokeAI initialization file (usually
`invokeai.init` in your home directory). You can change the default at any
`.invokeai` in your home directory). You can change the default at any
time by opening this file in a text editor and commenting or
uncommenting the line `--nsfw_checker`.

View File

@@ -268,7 +268,7 @@ model is so good at inpainting, a good substitute is to use the `clipseg` text
masking option:
```bash
invoke> a fluffy cat eating a hotdog
invoke> a fluffy cat eating a hotdot
Outputs:
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat

View File

@@ -17,7 +17,7 @@ notebooks.
You will need a GPU to perform training in a reasonable length of
time, and at least 12 GB of VRAM. We recommend using the [`xformers`
library](../installation/070_INSTALL_XFORMERS.md) to accelerate the
library](../installation/070_INSTALL_XFORMERS) to accelerate the
training process further. During training, about ~8 GB is temporarily
needed in order to store intermediate models, checkpoints and logs.

View File

@@ -89,7 +89,7 @@ experimental versions later.
sudo apt update
sudo apt install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt install -y python3.10 python3-pip python3.10-venv
sudo apt install python3.10 python3-pip python3.10-venv
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
```
@@ -417,7 +417,7 @@ Then type the following commands:
=== "AMD System"
```bash
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.2
```
### Corrupted configuration file

View File

@@ -154,7 +154,7 @@ manager, please follow these steps:
=== "ROCm (AMD)"
```bash
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
```
=== "CPU (Intel Macs & non-GPU systems)"
@@ -315,7 +315,7 @@ installation protocol (important!)
=== "ROCm (AMD)"
```bash
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.2
```
=== "CPU (Intel Macs & non-GPU systems)"

View File

@@ -110,7 +110,7 @@ recipes are available
When installing torch and torchvision manually with `pip`, remember to provide
the argument `--extra-index-url
https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
https://download.pytorch.org/whl/rocm5.2` as described in the [Manual
Installation Guide](020_INSTALL_MANUAL.md).
This will be done automatically for you if you use the installer

View File

@@ -50,7 +50,7 @@ subset that are currently installed are found in
|stable-diffusion-1.5|runwayml/stable-diffusion-v1-5|Stable Diffusion version 1.5 diffusers model (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-v1-5 |
|sd-inpainting-1.5|runwayml/stable-diffusion-inpainting|RunwayML SD 1.5 model optimized for inpainting, diffusers version (4.27 GB)|https://huggingface.co/runwayml/stable-diffusion-inpainting |
|stable-diffusion-2.1|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.1 diffusers model, trained on 768 pixel images (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-inpainting|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-inpainting |
|sd-inpainting-2.0|stabilityai/stable-diffusion-2-1|Stable Diffusion version 2.0 inpainting model (5.21 GB)|https://huggingface.co/stabilityai/stable-diffusion-2-1 |
|analog-diffusion-1.0|wavymulder/Analog-Diffusion|An SD-1.5 model trained on diverse analog photographs (2.13 GB)|https://huggingface.co/wavymulder/Analog-Diffusion |
|deliberate-1.0|XpucT/Deliberate|Versatile model that produces detailed images up to 768px (4.27 GB)|https://huggingface.co/XpucT/Deliberate |
|d&d-diffusion-1.0|0xJustin/Dungeons-and-Diffusion|Dungeons & Dragons characters (2.13 GB)|https://huggingface.co/0xJustin/Dungeons-and-Diffusion |

View File

@@ -24,7 +24,7 @@ You need to have opencv installed so that pypatchmatch can be built:
brew install opencv
```
The next time you start `invoke`, after successfully installing opencv, pypatchmatch will be built.
The next time you start `invoke`, after sucesfully installing opencv, pypatchmatch will be built.
## Linux
@@ -56,7 +56,7 @@ Prior to installing PyPatchMatch, you need to take the following steps:
5. Confirm that pypatchmatch is installed. At the command-line prompt enter
`python`, and then at the `>>>` line type
`from patchmatch import patch_match`: It should look like the following:
`from patchmatch import patch_match`: It should look like the follwing:
```py
Python 3.9.5 (default, Nov 23 2021, 15:27:38)
@@ -108,4 +108,4 @@ Prior to installing PyPatchMatch, you need to take the following steps:
[**Next, Follow Steps 4-6 from the Debian Section above**](#linux)
If you see no errors you're ready to go!
If you see no errors, then you're ready to go!

View File

@@ -456,7 +456,7 @@ def get_torch_source() -> (Union[str, None],str):
optional_modules = None
if OS == "Linux":
if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.4.2"
url = "https://download.pytorch.org/whl/rocm5.2"
elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu"

View File

@@ -24,9 +24,9 @@ if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
while true
do
if [ "$0" != "bash" ]; then
while true
do
echo "Do you want to generate images using the"
echo "1. command-line interface"
echo "2. browser-based UI"
@@ -67,29 +67,29 @@ if [ "$0" != "bash" ]; then
;;
7)
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
;;
8)
echo "Developer Console:"
;;
8)
echo "Developer Console:"
file_name=$(basename "${BASH_SOURCE[0]}")
bash --init-file "$file_name"
;;
9)
echo "Update:"
echo "Update:"
invokeai-update
;;
10)
invokeai --help
;;
[qQ])
[qQ])
exit 0
;;
*)
echo "Invalid selection"
exit;;
esac
done
else # in developer console
python --version
echo "Press ^D to exit"
export PS1="(InvokeAI) \u@\h \w> "
fi
done

View File

@@ -1,23 +1,18 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from argparse import Namespace
import invokeai.backend.util.logging as logger
from typing import types
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.graph import GraphExecutionState
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from .events import FastAPIEventService
@@ -43,16 +38,15 @@ class ApiDependencies:
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
def initialize(config, event_handler_id: int):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
Globals.disable_xformers = not config.xformers
Globals.ckpt_convert = config.ckpt_convert
# TO DO: Use the config to select the logger rather than use the default
# invokeai logging module
logger.info(f"Internet connectivity is {Globals.internet_available}")
# TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}")
events = FastAPIEventService(event_handler_id)
@@ -60,35 +54,23 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs")
)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
metadata = PngMetadataService()
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
images = DiskImageStorage(output_folder)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
model_manager=get_model_manager(config,logger),
model_manager=get_model_manager(config),
events=events,
logger=logger,
latents=latents,
images=images,
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger),
restoration=RestorationServices(config),
)
create_system_graphs(services.graph_library)
ApiDependencies.invoker = Invoker(services)
@staticmethod

View File

@@ -45,7 +45,7 @@ class FastAPIEventService(EventServiceBase):
)
except Empty:
await asyncio.sleep(0.1)
await asyncio.sleep(0.001)
pass
except asyncio.CancelledError as e:

View File

@@ -1,40 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import InvokeAIMetadata
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class SavedImage(BaseModel):
image_name: str = Field(description="The name of the saved image")
thumbnail_name: str = Field(description="The name of the saved thumbnail")
created: int = Field(description="The created timestamp of the saved image")

View File

@@ -1,20 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from datetime import datetime, timezone
from fastapi import Path, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
)
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
@@ -26,123 +17,40 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
)
async def get_thumbnail(
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
):
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
201: {"description": "The image was uploaded successfully"},
404: {"description": "Session not found"},
},
status_code=201,
)
async def upload_image(
file: UploadFile, request: Request, response: Response
) -> ImageResponse:
async def upload_image(file: UploadFile, request: Request):
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
return Response(status_code=415)
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
im = Image.open(contents)
except:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
return Response(status_code=415)
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
saved_image = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img
return Response(
status_code=201,
headers={
"Location": request.url_for(
"get_image", image_type=ImageType.UPLOAD, image_name=filename
)
},
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name, True
)
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = image_url
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result

View File

@@ -1,248 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
import shutil
import asyncio
from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models() -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@models_router.post(
"/",
operation_id="update_model",
responses={200: {"status": "success"}},
)
async def update_model(
model_request: CreateModelRequest
) -> CreateModelResponse:
""" Add Model """
model_request_info = model_request.info
info_dict = model_request_info.dict()
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_request.name,
model_attributes=info_dict,
clobber=True,
)
return model_response
@models_router.delete(
"/{model_name}",
operation_id="del_model",
responses={
204: {
"description": "Model deleted successfully"
},
404: {
"description": "Model not found"
}
},
)
async def delete_model(model_name: str) -> None:
"""Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
logger = ApiDependencies.invoker.services.logger
model_exists = model_name in model_names
# check if model exists
logger.info(f"Checking for model {model_name}...")
if model_exists:
logger.info(f"Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
logger.info(f"Model Deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
else:
logger.error(f"Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:

View File

@@ -2,7 +2,8 @@
from typing import Annotated, List, Optional, Union
from fastapi import Body, HTTPException, Path, Query, Response
from fastapi import Body, Path, Query
from fastapi.responses import Response
from fastapi.routing import APIRouter
from pydantic.fields import Field
@@ -50,7 +51,7 @@ async def list_sessions(
query: str = Query(default="", description="The query string to search for"),
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if query == "":
if filter == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list(
page, per_page
)
@@ -75,7 +76,7 @@ async def get_session(
"""Gets a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
return Response(status_code=404)
else:
return session
@@ -98,7 +99,7 @@ async def add_node(
"""Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
return Response(status_code=404)
try:
session.add_node(node)
@@ -107,9 +108,9 @@ async def add_node(
) # TODO: can this be done automatically, or add node through an API?
return session.id
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
return Response(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
return Response(status_code=400)
@session_router.put(
@@ -131,7 +132,7 @@ async def update_node(
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
return Response(status_code=404)
try:
session.update_node(node_path, node)
@@ -140,9 +141,9 @@ async def update_node(
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
return Response(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
return Response(status_code=400)
@session_router.delete(
@@ -161,7 +162,7 @@ async def delete_node(
"""Deletes a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
return Response(status_code=404)
try:
session.delete_node(node_path)
@@ -170,9 +171,9 @@ async def delete_node(
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
return Response(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
return Response(status_code=400)
@session_router.post(
@@ -191,7 +192,7 @@ async def add_edge(
"""Adds an edge to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
return Response(status_code=404)
try:
session.add_edge(edge)
@@ -200,9 +201,9 @@ async def add_edge(
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
return Response(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
return Response(status_code=400)
# TODO: the edge being in the path here is really ugly, find a better solution
@@ -225,7 +226,7 @@ async def delete_edge(
"""Deletes an edge from the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
return Response(status_code=404)
try:
edge = Edge(
@@ -238,9 +239,9 @@ async def delete_edge(
) # TODO: can this be done automatically, or add node through an API?
return session
except NodeAlreadyExecutedError:
raise HTTPException(status_code=400)
return Response(status_code=400)
except IndexError:
raise HTTPException(status_code=400)
return Response(status_code=400)
@session_router.put(
@@ -258,14 +259,14 @@ async def invoke_session(
all: bool = Query(
default=False, description="Whether or not to invoke all remaining invocations"
),
) -> Response:
) -> None:
"""Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
if session is None:
raise HTTPException(status_code=404)
return Response(status_code=404)
if session.is_complete():
raise HTTPException(status_code=400)
return Response(status_code=400)
ApiDependencies.invoker.invoke(session, invoke_all=all)
return Response(status_code=202)
@@ -280,7 +281,7 @@ async def invoke_session(
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),
) -> Response:
) -> None:
"""Invokes a session"""
ApiDependencies.invoker.cancel(session_id)
return Response(status_code=202)

View File

@@ -3,7 +3,6 @@ import asyncio
from inspect import signature
import uvicorn
import invokeai.backend.util.logging as logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@@ -15,8 +14,9 @@ from pydantic.schema import schema
from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.routers import images, sessions
from .api.sockets import SocketIO
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
# Create the app
@@ -56,7 +56,7 @@ async def startup_event():
config.parse_args()
ApiDependencies.initialize(
config=config, event_handler_id=event_handler_id, logger=logger
config=config, event_handler_id=event_handler_id
)
@@ -76,8 +76,6 @@ app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?

View File

@@ -2,46 +2,14 @@
from abc import ABC, abstractmethod
import argparse
from typing import Any, Callable, Iterable, Literal, Union, get_args, get_origin, get_type_hints
from typing import Any, Callable, Iterable, Literal, get_args, get_origin, get_type_hints
from pydantic import BaseModel, Field
import networkx as nx
import matplotlib.pyplot as plt
import invokeai.backend.util.logging as logger
from ..invocations.baseinvocation import BaseInvocation
from ..invocations.image import ImageField
from ..services.graph import GraphExecutionState, LibraryGraph, Edge
from ..services.graph import GraphExecutionState
from ..services.invoker import Invoker
def add_field_argument(command_parser, name: str, field, default_override = None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=default,
help=field.field_info.description,
)
def add_parsers(
subparsers,
commands: list[type],
@@ -66,26 +34,30 @@ def add_parsers(
if name in exclude_fields:
continue
add_field_argument(command_parser, name, field)
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
for val in allowed_values:
allowed_types.add(type(val))
allowed_types_list = list(allowed_types)
field_type = allowed_types_list[0] if len(allowed_types) == 1 else Union[allowed_types_list] # type: ignore
def add_graph_parsers(
subparsers,
graphs: list[LibraryGraph],
add_arguments: Callable[[argparse.ArgumentParser], None]|None = None
):
for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description)
if add_arguments is not None:
add_arguments(command_parser)
# Add arguments for inputs
for exposed_input in graph.exposed_inputs:
node = graph.graph.get_node(exposed_input.node_path)
field = node.__fields__[exposed_input.field]
default_override = getattr(node, exposed_input.field)
add_field_argument(command_parser, exposed_input.alias, field, default_override)
command_parser.add_argument(
f"--{name}",
dest=name,
type=field_type,
default=field.default,
choices=allowed_values,
help=field.field_info.description,
)
else:
command_parser.add_argument(
f"--{name}",
dest=name,
type=field.type_,
default=field.default,
help=field.field_info.description,
)
class CliContext:
@@ -93,38 +65,17 @@ class CliContext:
session: GraphExecutionState
parser: argparse.ArgumentParser
defaults: dict[str, Any]
graph_nodes: dict[str, str]
nodes_added: list[str]
def __init__(self, invoker: Invoker, session: GraphExecutionState, parser: argparse.ArgumentParser):
self.invoker = invoker
self.session = session
self.parser = parser
self.defaults = dict()
self.graph_nodes = dict()
self.nodes_added = list()
def get_session(self):
self.session = self.invoker.services.graph_execution_manager.get(self.session.id)
return self.session
def reset(self):
self.session = self.invoker.create_execution_state()
self.graph_nodes = dict()
self.nodes_added = list()
# Leave defaults unchanged
def add_node(self, node: BaseInvocation):
self.get_session()
self.session.graph.add_node(node)
self.nodes_added.append(node.id)
self.invoker.services.graph_execution_manager.set(self.session)
def add_edge(self, edge: Edge):
self.get_session()
self.session.add_edge(edge)
self.invoker.services.graph_execution_manager.set(self.session)
class ExitCli(Exception):
"""Exception to exit the CLI"""
@@ -230,7 +181,7 @@ class HistoryCommand(BaseCommand):
for i in range(min(self.count, len(history))):
entry_id = history[-1 - i]
entry = context.get_session().graph.get_node(entry_id)
logger.info(f"{entry_id}: {get_invocation_command(entry)}")
print(f"{entry_id}: {get_invocation_command(entry)}")
class SetDefaultCommand(BaseCommand):
@@ -249,39 +200,3 @@ class SetDefaultCommand(BaseCommand):
del context.defaults[self.field]
else:
context.defaults[self.field] = self.value
class DrawGraphCommand(BaseCommand):
"""Debugs a graph"""
type: Literal['draw_graph'] = 'draw_graph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()
class DrawExecutionGraphCommand(BaseCommand):
"""Debugs an execution graph"""
type: Literal['draw_xgraph'] = 'draw_xgraph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
nxgraph = session.execution_graph.nx_graph_flat()
# Draw the networkx graph
plt.figure(figsize=(20, 20))
pos = nx.spectral_layout(nxgraph)
nx.draw_networkx_nodes(nxgraph, pos, node_size=1000)
nx.draw_networkx_edges(nxgraph, pos, width=2)
nx.draw_networkx_labels(nxgraph, pos, font_size=20, font_family="sans-serif")
plt.axis("off")
plt.show()

View File

@@ -1,168 +0,0 @@
"""
Readline helper functions for cli_app.py
You may import the global singleton `completer` to get access to the
completer object.
"""
import atexit
import readline
import shlex
from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin
import invokeai.backend.util.logging as logger
from ...backend import ModelManager, Globals
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
# singleton object, class variable
completer = None
class Completer(object):
def __init__(self, model_manager: ModelManager):
self.commands = self.get_commands()
self.matches = None
self.linebuffer = None
self.manager = model_manager
return
def complete(self, text, state):
"""
Complete commands and switches fromm the node CLI command line.
Switches are determined in a context-specific manner.
"""
buffer = readline.get_line_buffer()
if state == 0:
options = None
try:
current_command, current_switch = self.get_current_command(buffer)
options = self.get_command_options(current_command, current_switch)
except IndexError:
pass
options = options or list(self.parse_commands().keys())
if not text: # first time
self.matches = options
else:
self.matches = [s for s in options if s and s.startswith(text)]
try:
match = self.matches[state]
except IndexError:
match = None
return match
@classmethod
def get_commands(self)->List[object]:
"""
Return a list of all the client commands and invocations.
"""
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
def get_current_command(self, buffer: str)->tuple[str, str]:
"""
Parse the readline buffer to find the most recent command and its switch.
"""
if len(buffer)==0:
return None, None
tokens = shlex.split(buffer)
command = None
switch = None
for t in tokens:
if t[0].isalpha():
if switch is None:
command = t
else:
switch = t
# don't try to autocomplete switches that are already complete
if switch and buffer.endswith(' '):
switch=None
return command or '', switch or ''
def parse_commands(self)->Dict[str, List[str]]:
"""
Return a dict in which the keys are the command name
and the values are the parameters the command takes.
"""
result = dict()
for command in self.commands:
hints = get_type_hints(command)
name = get_args(hints['type'])[0]
result.update({name:hints})
return result
def get_command_options(self, command: str, switch: str)->List[str]:
"""
Return all the parameters that can be passed to the command as
command-line switches. Returns None if the command is unrecognized.
"""
parsed_commands = self.parse_commands()
if command not in parsed_commands:
return None
# handle switches in the format "-foo=bar"
argument = None
if switch and '=' in switch:
switch, argument = switch.split('=')
parameter = switch.strip('-')
if parameter in parsed_commands[command]:
if argument is None:
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
else:
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
else:
return [f"--{x}" for x in parsed_commands[command].keys()]
def get_parameter_options(self, parameter: str, typehint)->List[str]:
"""
Given a parameter type (such as Literal), offers autocompletions.
"""
if get_origin(typehint) == Literal:
return get_args(typehint)
if parameter == 'model':
return self.manager.model_names()
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def set_autocompleter(model_manager: ModelManager) -> Completer:
global completer
if completer:
return completer
completer = Completer(model_manager)
readline.set_completer(completer.complete)
# pyreadline3 does not have a set_auto_history() method
try:
readline.set_auto_history(True)
except:
pass
readline.set_pre_input_hook(completer._pre_input_hook)
readline.set_completer_delims(" ")
readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set print-completions-horizontally off")
readline.parse_and_bind("set page-completions on")
readline.parse_and_bind("set skip-completed-text on")
readline.parse_and_bind("set show-all-if-ambiguous on")
histfile = Path(Globals.root, ".invoke_history")
try:
readline.read_history_file(histfile)
readline.set_history_length(1000)
except FileNotFoundError:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
logger.error(
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile)

View File

@@ -2,7 +2,6 @@
import argparse
import os
import re
import shlex
import time
from typing import (
@@ -13,21 +12,14 @@ from typing import (
from pydantic import BaseModel
from pydantic.fields import Field
import invokeai.backend.util.logging as logger
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..backend import Args
from .cli.commands import BaseCommand, CliContext, ExitCli, add_graph_parsers, add_parsers
from .cli.completer import set_autocompleter
from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_graph_execution_history
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.graph import Edge, EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@@ -51,7 +43,7 @@ def add_invocation_args(command_parser):
"-l",
action="append",
nargs=3,
help="A link in the format 'source_node source_field dest_field'. source_node can be relative to history (e.g. -1)",
help="A link in the format 'dest_field source_node source_field'. source_node can be relative to history (e.g. -1)",
)
command_parser.add_argument(
@@ -62,7 +54,7 @@ def add_invocation_args(command_parser):
)
def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
def get_command_parser() -> argparse.ArgumentParser:
# Create invocation parser
parser = argparse.ArgumentParser()
@@ -80,72 +72,20 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
commands = BaseCommand.get_all_subclasses()
add_parsers(subparsers, commands, exclude_fields=["type"])
# Create subparsers for exposed CLI graphs
# TODO: add a way to identify these graphs
text_to_image = services.graph_library.get(default_text_to_image_graph_id)
add_graph_parsers(subparsers, [text_to_image], add_arguments=add_invocation_args)
return parser
class NodeField():
alias: str
node_path: str
field: str
field_type: type
def __init__(self, alias: str, node_path: str, field: str, field_type: type):
self.alias = alias
self.node_path = node_path
self.field = field
self.field_type = field_type
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_input.node_path))
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_output.node_path))
node_output_type = node_type.get_output_type()
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
"""Gets the inputs for the specified invocation from the context"""
node_type = type(invocation)
if node_type is not GraphInvocation:
return fields_from_type_hints(get_type_hints(node_type), invocation.id)
else:
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
return {e.alias: get_node_input_field(graph, e.alias, invocation.id) for e in graph.exposed_inputs}
def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
"""Gets the outputs for the specified invocation from the context"""
node_type = type(invocation)
if node_type is not GraphInvocation:
return fields_from_type_hints(get_type_hints(node_type.get_output_type()), invocation.id)
else:
graph: LibraryGraph = context.invoker.services.graph_library.get(context.graph_nodes[invocation.id])
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation, context: CliContext
a: BaseInvocation, b: BaseInvocation
) -> list[Edge]:
"""Generates all possible edges between two invocations"""
afields = get_node_outputs(a, context)
bfields = get_node_inputs(b, context)
atype = type(a)
btype = type(b)
aoutputtype = atype.get_output_type()
afields = get_type_hints(aoutputtype)
bfields = get_type_hints(btype)
matching_fields = set(afields.keys()).intersection(bfields.keys())
@@ -153,15 +93,12 @@ def generate_matching_edges(
invalid_fields = set(["type", "id"])
matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
edges = [
Edge(
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
source=EdgeConnection(node_id=a.id, field=field),
destination=EdgeConnection(node_id=b.id, field=field)
)
for alias in matching_fields
for field in matching_fields
]
return edges
@@ -181,7 +118,7 @@ def invoke_all(context: CliContext):
# Print any errors
if context.session.has_error():
for n in context.session.errors:
context.invoker.services.logger.error(
print(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
@@ -191,18 +128,10 @@ def invoke_all(context: CliContext):
def invoke_cli():
config = Args()
config.parse_args()
model_manager = get_model_manager(config,logger=logger)
# This initializes the autocompleter and returns it.
# Currently nothing is done with the returned Completer
# object, but the object can be used to change autocompletion
# behavior on the fly, if desired.
set_autocompleter(model_manager)
model_manager = get_model_manager(config)
events = EventServiceBase()
metadata = PngMetadataService()
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
@@ -213,29 +142,18 @@ def invoke_cli():
services = InvocationServices(
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata,
images=DiskImageStorage(output_folder),
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,
restoration=RestorationServices(config),
)
system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs])
invoker = Invoker(services)
session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser(services)
re_negid = re.compile('^-[0-9]+$')
parser = get_command_parser()
# Uncomment to print out previous sessions at startup
# print(services.session_manager.list())
@@ -244,19 +162,18 @@ def invoke_cli():
while True:
try:
cmd_input = input("invoke> ")
except (KeyboardInterrupt, EOFError):
cmd_input = input("> ")
except KeyboardInterrupt:
# Ctrl-c exits
break
try:
# Refresh the state of the session
#history = list(get_graph_execution_history(context.session))
history = list(reversed(context.nodes_added))
history = list(get_graph_execution_history(context.session))
# Split the command for piping
cmds = cmd_input.split("|")
start_id = len(context.nodes_added)
start_id = len(history)
current_id = start_id
new_invocations = list()
for cmd in cmds:
@@ -272,24 +189,8 @@ def invoke_cli():
args[field_name] = field_default
# Parse invocation
command: CliCommand = None # type:ignore
system_graph: LibraryGraph|None = None
if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
for exposed_input in system_graph.exposed_inputs:
if exposed_input.alias in args:
node = invocation.graph.get_node(exposed_input.node_path)
field = exposed_input.field
setattr(node, field, args[exposed_input.alias])
command = CliCommand(command = invocation)
context.graph_nodes[invocation.id] = system_graph.id
else:
args["id"] = current_id
command = CliCommand(command=args)
if command is None:
continue
args["id"] = current_id
command = CliCommand(command=args)
# Run any CLI commands immediately
if isinstance(command.command, BaseCommand):
@@ -300,7 +201,6 @@ def invoke_cli():
command.command.run(context)
continue
# TODO: handle linking with library graphs
# Pipe previous command output (if there was a previous command)
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
@@ -313,20 +213,16 @@ def invoke_cli():
else context.session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(
from_node, command.command, context
from_node, command.command
)
edges.extend(matching_edges)
# Parse provided links
if "link_node" in args and args["link_node"]:
for link in args["link_node"]:
node_id = link
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
link_node = context.session.graph.get_node(node_id)
link_node = context.session.graph.get_node(link)
matching_edges = generate_matching_edges(
link_node, command.command, context
link_node, command.command
)
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
@@ -334,20 +230,13 @@ def invoke_cli():
if "link" in args and args["link"]:
for link in args["link"]:
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
node_id = link[0]
if re_negid.match(node_id):
node_id = str(current_id + int(node_id))
# TODO: handle missing input/output
node_output = get_node_outputs(context.session.graph.get_node(node_id), context)[link[1]]
node_input = get_node_inputs(command.command, context)[link[2]]
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
edges.append(
Edge(
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
source=EdgeConnection(node_id=link[1], field=link[0]),
destination=EdgeConnection(
node_id=command.command.id, field=link[2]
)
)
)
@@ -356,22 +245,22 @@ def invoke_cli():
current_id = current_id + 1
# Add the node to the session
context.add_node(command.command)
context.session.add_node(command.command)
for edge in edges:
print(edge)
context.add_edge(edge)
context.session.add_edge(edge)
# Execute all remaining nodes
invoke_all(context)
except InvalidArgs:
invoker.services.logger.warning('Invalid command, use "help" to list commands')
print('Invalid command, use "help" to list commands')
continue
except SessionError:
# Start a new session
invoker.services.logger.warning("Session error: creating a new session")
context.reset()
print("Session error: creating a new session")
context.session = context.invoker.create_execution_state()
except ExitCli:
break

View File

@@ -2,7 +2,7 @@
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from typing import get_args, get_type_hints
from pydantic import BaseModel, Field
@@ -76,56 +76,3 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
#fmt: on
# TODO: figure out a better way to provide these hints
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
class UIConfig(TypedDict, total=False):
type_hints: Dict[
str,
Literal[
"integer",
"float",
"boolean",
"string",
"enum",
"image",
"latents",
"model",
],
]
tags: List[str]
title: str
class CustomisedSchemaExtra(TypedDict):
ui: UIConfig
class InvocationConfig(BaseModel.Config):
"""Customizes pydantic's BaseModel.Config class for use by Invocations.
Provide `schema_extra` a `ui` dict to add hints for generated UIs.
`tags`
- A list of strings, used to categorise invocations.
`type_hints`
- A dict of field types which override the types in the invocation definition.
- Each key should be the name of one of the invocation's fields.
- Each value should be one of the valid types:
- `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`
```python
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"initial_image": "image",
},
},
}
```
"""
schema_extra: CustomisedSchemaExtra

View File

@@ -1,64 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional
import numpy as np
import numpy.random
from pydantic import Field
from .baseinvocation import (
BaseInvocation,
InvocationConfig,
InvocationContext,
BaseInvocationOutput,
)
class IntCollectionOutput(BaseInvocationOutput):
"""A collection of integers"""
type: Literal["int_collection"] = "int_collection"
# Outputs
collection: list[int] = Field(default=[], description="The int collection")
class RangeInvocation(BaseInvocation):
"""Creates a range"""
type: Literal["range"] = "range"
# Inputs
start: int = Field(default=0, description="The start of the range")
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(
collection=list(range(self.start, self.stop, self.step))
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""
type: Literal["random_range"] = "random_range"
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate")
seed: Optional[int] = Field(
ge=0,
le=np.iinfo(np.int32).max,
description="The seed for the RNG",
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
)
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
rng = np.random.default_rng(self.seed)
return IntCollectionOutput(
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
)

View File

@@ -1,245 +0,0 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from compel import Compel
from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment,
)
from invokeai.backend.globals import Globals
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
class CompelOutput(BaseInvocationOutput):
"""Compel parser output"""
#fmt: off
type: Literal["compel_output"] = "compel_output"
conditioning: ConditioningField = Field(default=None, description="Conditioning")
#fmt: on
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt")
model: str = Field(default="", description="Model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
model = choose_model(context.services.model_manager, self.model)
pipeline = model["model"]
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
# TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast"
#use_full_precision = False
# TODO: redo TI when separate model loding implemented
#textual_inversion_manager = TextualInversionManager(
# tokenizer=tokenizer,
# text_encoder=text_encoder,
# full_precision=use_full_precision,
#)
def load_huggingface_concepts(concepts: list[str]):
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
# apply the concepts library to the prompt
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.prompt,
lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(),
)
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_str
)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
# TODO: support legacy blend?
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
if getattr(Globals, "log_tokenization", False):
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
[
get_max_token_count(tokenizer, c, truncate_if_too_long)
for c in blend.prompts
]
)
else:
return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
)
def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]:
if type(parsed_prompt) is Blend:
raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children"
)
text_fragments = [
x.text
if type(x) is Fragment
else (
" ".join([f.text for f in x.original])
if type(x) is CrossAttentionControlSubstitute
else str(x)
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
for i, c in enumerate(blend.prompts):
log_tokenization_for_prompt_object(
c,
tokenizer,
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
)
elif type(p) is FlattenedPrompt:
flattened_prompt: FlattenedPrompt = p
if flattened_prompt.wants_cross_attention_control:
original_fragments = []
edited_fragments = []
for f in flattened_prompt.children:
if type(f) is CrossAttentionControlSubstitute:
original_fragments += f.original
edited_fragments += f.edited
else:
original_fragments.append(f)
edited_fragments.append(f)
original_text = " ".join([x.text for x in original_fragments])
log_tokenization_for_text(
original_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap originals)",
)
edited_text = " ".join([x.text for x in edited_fragments])
log_tokenization_for_text(
edited_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap replacements)",
)
else:
text = " ".join([x.text for x in flattened_prompt.children])
log_tokenization_for_text(
text, tokenizer, display_label=display_label_prefix
)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace("</w>", " ")
# alternate color
s = (usedTokens % 6) + 1
if truncate_if_too_long and i >= tokenizer.model_max_length:
discarded = discarded + f"\x1b[0;3{s};40m{token}"
else:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f"{tokenized}\x1b[0m")
if discarded != "":
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
print(f"{discarded}\x1b[0m")

View File

@@ -5,26 +5,14 @@ from typing import Literal
import cv2 as cv
import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class CvInvocationConfig(BaseModel):
"""Helper class to provide all OpenCV invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["cv", "image"],
},
}
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
class CvInpaintInvocation(BaseInvocation):
"""Simple inpaint using opencv."""
#fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"
@@ -56,14 +44,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
)

View File

@@ -1,41 +1,29 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial
from typing import Literal, Optional, Union
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
import numpy as np
from torch import Tensor
from PIL import Image
from pydantic import Field
from skimage.exposure.histogram_matching import match_histograms
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator, Generator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["stable-diffusion", "image"],
"type_hints": {
"model": "model",
},
},
}
from ...backend.util.util import image_to_dataURL
SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
]
# Text to image
class TextToImageInvocation(BaseInvocation, SDImageInvocation):
class TextToImageInvocation(BaseInvocation):
"""Generates an image using text2img."""
type: Literal["txt2img"] = "txt2img"
@@ -46,10 +34,10 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
width: int = Field(default=512, multiple_of=64, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=64, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
sampler_name: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The sampler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
@@ -57,31 +45,41 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
self, context: InvocationContext, sample: Tensor, step: int
) -> None:
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
self.id,
{
"width": width,
"height": height,
"dataURL": dataURL
},
step,
self.steps,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state.latents, state.step)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
# (right now uses whatever current model is set in model manager)
model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
step_callback=step_callback,
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
@@ -97,18 +95,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image,
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
@@ -127,19 +116,6 @@ class ImageToImageInvocation(TextToImageInvocation):
description="Whether or not the result should be fit to the aspect ratio of the input image",
)
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
@@ -150,31 +126,24 @@ class ImageToImageInvocation(TextToImageInvocation):
)
mask = None
if self.fit:
image = image.resize((self.width, self.height))
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
model = context.services.model_manager.get_model()
generator_output = next(
Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image
@@ -185,19 +154,11 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@@ -212,19 +173,6 @@ class InpaintInvocation(ImageToImageInvocation):
description="The amount by which to replace masked areas with latent noise",
)
def dispatch_progress(
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
None
@@ -239,28 +187,24 @@ class InpaintInvocation(ImageToImageInvocation):
else context.services.images.get(self.mask.image_type, self.mask.image_name)
)
def step_callback(sample, step=0):
self.dispatch_progress(context, sample, step)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
manager = context.services.model_manager.get_model()
generator_output = next(
Inpaint(model).generate(
prompt=self.prompt,
init_image=image,
mask_image=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_image=image,
mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image
@@ -271,14 +215,7 @@ class InpaintInvocation(ImageToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -1,97 +1,54 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
image_type: str = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
#fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {
"required": ["type", "image", "width", "height", "mode"]
}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
"""Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
)
return ImageOutput(
image=image_field,
width=image.width,
height=image.height,
mode=image.mode,
)
#fmt: on
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
#fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
#fomt: on
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output."""
# fmt: off
"""Load an image from a filename and provide it as output."""
#fmt: off
type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image_type, self.image_name)
#fmt: on
return build_image_output(
image_type=self.image_type,
image_name=self.image_name,
image=image,
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image=ImageField(image_type=self.image_type, image_name=self.image_name)
)
@@ -112,17 +69,16 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
return ImageOutput(
image=ImageField(
image_type=self.image.image_type, image_name=self.image.image_name
)
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
class CropImageInvocation(BaseInvocation):
"""Crops an image to a specified box. The box can be outside of the image."""
# fmt: off
#fmt: off
type: Literal["crop"] = "crop"
# Inputs
@@ -131,7 +87,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@@ -147,23 +103,15 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_crop,
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
class PasteImageInvocation(BaseInvocation):
"""Pastes an image into another image."""
# fmt: off
#fmt: off
type: Literal["paste"] = "paste"
# Inputs
@@ -172,7 +120,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
@@ -185,7 +133,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
None
if self.mask is None
else ImageOps.invert(
context.services.images.get(self.mask.image_type, self.mask.image_name)
services.images.get(self.mask.image_type, self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@@ -205,29 +153,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=new_image,
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
class MaskFromAlphaInvocation(BaseInvocation):
"""Extracts the alpha channel of an image as a mask."""
# fmt: off
#fmt: off
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
@@ -242,27 +182,22 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_mask, metadata)
context.services.images.save(image_type, image_name, image_mask)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig):
class BlurInvocation(BaseInvocation):
"""Blurs an image"""
# fmt: off
#fmt: off
type: Literal["blur"] = "blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -279,28 +214,22 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
class LerpInvocation(BaseInvocation):
"""Linear interpolation of all pixels of an image"""
# fmt: off
#fmt: off
type: Literal["lerp"] = "lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@@ -316,29 +245,23 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
class InverseLerpInvocation(BaseInvocation):
"""Inverse linear interpolation of all pixels of an image"""
# fmt: off
#fmt: off
type: Literal["ilerp"] = "ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
#fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -358,12 +281,7 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)

View File

@@ -1,435 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional
from pydantic import BaseModel, Field
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
class Config:
schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
type: Literal["latent_output"] = "latent_output"
latents: LatentsField = Field(default=None, description="The output latents")
#fmt: on
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
noise: LatentsField = Field(default=None, description="The output noise")
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_channels:int=4, use_mps_noise:bool=False, downsampling_factor:int = 8):
# limit noise to only the diffusion image channels, not the mask channels
input_channels = min(latent_channels, 4)
use_device = "cpu" if (use_mps_noise or device.type == "mps") else device
generator = torch.Generator(device=use_device).manual_seed(seed)
x = torch.randn(
[
1,
input_channels,
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
device=use_device,
generator=generator,
).to(device)
# if self.perlin > 0.0:
# perlin_noise = self.get_perlin_noise(
# width // self.downsampling_factor, height // self.downsampling_factor
# )
# x = (1 - self.perlin) * x + self.perlin * perlin_noise
return x
def random_seed():
return random.randint(0, np.iinfo(np.uint32).max)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
},
}
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
return NoiseOutput(
noise=LatentsField(latents_name=name)
)
# Text to image
class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l"] = "t2l"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
}
},
}
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model)
model_name = model_info['model_name']
model_hash = model_info['hash']
model: StableDiffusionGeneratorPipeline = model_info['model']
model.scheduler = get_scheduler(
model=model,
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData(
uc,
c,
self.cfg_scale,
extra_conditioning_info,
postprocessing_settings=PostprocessingSettings(
threshold=0.0,#threshold,
warmup=0.2,#warmup,
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
return conditioning_data
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
"""Generates latents using latents as base image."""
type: Literal["l2l"] = "l2l"
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model"
}
},
}
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
# Latent to image
class LatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
model: str = Field(default="", description="The model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=image
)
LATENTS_INTERPOLATION_MODE = Literal[
"nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
]
class ResizeLatentsInvocation(BaseInvocation):
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
resized_latents = torch.nn.functional.interpolate(
latents,
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
class ScaleLatentsInvocation(BaseInvocation):
"""Scales latents by a given factor."""
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# resizing
resized_latents = torch.nn.functional.interpolate(
latents,
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))

View File

@@ -1,75 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from pydantic import BaseModel, Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
class MathInvocationConfig(BaseModel):
"""Helper class to provide all math invocations with additional config"""
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["math"],
}
}
class IntOutput(BaseInvocationOutput):
"""An integer output"""
#fmt: off
type: Literal["int_output"] = "int_output"
a: int = Field(default=None, description="The output integer")
#fmt: on
class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers"""
#fmt: off
type: Literal["add"] = "add"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a + self.b)
class SubtractInvocation(BaseInvocation, MathInvocationConfig):
"""Subtracts two numbers"""
#fmt: off
type: Literal["sub"] = "sub"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a - self.b)
class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
"""Multiplies two numbers"""
#fmt: off
type: Literal["mul"] = "mul"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a * self.b)
class DivideInvocation(BaseInvocation, MathInvocationConfig):
"""Divides two numbers"""
#fmt: off
type: Literal["div"] = "div"
a: int = Field(default=0, description="The first number")
b: int = Field(default=0, description="The second number")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))

View File

@@ -1,18 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput
# Pass-through parameter nodes - used by subgraphs
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
#fmt: off
type: Literal["param_int"] = "param_int"
a: int = Field(default=0, description="The integer value")
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)

View File

@@ -12,11 +12,3 @@ class PromptOutput(BaseInvocationOutput):
prompt: str = Field(default=None, description="The output prompt")
#fmt: on
class Config:
schema_extra = {
'required': [
'type',
'prompt',
]
}

View File

@@ -1,11 +1,12 @@
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
@@ -17,14 +18,6 @@ class RestoreFaceInvocation(BaseInvocation):
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
#fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["restoration", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -43,14 +36,7 @@ class RestoreFaceInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@@ -1,12 +1,14 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
class UpscaleInvocation(BaseInvocation):
@@ -20,15 +22,6 @@ class UpscaleInvocation(BaseInvocation):
level: Literal[2, 4] = Field(default=2, description="The upscale level")
#fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@@ -47,14 +40,7 @@ class UpscaleInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@@ -1,13 +0,0 @@
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
else:
model = model_manager.get_model()
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
return model

View File

@@ -1,3 +0,0 @@
class CanceledException(Exception):
"""Execution canceled by user."""
pass

View File

@@ -1,29 +0,0 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
def is_image_type(obj):
try:
ImageType(obj)
except ValueError:
return False
return True
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_type: ImageType = Field(
default=ImageType.RESULT, description="The type of the image"
)
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_type", "image_name"]}

View File

@@ -1,63 +0,0 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
def create_text_to_image() -> LibraryGraph:
return LibraryGraph(
id=default_text_to_image_graph_id,
name='t2i',
description='Converts text to an image',
graph=Graph(
nodes={
'width': ParamIntInvocation(id='width', a=512),
'height': ParamIntInvocation(id='height', a=512),
'seed': ParamIntInvocation(id='seed', a=-1),
'3': NoiseInvocation(id='3'),
'4': CompelInvocation(id='4'),
'5': CompelInvocation(id='5'),
'6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'),
},
edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
]
),
exposed_inputs=[
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height'),
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
],
exposed_outputs=[
ExposedNodeOutput(node_path='7', field='image', alias='image')
])
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
#if text_to_image is None:
text_to_image = create_text_to_image()
graph_library.set(text_to_image)
graphs.append(text_to_image)
return graphs

View File

@@ -1,9 +1,10 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp
from typing import Any, Dict, TypedDict
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase:
session_event: str = "session_event"
@@ -13,8 +14,7 @@ class EventServiceBase:
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_session_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
self.dispatch(
event_name=EventServiceBase.session_event,
payload=dict(event=event_name, data=payload),
@@ -25,8 +25,7 @@ class EventServiceBase:
def emit_generator_progress(
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
invocation_id: str,
progress_image: ProgressImage | None,
step: int,
total_steps: int,
@@ -36,60 +35,48 @@ class EventServiceBase:
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
invocation_id=invocation_id,
progress_image=progress_image,
step=step,
total_steps=total_steps,
),
)
def emit_invocation_complete(
self,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
self, graph_execution_state_id: str, invocation_id: str, result: Dict
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
invocation_id=invocation_id,
result=result,
),
)
def emit_invocation_error(
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error: str,
self, graph_execution_state_id: str, invocation_id: str, error: str
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
invocation_id=invocation_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, node: dict, source_node_id: str
self, graph_execution_state_id: str, invocation_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
invocation_id=invocation_id,
),
)
@@ -97,7 +84,5 @@ class EventServiceBase:
"""Emitted when a session has completed all invocations"""
self.__emit_session_event(
event_name="graph_execution_state_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
),
payload=dict(graph_execution_state_id=graph_execution_state_id),
)

View File

@@ -2,6 +2,7 @@
import copy
import itertools
import traceback
import uuid
from types import NoneType
from typing import (
@@ -16,7 +17,7 @@ from typing import (
)
import networkx as nx
from pydantic import BaseModel, root_validator, validator
from pydantic import BaseModel, validator
from pydantic.fields import Field
from ..invocations import *
@@ -25,6 +26,7 @@ from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationContext,
)
from .invocation_services import InvocationServices
class EdgeConnection(BaseModel):
@@ -125,13 +127,6 @@ class NodeAlreadyExecutedError(Exception):
class GraphInvocationOutput(BaseInvocationOutput):
type: Literal["graph_output"] = "graph_output"
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
@@ -152,13 +147,6 @@ class IterateInvocationOutput(BaseInvocationOutput):
item: Any = Field(description="The item being iterated over")
class Config:
schema_extra = {
'required': [
'type',
'item',
]
}
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
@@ -181,13 +169,6 @@ class CollectInvocationOutput(BaseInvocationOutput):
collection: list[Any] = Field(description="The collection of input items")
class Config:
schema_extra = {
'required': [
'type',
'collection',
]
}
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""
@@ -213,7 +194,7 @@ InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()]
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
id: str = Field(description="The id of this graph", default_factory=uuid.uuid4)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
@@ -281,8 +262,7 @@ class Graph(BaseModel):
:raises InvalidEdgeError: the provided edge is invalid.
"""
self._validate_edge(edge)
if edge not in self.edges:
if self._is_edge_valid(edge) and edge not in self.edges:
self.edges.append(edge)
else:
raise InvalidEdgeError()
@@ -353,7 +333,7 @@ class Graph(BaseModel):
return True
def _validate_edge(self, edge: Edge):
def _is_edge_valid(self, edge: Edge) -> bool:
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
@@ -361,53 +341,54 @@ class Graph(BaseModel):
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist")
return False
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
return False
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
return False
# Validate that the field types are compatible
if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field
):
raise InvalidEdgeError(f'Fields are incompatible')
return False
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
return False
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
return False
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Collector output type does not match collector input type')
return False
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Collector input type does not match collector output type')
return False
return True
def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph."""
@@ -731,7 +712,7 @@ class Graph(BaseModel):
for sgn in (
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
):
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded
@@ -748,7 +729,9 @@ class Graph(BaseModel):
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
id: str = Field(
description="The id of the execution state", default_factory=uuid.uuid4
)
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")
@@ -790,6 +773,9 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
# Declare all fields as required; necessary for OpenAPI schema generation build.
# Technically only fields without a `default_factory` need to be listed here.
# See: https://github.com/pydantic/pydantic/discussions/4577
class Config:
schema_extra = {
'required': [
@@ -854,8 +840,7 @@ class GraphExecutionState(BaseModel):
def is_complete(self) -> bool:
"""Returns true if the graph is complete"""
node_ids = set(self.graph.nx_graph_flat().nodes)
return self.has_error() or all((k in self.executed for k in node_ids))
return self.has_error() or all((k in self.executed for k in self.graph.nodes))
def has_error(self) -> bool:
"""Returns true if the graph has any errors"""
@@ -943,11 +928,11 @@ class GraphExecutionState(BaseModel):
def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph_flat()
g = self.graph.nx_graph()
collectors = (
n
for n in self.graph.nodes
if isinstance(self.graph.get_node(n), CollectInvocation)
if isinstance(self.graph.nodes[n], CollectInvocation)
)
for c in collectors:
g.remove_edges_from(list(g.in_edges(c)))
@@ -959,7 +944,7 @@ class GraphExecutionState(BaseModel):
iterators = [
n
for n in nx.ancestors(g, node_id)
if isinstance(self.graph.get_node(n), IterateInvocation)
if isinstance(self.graph.nodes[n], IterateInvocation)
]
return iterators
@@ -1063,8 +1048,9 @@ class GraphExecutionState(BaseModel):
n
for n in prepared_nodes
if all(
nx.has_path(execution_graph, pit[0], n)
pit
for pit in parent_iterators
if nx.has_path(execution_graph, pit[0], n)
)
),
None,
@@ -1095,9 +1081,7 @@ class GraphExecutionState(BaseModel):
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: Edge) -> bool:
try:
self.graph._validate_edge(edge)
except InvalidEdgeError:
if not self._is_edge_valid(edge):
return False
# Invalid if destination has already been prepared or executed
@@ -1143,52 +1127,4 @@ class GraphExecutionState(BaseModel):
self.graph.delete_edge(edge)
class ExposedNodeInput(BaseModel):
node_path: str = Field(description="The node path to the node with the input")
field: str = Field(description="The field name of the input")
alias: str = Field(description="The alias of the input")
class ExposedNodeOutput(BaseModel):
node_path: str = Field(description="The node path to the node with the output")
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
@validator('exposed_inputs', 'exposed_outputs')
def validate_exposed_aliases(cls, v):
if len(v) != len(set(i.alias for i in v)):
raise ValueError("Duplicate exposed alias")
return v
@root_validator
def validate_exposed_nodes(cls, values):
graph = values['graph']
# Validate exposed inputs
for exposed_input in values['exposed_inputs']:
if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None:
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
# Validate exposed outputs
for exposed_output in values['exposed_outputs']:
if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None:
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
return values
GraphInvocation.update_forward_refs()

View File

@@ -1,29 +1,22 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import datetime
import os
from glob import glob
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Dict, List
from typing import Dict
from PIL.Image import Image
import PIL.Image as PILImage
from send2trash import send2trash
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
SavedImage,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from invokeai.backend.image_util import PngWriter
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
class ImageStorageBase(ABC):
@@ -31,74 +24,40 @@ class ImageStorageBase(ABC):
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the external URI to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
def get_path(self, image_type: ImageType, image_name: str) -> str:
pass
@abstractmethod
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
def create_name(self, context_id: str, node_id: str) -> str:
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__pngWriter: PngWriter
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
def __init__(self, output_folder: str):
self.__output_folder = output_folder
self.__pngWriter = PngWriter(output_folder)
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True)
@@ -107,61 +66,6 @@ class DiskImageStorage(ImageStorageBase):
Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=self.get_uri(image_type, filename),
thumbnail_url=self.get_uri(image_type, filename, True),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
@@ -169,97 +73,33 @@ class DiskImageStorage(ImageStorageBase):
if cache_item:
return cache_item
image = PILImage.open(image_path)
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
def get_path(self, image_type: ImageType, image_name: str) -> str:
path = os.path.join(self.__output_folder, image_type, image_name)
return path
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None
) # TODO: just pass full path to png writer
abspath = os.path.abspath(path)
return abspath
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
thumbnail_basename = get_thumbnail_name(basename)
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
else:
uri = f"api/v1/images/{image_type.value}/{basename}"
return uri
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except Exception:
return False
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
def delete(self, image_type: ImageType, image_name: str) -> None:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
image_path = self.get_path(image_type, image_name)
if os.path.exists(image_path):
send2trash(image_path)
os.remove(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
def __get_cache(self, image_name: str) -> Image | None:
def __get_cache(self, image_name: str) -> Image:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image):

View File

@@ -1,17 +1,30 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import time
from abc import ABC, abstractmethod
from queue import Queue
from pydantic import BaseModel, Field
import time
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)
# TODO: make this serializable
class InvocationQueueItem:
# session_id: str
graph_execution_state_id: str
invocation_id: str
invoke_all: bool
timestamp: float
def __init__(
self,
# session_id: str,
graph_execution_state_id: str,
invocation_id: str,
invoke_all: bool = False,
):
# self.session_id = session_id
self.graph_execution_state_id = graph_execution_state_id
self.invocation_id = invocation_id
self.invoke_all = invoke_all
self.timestamp = time.time()
class InvocationQueueABC(ABC):

View File

@@ -1,11 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import types
from invokeai.app.services.metadata import MetadataServiceBase
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.backend import ModelManager
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
@@ -15,15 +11,12 @@ class InvocationServices:
"""Services that can be used by invocations"""
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
processor: "InvocationProcessorABC"
@@ -31,24 +24,16 @@ class InvocationServices:
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.metadata = metadata
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor
self.restoration = restoration

View File

@@ -33,6 +33,7 @@ class Invoker:
self.services.graph_execution_manager.set(graph_execution_state)
# Queue the invocation
print(f"queueing item {invocation.id}")
self.services.queue.put(
InvocationQueueItem(
# session_id = session.id,
@@ -49,7 +50,7 @@ class Invoker:
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
self.services.graph_execution_manager.set(new_state)
return new_state
def cancel(self, graph_execution_state_id: str) -> None:
"""Cancels the given execution state"""
self.services.queue.cancel(graph_execution_state_id)
@@ -71,12 +72,18 @@ class Invoker:
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
def stop(self) -> None:
"""Stops the invoker. A new invoker will have to be created to execute further."""
# First stop all services
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
for service in vars(self.services):
self.__stop_service(getattr(self.services, service))
self.services.queue.put(None)

View File

@@ -1,93 +0,0 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict
import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
@abstractmethod
def get(self, name: str) -> torch.Tensor:
pass
@abstractmethod
def set(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
def delete(self, name: str) -> None:
pass
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
__cache: Dict[str, torch.Tensor]
__cache_ids: Queue
__max_cache_size: int
__underlying_storage: LatentsStorageBase
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
self.__underlying_storage = underlying_storage
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = max_cache_size
def get(self, name: str) -> torch.Tensor:
cache_item = self.__get_cache(name)
if cache_item is not None:
return cache_item
latent = self.__underlying_storage.get(name)
self.__set_cache(name, latent)
return latent
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
self.__underlying_storage.delete(name)
if name in self.__cache:
del self.__cache[name]
def __get_cache(self, name: str) -> torch.Tensor|None:
return None if name not in self.__cache else self.__cache[name]
def __set_cache(self, name: str, data: torch.Tensor):
if not name in self.__cache:
self.__cache[name] = data
self.__cache_ids.put(name)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())
class DiskLatentsStorage(LatentsStorageBase):
"""Stores latents in a folder on disk without caching"""
__output_folder: str
def __init__(self, output_folder: str):
self.__output_folder = output_folder
Path(output_folder).mkdir(parents=True, exist_ok=True)
def get(self, name: str) -> torch.Tensor:
latent_path = self.get_path(name)
return torch.load(latent_path)
def set(self, name: str, data: torch.Tensor) -> None:
latent_path = self.get_path(name)
torch.save(data, latent_path)
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
os.remove(latent_path)
def get_path(self, name: str) -> str:
return os.path.join(self.__output_folder, name)

View File

@@ -1,96 +0,0 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypedDict
from PIL import Image, PngImagePlugin
from pydantic import BaseModel
from invokeai.app.models.image import ImageType, is_image_type
class MetadataImageField(TypedDict):
"""Pydantic-less ImageField, used for metadata parsing."""
image_type: ImageType
image_name: str
class MetadataLatentsField(TypedDict):
"""Pydantic-less LatentsField, used for metadata parsing."""
latents_name: str
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
]
class InvokeAIMetadata(TypedDict, total=False):
"""InvokeAI-specific metadata format."""
session_id: Optional[str]
node: Optional[NodeMetadata]
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai", json.dumps(metadata))
return pnginfo
class MetadataServiceBase(ABC):
@abstractmethod
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
pass
@abstractmethod
def build_metadata(
self, session_id: str, node: BaseModel
) -> InvokeAIMetadata | None:
"""Builds an InvokeAIMetadata object"""
pass
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
def _load_metadata(self, image: Image.Image) -> dict | None:
"""Loads a specific info entry from a PIL Image."""
try:
info = image.info.get("invokeai")
if type(info) is not str:
return None
loaded_metadata = json.loads(info)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
return metadata

View File

@@ -5,7 +5,6 @@ from argparse import Namespace
from invokeai.backend import Args
from omegaconf import OmegaConf
from pathlib import Path
from typing import types
import invokeai.version
from ...backend import ModelManager
@@ -13,16 +12,16 @@ from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: Replace with an abstract class base ModelManagerBase
def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
def get_model_manager(config: Args) -> ModelManager:
if not config.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
report_model_error(
config, FileNotFoundError(f"The file {config_file} could not be found."), logger
config, FileNotFoundError(f"The file {config_file} could not be found.")
)
logger.info(f"{invokeai.version.__app_name__}, version {invokeai.version.__version__}")
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
print(f">> {invokeai.version.__app_name__}, version {invokeai.version.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
@@ -63,12 +62,11 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
device_type=device,
max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path),
logger = logger,
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(config, e, logger)
report_model_error(config, e)
except (IOError, KeyError) as e:
logger.error(f"{e}. Aborting.")
print(f"{e}. Aborting.")
sys.exit(-1)
# try to autoconvert new models
@@ -78,18 +76,18 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
conf_path=config.conf,
weights_directory=path,
)
logger.info('Model manager initialized')
return model_manager
def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
logger.error(f'An error occurred while attempting to initialize the model: "{str(e)}"')
logger.error(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
def report_model_error(opt: Namespace, e: Exception):
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print(
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
logger.warning(
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
@@ -98,12 +96,13 @@ def report_model_error(opt: Namespace, e: Exception, logger: types.ModuleType):
if response.startswith(("n", "N")):
return
logger.info("invokeai-configure is launching....\n")
print("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
root_dir = ["--root", opt.root_dir] if opt.root_dir is not None else []
config = ["--config", opt.conf] if opt.conf is not None else []
previous_config = sys.argv
sys.argv = ["invokeai-configure"]
sys.argv.extend(root_dir)
sys.argv.extend(config.to_dict())

View File

@@ -1,20 +1,17 @@
import traceback
from threading import Event, Thread, BoundedSemaphore
from threading import Event, Thread
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
__invoker: Invoker
__threadLimit: BoundedSemaphore
def start(self, invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__threadLimit = BoundedSemaphore(1)
self.__invoker = invoker
self.__stop_event = Event()
self.__invoker_thread = Thread(
@@ -23,7 +20,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
kwargs=dict(stop_event=self.__stop_event),
)
self.__invoker_thread.daemon = (
True # TODO: make async and do not use threads
True # TODO: probably better to just not use threads?
)
self.__invoker_thread.start()
@@ -32,7 +29,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
def __process(self, stop_event: Event):
try:
self.__threadLimit.acquire()
while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
if not queue_item: # Probably stopping
@@ -47,14 +43,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item.invocation_id
)
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id
invocation_id=invocation.id,
)
# Invoke
@@ -83,17 +75,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
invocation_id=invocation.id,
result=outputs.dict(),
)
except KeyboardInterrupt:
pass
except CanceledException:
pass
except Exception as e:
error = traceback.format_exc()
@@ -108,13 +96,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
invocation_id=invocation.id,
error=error,
)
pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
@@ -131,6 +118,4 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
)
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
finally:
self.__threadLimit.release()
... # Log something?

View File

@@ -1,7 +1,6 @@
import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
@@ -11,7 +10,7 @@ from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
class RestorationServices:
'''Face restoration and upscaling'''
def __init__(self,args,logger:types.ModuleType):
def __init__(self,args):
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
@@ -21,22 +20,20 @@ class RestorationServices:
args.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
print(">> Face restoration disabled")
if args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")
print(">> Upscaling disabled")
else:
logger.info("Face restoration and upscaling disabled")
print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
self.logger = logger
self.logger.info('Face restoration initialized')
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
@@ -61,15 +58,15 @@ class RestorationServices:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
print(
">> GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
print(
">> CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
@@ -83,7 +80,7 @@ class RestorationServices:
fidelity=codeformer_fidelity,
)
else:
self.logger.info("Face Restoration is disabled.")
print(">> Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
@@ -96,10 +93,10 @@ class RestorationServices:
denoise_str=upscale_denoise_str,
)
else:
self.logger.info("ESRGAN is disabled. Image not upscaled.")
print(">> ESRGAN is disabled. Image not upscaled.")
except Exception as e:
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
print(
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:

View File

@@ -59,7 +59,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(item)
@@ -85,7 +84,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id)

View File

@@ -1,5 +0,0 @@
import datetime
def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())

View File

@@ -1,55 +0,0 @@
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
def stable_diffusion_step_callback(
context: InvocationContext,
intermediate_state: PipelineIntermediateState,
node: dict,
source_node_id: str,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be. Use
# that estimate if it is available.
if intermediate_state.predicted_original is not None:
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
(width, height) = image.size
width *= 8
height *= 8
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
total_steps=node["steps"],
)

View File

@@ -1,15 +0,0 @@
import os
from PIL import Image
def get_thumbnail_name(image_name: str) -> str:
"""Formats given an image name, returns the appropriate thumbnail image name"""
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
return thumbnail_name
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
"""Makes a thumbnail from a PIL Image"""
thumbnail = image.copy()
thumbnail.thumbnail(size=(size, size))
return thumbnail

View File

@@ -10,7 +10,7 @@ from .generator import (
Img2Img,
Inpaint
)
from .model_management import ModelManager, SDModelComponent
from .model_management import ModelManager
from .safety_checker import SafetyChecker
from .args import Args
from .globals import Globals

View File

@@ -96,7 +96,6 @@ from pathlib import Path
from typing import List
import invokeai.version
import invokeai.backend.util.logging as logger
from invokeai.backend.image_util import retrieve_metadata
from .globals import Globals
@@ -190,7 +189,7 @@ class Args(object):
print(f"{APP_NAME} {APP_VERSION}")
sys.exit(0)
logger.info("Initializing, be patient...")
print("* Initializing, be patient...")
Globals.root = Path(os.path.abspath(switches.root_dir or Globals.root))
Globals.try_patchmatch = switches.patchmatch
@@ -198,13 +197,14 @@ class Args(object):
initfile = os.path.expanduser(os.path.join(Globals.root, Globals.initfile))
legacyinit = os.path.expanduser("~/.invokeai")
if os.path.exists(initfile):
logger.info(
f"Initialization file {initfile} found. Loading...",
print(
f">> Initialization file {initfile} found. Loading...",
file=sys.stderr,
)
sysargs.insert(0, f"@{initfile}")
elif os.path.exists(legacyinit):
logger.warning(
f"Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
print(
f">> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init."
)
sysargs.insert(0, f"@{legacyinit}")
Globals.log_tokenization = self._arg_parser.parse_args(
@@ -214,7 +214,7 @@ class Args(object):
self._arg_switches = self._arg_parser.parse_args(sysargs)
return self._arg_switches
except Exception as e:
logger.error(f"An exception has occurred: {e}")
print(f"An exception has occurred: {e}")
return None
def parse_cmd(self, cmd_string):
@@ -561,7 +561,7 @@ class Args(object):
"--autoimport",
default=None,
type=str,
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
)
model_group.add_argument(
"--autoconvert",
@@ -1154,7 +1154,7 @@ class Args(object):
def format_metadata(**kwargs):
logger.warning("format_metadata() is deprecated. Please use metadata_dumps()")
print("format_metadata() is deprecated. Please use metadata_dumps()")
return metadata_dumps(kwargs)
@@ -1326,7 +1326,7 @@ def metadata_loads(metadata) -> list:
import sys
import traceback
logger.error("Could not read metadata")
print(">> could not read metadata", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return results

View File

@@ -67,6 +67,7 @@ def install_requested_models(
scan_directory: Path = None,
external_models: List[str] = None,
scan_at_startup: bool = False,
convert_to_diffusers: bool = False,
precision: str = "float16",
purge_deleted: bool = False,
config_file_path: Path = None,
@@ -112,6 +113,7 @@ def install_requested_models(
try:
model_manager.heuristic_import(
path_url_or_repo,
convert=convert_to_diffusers,
commit_to_conf=config_file_path,
)
except KeyboardInterrupt:
@@ -120,7 +122,7 @@ def install_requested_models(
pass
if scan_at_startup and scan_directory.is_dir():
argument = "--autoconvert"
argument = "--autoconvert" if convert_to_diffusers else "--autoimport"
initfile = Path(Globals.root, Globals.initfile)
replacement = Path(Globals.root, f"{Globals.initfile}.new")
directory = str(scan_directory).replace("\\", "/")

View File

@@ -27,7 +27,6 @@ from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf
from pathlib import Path
import invokeai.backend.util.logging as logger
from .args import metadata_from_png
from .generator import infill_methods
from .globals import Globals, global_cache_dir
@@ -196,12 +195,12 @@ class Generate:
# device to Generate(). However the device was then ignored, so
# it wasn't actually doing anything. This logic could be reinstated.
self.device = torch.device(choose_torch_device())
logger.info(f"Using device_type {self.device.type}")
print(f">> Using device_type {self.device.type}")
if full_precision:
if self.precision != "auto":
raise ValueError("Remove --full_precision / -F if using --precision")
logger.warning("Please remove deprecated --full_precision / -F")
logger.warning("If auto config does not work you can use --precision=float32")
print("Please remove deprecated --full_precision / -F")
print("If auto config does not work you can use --precision=float32")
self.precision = "float32"
if self.precision == "auto":
self.precision = choose_precision(self.device)
@@ -209,13 +208,13 @@ class Generate:
if is_xformers_available():
if torch.cuda.is_available() and not Globals.disable_xformers:
logger.info("xformers memory-efficient attention is available and enabled")
print(">> xformers memory-efficient attention is available and enabled")
else:
logger.info(
"xformers memory-efficient attention is available but disabled"
print(
">> xformers memory-efficient attention is available but disabled"
)
else:
logger.info("xformers not installed")
print(">> xformers not installed")
# model caching system for fast switching
self.model_manager = ModelManager(
@@ -230,8 +229,8 @@ class Generate:
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
model = model or fallback
if not self.model_manager.valid_model(model):
logger.warning(
f'"{model}" is not a known model name; falling back to {fallback}.'
print(
f'** "{model}" is not a known model name; falling back to {fallback}.'
)
model = None
self.model_name = model or fallback
@@ -247,10 +246,10 @@ class Generate:
# load safety checker if requested
if safety_checker:
logger.info("Initializing NSFW checker")
print(">> Initializing NSFW checker")
self.safety_checker = SafetyChecker(self.device)
else:
logger.info("NSFW checker is disabled")
print(">> NSFW checker is disabled")
def prompt2png(self, prompt, outdir, **kwargs):
"""
@@ -568,7 +567,7 @@ class Generate:
self.clear_cuda_cache()
if catch_interrupts:
logger.warning("Interrupted** Partial results will be returned.")
print("**Interrupted** Partial results will be returned.")
else:
raise KeyboardInterrupt
except RuntimeError:
@@ -576,11 +575,11 @@ class Generate:
self.clear_cuda_cache()
print(traceback.format_exc(), file=sys.stderr)
logger.info("Could not generate image.")
print(">> Could not generate image.")
toc = time.time()
logger.info("Usage stats:")
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
print("\n>> Usage stats:")
print(f">> {len(results)} image(s) generated in", "%4.2fs" % (toc - tic))
self.print_cuda_stats()
return results
@@ -610,16 +609,16 @@ class Generate:
def print_cuda_stats(self):
if self._has_cuda():
self.gather_cuda_stats()
logger.info(
"Max VRAM used for this generation: "+
"%4.2fG. " % (self.max_memory_allocated / 1e9)+
"Current VRAM utilization: "+
"%4.2fG" % (self.memory_allocated / 1e9)
print(
">> Max VRAM used for this generation:",
"%4.2fG." % (self.max_memory_allocated / 1e9),
"Current VRAM utilization:",
"%4.2fG" % (self.memory_allocated / 1e9),
)
logger.info(
"Max VRAM used since script start: " +
"%4.2fG" % (self.session_peakmem / 1e9)
print(
">> Max VRAM used since script start: ",
"%4.2fG" % (self.session_peakmem / 1e9),
)
# this needs to be generalized to all sorts of postprocessors, which should be wrapped
@@ -648,7 +647,7 @@ class Generate:
seed = random.randrange(0, np.iinfo(np.uint32).max)
prompt = opt.prompt or args.prompt or ""
logger.info(f'using seed {seed} and prompt "{prompt}" for {image_path}')
print(f'>> using seed {seed} and prompt "{prompt}" for {image_path}')
# try to reuse the same filename prefix as the original file.
# we take everything up to the first period
@@ -697,8 +696,8 @@ class Generate:
try:
extend_instructions[direction] = int(pixels)
except ValueError:
logger.warning(
'invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
print(
'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"'
)
opt.seed = seed
@@ -721,8 +720,8 @@ class Generate:
# fetch the metadata from the image
generator = self.select_generator(embiggen=True)
opt.strength = opt.embiggen_strength or 0.40
logger.info(
f"Setting img2img strength to {opt.strength} for happy embiggening"
print(
f">> Setting img2img strength to {opt.strength} for happy embiggening"
)
generator.generate(
prompt,
@@ -749,12 +748,12 @@ class Generate:
return restorer.process(opt, args, image_callback=callback, prefix=prefix)
elif tool is None:
logger.warning(
"please provide at least one postprocessing option, such as -G or -U"
print(
"* please provide at least one postprocessing option, such as -G or -U"
)
return None
else:
logger.warning(f"postprocessing tool {tool} is not yet supported")
print(f"* postprocessing tool {tool} is not yet supported")
return None
def select_generator(
@@ -798,8 +797,8 @@ class Generate:
image = self._load_img(img)
if image.width < self.width and image.height < self.height:
logger.warning(
f"img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
print(
f">> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions"
)
# if image has a transparent area and no mask was provided, then try to generate mask
@@ -810,8 +809,8 @@ class Generate:
if (image.width * image.height) > (
self.width * self.height
) and self.size_matters:
logger.info(
"This input is larger than your defaults. If you run out of memory, please use a smaller image."
print(
">> This input is larger than your defaults. If you run out of memory, please use a smaller image."
)
self.size_matters = False
@@ -892,11 +891,11 @@ class Generate:
try:
model_data = cache.get_model(model_name)
except Exception as e:
logger.warning(f"model {model_name} could not be loaded: {str(e)}")
print(f"** model {model_name} could not be loaded: {str(e)}")
print(traceback.format_exc(), file=sys.stderr)
if previous_model_name is None:
raise e
logger.warning("trying to reload previous model")
print("** trying to reload previous model")
model_data = cache.get_model(previous_model_name) # load previous
if model_data is None:
raise e
@@ -963,15 +962,15 @@ class Generate:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
logger.info(
"GFPGAN not found. Face restoration is disabled."
print(
">> GFPGAN not found. Face restoration is disabled."
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
logger.info(
"CodeFormer not found. Face restoration is disabled."
print(
">> CodeFormer not found. Face restoration is disabled."
)
else:
cf_device = (
@@ -985,7 +984,7 @@ class Generate:
fidelity=codeformer_fidelity,
)
else:
logger.info("Face Restoration is disabled.")
print(">> Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
@@ -998,10 +997,10 @@ class Generate:
denoise_str=upscale_denoise_str,
)
else:
logger.info("ESRGAN is disabled. Image not upscaled.")
print(">> ESRGAN is disabled. Image not upscaled.")
except Exception as e:
logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
print(
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
)
if image_callback is not None:
@@ -1067,17 +1066,17 @@ class Generate:
if self.sampler_name in scheduler_map:
sampler_class = scheduler_map[self.sampler_name]
msg = (
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
f">> Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
)
self.sampler = sampler_class.from_config(self.model.scheduler.config)
else:
msg = (
f" Unsupported Sampler: {self.sampler_name} "+
f">> Unsupported Sampler: {self.sampler_name} "
f"Defaulting to {default}"
)
self.sampler = default
logger.info(msg)
print(msg)
if not hasattr(self.sampler, "uses_inpainting_model"):
# FIXME: terrible kludge!
@@ -1086,17 +1085,17 @@ class Generate:
def _load_img(self, img) -> Image:
if isinstance(img, Image.Image):
image = img
logger.info(f"using provided input image of size {image.width}x{image.height}")
print(f">> using provided input image of size {image.width}x{image.height}")
elif isinstance(img, str):
assert os.path.exists(img), f"{img}: File not found"
assert os.path.exists(img), f">> {img}: File not found"
image = Image.open(img)
logger.info(
f"loaded input image of size {image.width}x{image.height} from {img}"
print(
f">> loaded input image of size {image.width}x{image.height} from {img}"
)
else:
image = Image.open(img)
logger.info(f"loaded input image of size {image.width}x{image.height}")
print(f">> loaded input image of size {image.width}x{image.height}")
image = ImageOps.exif_transpose(image)
return image
@@ -1184,14 +1183,14 @@ class Generate:
def _transparency_check_and_warning(self, image, mask, force_outpaint=False):
if not mask:
logger.info(
"Initial image has transparent areas. Will inpaint in these regions."
print(
">> Initial image has transparent areas. Will inpaint in these regions."
)
if (not force_outpaint) and self._check_for_erasure(image):
logger.info(
"Colors underneath the transparent region seem to have been erased.\n" +
"Inpainting will be suboptimal. Please preserve the colors when making\n" +
"a transparency mask, or provide mask explicitly using --init_mask (-M)."
if (not force_outpaint) and self._check_for_erasure(image):
print(
">> WARNING: Colors underneath the transparent region seem to have been erased.\n",
">> Inpainting will be suboptimal. Please preserve the colors when making\n",
">> a transparency mask, or provide mask explicitly using --init_mask (-M).",
)
def _squeeze_image(self, image):
@@ -1202,11 +1201,11 @@ class Generate:
def _fit_image(self, image, max_dimensions):
w, h = max_dimensions
logger.info(f"image will be resized to fit inside a box {w}x{h} in size.")
print(f">> image will be resized to fit inside a box {w}x{h} in size.")
# note that InitImageResizer does the multiple of 64 truncation internally
image = InitImageResizer(image).resize(width=w, height=h)
logger.info(
f"after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
print(
f">> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}"
)
return image
@@ -1217,8 +1216,8 @@ class Generate:
) # resize to integer multiple of 64
if h != height or w != width:
if log:
logger.info(
f"Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
print(
f">> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}"
)
height = h
width = w

View File

@@ -21,11 +21,10 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import Callable, List, Iterator, Optional, Type
from typing import List, Iterator, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.backend.util.logging as logger
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
@@ -36,23 +35,23 @@ downsampling = 8
@dataclass
class InvokeAIGeneratorBasicParams:
seed: Optional[int]=None
seed: int=None
width: int=512
height: int=512
cfg_scale: float=7.5
cfg_scale: int=7.5
steps: int=20
ddim_eta: float=0.0
scheduler: str='ddim'
scheduler: int='ddim'
precision: str='float16'
perlin: float=0.0
threshold: float=0.0
threshold: int=0.0
seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: Optional[float]=None
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
safety_checker: Optional[SafetyChecker]=None
safety_checker: SafetyChecker=None
@dataclass
class InvokeAIGeneratorOutput:
@@ -62,10 +61,10 @@ class InvokeAIGeneratorOutput:
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
image: Image.Image
image: Image
seed: int
model_hash: str
attention_maps_images: List[Image.Image]
attention_maps_images: List[Image]
params: Namespace
# we are interposing a wrapper around the original Generator classes so that
@@ -93,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def generate(self,
prompt: str='',
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
callback: callable=None,
step_callback: callable=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
@@ -155,7 +154,6 @@ class InvokeAIGenerator(metaclass=ABCMeta):
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
@@ -207,10 +205,10 @@ class Txt2Img(InvokeAIGenerator):
# ------------------------------------
class Img2Img(InvokeAIGenerator):
def generate(self,
init_image: Image.Image | torch.FloatTensor,
init_image: Image | torch.FloatTensor,
strength: float=0.75,
**keyword_args
)->Iterator[InvokeAIGeneratorOutput]:
)->List[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image,
strength=strength,
**keyword_args
@@ -224,7 +222,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img):
def generate(self,
mask_image: Image.Image | torch.FloatTensor,
mask_image: Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
@@ -237,7 +235,7 @@ class Inpaint(Img2Img):
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args
)->Iterator[InvokeAIGeneratorOutput]:
)->List[InvokeAIGeneratorOutput]:
return super().generate(
mask_image=mask_image,
seam_size=seam_size,
@@ -264,7 +262,7 @@ class Embiggen(Txt2Img):
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
**kwargs)->List[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,
@@ -373,7 +371,7 @@ class Generator:
try:
x_T = self.get_noise(width, height)
except:
logger.error("An error occurred while getting initial noise")
print("** An error occurred while getting initial noise **")
print(traceback.format_exc())
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
@@ -608,7 +606,7 @@ class Generator:
image = self.sample_to_image(sample)
dirname = os.path.dirname(filepath) or "."
if not os.path.exists(dirname):
logger.info(f"creating directory {dirname}")
print(f"** creating directory {dirname}")
os.makedirs(dirname, exist_ok=True)
image.save(filepath, "PNG")

View File

@@ -8,11 +8,10 @@ import torch
from PIL import Image
from tqdm import trange
import invokeai.backend.util.logging as logger
from .base import Generator
from .img2img import Img2Img
class Embiggen(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@@ -73,22 +72,22 @@ class Embiggen(Generator):
embiggen = [1.0] # If not specified, assume no scaling
elif embiggen[0] < 0:
embiggen[0] = 1.0
logger.warning(
"Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
print(
">> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !"
)
if len(embiggen) < 2:
embiggen.append(0.75)
elif embiggen[1] > 1.0 or embiggen[1] < 0:
embiggen[1] = 0.75
logger.warning(
"Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
print(
">> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !"
)
if len(embiggen) < 3:
embiggen.append(0.25)
elif embiggen[2] < 0:
embiggen[2] = 0.25
logger.warning(
"Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
print(
">> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !"
)
# Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
@@ -98,8 +97,8 @@ class Embiggen(Generator):
embiggen_tiles.sort()
if strength >= 0.5:
logger.warning(
f"Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
print(
f"* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45."
)
# Prep img2img generator, since we wrap over it
@@ -122,8 +121,8 @@ class Embiggen(Generator):
from ..restoration.realesrgan import ESRGAN
esrgan = ESRGAN()
logger.info(
f"ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
print(
f">> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}"
)
if embiggen[0] > 2:
initsuperimage = esrgan.process(
@@ -313,10 +312,10 @@ class Embiggen(Generator):
def make_image():
# Make main tiles -------------------------------------------------
if embiggen_tiles:
logger.info(f"Making {len(embiggen_tiles)} Embiggen tiles...")
print(f">> Making {len(embiggen_tiles)} Embiggen tiles...")
else:
logger.info(
f"Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
print(
f">> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})..."
)
emb_tile_store = []
@@ -362,11 +361,11 @@ class Embiggen(Generator):
# newinitimage.save(newinitimagepath)
if embiggen_tiles:
logger.debug(
print(
f"Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)"
)
else:
logger.debug(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
print(f"Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles")
# create a torch tensor from an Image
newinitimage = np.array(newinitimage).astype(np.float32) / 255.0
@@ -548,8 +547,8 @@ class Embiggen(Generator):
# Layer tile onto final image
outputsuperimage.alpha_composite(intileimage, (left, top))
else:
logger.error(
"Could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
print(
"Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation."
)
# after internal loops and patching up return Embiggen image

View File

@@ -14,8 +14,6 @@ from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeli
from ..stable_diffusion.diffusers_pipeline import ConditioningData
from ..stable_diffusion.diffusers_pipeline import trim_to_multiple_of
import invokeai.backend.util.logging as logger
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
super().__init__(model, precision)
@@ -79,8 +77,8 @@ class Txt2Img2Img(Generator):
# the message below is accurate.
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
logger.info(
f"Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
print(
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
)
# resizing

View File

@@ -5,9 +5,10 @@ wraps the actual patchmatch object. It respects the global
be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
class PatchMatch:
"""
Thin class wrapper around the patchmatch function.
@@ -27,12 +28,12 @@ class PatchMatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:
logger.info("Patchmatch initialized")
print(">> Patchmatch initialized")
else:
logger.info("Patchmatch not loaded (nonfatal)")
print(">> Patchmatch not loaded (nonfatal)")
self.patch_match = pm
else:
logger.info("Patchmatch loading disabled")
print(">> Patchmatch loading disabled")
self.tried_load = True
@classmethod

View File

@@ -30,9 +30,9 @@ work fine.
import numpy as np
import torch
from PIL import Image, ImageOps
from torchvision import transforms
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import global_cache_dir
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
@@ -83,7 +83,7 @@ class Txt2Mask(object):
"""
def __init__(self, device="cpu", refined=False):
logger.info("Initializing clipseg model for text to mask inference")
print(">> Initializing clipseg model for text to mask inference")
# BUG: we are not doing anything with the device option at this time
self.device = device
@@ -101,6 +101,18 @@ class Txt2Mask(object):
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
"""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
transforms.Resize(
(CLIPSEG_SIZE, CLIPSEG_SIZE)
), # must be multiple of 64...
]
)
if type(image) is str:
image = Image.open(image).convert("RGB")

View File

@@ -5,7 +5,5 @@ from .convert_ckpt_to_diffusers import (
convert_ckpt_to_diffusers,
load_pipeline_from_original_stable_diffusion_ckpt,
)
from .model_manager import ModelManager,SDModelComponent
from .model_manager import ModelManager

View File

@@ -25,7 +25,6 @@ from typing import Union
import torch
from safetensors.torch import load_file
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import global_cache_dir, global_config_dir
from .model_manager import ModelManager, SDLegacyType
@@ -373,32 +372,22 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.debug(f"Checkpoint {path} has both EMA and non-EMA weights.")
print(f" | Checkpoint {path} has both EMA and non-EMA weights.")
if extract_ema:
logger.debug("Extracting EMA weights (usually better for inference)")
print(" | Extracting EMA weights (usually better for inference)")
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
flat_ema_key_alt = "model_ema." + "".join(key.split(".")[2:])
if flat_ema_key in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
elif flat_ema_key_alt in checkpoint:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key_alt
)
else:
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
key
)
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
flat_ema_key
)
else:
logger.debug(
"Extracting only the non-EMA weights (usually better for fine-tuning)"
print(
" | Extracting only the non-EMA weights (usually better for fine-tuning)"
)
for key in keys:
if key.startswith("model.diffusion_model") and key in checkpoint:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
@@ -1037,15 +1026,6 @@ def convert_open_clip_checkpoint(checkpoint):
return text_model
def replace_checkpoint_vae(checkpoint, vae_path:str):
if vae_path.endswith(".safetensors"):
vae_ckpt = load_file(vae_path)
else:
vae_ckpt = torch.load(vae_path, map_location="cpu")
state_dict = vae_ckpt['state_dict'] if "state_dict" in vae_ckpt else vae_ckpt
for vae_key in state_dict:
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key]
def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
@@ -1058,10 +1038,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
extract_ema: bool = True,
upcast_attn: bool = False,
vae: AutoencoderKL = None,
vae_path: str = None,
precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False,
scan_needed:bool=True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
"""
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
@@ -1089,8 +1067,6 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1.
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
"""
with warnings.catch_warnings():
@@ -1098,13 +1074,12 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
if Path(checkpoint_path).suffix == '.ckpt':
if scan_needed:
ModelManager.scan_model(checkpoint_path,checkpoint_path)
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
checkpoint = (
torch.load(checkpoint_path)
if Path(checkpoint_path).suffix == ".ckpt"
else load_file(checkpoint_path)
)
cache_dir = global_cache_dir("hub")
pipeline_class = (
StableDiffusionGeneratorPipeline
@@ -1116,7 +1091,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
logger.debug("global_step key not found in model")
print(" | global_step key not found in model")
global_step = None
# sometimes there is a state_dict key and sometimes not
@@ -1227,19 +1202,9 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint)
# If a replacement VAE path was specified, we'll incorporate that into
# the checkpoint model and then convert it
if vae_path:
logger.debug(f"Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
logger.debug("Using checkpoint model's original VAE")
if vae:
logger.debug("Using replacement diffusers VAE")
else: # convert the original or replacement VAE
# Convert the VAE model, or use the one passed
if not vae:
print(" | Using checkpoint model's original VAE")
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size
)
@@ -1249,6 +1214,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
else:
print(" | Using external VAE specified in config")
# Convert the text model.
model_type = pipeline_type
@@ -1265,10 +1232,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
cache_dir=cache_dir,
)
pipe = pipeline_class(
vae=vae.to(precision),
text_encoder=text_model.to(precision),
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet.to(precision),
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,

View File

@@ -1,4 +1,4 @@
"""enum
"""
Manage a cache of Stable Diffusion model files for fast switching.
They are moved between GPU and CPU as necessary. If CPU memory falls
below a preset minimum, the least recently used model will be
@@ -15,22 +15,17 @@ import sys
import textwrap
import time
import warnings
from enum import Enum, auto
from enum import Enum
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Optional, Union, Callable, types
from typing import Any, Optional, Union
import safetensors
import safetensors.torch
import torch
import transformers
import invokeai.backend.util.logging as logger
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
SchedulerMixin,
logging as dlogging,
)
from diffusers import AutoencoderKL
from diffusers import logging as dlogging
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@@ -38,61 +33,40 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from transformers import (
CLIPTextModel,
CLIPTokenizer,
CLIPFeatureExtractor,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from ..util import CUDA_DEVICE, ask_user, download_with_resume
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = auto()
V1_INPAINT = auto()
V2 = auto()
V2_e = auto()
V2_v = auto()
UNKNOWN = auto()
V1 = 1
V1_INPAINT = 2
V2 = 3
V2_e = 4
V2_v = 5
UNKNOWN = 99
class SDModelComponent(Enum):
vae="vae"
text_encoder="text_encoder"
tokenizer="tokenizer"
unet="unet"
scheduler="scheduler"
safety_checker="safety_checker"
feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
"vae-ft-mse-840000-ema-pruned": "stabilityai/sd-vae-ft-mse",
}
class ModelManager(object):
"""
'''
Model manager handles loading, caching, importing, deleting, converting, and editing models.
"""
logger: types.ModuleType = logger
'''
def __init__(
self,
config: OmegaConf | Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
embedding_path: Path = None,
logger: types.ModuleType = logger,
self,
config: OmegaConf|Path,
device_type: torch.device = CUDA_DEVICE,
precision: str = "float16",
max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False,
embedding_path: Path=None,
):
"""
Initialize with the path to the models.yaml config file or
an initialized OmegaConf dictionary. Optional parameters
are the torch device type, precision, max_loaded_models,
and sequential_offload boolean. Note that the default device
and sequential_offload boolean. Note that the default device
type and precision are set up for a CUDA system running at half precision.
"""
# prevent nasty-looking CLIP log message
@@ -108,7 +82,6 @@ class ModelManager(object):
self.current_model = None
self.sequential_offload = sequential_offload
self.embedding_path = embedding_path
self.logger = logger
def valid_model(self, model_name: str) -> bool:
"""
@@ -117,28 +90,18 @@ class ModelManager(object):
"""
return model_name in self.config
def get_model(self, model_name: str = None) -> dict:
"""Given a model named identified in models.yaml, return a dict
containing the model object and some of its key features. If
in RAM will load into GPU VRAM. If on disk, will load from
there.
The dict has the following keys:
'model': The StableDiffusionGeneratorPipeline object
'model_name': The name of the model in models.yaml
'width': The width of images trained by this model
'height': The height of images trained by this model
'hash': A unique hash of this model's files on disk.
def get_model(self, model_name: str=None)->dict:
"""
Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there.
"""
if not model_name:
return (
self.get_model(self.current_model)
if self.current_model
else self.get_model(self.default_model())
)
return self.get_model(self.current_model) if self.current_model else self.get_model(self.default_model())
if not self.valid_model(model_name):
self.logger.error(
f'"{model_name}" is not a known model name. Please check your models.yaml file'
print(
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
)
return self.current_model
@@ -149,7 +112,7 @@ class ModelManager(object):
if model_name in self.models:
requested_model = self.models[model_name]["model"]
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
print(f">> Retrieving model {model_name} from system RAM cache")
requested_model.ready()
width = self.models[model_name]["width"]
height = self.models[model_name]["height"]
@@ -175,81 +138,6 @@ class ModelManager(object):
"hash": hash,
}
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned VAE as an
AutoencoderKL object. If no model name is provided, return the
vae from the model currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.vae)
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTokenizer. If no
model name is provided, return the tokenizer from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned UNet2DConditionModel. If no model
name is provided, return the UNet from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.unet)
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTextModel. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned scheduler. If no
model name is provided, return the text encoder from the model
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.scheduler)
def _get_sub_model(
self,
model_name: str=None,
model_part: SDModelComponent=SDModelComponent.vae,
) -> Union[
AutoencoderKL,
CLIPTokenizer,
CLIPFeatureExtractor,
UNet2DConditionModel,
CLIPTextModel,
StableDiffusionSafetyChecker,
]:
"""Given a model name identified in models.yaml, and the part of the
model you wish to retrieve, return that part. Parts are in an Enum
class named SDModelComponent, and consist of:
SDModelComponent.vae
SDModelComponent.text_encoder
SDModelComponent.tokenizer
SDModelComponent.unet
SDModelComponent.scheduler
SDModelComponent.safety_checker
SDModelComponent.feature_extractor
"""
model_dict = self.get_model(model_name)
model = model_dict["model"]
return getattr(model, model_part.value)
def default_model(self) -> str | None:
"""
Returns the name of the default model, or None
@@ -384,7 +272,7 @@ class ModelManager(object):
"""
omega = self.config
if model_name not in omega:
self.logger.error(f"Unknown model {model_name}")
print(f"** Unknown model {model_name}")
return
# save these for use in deletion later
conf = omega[model_name]
@@ -397,13 +285,13 @@ class ModelManager(object):
self.stack.remove(model_name)
if delete_files:
if weights:
self.logger.info(f"Deleting file {weights}")
print(f"** deleting file {weights}")
Path(weights).unlink(missing_ok=True)
elif path:
self.logger.info(f"Deleting directory {path}")
print(f"** deleting directory {path}")
rmtree(path, ignore_errors=True)
elif repo_id:
self.logger.info(f"Deleting the cached model directory for {repo_id}")
print(f"** deleting the cached model directory for {repo_id}")
self._delete_model_from_cache(repo_id)
def add_model(
@@ -444,7 +332,7 @@ class ModelManager(object):
def _load_model(self, model_name: str):
"""Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config:
self.logger.error(
print(
f'"{model_name}" is not a known model name. Please check your models.yaml file'
)
return
@@ -462,7 +350,7 @@ class ModelManager(object):
model_format = mconfig.get("format", "ckpt")
if model_format == "ckpt":
weights = mconfig.weights
self.logger.info(f"Loading {model_name} from {weights}")
print(f">> Loading {model_name} from {weights}")
model, width, height, model_hash = self._load_ckpt_model(
model_name, mconfig
)
@@ -474,19 +362,16 @@ class ModelManager(object):
raise NotImplementedError(
f"Unknown model format {model_name}: {model_format}"
)
self._add_embeddings_to_model(model)
# usage statistics
toc = time.time()
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
print(">> Model loaded in", "%4.2fs" % (toc - tic))
if self._has_cuda():
self.logger.info(
"Max VRAM used to load the model: "+
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
)
self.logger.info(
"Current VRAM usage: "+
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
print(
">> Max VRAM used to load the model:",
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9),
"\n>> Current VRAM usage:"
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
)
return model, width, height, model_hash
@@ -494,11 +379,11 @@ class ModelManager(object):
name_or_path = self.model_name_or_path(mconfig)
using_fp16 = self.precision == "float16"
self.logger.info(f"Loading diffusers model from {name_or_path}")
print(f">> Loading diffusers model from {name_or_path}")
if using_fp16:
self.logger.debug("Using faster float16 precision")
print(" | Using faster float16 precision")
else:
self.logger.debug("Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
@@ -530,8 +415,8 @@ class ModelManager(object):
if str(e).startswith("fp16 is not a valid"):
pass
else:
self.logger.error(
f"An unexpected error occurred while downloading the model: {e})"
print(
f"** An unexpected error occurred while downloading the model: {e})"
)
if pipeline:
break
@@ -549,7 +434,9 @@ class ModelManager(object):
# square images???
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
height = width
self.logger.debug(f"Default image dimensions = {width} x {height}")
print(f" | Default image dimensions = {width} x {height}")
self._add_embeddings_to_model(pipeline)
return pipeline, width, height, model_hash
@@ -566,29 +453,19 @@ class ModelManager(object):
weights = os.path.normpath(os.path.join(Globals.root, weights))
# Convert to diffusers and return a diffusers pipeline
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
print(f">> Converting legacy checkpoint {model_name} into a diffusers model...")
from . import load_pipeline_from_original_stable_diffusion_ckpt
try:
if self.list_models()[self.current_model]["status"] == "active":
self.offload_model(self.current_model)
except Exception:
pass
vae_path = None
if vae:
vae_path = (
vae
if os.path.isabs(vae)
else os.path.normpath(os.path.join(Globals.root, vae))
)
self.offload_model(self.current_model)
if vae_config := self._choose_diffusers_vae(model_name):
vae = self._load_vae(vae_config)
if self._has_cuda():
torch.cuda.empty_cache()
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=weights,
original_config_file=config,
vae_path=vae_path,
vae=vae,
return_generator_pipeline=True,
precision=torch.float16 if self.precision == "float16" else torch.float32,
)
@@ -596,6 +473,7 @@ class ModelManager(object):
pipeline.enable_offload_submodels(self.device)
else:
pipeline.to(self.device)
return (
pipeline,
width,
@@ -631,42 +509,44 @@ class ModelManager(object):
if model_name not in self.models:
return
self.logger.info(f"Offloading {model_name} to CPU")
print(f">> Offloading {model_name} to CPU")
model = self.models[model_name]["model"]
model.offload_all()
self.current_model = None
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
@classmethod
def scan_model(self, model_name, checkpoint):
"""
Apply picklescanner to the indicated checkpoint and issue a warning
and option to exit if an infected file is identified.
"""
# scan model
self.logger.debug(f"Scanning Model: {model_name}")
print(f">> Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
self.logger.critical("The model you are trying to load seems to be infected.")
self.logger.critical("For your safety, InvokeAI will not load this model.")
self.logger.critical("Please use checkpoints from trusted sources.")
self.logger.critical("Exiting InvokeAI")
print(f"\n### Issues Found In Model: {scan_result.issues_count}")
print(
"### WARNING: The model you are trying to load seems to be infected."
)
print("### For your safety, InvokeAI will not load this model.")
print("### Please use checkpoints from trusted sources.")
print("### Exiting InvokeAI")
sys.exit()
else:
self.logger.warning("InvokeAI was unable to scan the model you are using.")
print(
"\n### WARNING: InvokeAI was unable to scan the model you are using."
)
model_safe_check_fail = ask_user(
"Do you want to to continue loading the model?", ["y", "n"]
)
if model_safe_check_fail.lower() != "y":
self.logger.critical("Exiting InvokeAI")
print("### Exiting InvokeAI")
sys.exit()
else:
self.logger.debug("Model scanned ok")
print(">> Model scanned ok")
def import_diffuser_model(
self,
@@ -688,7 +568,9 @@ class ModelManager(object):
models.yaml file.
"""
model_name = model_name or Path(repo_or_path).stem
model_description = description or f"Imported diffusers model {model_name}"
model_description = (
description or f"Imported diffusers model {model_name}"
)
new_config = dict(
description=model_description,
vae=vae,
@@ -717,7 +599,7 @@ class ModelManager(object):
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
SDLegacyType.UNKNOWN
"""
global_step = checkpoint.get("global_step")
global_step = checkpoint.get('global_step')
state_dict = checkpoint.get("state_dict") or checkpoint
try:
@@ -743,15 +625,16 @@ class ModelManager(object):
return SDLegacyType.UNKNOWN
def heuristic_import(
self,
path_url_or_repo: str,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
config_file_callback: Callable[[Path], Path] = None,
self,
path_url_or_repo: str,
convert: bool = True,
model_name: str = None,
description: str = None,
model_config_file: Path = None,
commit_to_conf: Path = None,
) -> str:
"""Accept a string which could be:
"""
Accept a string which could be:
- a HF diffusers repo_id
- a URL pointing to a legacy .ckpt or .safetensors file
- a local path pointing to a legacy .ckpt or .safetensors file
@@ -765,42 +648,40 @@ class ModelManager(object):
The model_name and/or description can be provided. If not, they will
be generated automatically.
If convert is true, legacy models will be converted to diffusers
before importing.
If commit_to_conf is provided, the newly loaded model will be written
to the `models.yaml` file at the indicated path. Otherwise, the changes
will only remain in memory.
The routine will do its best to figure out the config file
needed to convert legacy checkpoint file, but if it can't it
will call the config_file_callback routine, if provided. The
callback accepts a single argument, the Path to the checkpoint
file, and returns a Path to the config file to use.
The (potentially derived) name of the model is returned on
success, or None on failure. When multiple models are added
from a directory, only the last imported one is returned.
The (potentially derived) name of the model is returned on success, or None
on failure. When multiple models are added from a directory, only the last
imported one is returned.
"""
model_path: Path = None
thing = path_url_or_repo # to save typing
self.logger.info(f"Probing {thing} for import")
print(f">> Probing {thing} for import")
if thing.startswith(("http:", "https:", "ftp:")):
self.logger.info(f"{thing} appears to be a URL")
print(f" | {thing} appears to be a URL")
model_path = self._resolve_path(
thing, "models/ldm/stable-diffusion-v1"
) # _resolve_path does a download if needed
elif Path(thing).is_file() and thing.endswith((".ckpt", ".safetensors")):
if Path(thing).stem in ["model", "diffusion_pytorch_model"]:
self.logger.debug(f"{Path(thing).name} appears to be part of a diffusers model. Skipping import")
print(
f" | {Path(thing).name} appears to be part of a diffusers model. Skipping import"
)
return
else:
self.logger.debug(f"{thing} appears to be a checkpoint file on disk")
print(f" | {thing} appears to be a checkpoint file on disk")
model_path = self._resolve_path(thing, "models/ldm/stable-diffusion-v1")
elif Path(thing).is_dir() and Path(thing, "model_index.json").exists():
self.logger.debug(f"{thing} appears to be a diffusers file on disk")
print(f" | {thing} appears to be a diffusers file on disk")
model_name = self.import_diffuser_model(
thing,
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
@@ -811,30 +692,34 @@ class ModelManager(object):
elif Path(thing).is_dir():
if (Path(thing) / "model_index.json").exists():
self.logger.debug(f"{thing} appears to be a diffusers model.")
print(f" | {thing} appears to be a diffusers model.")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
else:
self.logger.debug(f"{thing} appears to be a directory. Will scan for models to import")
print(
f" |{thing} appears to be a directory. Will scan for models to import"
)
for m in list(Path(thing).rglob("*.ckpt")) + list(
Path(thing).rglob("*.safetensors")
):
if model_name := self.heuristic_import(
str(m), commit_to_conf=commit_to_conf
str(m), convert, commit_to_conf=commit_to_conf
):
self.logger.info(f"{model_name} successfully imported")
print(f" >> {model_name} successfully imported")
return model_name
elif re.match(r"^[\w.+-]+/[\w.+-]+$", thing):
self.logger.debug(f"{thing} appears to be a HuggingFace diffusers repo_id")
print(f" | {thing} appears to be a HuggingFace diffusers repo_id")
model_name = self.import_diffuser_model(
thing, commit_to_conf=commit_to_conf
)
pipeline, _, _, _ = self._load_diffusers_model(self.config[model_name])
return model_name
else:
self.logger.warning(f"{thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id")
print(
f"** {thing}: Unknown thing. Please provide a URL, file path, directory or HuggingFace repo_id"
)
# Model_path is set in the event of a legacy checkpoint file.
# If not set, we're all done
@@ -842,72 +727,55 @@ class ModelManager(object):
return
if model_path.stem in self.config: # already imported
self.logger.debug("Already imported. Skipping")
print(" | Already imported. Skipping")
return model_path.stem
# another round of heuristics to guess the correct config file.
checkpoint = None
if model_path.suffix in [".ckpt", ".pt"]:
self.scan_model(model_path, model_path)
checkpoint = torch.load(model_path)
else:
checkpoint = safetensors.torch.load_file(model_path)
checkpoint = (
torch.load(model_path)
if model_path.suffix == ".ckpt"
else safetensors.torch.load_file(model_path)
)
# additional probing needed if no config file provided
if model_config_file is None:
# look for a like-named .yaml file in same directory
if model_path.with_suffix(".yaml").exists():
model_config_file = model_path.with_suffix(".yaml")
self.logger.debug(f"Using config file {model_config_file.name}")
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
print(" | SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
print(" | SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V2_v:
print(
" | SD-v2-v model detected; model will be converted to diffusers format"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
convert = True
elif model_type == SDLegacyType.V2_e:
print(
" | SD-v2-e model detected; model will be converted to diffusers format"
)
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
convert = True
elif model_type == SDLegacyType.V2:
print(
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
)
return
else:
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
self.logger.debug("SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
self.logger.debug("SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml",
)
elif model_type == SDLegacyType.V2_v:
self.logger.debug("SD-v2-v model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
elif model_type == SDLegacyType.V2_e:
self.logger.debug("SD-v2-e model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
elif model_type == SDLegacyType.V2:
self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
)
return
else:
self.logger.warning(
f"{thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
)
return
if not model_config_file and config_file_callback:
model_config_file = config_file_callback(model_path)
# despite our best efforts, we could not find a model config file, so give up
if not model_config_file:
return
# look for a custom vae, a like-named file ending with .vae in the same directory
vae_path = None
for suffix in ["pt", "ckpt", "safetensors"]:
if (model_path.with_suffix(f".vae.{suffix}")).exists():
vae_path = model_path.with_suffix(f".vae.{suffix}")
self.logger.debug(f"Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
print(
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
)
return
diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
@@ -915,13 +783,11 @@ class ModelManager(object):
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
vae=vae,
vae_path=str(vae_path),
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
model_name=model_name,
model_description=description,
original_config_file=model_config_file,
commit_to_conf=commit_to_conf,
scan_needed=False,
)
return model_name
@@ -931,11 +797,9 @@ class ModelManager(object):
diffusers_path: Path,
model_name=None,
model_description=None,
vae: dict = None,
vae_path: Path = None,
vae=None,
original_config_file: Path = None,
commit_to_conf: Path = None,
scan_needed: bool = True,
) -> str:
"""
Convert a legacy ckpt weights file to diffuser model and import
@@ -952,34 +816,29 @@ class ModelManager(object):
from . import convert_ckpt_to_diffusers
if diffusers_path.exists():
self.logger.error(
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
print(
f"ERROR: The path {str(diffusers_path)} already exists. Please move or remove it and try again."
)
return
model_name = model_name or diffusers_path.name
model_description = model_description or f"Converted version of {model_name}"
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
model_description = model_description or f"Optimized version of {model_name}"
print(f">> Optimizing {model_name} (30-60s)")
try:
# By passing the specified VAE to the conversion function, the autoencoder
# will be built into the model rather than tacked on afterward via the config file
vae_model = None
if vae:
vae_model = self._load_vae(vae)
vae_path = None
vae_model = self._load_vae(vae) if vae else None
convert_ckpt_to_diffusers(
ckpt_path,
diffusers_path,
extract_ema=True,
original_config_file=original_config_file,
vae=vae_model,
vae_path=vae_path,
scan_needed=scan_needed,
)
self.logger.debug(
f"Success. Converted model is now located at {str(diffusers_path)}"
print(
f" | Success. Optimized model is now located at {str(diffusers_path)}"
)
self.logger.debug(f"Writing new config file entry for {model_name}")
print(f" | Writing new config file entry for {model_name}")
new_config = dict(
path=str(diffusers_path),
description=model_description,
@@ -990,17 +849,17 @@ class ModelManager(object):
self.add_model(model_name, new_config, True)
if commit_to_conf:
self.commit(commit_to_conf)
self.logger.debug("Conversion succeeded")
print(">> Conversion succeeded")
except Exception as e:
self.logger.warning(f"Conversion failed: {str(e)}")
self.logger.warning(
"If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
print(f"** Conversion failed: {str(e)}")
print(
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
)
return model_name
def search_models(self, search_folder):
self.logger.info(f"Finding Models In: {search_folder}")
print(f">> Finding Models In: {search_folder}")
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
models_folder_safetensors = Path(search_folder).glob("**/*.safetensors")
@@ -1020,12 +879,42 @@ class ModelManager(object):
return search_folder, found_models
def _choose_diffusers_vae(
self, model_name: str, vae: str = None
) -> Union[dict, str]:
# In the event that the original entry is using a custom ckpt VAE, we try to
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
# I would prefer to do this differently: We load the ckpt model into memory, swap the
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
# VAE is built into the model. However, when I tried this I got obscure key errors.
if vae:
return vae
if model_name in self.config and (
vae_ckpt_path := self.model_info(model_name).get("vae", None)
):
vae_basename = Path(vae_ckpt_path).stem
diffusers_vae = None
if diffusers_vae := VAE_TO_REPO_ID.get(vae_basename, None):
print(
f">> {vae_basename} VAE corresponds to known {diffusers_vae} diffusers version"
)
vae = {"repo_id": diffusers_vae}
else:
print(
f'** Custom VAE "{vae_basename}" found, but corresponding diffusers model unknown'
)
print(
'** Using "stabilityai/sd-vae-ft-mse"; If this isn\'t right, please edit the model config'
)
vae = {"repo_id": "stabilityai/sd-vae-ft-mse"}
return vae
def _make_cache_room(self) -> None:
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model()
self.logger.info(
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
print(
f">> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
)
if least_recent_model is not None:
del self.models[least_recent_model]
@@ -1033,8 +922,8 @@ class ModelManager(object):
def print_vram_usage(self) -> None:
if self._has_cuda:
self.logger.info(
"Current VRAM usage:"+
print(
">> Current VRAM usage: ",
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
)
@@ -1082,16 +971,16 @@ class ModelManager(object):
legacy_locations = [
Path(
models_dir,
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
),
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
Path(
models_dir,
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
),
]
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
legacy_layout = False
for model in legacy_locations:
legacy_layout = legacy_layout or model.exists()
@@ -1109,7 +998,7 @@ class ModelManager(object):
>> make adjustments, please press ctrl-C now to abort and relaunch InvokeAI when you are ready.
>> Otherwise press <enter> to continue."""
)
input("continue> ")
input('continue> ')
# transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():
@@ -1123,10 +1012,10 @@ class ModelManager(object):
dest = hub / model.stem
if dest.exists() and not source.exists():
continue
cls.logger.info(f"{source} => {dest}")
print(f"** {source} => {dest}")
if source.exists():
if dest.is_symlink():
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
print(f"** Found symlink at {dest.name}. Not migrating.")
elif dest.exists():
if source.is_dir():
rmtree(source)
@@ -1143,7 +1032,7 @@ class ModelManager(object):
]
for d in empty:
os.rmdir(d)
cls.logger.info("Migration is done. Continuing...")
print("** Migration is done. Continuing...")
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
@@ -1186,22 +1075,22 @@ class ModelManager(object):
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
if self.embedding_path is not None:
self.logger.info(f"Loading embeddings from {self.embedding_path}")
print(f">> Loading embeddings from {self.embedding_path}")
for root, _, files in os.walk(self.embedding_path):
for name in files:
ti_path = os.path.join(root, name)
model.textual_inversion_manager.load_textual_inversion(
ti_path, defer_injecting_tokens=True
)
self.logger.info(
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
print(
f'>> Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
)
def _has_cuda(self) -> bool:
return self.device.type == "cuda"
def _diffuser_sha256(
self, name_or_path: Union[str, Path], chunksize=16777216
self, name_or_path: Union[str, Path], chunksize=4096
) -> Union[str, bytes]:
path = None
if isinstance(name_or_path, Path):
@@ -1216,7 +1105,7 @@ class ModelManager(object):
with open(hashpath) as f:
hash = f.read()
return hash
self.logger.debug("Calculating sha256 hash of model files")
print(" | Calculating sha256 hash of model files")
tic = time.time()
sha = hashlib.sha256()
count = 0
@@ -1228,7 +1117,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
print(f" | sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f:
f.write(hash)
return hash
@@ -1246,13 +1135,13 @@ class ModelManager(object):
hash = f.read()
return hash
self.logger.debug("Calculating sha256 hash of weights file")
print(" | Calculating sha256 hash of weights file")
tic = time.time()
sha = hashlib.sha256()
sha.update(data)
hash = sha.hexdigest()
toc = time.time()
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
print(f">> sha256 = {hash}", "(%4.2fs)" % (toc - tic))
with open(hashpath, "w") as f:
f.write(hash)
@@ -1273,12 +1162,12 @@ class ModelManager(object):
local_files_only=not Globals.internet_available,
)
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
print(f" | Loading diffusers VAE from {name_or_path}")
if using_fp16:
vae_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
else:
self.logger.debug("Using more accurate float32 precision")
print(" | Using more accurate float32 precision")
fp_args_list = [{}]
vae = None
@@ -1302,12 +1191,12 @@ class ModelManager(object):
break
if not vae and deferred_error:
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
print(f"** Could not load VAE {name_or_path}: {str(deferred_error)}")
return vae
@classmethod
def _delete_model_from_cache(cls,repo_id):
@staticmethod
def _delete_model_from_cache(repo_id):
cache_info = scan_cache_dir(global_cache_dir("hub"))
# I'm sure there is a way to do this with comprehensions
@@ -1318,8 +1207,8 @@ class ModelManager(object):
for revision in repo.revisions:
hashes_to_delete.add(revision.commit_hash)
strategy = cache_info.delete_revisions(*hashes_to_delete)
cls.logger.warning(
f"Deletion of this model is expected to free {strategy.expected_freed_size_str}"
print(
f"** deletion of this model is expected to free {strategy.expected_freed_size_str}"
)
strategy.execute()

View File

@@ -18,7 +18,6 @@ from compel.prompt_parser import (
PromptParser,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from ..stable_diffusion import InvokeAIDiffuserComponent
@@ -163,8 +162,8 @@ def log_tokenization(
negative_prompt: Union[Blend, FlattenedPrompt],
tokenizer,
):
logger.info(f"[TOKENLOG] Parsed Prompt: {positive_prompt}")
logger.info(f"[TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
log_tokenization_for_prompt_object(
@@ -238,12 +237,12 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
usedTokens += 1
if usedTokens > 0:
logger.info(f'[TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
logger.debug(f"{tokenized}\x1b[0m")
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f"{tokenized}\x1b[0m")
if discarded != "":
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m")
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
print(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
@@ -296,8 +295,8 @@ def split_weighted_subprompts(text, skip_normalize=False) -> list:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
logger.warning(
"Subprompt weights add up to zero. Discarding and using even weights instead."
print(
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead."
)
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]

View File

@@ -1,5 +1,3 @@
import invokeai.backend.util.logging as logger
class Restoration:
def __init__(self) -> None:
pass
@@ -10,17 +8,17 @@ class Restoration:
# Load GFPGAN
gfpgan = self.load_gfpgan(gfpgan_model_path)
if gfpgan.gfpgan_model_exists:
logger.info("GFPGAN Initialized")
print(">> GFPGAN Initialized")
else:
logger.info("GFPGAN Disabled")
print(">> GFPGAN Disabled")
gfpgan = None
# Load CodeFormer
codeformer = self.load_codeformer()
if codeformer.codeformer_model_exists:
logger.info("CodeFormer Initialized")
print(">> CodeFormer Initialized")
else:
logger.info("CodeFormer Disabled")
print(">> CodeFormer Disabled")
codeformer = None
return gfpgan, codeformer
@@ -41,5 +39,5 @@ class Restoration:
from .realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile)
logger.info("ESRGAN Initialized")
print(">> ESRGAN Initialized")
return esrgan

View File

@@ -5,7 +5,6 @@ import warnings
import numpy as np
import torch
import invokeai.backend.util.logging as logger
from ..globals import Globals
pretrained_model_url = (
@@ -24,12 +23,12 @@ class CodeFormerRestoration:
self.codeformer_model_exists = os.path.isfile(self.model_path)
if not self.codeformer_model_exists:
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
print("## NOT FOUND: CodeFormer model not found at " + self.model_path)
sys.path.append(os.path.abspath(codeformer_dir))
def process(self, image, strength, device, seed=None, fidelity=0.75):
if seed is not None:
logger.info(f"CodeFormer - Restoring Faces for image seed:{seed}")
print(f">> CodeFormer - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
@@ -98,7 +97,7 @@ class CodeFormerRestoration:
del output
torch.cuda.empty_cache()
except RuntimeError as error:
logger.error(f"Failed inference for CodeFormer: {error}.")
print(f"\tFailed inference for CodeFormer: {error}.")
restored_face = cropped_face
restored_face = restored_face.astype("uint8")

View File

@@ -6,9 +6,9 @@ import numpy as np
import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
if not os.path.isabs(gfpgan_model_path):
@@ -19,7 +19,7 @@ class GFPGAN:
self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists:
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
print("## NOT FOUND: GFPGAN model not found at " + self.model_path)
return None
def model_exists(self):
@@ -27,7 +27,7 @@ class GFPGAN:
def process(self, image, strength: float, seed: str = None):
if seed is not None:
logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}")
print(f">> GFPGAN - Restoring Faces for image seed:{seed}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
@@ -47,14 +47,14 @@ class GFPGAN:
except Exception:
import traceback
logger.error("Error loading GFPGAN:", file=sys.stderr)
print(">> Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
os.chdir(cwd)
if self.gfpgan is None:
logger.warning("WARNING: GFPGAN not initialized.")
logger.warning(
f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
print(f">> WARNING: GFPGAN not initialized.")
print(
f">> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}"
)
image = image.convert("RGB")

View File

@@ -1,7 +1,7 @@
import math
from PIL import Image
import invokeai.backend.util.logging as logger
class Outcrop(object):
def __init__(
@@ -82,7 +82,7 @@ class Outcrop(object):
pixels = extents[direction]
# round pixels up to the nearest 64
pixels = math.ceil(pixels / 64) * 64
logger.info(f"extending image {direction}ward by {pixels} pixels")
print(f">> extending image {direction}ward by {pixels} pixels")
image = self._rotate(image, direction)
image = self._extend(image, pixels)
image = self._rotate(image, direction, reverse=True)

View File

@@ -6,13 +6,18 @@ import torch
from PIL import Image
from PIL.Image import Image as ImageType
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
class ESRGAN:
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
@@ -69,16 +74,16 @@ class ESRGAN:
import sys
import traceback
logger.error("Error loading Real-ESRGAN:")
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
logger.warning("Real-ESRGAN: Invalid scaling option. Image not upscaled.")
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
return image
if seed is not None:
logger.info(
f"Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
print(
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")

View File

@@ -14,7 +14,6 @@ from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from .globals import global_cache_dir
from .util import CPU_DEVICE
@@ -41,8 +40,8 @@ class SafetyChecker(object):
cache_dir=safety_model_path,
)
except Exception:
logger.error(
"An error was encountered while installing the safety checker:"
print(
"** An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
@@ -66,8 +65,8 @@ class SafetyChecker(object):
)
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]:
logger.warning(
"An image with potential non-safe content has been detected. A blurred image will be returned."
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
)
return self.blur(image)
else:

View File

@@ -6,6 +6,7 @@ The interface is through the Concepts() object.
"""
import os
import re
import traceback
from typing import Callable
from urllib import error as ul_error
from urllib import request
@@ -14,10 +15,10 @@ from huggingface_hub import (
HfApi,
HfFolder,
ModelFilter,
ModelSearchArguments,
hf_hub_url,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
@@ -58,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
elif Globals.internet_available is True:
else:
try:
models = self.hf_api.list_models(
filter=ModelFilter(model_name="sd-concepts-library/")
@@ -67,15 +68,13 @@ class HuggingFaceConceptsLibrary(object):
# when init, add all in dir. when not init, add only concepts added between init and now
self.concept_list.extend(list(local_concepts_to_add))
except Exception as e:
logger.warning(
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
print(
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
)
logger.warning(
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
print(
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
)
return self.concept_list
else:
return self.concept_list
def get_concept_model_path(self, concept_name: str) -> str:
"""
@@ -84,8 +83,8 @@ class HuggingFaceConceptsLibrary(object):
be downloaded.
"""
if not concept_name in self.list_concepts():
logger.warning(
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
print(
f"This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
@@ -222,7 +221,7 @@ class HuggingFaceConceptsLibrary(object):
if chunk == 0:
bytes += total
logger.info(f"Downloading {repo_id}...", end="")
print(f">> Downloading {repo_id}...", end="")
try:
for file in (
"README.md",
@@ -236,22 +235,22 @@ class HuggingFaceConceptsLibrary(object):
)
except ul_error.HTTPError as e:
if e.code == 404:
logger.warning(
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
print(
f"This concept is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
logger.warning(
print(
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
)
os.rmdir(dest)
return False
except ul_error.URLError as e:
logger.error(
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
print(
f"ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False
logger.info("...{:.2f}Kb".format(bytes / 1024))
print("...{:.2f}Kb".format(bytes / 1024))
return succeeded
def _concept_id(self, concept_name: str) -> str:

View File

@@ -445,15 +445,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@property
def _submodels(self) -> Sequence[torch.nn.Module]:
module_names, _, _ = self.extract_init_dict(dict(self.config))
submodels = []
for name in module_names.keys():
if hasattr(self, name):
value = getattr(self, name)
else:
value = getattr(self.config, name)
if isinstance(value, torch.nn.Module):
submodels.append(value)
return submodels
values = [getattr(self, name) for name in module_names.keys()]
return [m for m in values if isinstance(m, torch.nn.Module)]
def image_from_embeddings(
self,
@@ -551,7 +544,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
yield PipelineIntermediateState(
run_id=run_id,
step=-1,
timestep=self.scheduler.config.num_train_timesteps,
timestep=self.scheduler.num_train_timesteps,
latents=latents,
)
@@ -922,7 +915,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
@property
def channels(self) -> int:
"""Compatible with DiffusionWrapper"""
return self.unet.config.in_channels
return self.unet.in_channels
def decode_latents(self, latents):
# Explicit call to get the vae loaded, since `decode` isn't the forward method.

View File

@@ -10,12 +10,13 @@ import diffusers
import psutil
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.cross_attention import AttnProcessor
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@@ -187,7 +188,7 @@ class Context:
class InvokeAICrossAttentionMixin:
"""
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
and dymamic slicing strategy selection.
"""
@@ -208,7 +209,7 @@ class InvokeAICrossAttentionMixin:
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
`module` is the current Attention module for which the callback is being invoked.
`module` is the current CrossAttention module for which the callback is being invoked.
`suggested_attention_slice` is the default-calculated attention slice
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
@@ -344,11 +345,11 @@ class InvokeAICrossAttentionMixin:
def restore_default_cross_attention(
model,
is_running_diffusers: bool,
restore_attention_processor: Optional[AttentionProcessor] = None,
restore_attention_processor: Optional[AttnProcessor] = None,
):
if is_running_diffusers:
unet = model
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor())
else:
remove_attention_function(model)
@@ -407,9 +408,12 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
def get_cross_attention_modules(
model, which: CrossAttentionType
) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
from ldm.modules.attention import CrossAttention # avoid circular import
cross_attention_class: type = (
InvokeAIDiffusersCrossAttention
if isinstance(model, UNet2DConditionModel)
else CrossAttention
)
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
attention_module_tuples = [
@@ -421,13 +425,13 @@ def get_cross_attention_modules(
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
logger.error(
print(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ "work properly until it is fixed."
+ f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ f"work properly until it is fixed."
)
return attention_module_tuples
@@ -546,7 +550,7 @@ def get_mem_free_total(device):
class InvokeAIDiffusersCrossAttention(
diffusers.models.attention.Attention, InvokeAICrossAttentionMixin
diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin
):
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -568,8 +572,8 @@ class InvokeAIDiffusersCrossAttention(
"""
# base implementation
class AttnProcessor:
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
class CrossAttnProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
@@ -597,9 +601,9 @@ class AttnProcessor:
from dataclasses import dataclass, field
import torch
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
from diffusers.models.cross_attention import (
CrossAttention,
CrossAttnProcessor,
SlicedAttnProcessor,
)
@@ -649,7 +653,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
def __call__(
self,
attn: Attention,
attn: CrossAttention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,

View File

@@ -5,10 +5,9 @@ from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import torch
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.cross_attention import AttnProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from .cross_attention_control import (
@@ -102,7 +101,7 @@ class InvokeAIDiffuserComponent:
def override_cross_attention(
self, conditioning: ExtraConditioningInfo, step_count: int
) -> Dict[str, AttentionProcessor]:
) -> Dict[str, AttnProcessor]:
"""
setup cross attention .swap control. for diffusers this replaces the attention processor, so
the previous attention processor is returned so that the caller can restore it later.
@@ -119,7 +118,7 @@ class InvokeAIDiffuserComponent:
)
def restore_default_cross_attention(
self, restore_attention_processor: Optional["AttentionProcessor"] = None
self, restore_attention_processor: Optional["AttnProcessor"] = None
):
self.conditioning = None
self.cross_attention_control_context = None
@@ -263,7 +262,7 @@ class InvokeAIDiffuserComponent:
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(
"Either step_index or sigma is required when doing cross attention control, but both are None."
f"Either step_index or sigma is required when doing cross attention control, but both are None."
)
percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through
@@ -467,14 +466,10 @@ class InvokeAIDiffuserComponent:
outside = torch.count_nonzero(
(latents < -current_threshold) | (latents > current_threshold)
)
logger.info(
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
)
logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
print(
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
)
if maxval < current_threshold and minval > -current_threshold:
@@ -501,11 +496,9 @@ class InvokeAIDiffuserComponent:
)
if self.debug_thresholding:
logger.debug(
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
)
logger.debug(
f"{num_altered / latents.numel() * 100:.2f}% values altered"
print(
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
)
return latents
@@ -606,6 +599,7 @@ class InvokeAIDiffuserComponent:
)
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c, weight in weighted_cond_list]
weights = [1] + [weight for c, weight in weighted_cond_list]
chunk_count = ceil(len(conditionings) / 2)

View File

@@ -10,7 +10,7 @@ from torchvision.utils import make_grid
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
import invokeai.backend.util.logging as logger
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@@ -191,7 +191,7 @@ def mkdirs(paths):
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + "_archived_" + get_timestamp()
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
print("Path already exists. Rename it to [{:s}]".format(new_name))
os.replace(path, new_name)
os.makedirs(path)

View File

@@ -1,27 +1,16 @@
import os
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union, List
from typing import Optional, Union
import safetensors.torch
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
import invokeai.backend.util.logging as logger
from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
class EmbeddingInfo:
name: str
embedding: torch.Tensor
num_vectors_per_token: int
token_dim: int
trained_steps: int = None
trained_model_name: str = None
trained_model_checksum: str = None
@dataclass
class TextualInversion:
@@ -60,12 +49,12 @@ class TextualInversionManager(BaseTextualInversionManager):
or self.has_textual_inversion_for_trigger_string(concept_name)
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
): # in case a token with literal angle brackets encountered
logger.info(f"Loaded local embedding for trigger {concept_name}")
print(f">> Loaded local embedding for trigger {concept_name}")
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
logger.info(f"Loaded remote embedding for trigger {concept_name}")
print(f">> Loaded remote embedding for trigger {concept_name}")
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name] = True
@@ -83,46 +72,66 @@ class TextualInversionManager(BaseTextualInversionManager):
if str(ckpt_path).endswith(".DS_Store"):
return
embedding_list = self._parse_embedding(str(ckpt_path))
for embedding_info in embedding_list:
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
logger.warning(
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
try:
scan_result = scan_file_path(str(ckpt_path))
if scan_result.infected_files == 1:
print(
f"\n### Security Issues Found in Model: {scan_result.issues_count}"
)
continue
# Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info.name
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else ckpt_path.name
print("### For your safety, InvokeAI will not load this embed.")
return
except Exception:
print(
f"### {ckpt_path.parents[0].name}/{ckpt_path.name} is damaged or corrupt."
)
return
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
logger.info(
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
embedding_info = self._parse_embedding(str(ckpt_path))
try:
self._add_textual_inversion(
trigger_str,
embedding_info.embedding,
defer_injecting_tokens=defer_injecting_tokens,
)
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
if embedding_info is None:
# We've already put out an error message about the bad embedding in _parse_embedding, so just return.
return
elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info["token_dim"]
):
print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
)
return
except ValueError as e:
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
logger.debug(f"The error was {str(e)}")
# Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"]
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else ckpt_path.name
)
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
try:
self._add_textual_inversion(
trigger_str,
embedding_info["embedding"],
defer_injecting_tokens=defer_injecting_tokens,
)
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
@@ -134,8 +143,8 @@ class TextualInversionManager(BaseTextualInversionManager):
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
logger.warning(
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
print(
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
)
return
if not self.full_precision:
@@ -156,11 +165,11 @@ class TextualInversionManager(BaseTextualInversionManager):
except ValueError as e:
if str(e).startswith("Warning"):
logger.warning(f"{str(e)}")
print(f">> {str(e)}")
else:
traceback.print_exc()
logger.error(
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
print(
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
)
raise
@@ -220,16 +229,16 @@ class TextualInversionManager(BaseTextualInversionManager):
for ti in self.textual_inversions:
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
if ti.embedding_vector_length > 1:
logger.info(
f"Preparing tokens for textual inversion {ti.trigger_string}..."
print(
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
)
try:
self._inject_tokens_and_assign_embeddings(ti)
except ValueError as e:
logger.debug(
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
print(
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
)
logger.debug(f"The error was {str(e)}")
print(f" | The error was {str(e)}")
continue
injected_token_ids.append(ti.trigger_token_id)
injected_token_ids.extend(ti.pad_token_ids)
@@ -300,130 +309,111 @@ class TextualInversionManager(BaseTextualInversionManager):
return token_id
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
suffix = Path(embedding_file).suffix
try:
if suffix in [".pt",".ckpt",".bin"]:
scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
logger.critical(
f"Security Issues Found in Model: {scan_result.issues_count}"
)
logger.critical("For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# try to figure out what kind of embedding file it is and parse accordingly
keys = list(ckpt.keys())
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
elif all(x in keys for x in ['string_to_token','string_to_param']):
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
elif 'emb_params' in keys:
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
def _parse_embedding(self, embedding_file: str):
file_type = embedding_file.split(".")[-1]
if file_type == "pt":
return self._parse_embedding_pt(embedding_file)
elif file_type == "bin":
return self._parse_embedding_bin(embedding_file)
else:
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
return None
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem
logger.debug(f'Loading v1 embedding file: {basename}')
def _parse_embedding_pt(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
embeddings = list()
token_counter = -1
for token,embedding in embedding_ckpt["string_to_param"].items():
if token_counter < 0:
trigger = embedding_ckpt["name"]
elif token_counter == 0:
trigger = '<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
trained_steps = embedding_ckpt["step"],
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
)
embeddings.append(embedding_info)
return embeddings
# Check if valid embedding file
if "string_to_token" and "string_to_param" in embedding_ckpt:
# Catch variants that do not have the expected keys or values.
try:
embedding_info["name"] = embedding_ckpt["name"] or os.path.basename(
os.path.splitext(embedding_file)[0]
)
def _parse_embedding_v2 (
self, embedding_ckpt: dict, file_path: str
) -> List[EmbeddingInfo]:
# Check num of embeddings and warn user only the first will be used
embedding_info["num_of_embeddings"] = len(
embedding_ckpt["string_to_token"]
)
if embedding_info["num_of_embeddings"] > 1:
print(">> More than 1 embedding found. Will use the first one")
embedding = list(embedding_ckpt["string_to_param"].values())[0]
except (AttributeError, KeyError):
return self._handle_broken_pt_variants(embedding_ckpt, embedding_file)
embedding_info["embedding"] = embedding
embedding_info["num_vectors_per_token"] = embedding.size()[0]
embedding_info["token_dim"] = embedding.size()[1]
try:
embedding_info["trained_steps"] = embedding_ckpt["step"]
embedding_info["trained_model_name"] = embedding_ckpt[
"sd_checkpoint_name"
]
embedding_info["trained_model_checksum"] = embedding_ckpt[
"sd_checkpoint"
]
except AttributeError:
print(">> No Training Details Found. Passing ...")
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
# They are actually .bin files
elif len(embedding_ckpt.keys()) == 1:
embedding_info = self._parse_embedding_bin(embedding_file)
else:
print(">> Invalid embedding format")
embedding_info = None
return embedding_info
def _parse_embedding_bin(self, embedding_file):
embedding_ckpt = torch.load(embedding_file, map_location="cpu")
embedding_info = {}
if list(embedding_ckpt.keys()) == 0:
print(">> Invalid concepts file")
embedding_info = None
else:
for token in list(embedding_ckpt.keys()):
embedding_info["name"] = (
token
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
)
embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[
"num_vectors_per_token"
] = 1 # All Concepts seem to default to 1
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
return embedding_info
def _handle_broken_pt_variants(
self, embedding_ckpt: dict, embedding_file: str
) -> dict:
"""
This handles embedding .pt file variant #2.
This handles the broken .pt file variants. We only know of one at present.
"""
basename = Path(file_path).stem
logger.debug(f'Loading v2 embedding file: {basename}')
embeddings = list()
embedding_info = {}
if isinstance(
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
):
token_counter = 0
for token,embedding in embedding_ckpt["string_to_param"].items():
trigger = token if token != '*' \
else f'<{basename}>' if token_counter == 0 \
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
embedding_info = EmbeddingInfo(
name = trigger,
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
for token in list(embedding_ckpt["string_to_token"].keys()):
embedding_info["name"] = (
token
if token != "*"
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
)
embeddings.append(embedding_info)
embedding_info["embedding"] = embedding_ckpt[
"string_to_param"
].state_dict()[token]
embedding_info["num_vectors_per_token"] = embedding_info[
"embedding"
].shape[0]
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
else:
logger.warning(f"{basename}: Unrecognized embedding format")
print(">> Invalid embedding format")
embedding_info = None
return embeddings
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
"""
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
logger.debug(f'Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
embedding = embedding,
num_vectors_per_token = embedding.size()[0],
token_dim = embedding.size()[1],
)
return [embedding_info]
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
"""
Parse 'version 4' of the textual inversion embedding files. This one
is usually associated with .bin files trained by HuggingFace diffusers.
"""
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
logger.debug(f'Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
logger.warning(f"Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(
name = token or f"<{basename}>",
embedding = embedding,
num_vectors_per_token = 1, # All Concepts seem to default to 1
token_dim = embedding.size()[0],
)
embeddings.append(embedding_info)
return embeddings
return embedding_info

View File

@@ -1,109 +0,0 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""invokeai.util.logging
Logging class for InvokeAI that produces console messages that follow
the conventions established in InvokeAI 1.X through 2.X.
One way to use it:
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(__name__)
logger.critical('this is critical')
logger.error('this is an error')
logger.warning('this is a warning')
logger.info('this is info')
logger.debug('this is debugging')
Console messages:
### this is critical
*** this is an error ***
** this is a warning
>> this is info
| this is debugging
Another way:
import invokeai.backend.util.logging as ialog
ialogger.debug('this is a debugging message')
"""
import logging
# module level functions
def debug(msg, *args, **kwargs):
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
def info(msg, *args, **kwargs):
InvokeAILogger.getLogger().info(msg, *args, **kwargs)
def warning(msg, *args, **kwargs):
InvokeAILogger.getLogger().warning(msg, *args, **kwargs)
def error(msg, *args, **kwargs):
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
def critical(msg, *args, **kwargs):
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
def log(level, msg, *args, **kwargs):
InvokeAILogger.getLogger().log(level, msg, *args, **kwargs)
def disable(level=logging.CRITICAL):
InvokeAILogger.getLogger().disable(level)
def basicConfig(**kwargs):
InvokeAILogger.getLogger().basicConfig(**kwargs)
def getLogger(name: str=None)->logging.Logger:
return InvokeAILogger.getLogger(name)
class InvokeAILogFormatter(logging.Formatter):
'''
Repurposed from:
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
'''
crit_fmt = "### %(msg)s"
err_fmt = "*** %(msg)s"
warn_fmt = "** %(msg)s"
info_fmt = ">> %(msg)s"
dbg_fmt = " | %(msg)s"
def __init__(self):
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
def format(self, record):
# Remember the format used when the logging module
# was installed (in the event that this formatter is
# used with the vanilla logging module.
format_orig = self._style._fmt
if record.levelno == logging.DEBUG:
self._style._fmt = InvokeAILogFormatter.dbg_fmt
if record.levelno == logging.INFO:
self._style._fmt = InvokeAILogFormatter.info_fmt
if record.levelno == logging.WARNING:
self._style._fmt = InvokeAILogFormatter.warn_fmt
if record.levelno == logging.ERROR:
self._style._fmt = InvokeAILogFormatter.err_fmt
if record.levelno == logging.CRITICAL:
self._style._fmt = InvokeAILogFormatter.crit_fmt
# parent class does the work
result = super().format(record)
self._style._fmt = format_orig
return result
class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(self, name:str='invokeai')->logging.Logger:
if name not in self.loggers:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
fmt = InvokeAILogFormatter()
ch.setFormatter(fmt)
logger.addHandler(ch)
self.loggers[name] = logger
return self.loggers[name]

View File

@@ -18,7 +18,6 @@ import torch
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import invokeai.backend.util.logging as logger
from .devices import torch_dtype
@@ -39,7 +38,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
logger.warning("Cant encode string for logging. Skipping.")
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -81,8 +80,8 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
logger.debug(
f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
print(
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params
@@ -133,8 +132,8 @@ def parallel_data_prefetch(
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
logger.warning(
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
print(
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
@@ -176,7 +175,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
logger.info("Start prefetching...")
print("Start prefetching...")
import time
start = time.time()
@@ -195,7 +194,7 @@ def parallel_data_prefetch(
gather_res[res[0]] = res[1]
except Exception as e:
logger.error("Exception: ", e)
print("Exception: ", e)
for p in processes:
p.terminate()
@@ -203,7 +202,7 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
print(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
@@ -319,23 +318,23 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
resp = requests.get(url, headers=header, stream=True) # new request with range
if exist_size > content_length:
logger.warning("corrupt existing file found. re-downloading")
print("* corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or exist_size == content_length:
logger.warning(f"{dest}: complete file found. Skipping.")
print(f"* {dest}: complete file found. Skipping.")
return dest
elif resp.status_code == 206 or exist_size > 0:
logger.warning(f"{dest}: partial file found. Resuming...")
print(f"* {dest}: partial file found. Resuming...")
elif resp.status_code != 200:
logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
print(f"** An error occurred during downloading {dest}: {resp.reason}")
else:
logger.error(f"{dest}: Downloading...")
print(f"* {dest}: Downloading...")
try:
if content_length < 2000:
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
return None
with open(dest, open_mode) as file, tqdm(
@@ -350,7 +349,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
size = file.write(data)
bar.update(size)
except Exception as e:
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
print(f"An error occurred while downloading {dest}: {str(e)}")
return None
return dest

View File

@@ -19,7 +19,6 @@ from PIL import Image
from PIL.Image import Image as ImageType
from werkzeug.utils import secure_filename
import invokeai.backend.util.logging as logger
import invokeai.frontend.web.dist as frontend
from .. import Generate
@@ -78,6 +77,7 @@ class InvokeAIWebServer:
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
# Socket IO
logger = True if args.web_verbose else False
engineio_logger = True if args.web_verbose else False
max_http_buffer_size = 10000000
@@ -213,7 +213,7 @@ class InvokeAIWebServer:
self.load_socketio_listeners(self.socketio)
if args.gui:
logger.info("Launching Invoke AI GUI")
print(">> Launching Invoke AI GUI")
try:
from flaskwebgui import FlaskUI
@@ -231,17 +231,17 @@ class InvokeAIWebServer:
sys.exit(0)
else:
useSSL = args.certfile or args.keyfile
logger.info("Started Invoke AI Web Server")
print(">> Started Invoke AI Web Server")
if self.host == "0.0.0.0":
logger.info(
print(
f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address."
)
else:
logger.info(
"Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
print(
">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address."
)
logger.info(
f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
print(
f">> Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}"
)
if not useSSL:
self.socketio.run(app=self.app, host=self.host, port=self.port)
@@ -273,7 +273,7 @@ class InvokeAIWebServer:
# path for thumbnail images
self.thumbnail_image_path = os.path.join(self.result_path, "thumbnails/")
# txt log
self.log_path = os.path.join(self.result_path, "invoke_logger.txt")
self.log_path = os.path.join(self.result_path, "invoke_log.txt")
# make all output paths
[
os.makedirs(path, exist_ok=True)
@@ -290,7 +290,7 @@ class InvokeAIWebServer:
def load_socketio_listeners(self, socketio):
@socketio.on("requestSystemConfig")
def handle_request_capabilities():
logger.info("System config requested")
print(">> System config requested")
config = self.get_system_config()
config["model_list"] = self.generate.model_manager.list_models()
config["infill_methods"] = infill_methods()
@@ -330,7 +330,7 @@ class InvokeAIWebServer:
if model_name in current_model_list:
update = True
logger.info(f"Adding New Model: {model_name}")
print(f">> Adding New Model: {model_name}")
self.generate.model_manager.add_model(
model_name=model_name,
@@ -348,14 +348,14 @@ class InvokeAIWebServer:
"update": update,
},
)
logger.info(f"New Model Added: {model_name}")
print(f">> New Model Added: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on("deleteModel")
def handle_delete_model(model_name: str):
try:
logger.info(f"Deleting Model: {model_name}")
print(f">> Deleting Model: {model_name}")
self.generate.model_manager.del_model(model_name)
self.generate.model_manager.commit(opt.conf)
updated_model_list = self.generate.model_manager.list_models()
@@ -366,14 +366,14 @@ class InvokeAIWebServer:
"model_list": updated_model_list,
},
)
logger.info(f"Model Deleted: {model_name}")
print(f">> Model Deleted: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@socketio.on("requestModelChange")
def handle_set_model(model_name: str):
try:
logger.info(f"Model change requested: {model_name}")
print(f">> Model change requested: {model_name}")
model = self.generate.set_model(model_name)
model_list = self.generate.model_manager.list_models()
if model is None:
@@ -454,7 +454,7 @@ class InvokeAIWebServer:
"update": True,
},
)
logger.info(f"Model Converted: {model_name}")
print(f">> Model Converted: {model_name}")
except Exception as e:
self.handle_exceptions(e)
@@ -490,7 +490,7 @@ class InvokeAIWebServer:
if vae := self.generate.model_manager.config[models_to_merge[0]].get(
"vae", None
):
logger.info(f"Using configured VAE assigned to {models_to_merge[0]}")
print(f">> Using configured VAE assigned to {models_to_merge[0]}")
merged_model_config.update(vae=vae)
self.generate.model_manager.import_diffuser_model(
@@ -507,8 +507,8 @@ class InvokeAIWebServer:
"update": True,
},
)
logger.info(f"Models Merged: {models_to_merge}")
logger.info(f"New Model Added: {model_merge_info['merged_model_name']}")
print(f">> Models Merged: {models_to_merge}")
print(f">> New Model Added: {model_merge_info['merged_model_name']}")
except Exception as e:
self.handle_exceptions(e)
@@ -698,7 +698,7 @@ class InvokeAIWebServer:
}
)
except Exception as e:
logger.info(f"Unable to load {path}")
print(f">> Unable to load {path}")
socketio.emit(
"error", {"message": f"Unable to load {path}: {str(e)}"}
)
@@ -735,9 +735,9 @@ class InvokeAIWebServer:
printable_parameters["init_mask"][:64] + "..."
)
logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n")
logger.info(f"ESRGAN Parameters: {esrgan_parameters}")
logger.info(f"Facetool Parameters: {facetool_parameters}")
print(f"\n>> Image Generation Parameters:\n\n{printable_parameters}\n")
print(f">> ESRGAN Parameters: {esrgan_parameters}")
print(f">> Facetool Parameters: {facetool_parameters}")
self.generate_images(
generation_parameters,
@@ -750,8 +750,8 @@ class InvokeAIWebServer:
@socketio.on("runPostprocessing")
def handle_run_postprocessing(original_image, postprocessing_parameters):
try:
logger.info(
f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
print(
f'>> Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}'
)
progress = Progress()
@@ -861,14 +861,14 @@ class InvokeAIWebServer:
@socketio.on("cancel")
def handle_cancel():
logger.info("Cancel processing requested")
print(">> Cancel processing requested")
self.canceled.set()
# TODO: I think this needs a safety mechanism.
@socketio.on("deleteImage")
def handle_delete_image(url, thumbnail, uuid, category):
try:
logger.info(f'Delete requested "{url}"')
print(f'>> Delete requested "{url}"')
from send2trash import send2trash
path = self.get_image_path_from_url(url)
@@ -1022,7 +1022,7 @@ class InvokeAIWebServer:
"RGB"
)
def image_progress(intermediate_state: PipelineIntermediateState):
def image_progress(sample, step):
if self.canceled.is_set():
raise CanceledException
@@ -1030,14 +1030,6 @@ class InvokeAIWebServer:
nonlocal generation_parameters
nonlocal progress
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
generation_messages = {
"txt2img": "common.statusGeneratingTextToImage",
"img2img": "common.statusGeneratingImageToImage",
@@ -1263,7 +1255,7 @@ class InvokeAIWebServer:
image, os.path.basename(path), self.thumbnail_image_path
)
logger.info(f'Image generated: "{path}"\n')
print(f'\n\n>> Image generated: "{path}"\n')
self.write_log_message(f'[Generated] "{path}": {command}')
if progress.total_iterations > progress.current_iteration:
@@ -1310,9 +1302,16 @@ class InvokeAIWebServer:
progress.set_current_iteration(progress.current_iteration + 1)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return image_progress(progress_state.latents, progress_state.step)
else:
return image_progress(*cb_args, **kwargs)
self.generate.prompt2image(
**generation_parameters,
step_callback=image_progress,
step_callback=diffusers_step_callback_adapter,
image_callback=image_done,
)
@@ -1329,7 +1328,7 @@ class InvokeAIWebServer:
except Exception as e:
# Clear the CUDA cache on an exception
self.empty_cuda_cache()
logger.error(e)
print(e)
self.handle_exceptions(e)
def empty_cuda_cache(self):

View File

@@ -16,7 +16,6 @@ if sys.platform == "darwin":
import pyparsing # type: ignore
import invokeai.version as invokeai
import invokeai.backend.util.logging as logger
from ...backend import Generate, ModelManager
from ...backend.args import Args, dream_cmd_from_png, metadata_dumps, metadata_from_png
@@ -70,7 +69,7 @@ def main():
# run any post-install patches needed
run_patches()
logger.info(f"Internet connectivity is {Globals.internet_available}")
print(f">> Internet connectivity is {Globals.internet_available}")
if not args.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
@@ -79,8 +78,8 @@ def main():
opt, FileNotFoundError(f"The file {config_file} could not be found.")
)
logger.info(f"{invokeai.__app_name__}, version {invokeai.__version__}")
logger.info(f'InvokeAI runtime directory is "{Globals.root}"')
print(f">> {invokeai.__app_name__}, version {invokeai.__version__}")
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# loading here to avoid long delays on startup
# these two lines prevent a horrible warning message from appearing
@@ -122,7 +121,7 @@ def main():
else:
raise FileNotFoundError(f"{opt.infile} not found.")
except (FileNotFoundError, IOError) as e:
logger.critical('Aborted',exc_info=True)
print(f"{e}. Aborting.")
sys.exit(-1)
# creating a Generate object:
@@ -143,12 +142,12 @@ def main():
)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt, e)
except (IOError, KeyError):
logger.critical("Aborted",exc_info=True)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
if opt.seamless:
logger.info("Changed to seamless tiling mode")
print(">> changed to seamless tiling mode")
# preload the model
try:
@@ -159,9 +158,14 @@ def main():
report_model_error(opt, e)
# try to autoconvert new models
if path := opt.autoimport:
gen.model_manager.heuristic_import(
str(path), convert=False, commit_to_conf=opt.conf
)
if path := opt.autoconvert:
gen.model_manager.heuristic_import(
str(path), commit_to_conf=opt.conf
str(path), convert=True, commit_to_conf=opt.conf
)
# web server loops forever
@@ -181,7 +185,9 @@ def main():
f'\nGoodbye!\nYou can start InvokeAI again by running the "invoke.bat" (or "invoke.sh") script from {Globals.root}'
)
except Exception:
logger.error("An error occurred",exc_info=True)
print(">> An error occurred:")
traceback.print_exc()
# TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, opt):
@@ -247,7 +253,7 @@ def main_loop(gen, opt):
if not opt.prompt:
oldargs = metadata_from_png(opt.init_img)
opt.prompt = oldargs.prompt
logger.info(f'Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
except (OSError, AttributeError, KeyError):
pass
@@ -264,9 +270,9 @@ def main_loop(gen, opt):
if opt.init_img is not None and re.match("^-\\d+$", opt.init_img):
try:
opt.init_img = last_results[int(opt.init_img)][0]
logger.info(f"Reusing previous image {opt.init_img}")
print(f">> Reusing previous image {opt.init_img}")
except IndexError:
logger.info(f"No previous initial image at position {opt.init_img} found")
print(f">> No previous initial image at position {opt.init_img} found")
opt.init_img = None
continue
@@ -287,9 +293,9 @@ def main_loop(gen, opt):
if opt.seed is not None and opt.seed < 0 and operation != "postprocess":
try:
opt.seed = last_results[opt.seed][1]
logger.info(f"Reusing previous seed {opt.seed}")
print(f">> Reusing previous seed {opt.seed}")
except IndexError:
logger.info(f"No previous seed at position {opt.seed} found")
print(f">> No previous seed at position {opt.seed} found")
opt.seed = None
continue
@@ -308,7 +314,7 @@ def main_loop(gen, opt):
subdir = subdir[: (path_max - 39 - len(os.path.abspath(opt.outdir)))]
current_outdir = os.path.join(opt.outdir, subdir)
logger.info('Writing files to directory: "' + current_outdir + '"')
print('Writing files to directory: "' + current_outdir + '"')
# make sure the output directory exists
if not os.path.exists(current_outdir):
@@ -437,14 +443,15 @@ def main_loop(gen, opt):
catch_interrupts=catch_ctrl_c,
**vars(opt),
)
except (PromptParser.ParsingException, pyparsing.ParseException):
logger.error("An error occurred while processing your prompt",exc_info=True)
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
print("** An error occurred while processing your prompt **")
print(f"** {str(e)} **")
elif operation == "postprocess":
logger.info(f"fixing {opt.prompt}")
print(f">> fixing {opt.prompt}")
opt.last_operation = do_postprocess(gen, opt, image_writer)
elif operation == "mask":
logger.info(f"generating masks from {opt.prompt}")
print(f">> generating masks from {opt.prompt}")
do_textmask(gen, opt, image_writer)
if opt.grid and len(grid_images) > 0:
@@ -467,12 +474,12 @@ def main_loop(gen, opt):
)
results = [[path, formatted_dream_prompt]]
except AssertionError:
logger.error(e)
except AssertionError as e:
print(e)
continue
except OSError as e:
logger.error(e)
print(e)
continue
print("Outputs:")
@@ -511,7 +518,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
gen.set_model(model_name)
add_embedding_terms(gen, completer)
except KeyError as e:
logger.error(e)
print(str(e))
except Exception as e:
report_model_error(opt, e)
completer.add_history(command)
@@ -525,8 +532,8 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!import"):
path = shlex.split(command)
if len(path) < 2:
logger.warning(
"please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
print(
"** please provide (1) a URL to a .ckpt file to import; (2) a local path to a .ckpt file; or (3) a diffusers repository id in the form stabilityai/stable-diffusion-2-1"
)
else:
try:
@@ -539,7 +546,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith(("!convert", "!optimize")):
path = shlex.split(command)
if len(path) < 2:
logger.warning("please provide the path to a .ckpt or .safetensors model")
print("** please provide the path to a .ckpt or .safetensors model")
else:
try:
convert_model(path[1], gen, opt, completer)
@@ -551,7 +558,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!edit"):
path = shlex.split(command)
if len(path) < 2:
logger.warning("please provide the name of a model")
print("** please provide the name of a model")
else:
edit_model(path[1], gen, opt, completer)
completer.add_history(command)
@@ -560,7 +567,7 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!del"):
path = shlex.split(command)
if len(path) < 2:
logger.warning("please provide the name of a model")
print("** please provide the name of a model")
else:
del_config(path[1], gen, opt, completer)
completer.add_history(command)
@@ -574,7 +581,6 @@ def do_command(command: str, gen, opt: Args, completer) -> tuple:
elif command.startswith("!replay"):
file_path = command.replace("!replay", "", 1).strip()
file_path = os.path.join(opt.outdir, file_path)
if infile is None and os.path.isfile(file_path):
infile = open(file_path, "r", encoding="utf-8")
completer.add_history(command)
@@ -620,7 +626,7 @@ def set_default_output_dir(opt: Args, completer: Completer):
completer.set_default_dir(opt.outdir)
def import_model(model_path: str, gen, opt, completer):
def import_model(model_path: str, gen, opt, completer, convert=False):
"""
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a
@@ -640,8 +646,8 @@ def import_model(model_path: str, gen, opt, completer):
try:
default_name = url_attachment_name(model_path)
default_name = Path(default_name).stem
except Exception:
logger.warning(f"A problem occurred while assigning the name of the downloaded model",exc_info=True)
except Exception as e:
print(f"** URL: {str(e)}")
model_name, model_desc = _get_model_name_and_desc(
gen.model_manager,
completer,
@@ -651,6 +657,7 @@ def import_model(model_path: str, gen, opt, completer):
model_path,
model_name=model_name,
description=model_desc,
convert=convert,
)
if not imported_name:
@@ -659,14 +666,15 @@ def import_model(model_path: str, gen, opt, completer):
model_path,
model_name=model_name,
description=model_desc,
convert=convert,
model_config_file=config_file,
)
if not imported_name:
logger.error("Aborting import.")
print("** Aborting import.")
return
if not _verify_load(imported_name, gen):
logger.error("model failed to load. Discarding configuration entry")
print("** model failed to load. Discarding configuration entry")
gen.model_manager.del_model(imported_name)
return
if click.confirm("Make this the default model?", default=False):
@@ -674,7 +682,7 @@ def import_model(model_path: str, gen, opt, completer):
gen.model_manager.commit(opt.conf)
completer.update_models(gen.model_manager.list_models())
logger.info(f"{imported_name} successfully installed")
print(f">> {imported_name} successfully installed")
def _pick_configuration_file(completer)->Path:
print(
@@ -718,21 +726,21 @@ Please select the type of this model:
return choice
def _verify_load(model_name: str, gen) -> bool:
logger.info("Verifying that new model loads...")
print(">> Verifying that new model loads...")
current_model = gen.model_name
try:
if not gen.set_model(model_name):
return
except Exception as e:
logger.warning(f"model failed to load: {str(e)}")
logger.warning(
print(f"** model failed to load: {str(e)}")
print(
"** note that importing 2.X checkpoints is not supported. Please use !convert_model instead."
)
return False
if click.confirm("Keep model loaded?", default=True):
gen.set_model(model_name)
else:
logger.info("Restoring previous model")
print(">> Restoring previous model")
gen.set_model(current_model)
return True
@@ -749,13 +757,14 @@ def _get_model_name_and_desc(
)
return model_name, model_description
def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace("\\", "/") # windows
manager = gen.model_manager
ckpt_path = None
original_config_file = None
if model_name_or_path == gen.model_name:
logger.warning("Can't convert the active model. !switch to another model first. **")
print("** Can't convert the active model. !switch to another model first. **")
return
elif model_info := manager.model_info(model_name_or_path):
if "weights" in model_info:
@@ -763,10 +772,16 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
original_config_file = Path(model_info["config"])
model_name = model_name_or_path
model_description = model_info["description"]
vae_path = model_info.get("vae")
vae = model_info["vae"]
else:
logger.warning(f"{model_name_or_path} is not a legacy .ckpt weights file")
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
return
if vae_repo := invokeai.backend.model_management.model_manager.VAE_TO_REPO_ID.get(
Path(vae).stem
):
vae_repo = dict(repo_id=vae_repo)
else:
vae_repo = None
model_name = manager.convert_and_import(
ckpt_path,
diffusers_path=Path(
@@ -775,27 +790,27 @@ def convert_model(model_name_or_path: Union[Path, str], gen, opt, completer):
model_name=model_name,
model_description=model_description,
original_config_file=original_config_file,
vae_path=vae_path,
vae=vae_repo,
)
else:
try:
import_model(model_name_or_path, gen, opt, completer)
import_model(model_name_or_path, gen, opt, completer, convert=True)
except KeyboardInterrupt:
return
manager.commit(opt.conf)
if click.confirm(f"Delete the original .ckpt file at {ckpt_path}?", default=False):
ckpt_path.unlink(missing_ok=True)
logger.warning(f"{ckpt_path} deleted")
print(f"{ckpt_path} deleted")
def del_config(model_name: str, gen, opt, completer):
current_model = gen.model_name
if model_name == current_model:
logger.warning("Can't delete active model. !switch to another model first. **")
print("** Can't delete active model. !switch to another model first. **")
return
if model_name not in gen.model_manager.config:
logger.warning(f"Unknown model {model_name}")
print(f"** Unknown model {model_name}")
return
if not click.confirm(
@@ -808,17 +823,17 @@ def del_config(model_name: str, gen, opt, completer):
)
gen.model_manager.del_model(model_name, delete_files=delete_completely)
gen.model_manager.commit(opt.conf)
logger.warning(f"{model_name} deleted")
print(f"** {model_name} deleted")
completer.update_models(gen.model_manager.list_models())
def edit_model(model_name: str, gen, opt, completer):
manager = gen.model_manager
if not (info := manager.model_info(model_name)):
logger.warning(f"** Unknown model {model_name}")
print(f"** Unknown model {model_name}")
return
print()
logger.info(f"Editing model {model_name} from configuration file {opt.conf}")
print(f"\n>> Editing model {model_name} from configuration file {opt.conf}")
new_name = _get_model_name(manager.list_models(), completer, model_name)
for attribute in info.keys():
@@ -856,7 +871,7 @@ def edit_model(model_name: str, gen, opt, completer):
manager.set_default_model(new_name)
manager.commit(opt.conf)
completer.update_models(manager.list_models())
logger.info("Model successfully updated")
print(">> Model successfully updated")
def _get_model_name(existing_names, completer, default_name: str = "") -> str:
@@ -867,11 +882,11 @@ def _get_model_name(existing_names, completer, default_name: str = "") -> str:
if len(model_name) == 0:
model_name = default_name
if not re.match("^[\w._+:/-]+$", model_name):
logger.warning(
'model name must contain only words, digits and the characters "._+:/-" **'
print(
'** model name must contain only words, digits and the characters "._+:/-" **'
)
elif model_name != default_name and model_name in existing_names:
logger.warning(f"the name {model_name} is already in use. Pick another.")
print(f"** the name {model_name} is already in use. Pick another.")
else:
done = True
return model_name
@@ -938,10 +953,11 @@ def do_postprocess(gen, opt, callback):
opt=opt,
)
except OSError:
logger.error(f"{file_path}: file could not be read",exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
print(f"** {file_path}: file could not be read")
return
except (KeyError, AttributeError):
logger.error(f"an error occurred while applying the {tool} postprocessor",exc_info=True)
print(traceback.format_exc(), file=sys.stderr)
return
return opt.last_operation
@@ -996,13 +1012,13 @@ def prepare_image_metadata(
try:
filename = opt.fnformat.format(**wildcards)
except KeyError as e:
logger.error(
f"The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
print(
f"** The filename format contains an unknown key '{e.args[0]}'. Will use {{prefix}}.{{seed}}.png' instead"
)
filename = f"{prefix}.{seed}.png"
except IndexError:
logger.error(
"The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
print(
"** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead"
)
filename = f"{prefix}.{seed}.png"
@@ -1091,14 +1107,14 @@ def split_variations(variations_string) -> list:
for part in variations_string.split(","):
seed_and_weight = part.split(":")
if len(seed_and_weight) != 2:
logger.warning(f'Could not parse with_variation part "{part}"')
print(f'** Could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
logger.warning(f'Could not parse with_variation part "{part}"')
print(f'** Could not parse with_variation part "{part}"')
broken = True
break
parts.append([seed, weight])
@@ -1122,23 +1138,23 @@ def load_face_restoration(opt):
opt.gfpgan_model_path
)
else:
logger.info("Face restoration disabled")
print(">> Face restoration disabled")
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
logger.info("Upscaling disabled")
print(">> Upscaling disabled")
else:
logger.info("Face restoration and upscaling disabled")
print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
return gfpgan, codeformer, esrgan
def make_step_callback(gen, opt, prefix):
destination = os.path.join(opt.outdir, "intermediates", prefix)
os.makedirs(destination, exist_ok=True)
logger.info(f"Intermediate images will be written into {destination}")
print(f">> Intermediate images will be written into {destination}")
def callback(state: PipelineIntermediateState):
latents = state.latents
@@ -1180,20 +1196,21 @@ def retrieve_dream_command(opt, command, completer):
try:
cmd = dream_cmd_from_png(path)
except OSError:
logger.error(f"{tokens[0]}: file could not be read")
print(f"## {tokens[0]}: file could not be read")
except (KeyError, AttributeError, IndexError):
logger.error(f"{tokens[0]}: file has no metadata")
print(f"## {tokens[0]}: file has no metadata")
except:
logger.error(f"{tokens[0]}: file could not be processed")
print(f"## {tokens[0]}: file could not be processed")
if len(cmd) > 0:
completer.set_line(cmd)
def write_commands(opt, file_path: str, outfilepath: str):
dir, basename = os.path.split(file_path)
try:
paths = sorted(list(Path(dir).glob(basename)))
except ValueError:
logger.error(f'"{basename}": unacceptable pattern')
print(f'## "{basename}": unacceptable pattern')
return
commands = []
@@ -1202,9 +1219,9 @@ def write_commands(opt, file_path: str, outfilepath: str):
try:
cmd = dream_cmd_from_png(path)
except (KeyError, AttributeError, IndexError):
logger.error(f"{path}: file has no metadata")
print(f"## {path}: file has no metadata")
except:
logger.error(f"{path}: file could not be processed")
print(f"## {path}: file could not be processed")
if cmd:
commands.append(f"# {path}")
commands.append(cmd)
@@ -1214,18 +1231,18 @@ def write_commands(opt, file_path: str, outfilepath: str):
outfilepath = os.path.join(opt.outdir, basename)
with open(outfilepath, "w", encoding="utf-8") as f:
f.write("\n".join(commands))
logger.info(f"File {outfilepath} with commands created")
print(f">> File {outfilepath} with commands created")
def report_model_error(opt: Namespace, e: Exception):
logger.warning(f'An error occurred while attempting to initialize the model: "{str(e)}"')
logger.warning(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
print(f'** An error occurred while attempting to initialize the model: "{str(e)}"')
print(
"** This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
)
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all:
logger.warning(
"Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
print(
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
if not click.confirm(
@@ -1234,7 +1251,7 @@ def report_model_error(opt: Namespace, e: Exception):
):
return
logger.info("invokeai-configure is launching....\n")
print("invokeai-configure is launching....\n")
# Match arguments that were set on the CLI
# only the arguments accepted by the configuration script are parsed
@@ -1251,7 +1268,7 @@ def report_model_error(opt: Namespace, e: Exception):
from ..install import invokeai_configure
invokeai_configure()
logger.warning("InvokeAI will now restart")
print("** InvokeAI will now restart")
sys.argv = previous_args
main() # would rather do a os.exec(), but doesn't exist?
sys.exit(0)

View File

@@ -1,9 +1,10 @@
'''
"""
Minimalist updater script. Prompts user for the tag or branch to update to and runs
pip install <path_to_git_source>.
'''
"""
import os
import platform
import requests
from rich import box, print
from rich.console import Console, Group, group
@@ -15,10 +16,8 @@ from rich.text import Text
from invokeai.version import __version__
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive"
INVOKE_AI_TAG="https://github.com/invoke-ai/InvokeAI/archive/refs/tags"
INVOKE_AI_BRANCH="https://github.com/invoke-ai/InvokeAI/archive/refs/heads"
INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases"
INVOKE_AI_SRC = "https://github.com/invoke-ai/InvokeAI/archive"
INVOKE_AI_REL = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
OS = platform.uname().system
ARCH = platform.uname().machine
@@ -29,22 +28,22 @@ if OS == "Windows":
else:
console = Console(style=Style(color="grey74", bgcolor="grey19"))
def get_versions()->dict:
def get_versions() -> dict:
return requests.get(url=INVOKE_AI_REL).json()
def welcome(versions: dict):
@group()
def text():
yield f'InvokeAI Version: [bold yellow]{__version__}'
yield ''
yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.'
yield ''
yield '[bold yellow]Options:'
yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
yield f"InvokeAI Version: [bold yellow]{__version__}"
yield ""
yield "This script will update InvokeAI to the latest release, or to a development version of your choice."
yield ""
yield "[bold yellow]Options:"
yield f"""[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic])
[2] Update to the bleeding-edge development version ([italic]main[/italic])
[3] Manually enter the [bold]tag name[/bold] for the version you wish to update to
[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to'''
[3] Manually enter the tag or branch name you wish to update"""
console.rule()
print(
@@ -60,41 +59,33 @@ def welcome(versions: dict):
)
console.line()
def main():
versions = get_versions()
welcome(versions)
tag = None
branch = None
release = None
choice = Prompt.ask('Choice:',choices=['1','2','3','4'],default='1')
if choice=='1':
release = versions[0]['tag_name']
elif choice=='2':
release = 'main'
elif choice=='3':
tag = Prompt.ask('Enter an InvokeAI tag name')
elif choice=='4':
branch = Prompt.ask('Enter an InvokeAI branch name')
choice = Prompt.ask("Choice:", choices=["1", "2", "3"], default="1")
print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]')
if release:
cmd = f'pip install {INVOKE_AI_SRC}/{release}.zip --use-pep517 --upgrade'
elif tag:
cmd = f'pip install {INVOKE_AI_TAG}/{tag}.zip --use-pep517 --upgrade'
if choice == "1":
tag = versions[0]["tag_name"]
elif choice == "2":
tag = "main"
elif choice == "3":
tag = Prompt.ask("Enter an InvokeAI tag or branch name")
print(f":crossed_fingers: Upgrading to [yellow]{tag}[/yellow]")
cmd = f"pip install {INVOKE_AI_SRC}/{tag}.zip --use-pep517"
print("")
print("")
if os.system(cmd) == 0:
print(f":heavy_check_mark: Upgrade successful")
else:
cmd = f'pip install {INVOKE_AI_BRANCH}/{branch}.zip --use-pep517 --upgrade'
print('')
print('')
if os.system(cmd)==0:
print(f':heavy_check_mark: Upgrade successful')
else:
print(f':exclamation: [bold red]Upgrade failed[/red bold]')
print(f":exclamation: [bold red]Upgrade failed[/red bold]")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
pass

View File

@@ -22,7 +22,6 @@ import torch
from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals, global_config_dir
from ...backend.config.model_install_backend import (
@@ -200,6 +199,17 @@ class addModelsForm(npyscreen.FormMultiPage):
relx=4,
scroll_exit=True,
)
self.nextrely += 1
self.convert_models = self.add_widget_intelligent(
npyscreen.TitleSelectOne,
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
values=["Keep original format", "Convert to diffusers"],
value=0,
begin_entry_at=4,
max_height=4,
hidden=True, # will appear when imported models box is edited
scroll_exit=True,
)
self.cancel = self.add_widget_intelligent(
npyscreen.ButtonPress,
name="CANCEL",
@@ -234,6 +244,8 @@ class addModelsForm(npyscreen.FormMultiPage):
self.show_directory_fields.addVisibleWhenSelected(i)
self.show_directory_fields.when_value_edited = self._clear_scan_directory
self.import_model_paths.when_value_edited = self._show_hide_convert
self.autoload_directory.when_value_edited = self._show_hide_convert
def resize(self):
super().resize()
@@ -244,6 +256,13 @@ class addModelsForm(npyscreen.FormMultiPage):
if not self.show_directory_fields.value:
self.autoload_directory.value = ""
def _show_hide_convert(self):
model_paths = self.import_model_paths.value or ""
autoload_directory = self.autoload_directory.value or ""
self.convert_models.hidden = (
len(model_paths) == 0 and len(autoload_directory) == 0
)
def _get_starter_model_labels(self) -> List[str]:
window_width, window_height = get_terminal_size()
label_width = 25
@@ -303,6 +322,7 @@ class addModelsForm(npyscreen.FormMultiPage):
.scan_directory: Path to a directory of models to scan and import
.autoscan_on_startup: True if invokeai should scan and import at startup time
.import_model_paths: list of URLs, repo_ids and file paths to import
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
"""
# we're using a global here rather than storing the result in the parentapp
# due to some bug in npyscreen that is causing attributes to be lost
@@ -339,6 +359,7 @@ class addModelsForm(npyscreen.FormMultiPage):
# URLs and the like
selections.import_model_paths = self.import_model_paths.value.split()
selections.convert_to_diffusers = self.convert_models.value[0] == 1
class AddModelApplication(npyscreen.NPSAppManaged):
@@ -351,6 +372,7 @@ class AddModelApplication(npyscreen.NPSAppManaged):
scan_directory=None,
autoscan_on_startup=None,
import_model_paths=None,
convert_to_diffusers=None,
)
def onStart(self):
@@ -371,6 +393,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
directory_to_scan = selections.scan_directory
scan_at_startup = selections.autoscan_on_startup
potential_models_to_install = selections.import_model_paths
convert_to_diffusers = selections.convert_to_diffusers
install_requested_models(
install_initial_models=models_to_install,
@@ -378,6 +401,7 @@ def process_and_execute(opt: Namespace, selections: Namespace):
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
external_models=potential_models_to_install,
scan_at_startup=scan_at_startup,
convert_to_diffusers=convert_to_diffusers,
precision="float32"
if opt.full_precision
else choose_precision(torch.device(choose_torch_device())),
@@ -456,8 +480,8 @@ def main():
Globals.root = os.path.expanduser(get_root(opt.root) or "")
if not global_config_dir().exists():
logger.info(
"Your InvokeAI root directory is not set up. Calling invokeai-configure."
print(
">> Your InvokeAI root directory is not set up. Calling invokeai-configure."
)
from invokeai.frontend.install import invokeai_configure
@@ -467,18 +491,18 @@ def main():
try:
select_and_download_models(opt)
except AssertionError as e:
logger.error(e)
print(str(e))
sys.exit(-1)
except KeyboardInterrupt:
logger.info("Goodbye! Come back soon.")
print("\nGoodbye! Come back soon.")
except widget.NotEnoughSpaceForWidget as e:
if str(e).startswith("Height of 1 allocated"):
logger.error(
"Insufficient vertical space for the interface. Please make your window taller and try again"
print(
"** Insufficient vertical space for the interface. Please make your window taller and try again"
)
elif str(e).startswith("addwstr"):
logger.error(
"Insufficient horizontal space for the interface. Please make your window wider and try again."
print(
"** Insufficient horizontal space for the interface. Please make your window wider and try again."
)

Some files were not shown because too many files have changed in this diff Show More