mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 09:18:00 -05:00
Compare commits
215 Commits
v3.6.1
...
refactor/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ffe672bc1 | ||
|
|
ed2d9ae0d9 | ||
|
|
09e7d35b55 | ||
|
|
9758082dc5 | ||
|
|
5f4ce0b118 | ||
|
|
8ac4b9b32c | ||
|
|
ec77599e79 | ||
|
|
2c1b8c0bc2 | ||
|
|
d4525e1282 | ||
|
|
b0d67ea2cc | ||
|
|
bd802d1e7a | ||
|
|
433eb73d8e | ||
|
|
b71f53ba86 | ||
|
|
68064c133a | ||
|
|
411ec1ed64 | ||
|
|
40a81c358d | ||
|
|
c45a43519a | ||
|
|
763816ca0c | ||
|
|
83a7c9059f | ||
|
|
c5f069a255 | ||
|
|
1d724bca4a | ||
|
|
a6508d1391 | ||
|
|
1eeca48529 | ||
|
|
cd169ee082 | ||
|
|
66b106f107 | ||
|
|
b10d745dae | ||
|
|
d20f98fb4f | ||
|
|
c9c150f850 | ||
|
|
a60e2b7c77 | ||
|
|
79d028ecbd | ||
|
|
da6e5b2ba1 | ||
|
|
c65d497cbc | ||
|
|
a68d8fe203 | ||
|
|
5de2288cfa | ||
|
|
2ce70b4457 | ||
|
|
6c5f743e2b | ||
|
|
bb242c4e1e | ||
|
|
c9e246ed1b | ||
|
|
2175fe3823 | ||
|
|
f64fc2c8b7 | ||
|
|
3d1b5c57ea | ||
|
|
31b9538976 | ||
|
|
97c1545cca | ||
|
|
6a8a3b50bc | ||
|
|
5a816818dc | ||
|
|
1cb866d1fc | ||
|
|
29bcc4b595 | ||
|
|
ca2bb6f0cc | ||
|
|
1c8fc908b2 | ||
|
|
d397beaa47 | ||
|
|
60eea09629 | ||
|
|
5b7b1122cb | ||
|
|
dfc8d1bb10 | ||
|
|
f9fa62164e | ||
|
|
d47905d2fb | ||
|
|
03b1cde97d | ||
|
|
7162ff04df | ||
|
|
32b1e974ca | ||
|
|
82c3c7fc38 | ||
|
|
3dcbb79ef7 | ||
|
|
3b41104427 | ||
|
|
35bf7ee66d | ||
|
|
430e17a5d2 | ||
|
|
400d66fa5d | ||
|
|
800c481515 | ||
|
|
79ae9c4e64 | ||
|
|
0dc6cb0535 | ||
|
|
810fc19e43 | ||
|
|
e0e106367d | ||
|
|
531d2c8fd7 | ||
|
|
37675ee4f5 | ||
|
|
26f721d0ec | ||
|
|
14472dc09d | ||
|
|
e8095b73ae | ||
|
|
c979cf5ecc | ||
|
|
1b4dbd283e | ||
|
|
fb50a221f8 | ||
|
|
52e07db06b | ||
|
|
6643b5cec4 | ||
|
|
e8bf9ea058 | ||
|
|
420f6050a6 | ||
|
|
ce3d37e829 | ||
|
|
8a61063e84 | ||
|
|
87ff96553a | ||
|
|
209bf105bc | ||
|
|
804dbeba34 | ||
|
|
067cd4dc2e | ||
|
|
feb4a3f242 | ||
|
|
4a886c0a4a | ||
|
|
8e500283b6 | ||
|
|
9804cb0e67 | ||
|
|
3205371654 | ||
|
|
d713620d9e | ||
|
|
c1300fa8b1 | ||
|
|
0976ddba23 | ||
|
|
3ebb806410 | ||
|
|
9f274c79dc | ||
|
|
88c08bbfc7 | ||
|
|
c2af124622 | ||
|
|
f972fe9836 | ||
|
|
dcfc883ab3 | ||
|
|
1d2bd6b8f7 | ||
|
|
4c5aedbcba | ||
|
|
f2777f5096 | ||
|
|
d3320dc4ee | ||
|
|
72db2ee352 | ||
|
|
60c3a4ad5e | ||
|
|
cf7a7928af | ||
|
|
1057314508 | ||
|
|
73a077956b | ||
|
|
5e1e50bd47 | ||
|
|
413fe566b8 | ||
|
|
c9b5f06c42 | ||
|
|
b53e432b0f | ||
|
|
88164447e9 | ||
|
|
1ac85fd049 | ||
|
|
ee6fc4ab1d | ||
|
|
9f793bdae8 | ||
|
|
a0eecaecd0 | ||
|
|
d532073f5b | ||
|
|
198e8c9d55 | ||
|
|
30367deeca | ||
|
|
e73298aea2 | ||
|
|
59279851a3 | ||
|
|
2965357d99 | ||
|
|
8bd32ee142 | ||
|
|
a4f892dcfb | ||
|
|
e675983e20 | ||
|
|
e9558f97c4 | ||
|
|
a1a611f8cb | ||
|
|
182dc859a0 | ||
|
|
c0240a8568 | ||
|
|
02bcff29e8 | ||
|
|
d4ed64df7d | ||
|
|
701f14c1e3 | ||
|
|
45bf2c7da6 | ||
|
|
a380d1f3b2 | ||
|
|
67ada70a26 | ||
|
|
06bcc07f65 | ||
|
|
4410ecf62c | ||
|
|
9f6b9d4d23 | ||
|
|
b24e8dd829 | ||
|
|
25291a2e01 | ||
|
|
332f3930a5 | ||
|
|
ed466a99ec | ||
|
|
f68f8898c0 | ||
|
|
a0996b1c0a | ||
|
|
522ff4a042 | ||
|
|
a769f93be0 | ||
|
|
2c5ef92979 | ||
|
|
5d773dc94c | ||
|
|
088e3420e6 | ||
|
|
14efc95707 | ||
|
|
f48a2c5fd2 | ||
|
|
74ae4d7774 | ||
|
|
191203ea0c | ||
|
|
6aceae5c22 | ||
|
|
8c6b3efd39 | ||
|
|
4602efd598 | ||
|
|
f70c0936ca | ||
|
|
0d4de4cc63 | ||
|
|
1e855f8290 | ||
|
|
bb2787584d | ||
|
|
a04981b418 | ||
|
|
d7f16b7c87 | ||
|
|
4477e04d59 | ||
|
|
30e11b4b42 | ||
|
|
b93695b78f | ||
|
|
b01311813b | ||
|
|
5ae80fab87 | ||
|
|
c4291f2136 | ||
|
|
287d3c2b04 | ||
|
|
7fde19730e | ||
|
|
13575642d8 | ||
|
|
3f5370b284 | ||
|
|
d048eb5b20 | ||
|
|
dd7031a472 | ||
|
|
4160d5ef26 | ||
|
|
51bdf2fd19 | ||
|
|
6a44697911 | ||
|
|
7a1d0ec228 | ||
|
|
b5928fd411 | ||
|
|
2f345d1976 | ||
|
|
f5d0721fa8 | ||
|
|
c3b36cb61d | ||
|
|
189c430e46 | ||
|
|
b922ee566a | ||
|
|
89da69f647 | ||
|
|
138caa34de | ||
|
|
26c3378ede | ||
|
|
aa134a2db8 | ||
|
|
d0391cb430 | ||
|
|
c955ea9de0 | ||
|
|
fc29a5d439 | ||
|
|
7e9942dbab | ||
|
|
c003967eaa | ||
|
|
b28fcc6be5 | ||
|
|
418cdbabb7 | ||
|
|
18e61e92d9 | ||
|
|
de20711637 | ||
|
|
55e91b97be | ||
|
|
f79bbd2d6e | ||
|
|
e1c2c3905d | ||
|
|
03ac93bfc7 | ||
|
|
89da976949 | ||
|
|
57dafd294d | ||
|
|
e611baa4b4 | ||
|
|
fc448d5b6d | ||
|
|
e59954f956 | ||
|
|
e160cbb1e9 | ||
|
|
68232e642f | ||
|
|
b94f6a4a29 | ||
|
|
4caf63d53d | ||
|
|
6b8a6e12bc | ||
|
|
6057229ceb |
98
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
98
.github/ISSUE_TEMPLATE/BUG_REPORT.yml
vendored
@@ -6,10 +6,6 @@ title: '[bug]: '
|
||||
|
||||
labels: ['bug']
|
||||
|
||||
# assignees:
|
||||
# - moderator_bot
|
||||
# - lstein
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
@@ -18,10 +14,9 @@ body:
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
label: Is there an existing issue for this problem?
|
||||
description: |
|
||||
Please use the [search function](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen+label%3Abug)
|
||||
irst to see if an issue already exists for the bug you encountered.
|
||||
Please [search](https://github.com/invoke-ai/InvokeAI/issues) first to see if an issue already exists for the problem.
|
||||
options:
|
||||
- label: I have searched the existing issues
|
||||
required: true
|
||||
@@ -33,80 +28,119 @@ body:
|
||||
- type: dropdown
|
||||
id: os_dropdown
|
||||
attributes:
|
||||
label: OS
|
||||
description: Which operating System did you use when the bug occured
|
||||
label: Operating system
|
||||
description: Your computer's operating system.
|
||||
multiple: false
|
||||
options:
|
||||
- 'Linux'
|
||||
- 'Windows'
|
||||
- 'macOS'
|
||||
- 'other'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: gpu_dropdown
|
||||
attributes:
|
||||
label: GPU
|
||||
description: Which kind of Graphic-Adapter is your System using
|
||||
label: GPU vendor
|
||||
description: Your GPU's vendor.
|
||||
multiple: false
|
||||
options:
|
||||
- 'cuda'
|
||||
- 'amd'
|
||||
- 'mps'
|
||||
- 'cpu'
|
||||
- 'Nvidia (CUDA)'
|
||||
- 'AMD (ROCm)'
|
||||
- 'Apple Silicon (MPS)'
|
||||
- 'None (CPU)'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: gpu_model
|
||||
attributes:
|
||||
label: GPU model
|
||||
description: Your GPU's model. If on Apple Silicon, this is your Mac's chip. Leave blank if on CPU.
|
||||
placeholder: ex. RTX 2080 Ti, Mac M1 Pro
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: vram
|
||||
attributes:
|
||||
label: VRAM
|
||||
description: Size of the VRAM if known
|
||||
label: GPU VRAM
|
||||
description: Your GPU's VRAM. If on Apple Silicon, this is your Mac's unified memory. Leave blank if on CPU.
|
||||
placeholder: 8GB
|
||||
validations:
|
||||
required: false
|
||||
|
||||
|
||||
- type: input
|
||||
id: version-number
|
||||
attributes:
|
||||
label: What version did you experience this issue on?
|
||||
label: Version number
|
||||
description: |
|
||||
Please share the version of Invoke AI that you experienced the issue on. If this is not the latest version, please update first to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: X.X.X
|
||||
The version of Invoke you have installed. If it is not the latest version, please update and try again to confirm the issue still exists. If you are testing main, please include the commit hash instead.
|
||||
placeholder: ex. 3.6.1
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: input
|
||||
id: browser-version
|
||||
attributes:
|
||||
label: Browser
|
||||
description: Your web browser and version.
|
||||
placeholder: ex. Firefox 123.0b3
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: python-deps
|
||||
attributes:
|
||||
label: Python dependencies
|
||||
description: |
|
||||
If the problem occurred during image generation, click the gear icon at the bottom left corner, click "About", click the copy button and then paste here.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
id: what-happened
|
||||
attributes:
|
||||
label: What happened?
|
||||
label: What happened
|
||||
description: |
|
||||
Briefly describe what happened, what you expected to happen and how to reproduce this bug.
|
||||
placeholder: When using the webinterface and right-clicking on button X instead of the popup-menu there error Y appears
|
||||
Describe what happened. Include any relevant error messages, stack traces and screenshots here.
|
||||
placeholder: I clicked button X and then Y happened.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: what-you-expected
|
||||
attributes:
|
||||
label: Screenshots
|
||||
description: If applicable, add screenshots to help explain your problem
|
||||
placeholder: this is what the result looked like <screenshot>
|
||||
label: What you expected to happen
|
||||
description: Describe what you expected to happen.
|
||||
placeholder: I expected Z to happen.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: how-to-repro
|
||||
attributes:
|
||||
label: How to reproduce the problem
|
||||
description: List steps to reproduce the problem.
|
||||
placeholder: Start the app, generate an image with these settings, then click button X.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: Add any other context about the problem here
|
||||
description: Any other context that might help us to understand the problem.
|
||||
placeholder: Only happens when there is full moon and Friday the 13th on Christmas Eve 🎅🏻
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: contact
|
||||
id: discord-username
|
||||
attributes:
|
||||
label: Contact Details
|
||||
description: __OPTIONAL__ How can we get in touch with you if we need more info (besides this issue)?
|
||||
placeholder: ex. email@example.com, discordname, twitter, ...
|
||||
label: Discord username
|
||||
description: If you are on the Invoke discord and would prefer to be contacted there, please provide your username.
|
||||
placeholder: supercoolusername123
|
||||
validations:
|
||||
required: false
|
||||
|
||||
59
.github/pr_labels.yml
vendored
Normal file
59
.github/pr_labels.yml
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
Root:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: '*'
|
||||
|
||||
PythonDeps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'pyproject.toml'
|
||||
|
||||
Python:
|
||||
- changed-files:
|
||||
- all-globs-to-any-file:
|
||||
- 'invokeai/**'
|
||||
- '!invokeai/frontend/web/**'
|
||||
|
||||
PythonTests:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'tests/**'
|
||||
|
||||
CICD:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: .github/**
|
||||
|
||||
Docker:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docker/**
|
||||
|
||||
Installer:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: installer/**
|
||||
|
||||
Documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docs/**
|
||||
|
||||
Invocations:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/invocations/**'
|
||||
|
||||
Backend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/backend/**'
|
||||
|
||||
Api:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/api/**'
|
||||
|
||||
Services:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/app/services/**'
|
||||
|
||||
FrontendDeps:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file:
|
||||
- '**/*/package.json'
|
||||
- '**/*/pnpm-lock.yaml'
|
||||
|
||||
Frontend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: 'invokeai/frontend/web/**'
|
||||
16
.github/workflows/label-pr.yml
vendored
Normal file
16
.github/workflows/label-pr.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
name: "Pull Request Labeler"
|
||||
on:
|
||||
- pull_request_target
|
||||
|
||||
jobs:
|
||||
labeler:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
configuration-path: .github/pr_labels.yml
|
||||
@@ -169,7 +169,7 @@ the command `npm install -g pnpm` if needed)
|
||||
_For Linux with an AMD GPU:_
|
||||
|
||||
```sh
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
|
||||
```
|
||||
|
||||
_For non-GPU systems:_
|
||||
|
||||
@@ -28,7 +28,7 @@ model. These are the:
|
||||
Hugging Face, as well as discriminating among model versions in
|
||||
Civitai, but can be used for arbitrary content.
|
||||
|
||||
* _ModelLoadServiceBase_ (**CURRENTLY UNDER DEVELOPMENT - NOT IMPLEMENTED**)
|
||||
* _ModelLoadServiceBase_
|
||||
Responsible for loading a model from disk
|
||||
into RAM and VRAM and getting it ready for inference.
|
||||
|
||||
@@ -41,10 +41,10 @@ The four main services can be found in
|
||||
* `invokeai/app/services/model_records/`
|
||||
* `invokeai/app/services/model_install/`
|
||||
* `invokeai/app/services/downloads/`
|
||||
* `invokeai/app/services/model_loader/` (**under development**)
|
||||
* `invokeai/app/services/model_load/`
|
||||
|
||||
Code related to the FastAPI web API can be found in
|
||||
`invokeai/app/api/routers/model_records.py`.
|
||||
`invokeai/app/api/routers/model_manager_v2.py`.
|
||||
|
||||
***
|
||||
|
||||
@@ -84,10 +84,10 @@ diffusers model. When this happens, `original_hash` is unchanged, but
|
||||
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
|
||||
are defined in `invokeai.backend.model_manager.config`. They are also
|
||||
imported by, and can be reexported from,
|
||||
`invokeai.app.services.model_record_service`:
|
||||
`invokeai.app.services.model_manager.model_records`:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_record_service import ModelType, ModelFormat, BaseModelType
|
||||
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
|
||||
```
|
||||
|
||||
The `path` field can be absolute or relative. If relative, it is taken
|
||||
@@ -123,7 +123,7 @@ taken to be the `models_dir` directory.
|
||||
|
||||
`variant` is an enumerated string class with values `normal`,
|
||||
`inpaint` and `depth`. If needed, it can be imported if needed from
|
||||
either `invokeai.app.services.model_record_service` or
|
||||
either `invokeai.app.services.model_records` or
|
||||
`invokeai.backend.model_manager.config`.
|
||||
|
||||
### ONNXSD2Config
|
||||
@@ -134,7 +134,7 @@ either `invokeai.app.services.model_record_service` or
|
||||
| `upcast_attention` | bool | Model requires its attention module to be upcast |
|
||||
|
||||
The `SchedulerPredictionType` enum can be imported from either
|
||||
`invokeai.app.services.model_record_service` or
|
||||
`invokeai.app.services.model_records` or
|
||||
`invokeai.backend.model_manager.config`.
|
||||
|
||||
### Other config classes
|
||||
@@ -157,15 +157,6 @@ indicates that the model is compatible with any of the base
|
||||
models. This works OK for some models, such as the IP Adapter image
|
||||
encoders, but is an all-or-nothing proposition.
|
||||
|
||||
Another issue is that the config class hierarchy is paralleled to some
|
||||
extent by a `ModelBase` class hierarchy defined in
|
||||
`invokeai.backend.model_manager.models.base` and its subclasses. These
|
||||
are classes representing the models after they are loaded into RAM and
|
||||
include runtime information such as load status and bytes used. Some
|
||||
of the fields, including `name`, `model_type` and `base_model`, are
|
||||
shared between `ModelConfigBase` and `ModelBase`, and this is a
|
||||
potential source of confusion.
|
||||
|
||||
## Reading and Writing Model Configuration Records
|
||||
|
||||
The `ModelRecordService` provides the ability to retrieve model
|
||||
@@ -177,11 +168,11 @@ initialization and can be retrieved within an invocation from the
|
||||
`InvocationContext` object:
|
||||
|
||||
```
|
||||
store = context.services.model_record_store
|
||||
store = context.services.model_manager.store
|
||||
```
|
||||
|
||||
or from elsewhere in the code by accessing
|
||||
`ApiDependencies.invoker.services.model_record_store`.
|
||||
`ApiDependencies.invoker.services.model_manager.store`.
|
||||
|
||||
### Creating a `ModelRecordService`
|
||||
|
||||
@@ -190,7 +181,7 @@ you can directly create either a `ModelRecordServiceSQL` or a
|
||||
`ModelRecordServiceFile` object:
|
||||
|
||||
```
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceSQL, ModelRecordServiceFile
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL, ModelRecordServiceFile
|
||||
|
||||
store = ModelRecordServiceSQL.from_connection(connection, lock)
|
||||
store = ModelRecordServiceSQL.from_db_file('/path/to/sqlite_database.db')
|
||||
@@ -252,7 +243,7 @@ So a typical startup pattern would be:
|
||||
```
|
||||
import sqlite3
|
||||
from invokeai.app.services.thread import lock
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
@@ -260,19 +251,6 @@ db_conn = sqlite3.connect(config.db_path.as_posix(), check_same_thread=False)
|
||||
store = ModelRecordServiceBase.open(config, db_conn, lock)
|
||||
```
|
||||
|
||||
_A note on simultaneous access to `invokeai.db`_: The current InvokeAI
|
||||
service architecture for the image and graph databases is careful to
|
||||
use a shared sqlite3 connection and a thread lock to ensure that two
|
||||
threads don't attempt to access the database simultaneously. However,
|
||||
the default `sqlite3` library used by Python reports using
|
||||
**Serialized** mode, which allows multiple threads to access the
|
||||
database simultaneously using multiple database connections (see
|
||||
https://www.sqlite.org/threadsafe.html and
|
||||
https://ricardoanderegg.com/posts/python-sqlite-thread-safety/). Therefore
|
||||
it should be safe to allow the record service to open its own SQLite
|
||||
database connection. Opening a model record service should then be as
|
||||
simple as `ModelRecordServiceBase.open(config)`.
|
||||
|
||||
### Fetching a Model's Configuration from `ModelRecordServiceBase`
|
||||
|
||||
Configurations can be retrieved in several ways.
|
||||
@@ -468,6 +446,44 @@ required parameters:
|
||||
|
||||
Once initialized, the installer will provide the following methods:
|
||||
|
||||
#### install_job = installer.heuristic_import(source, [config], [access_token])
|
||||
|
||||
This is a simplified interface to the installer which takes a source
|
||||
string, an optional model configuration dictionary and an optional
|
||||
access token.
|
||||
|
||||
The `source` is a string that can be any of these forms
|
||||
|
||||
1. A path on the local filesystem (`C:\\users\\fred\\model.safetensors`)
|
||||
2. A Url pointing to a single downloadable model file (`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
- `model/name` -- entire model
|
||||
- `model/name:fp32` -- entire model, using the fp32 variant
|
||||
- `model/name:fp16:vae` -- vae submodel, using the fp16 variant
|
||||
- `model/name::vae` -- vae submodel, using default precision
|
||||
- `model/name:fp16:path/to/model.safetensors` -- an individual model file, fp16 variant
|
||||
- `model/name::path/to/model.safetensors` -- an individual model file, default variant
|
||||
|
||||
Note that by specifying a relative path to the top of the HuggingFace
|
||||
repo, you can download and install arbitrary models files.
|
||||
|
||||
The variant, if not provided, will be automatically filled in with
|
||||
`fp32` if the user has requested full precision, and `fp16`
|
||||
otherwise. If a variant that does not exist is requested, then the
|
||||
method will install whatever HuggingFace returns as its default
|
||||
revision.
|
||||
|
||||
`config` is an optional dict of values that will override the
|
||||
autoprobed values for model type, base, scheduler prediction type, and
|
||||
so forth. See [Model configuration and
|
||||
probing](#Model-configuration-and-probing) for details.
|
||||
|
||||
`access_token` is an optional access token for accessing resources
|
||||
that need authentication.
|
||||
|
||||
The method will return a `ModelInstallJob`. This object is discussed
|
||||
at length in the following section.
|
||||
|
||||
#### install_job = installer.import_model()
|
||||
|
||||
The `import_model()` method is the core of the installer. The
|
||||
@@ -486,9 +502,10 @@ source2 = LocalModelSource(path='/opt/models/sushi_diffusers') # a local dif
|
||||
source3 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5') # a repo_id
|
||||
source4 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='vae') # a subfolder within a repo_id
|
||||
source5 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', variant='fp16') # a named variant of a HF model
|
||||
source6 = HFModelSource(repo_id='runwayml/stable-diffusion-v1-5', subfolder='OrangeMix/OrangeMix1.ckpt') # path to an individual model file
|
||||
|
||||
source6 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
||||
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
||||
source7 = URLModelSource(url='https://civitai.com/api/download/models/63006') # model located at a URL
|
||||
source8 = URLModelSource(url='https://civitai.com/api/download/models/63006', access_token='letmein') # with an access token
|
||||
|
||||
for source in [source1, source2, source3, source4, source5, source6, source7]:
|
||||
install_job = installer.install_model(source)
|
||||
@@ -544,7 +561,6 @@ can be passed to `import_model()`.
|
||||
attributes returned by the model prober. See the section below for
|
||||
details.
|
||||
|
||||
|
||||
#### LocalModelSource
|
||||
|
||||
This is used for a model that is located on a locally-accessible Posix
|
||||
@@ -737,7 +753,7 @@ and `cancelled`, as well as `in_terminal_state`. The last will return
|
||||
True if the job is in the complete, errored or cancelled states.
|
||||
|
||||
|
||||
#### Model confguration and probing
|
||||
#### Model configuration and probing
|
||||
|
||||
The install service uses the `invokeai.backend.model_manager.probe`
|
||||
module during import to determine the model's type, base type, and
|
||||
@@ -776,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will
|
||||
return from the call if jobs aren't completed in the specified
|
||||
time. An argument of 0 (the default) will block indefinitely.
|
||||
|
||||
#### jobs = installer.wait_for_job(job, [timeout])
|
||||
|
||||
Like `wait_for_installs()`, but block until a specific job has
|
||||
completed or errored, and then return the job. The optional `timeout`
|
||||
argument will return from the call if the job doesn't complete in the
|
||||
specified time. An argument of 0 (the default) will block
|
||||
indefinitely.
|
||||
|
||||
#### jobs = installer.list_jobs()
|
||||
|
||||
Return a list of all active and complete `ModelInstallJobs`.
|
||||
@@ -838,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally
|
||||
deletes the corresponding model weights file(s), regardless of whether
|
||||
they are inside or outside the InvokeAI models hierarchy.
|
||||
|
||||
|
||||
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||
|
||||
This utility routine will download the model file located at source,
|
||||
cache it, and return the path to the cached file. It does not attempt
|
||||
to determine the model type, probe its configuration values, or
|
||||
register it with the models database.
|
||||
|
||||
You may provide an access token if the remote source requires
|
||||
authorization. The call will block indefinitely until the file is
|
||||
completely downloaded, cancelled or raises an error of some sort. If
|
||||
you provide a timeout (in seconds), the call will raise a
|
||||
`TimeoutError` exception if the download hasn't completed in the
|
||||
specified period.
|
||||
|
||||
You may use this mechanism to request any type of file, not just a
|
||||
model. The file will be stored in a subdirectory of
|
||||
`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the
|
||||
cache, its path will be returned without redownloading it.
|
||||
|
||||
Be aware that the models cache is cleared of infrequently-used files
|
||||
and directories at regular intervals when the size of the cache
|
||||
exceeds the value specified in Invoke's `convert_cache` configuration
|
||||
variable.
|
||||
|
||||
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
|
||||
|
||||
This method will recursively scan the directory indicated in
|
||||
@@ -1128,7 +1177,7 @@ job = queue.create_download_job(
|
||||
event_handlers=[my_handler1, my_handler2], # if desired
|
||||
start=True,
|
||||
)
|
||||
```
|
||||
```
|
||||
|
||||
The `filename` argument forces the downloader to use the specified
|
||||
name for the file rather than the name provided by the remote source,
|
||||
@@ -1171,6 +1220,13 @@ queue or was not created by this queue.
|
||||
This method will block until all the active jobs in the queue have
|
||||
reached a terminal state (completed, errored or cancelled).
|
||||
|
||||
#### queue.wait_for_job(job, [timeout])
|
||||
|
||||
This method will block until the indicated job has reached a terminal
|
||||
state (completed, errored or cancelled). If the optional timeout is
|
||||
provided, the call will block for at most timeout seconds, and raise a
|
||||
TimeoutError otherwise.
|
||||
|
||||
#### jobs = queue.list_jobs()
|
||||
|
||||
This will return a list of all jobs, including ones that have not yet
|
||||
@@ -1449,9 +1505,9 @@ set of keys to the corresponding model config objects.
|
||||
Find all model metadata records that have the given author and return
|
||||
a set of keys to the corresponding model config objects.
|
||||
|
||||
# The remainder of this documentation is provisional, pending implementation of the Load service
|
||||
***
|
||||
|
||||
## Let's get loaded, the lowdown on ModelLoadService
|
||||
## The Lowdown on the ModelLoadService
|
||||
|
||||
The `ModelLoadService` is responsible for loading a named model into
|
||||
memory so that it can be used for inference. Despite the fact that it
|
||||
@@ -1465,7 +1521,7 @@ create alternative instances if you wish.
|
||||
### Creating a ModelLoadService object
|
||||
|
||||
The class is defined in
|
||||
`invokeai.app.services.model_loader_service`. It is initialized with
|
||||
`invokeai.app.services.model_load`. It is initialized with
|
||||
an InvokeAIAppConfig object, from which it gets configuration
|
||||
information such as the user's desired GPU and precision, and with a
|
||||
previously-created `ModelRecordServiceBase` object, from which it
|
||||
@@ -1475,26 +1531,29 @@ Here is a typical initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_record_service import ModelRecordServiceBase
|
||||
from invokeai.app.services.model_loader_service import ModelLoadService
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegistry
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
store = ModelRecordServiceBase.open(config)
|
||||
loader = ModelLoadService(config, store)
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry
|
||||
)
|
||||
```
|
||||
|
||||
Note that we are relying on the contents of the application
|
||||
configuration to choose the implementation of
|
||||
`ModelRecordServiceBase`.
|
||||
### load_model(model_config, [submodel_type], [context]) -> LoadedModel
|
||||
|
||||
### get_model(key, [submodel_type], [context]) -> ModelInfo:
|
||||
|
||||
*** TO DO: change to get_model(key, context=None, **kwargs)
|
||||
|
||||
The `get_model()` method, like its similarly-named cousin in
|
||||
`ModelRecordService`, receives the unique key that identifies the
|
||||
The `load_model()` method takes an `AnyModelConfig` returned by
|
||||
`ModelRecordService.get_model()` and returns the corresponding loaded
|
||||
model. It loads the model into memory, gets the model ready for use,
|
||||
and returns a `ModelInfo` object.
|
||||
and returns a `LoadedModel` object.
|
||||
|
||||
The optional second argument, `subtype` is a `SubModelType` string
|
||||
enum, such as "vae". It is mandatory when used with a main model, and
|
||||
@@ -1504,46 +1563,45 @@ The optional third argument, `context` can be provided by
|
||||
an invocation to trigger model load event reporting. See below for
|
||||
details.
|
||||
|
||||
The returned `ModelInfo` object shares some fields in common with
|
||||
`ModelConfigBase`, but is otherwise a completely different beast:
|
||||
The returned `LoadedModel` object contains a copy of the configuration
|
||||
record returned by the model record `get_model()` method, as well as
|
||||
the in-memory loaded model:
|
||||
|
||||
| **Field Name** | **Type** | **Description** |
|
||||
|
||||
| **Attribute Name** | **Type** | **Description** |
|
||||
|----------------|-----------------|------------------|
|
||||
| `key` | str | The model key derived from the ModelRecordService database |
|
||||
| `name` | str | Name of this model |
|
||||
| `base_model` | BaseModelType | Base model for this model |
|
||||
| `type` | ModelType or SubModelType | Either the model type (non-main) or the submodel type (main models)|
|
||||
| `location` | Path or str | Location of the model on the filesystem |
|
||||
| `precision` | torch.dtype | The torch.precision to use for inference |
|
||||
| `context` | ModelCache.ModelLocker | A context class used to lock the model in VRAM while in use |
|
||||
| `config` | AnyModelConfig | A copy of the model's configuration record for retrieving base type, etc. |
|
||||
| `model` | AnyModel | The instantiated model (details below) |
|
||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||
|
||||
The types for `ModelInfo` and `SubModelType` can be imported from
|
||||
`invokeai.app.services.model_loader_service`.
|
||||
Because the loader can return multiple model types, it is typed to
|
||||
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
||||
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
||||
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||
models. The others are obvious.
|
||||
|
||||
To use the model, you use the `ModelInfo` as a context manager using
|
||||
the following pattern:
|
||||
|
||||
`LoadedModel` acts as a context manager. The context loads the model
|
||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||
in the execution device for the duration of the context, and returns
|
||||
the model. Use it like this:
|
||||
|
||||
```
|
||||
model_info = loader.get_model('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||
with model_info as vae:
|
||||
image = vae.decode(latents)[0]
|
||||
```
|
||||
|
||||
The `vae` model will stay locked in the GPU during the period of time
|
||||
it is in the context manager's scope.
|
||||
`get_model_by_key()` may raise any of the following exceptions:
|
||||
|
||||
`get_model()` may raise any of the following exceptions:
|
||||
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `InvalidModelException` -- the model is guilty of a variety of sins
|
||||
- `UnknownModelException` -- key not in database
|
||||
- `ModelNotFoundException` -- key in database but model not found at path
|
||||
- `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||
|
||||
** TO DO: ** Resolve discrepancy between ModelInfo.location and
|
||||
ModelConfig.path.
|
||||
|
||||
### Emitting model loading events
|
||||
|
||||
When the `context` argument is passed to `get_model()`, it will
|
||||
When the `context` argument is passed to `load_model_*()`, it will
|
||||
retrieve the invocation event bus from the passed `InvocationContext`
|
||||
object to emit events on the invocation bus. The two events are
|
||||
"model_load_started" and "model_load_completed". Both carry the
|
||||
@@ -1556,10 +1614,174 @@ payload=dict(
|
||||
queue_batch_id=queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_key=model_key,
|
||||
submodel=submodel,
|
||||
submodel_type=submodel,
|
||||
hash=model_info.hash,
|
||||
location=str(model_info.location),
|
||||
precision=str(model_info.precision),
|
||||
)
|
||||
```
|
||||
|
||||
### Adding Model Loaders
|
||||
|
||||
Model loaders are small classes that inherit from the `ModelLoader`
|
||||
base class. They typically implement one method `_load_model()` whose
|
||||
signature is:
|
||||
|
||||
```
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
```
|
||||
|
||||
`_load_model()` will be passed the path to the model on disk, an
|
||||
optional repository variant (used by the diffusers loaders to select,
|
||||
e.g. the `fp16` variant, and an optional submodel_type for main and
|
||||
onnx models.
|
||||
|
||||
To install a new loader, place it in
|
||||
`invokeai/backend/model_manager/load/model_loaders`. Inherit from
|
||||
`ModelLoader` and use the `@ModelLoaderRegistry.register()` decorator to
|
||||
indicate what type of models the loader can handle.
|
||||
|
||||
Here is a complete example from `generic_diffusers.py`, which is able
|
||||
to load several different diffusers types:
|
||||
|
||||
```
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
class GenericDiffusersLoader(ModelLoader):
|
||||
"""Class to load simple diffusers models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_class = self._get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
||||
return result
|
||||
```
|
||||
|
||||
Note that a loader can register itself to handle several different
|
||||
model types. An exception will be raised if more than one loader tries
|
||||
to register the same model type.
|
||||
|
||||
#### Conversion
|
||||
|
||||
Some models require conversion to diffusers format before they can be
|
||||
loaded. These loaders should override two additional methods:
|
||||
|
||||
```
|
||||
_needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool
|
||||
_convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
```
|
||||
|
||||
The first method accepts the model configuration, the path to where
|
||||
the unmodified model is currently installed, and a proposed
|
||||
destination for the converted model. This method returns True if the
|
||||
model needs to be converted. It typically does this by comparing the
|
||||
last modification time of the original model file to the modification
|
||||
time of the converted model. In some cases you will also want to check
|
||||
the modification date of the configuration record, in the event that
|
||||
the user has changed something like the scheduler prediction type that
|
||||
will require the model to be re-converted. See `controlnet.py` for an
|
||||
example of this logic.
|
||||
|
||||
The second method accepts the model configuration, the path to the
|
||||
original model on disk, and the desired output path for the converted
|
||||
model. It does whatever it needs to do to get the model into diffusers
|
||||
format, and returns the Path of the resulting model. (The path should
|
||||
ordinarily be the same as `output_path`.)
|
||||
|
||||
## The ModelManagerService object
|
||||
|
||||
For convenience, the API provides a `ModelManagerService` object which
|
||||
gives a single point of access to the major model manager
|
||||
services. This object is created at initialization time and can be
|
||||
found in the global `ApiDependencies.invoker.services.model_manager`
|
||||
object, or in `context.services.model_manager` from within an
|
||||
invocation.
|
||||
|
||||
In the examples below, we have retrieved the manager using:
|
||||
```
|
||||
mm = ApiDependencies.invoker.services.model_manager
|
||||
```
|
||||
|
||||
The following properties and methods will be available:
|
||||
|
||||
### mm.store
|
||||
|
||||
This retrieves the `ModelRecordService` associated with the
|
||||
manager. Example:
|
||||
|
||||
```
|
||||
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
|
||||
```
|
||||
|
||||
### mm.install
|
||||
|
||||
This retrieves the `ModelInstallService` associated with the manager.
|
||||
Example:
|
||||
|
||||
```
|
||||
job = mm.install.heuristic_import(`https://civitai.com/models/58390/detail-tweaker-lora-lora`)
|
||||
```
|
||||
|
||||
### mm.load
|
||||
|
||||
This retrieves the `ModelLoaderService` associated with the manager. Example:
|
||||
|
||||
```
|
||||
configs = mm.store.get_model_by_attr(name='stable-diffusion-v1-5')
|
||||
assert len(configs) > 0
|
||||
|
||||
loaded_model = mm.load.load_model(configs[0])
|
||||
```
|
||||
|
||||
The model manager also offers a few convenience shortcuts for loading
|
||||
models:
|
||||
|
||||
### mm.load_model_by_config(model_config, [submodel], [context]) -> LoadedModel
|
||||
|
||||
Same as `mm.load.load_model()`.
|
||||
|
||||
### mm.load_model_by_attr(model_name, base_model, model_type, [submodel], [context]) -> LoadedModel
|
||||
|
||||
This accepts the combination of the model's name, type and base, which
|
||||
it passes to the model record config store for retrieval. If a unique
|
||||
model config is found, this method returns a `LoadedModel`. It can
|
||||
raise the following exceptions:
|
||||
|
||||
```
|
||||
UnknownModelException -- model with these attributes not known
|
||||
NotImplementedException -- the loader doesn't know how to load this type of model
|
||||
ValueError -- more than one model matches this combination of base/type/name
|
||||
```
|
||||
|
||||
### mm.load_model_by_key(key, [submodel], [context]) -> LoadedModel
|
||||
|
||||
This method takes a model key, looks it up using the
|
||||
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||
model configuration to `load_model_by_config()`. It may raise a
|
||||
`NotImplementedException`.
|
||||
|
||||
BIN
docs/img/favicon.ico
Normal file
BIN
docs/img/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.2 KiB |
@@ -117,6 +117,11 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
||||
|
||||
## :octicons-gift-24: InvokeAI Features
|
||||
|
||||
### Installation
|
||||
- [Automated Installer](installation/010_INSTALL_AUTOMATED.md)
|
||||
- [Manual Installation](installation/020_INSTALL_MANUAL.md)
|
||||
- [Docker Installation](installation/040_INSTALL_DOCKER.md)
|
||||
|
||||
### The InvokeAI Web Interface
|
||||
- [WebUI overview](features/WEB.md)
|
||||
- [WebUI hotkey reference guide](features/WEBUIHOTKEYS.md)
|
||||
|
||||
@@ -477,7 +477,7 @@ Then type the following commands:
|
||||
|
||||
=== "AMD System"
|
||||
```bash
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
pip install torch torchvision --force-reinstall --extra-index-url https://download.pytorch.org/whl/rocm5.6
|
||||
```
|
||||
|
||||
### Corrupted configuration file
|
||||
|
||||
@@ -154,7 +154,7 @@ manager, please follow these steps:
|
||||
=== "ROCm (AMD)"
|
||||
|
||||
```bash
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
pip install InvokeAI --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
|
||||
```
|
||||
|
||||
=== "CPU (Intel Macs & non-GPU systems)"
|
||||
@@ -313,7 +313,7 @@ code for InvokeAI. For this to work, you will need to install the
|
||||
on your system, please see the [Git Installation
|
||||
Guide](https://github.com/git-guides/install-git)
|
||||
|
||||
You will also need to install the [frontend development toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md).
|
||||
You will also need to install the [frontend development toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/README.md).
|
||||
|
||||
If you have a "normal" installation, you should create a totally separate virtual environment for the git-based installation, else the two may interfere.
|
||||
|
||||
@@ -345,7 +345,7 @@ installation protocol (important!)
|
||||
|
||||
=== "ROCm (AMD)"
|
||||
```bash
|
||||
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
|
||||
pip install -e . --use-pep517 --extra-index-url https://download.pytorch.org/whl/rocm5.6
|
||||
```
|
||||
|
||||
=== "CPU (Intel Macs & non-GPU systems)"
|
||||
@@ -361,7 +361,7 @@ installation protocol (important!)
|
||||
Be sure to pass `-e` (for an editable install) and don't forget the
|
||||
dot ("."). It is part of the command.
|
||||
|
||||
5. Install the [frontend toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/docs/contributing/contribution_guides/contributingToFrontend.md) and do a production build of the UI as described.
|
||||
5. Install the [frontend toolchain](https://github.com/invoke-ai/InvokeAI/blob/main/invokeai/frontend/web/README.md) and do a production build of the UI as described.
|
||||
|
||||
6. You can now run `invokeai` and its related commands. The code will be
|
||||
read from the repository, so that you can edit the .py source files
|
||||
|
||||
@@ -134,7 +134,7 @@ recipes are available
|
||||
|
||||
When installing torch and torchvision manually with `pip`, remember to provide
|
||||
the argument `--extra-index-url
|
||||
https://download.pytorch.org/whl/rocm5.4.2` as described in the [Manual
|
||||
https://download.pytorch.org/whl/rocm5.6` as described in the [Manual
|
||||
Installation Guide](020_INSTALL_MANUAL.md).
|
||||
|
||||
This will be done automatically for you if you use the installer
|
||||
|
||||
@@ -18,13 +18,18 @@ either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm
|
||||
driver).
|
||||
|
||||
|
||||
## **[Automated Installer](010_INSTALL_AUTOMATED.md)**
|
||||
✅ This is the recommended installation method for first-time users.
|
||||
## **[Automated Installer (Recommended)](010_INSTALL_AUTOMATED.md)**
|
||||
✅ This is the recommended installation method for first-time users.
|
||||
|
||||
This is a script that will install all of InvokeAI's essential
|
||||
third party libraries and InvokeAI itself. It includes access to a
|
||||
"developer console" which will help us debug problems with you and
|
||||
give you to access experimental features.
|
||||
third party libraries and InvokeAI itself.
|
||||
|
||||
🖥️ **Download the latest installer .zip file here** : https://github.com/invoke-ai/InvokeAI/releases/latest
|
||||
|
||||
- *Look for the file labelled "InvokeAI-installer-v3.X.X.zip" at the bottom of the page*
|
||||
- If you experience issues, read through the full [installation instructions](010_INSTALL_AUTOMATED.md) to make sure you have met all of the installation requirements. If you need more help, join the [Discord](discord.gg/invoke-ai) or create an issue on [Github](https://github.com/invoke-ai/InvokeAI).
|
||||
|
||||
|
||||
|
||||
## **[Manual Installation](020_INSTALL_MANUAL.md)**
|
||||
This method is recommended for experienced users and developers.
|
||||
|
||||
@@ -14,6 +14,7 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
|
||||
- Community Nodes
|
||||
+ [Adapters-Linked](#adapters-linked-nodes)
|
||||
+ [Autostereogram](#autostereogram-nodes)
|
||||
+ [Average Images](#average-images)
|
||||
+ [Clean Image Artifacts After Cut](#clean-image-artifacts-after-cut)
|
||||
+ [Close Color Mask](#close-color-mask)
|
||||
@@ -25,6 +26,7 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [GPT2RandomPromptMaker](#gpt2randompromptmaker)
|
||||
+ [Grid to Gif](#grid-to-gif)
|
||||
+ [Halftone](#halftone)
|
||||
+ [Hand Refiner with MeshGraphormer](#hand-refiner-with-meshgraphormer)
|
||||
+ [Image and Mask Composition Pack](#image-and-mask-composition-pack)
|
||||
+ [Image Dominant Color](#image-dominant-color)
|
||||
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
|
||||
@@ -40,6 +42,7 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [Oobabooga](#oobabooga)
|
||||
+ [Prompt Tools](#prompt-tools)
|
||||
+ [Remote Image](#remote-image)
|
||||
+ [BriaAI Background Remove](#briaai-remove-background)
|
||||
+ [Remove Background](#remove-background)
|
||||
+ [Retroize](#retroize)
|
||||
+ [Size Stepper Nodes](#size-stepper-nodes)
|
||||
@@ -66,6 +69,17 @@ Note: These are inherited from the core nodes so any update to the core nodes sh
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/adapters-linked-nodes
|
||||
|
||||
--------------------------------
|
||||
### Autostereogram Nodes
|
||||
|
||||
**Description:** Generate autostereogram images from a depth map. This is not a very practically useful node but more a 90s nostalgic indulgence as I used to love these images as a kid.
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/autostereogram_nodes
|
||||
|
||||
**Example Usage:**
|
||||
</br>
|
||||
<img src="https://github.com/skunkworxdark/autostereogram_nodes/blob/main/images/spider.png" width="200" /> -> <img src="https://github.com/skunkworxdark/autostereogram_nodes/blob/main/images/spider-depth.png" width="200" /> -> <img src="https://github.com/skunkworxdark/autostereogram_nodes/raw/main/images/spider-dots.png" width="200" /> <img src="https://github.com/skunkworxdark/autostereogram_nodes/raw/main/images/spider-pattern.png" width="200" />
|
||||
|
||||
--------------------------------
|
||||
### Average Images
|
||||
|
||||
@@ -196,6 +210,18 @@ CMYK Halftone Output:
|
||||
<img src="https://github.com/invoke-ai/InvokeAI/assets/34005131/c59c578f-db8e-4d66-8c66-2851752d75ea" width="300" />
|
||||
|
||||
--------------------------------
|
||||
|
||||
### Hand Refiner with MeshGraphormer
|
||||
|
||||
**Description**: Hand Refiner takes in your image and automatically generates a fixed depth map for the hands along with a mask of the hands region that will conveniently allow you to use them along with ControlNet to fix the wonky hands generated by Stable Diffusion
|
||||
|
||||
**Node Link:** https://github.com/blessedcoolant/invoke_meshgraphormer
|
||||
|
||||
**View**
|
||||
<img src="https://raw.githubusercontent.com/blessedcoolant/invoke_meshgraphormer/main/assets/preview.jpg" />
|
||||
|
||||
--------------------------------
|
||||
|
||||
### Image and Mask Composition Pack
|
||||
|
||||
**Description:** This is a pack of nodes for composing masks and images, including a simple text mask creator and both image and latent offset nodes. The offsets wrap around, so these can be used in conjunction with the Seamless node to progressively generate centered on different parts of the seamless tiling.
|
||||
@@ -409,6 +435,17 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
**Node Link:** https://github.com/fieldOfView/InvokeAI-remote_image
|
||||
|
||||
--------------------------------
|
||||
|
||||
### BriaAI Remove Background
|
||||
|
||||
**Description**: Implements one click background removal with BriaAI's new version 1.4 model which seems to be be producing better results than any other previous background removal tool.
|
||||
|
||||
**Node Link:** https://github.com/blessedcoolant/invoke_bria_rmbg
|
||||
|
||||
**View**
|
||||
<img src="https://raw.githubusercontent.com/blessedcoolant/invoke_bria_rmbg/main/assets/preview.jpg" />
|
||||
|
||||
--------------------------------
|
||||
### Remove Background
|
||||
|
||||
|
||||
@@ -13,46 +13,69 @@ We thank them for all of their time and hard work.
|
||||
|
||||
- [Lincoln D. Stein](mailto:lincoln.stein@gmail.com)
|
||||
|
||||
## **Current core team**
|
||||
## **Current Core Team**
|
||||
|
||||
* @lstein (Lincoln Stein) - Co-maintainer
|
||||
* @blessedcoolant - Co-maintainer
|
||||
* @hipsterusername (Kent Keirsey) - Co-maintainer, CEO, Positive Vibes
|
||||
* @psychedelicious (Spencer Mabrito) - Web Team Leader
|
||||
* @Kyle0654 (Kyle Schouviller) - Node Architect and General Backend Wizard
|
||||
* @damian0815 - Attention Systems and Compel Maintainer
|
||||
* @ebr (Eugene Brodsky) - Cloud/DevOps/Sofware engineer; your friendly neighbourhood cluster-autoscaler
|
||||
* @genomancer (Gregg Helt) - Controlnet support
|
||||
* @StAlKeR7779 (Sergey Borisov) - Torch stack, ONNX, model management, optimization
|
||||
* @chainchompa (Jennifer Player) - Web Development & Chain-Chomping
|
||||
* @josh is toast (Josh Corbett) - Web Development
|
||||
* @cheerio (Mary Rogers) - Lead Engineer & Web App Development
|
||||
* @ebr (Eugene Brodsky) - Cloud/DevOps/Sofware engineer; your friendly neighbourhood cluster-autoscaler
|
||||
* @sunija - Standalone version
|
||||
* @genomancer (Gregg Helt) - Controlnet support
|
||||
* @brandon (Brandon Rising) - Platform, Infrastructure, Backend Systems
|
||||
* @ryanjdick (Ryan Dick) - Machine Learning & Training
|
||||
* @millu (Millun Atluri) - Community Manager, Documentation, Node-wrangler
|
||||
* @chainchompa (Jennifer Player) - Web Development & Chain-Chomping
|
||||
* @JPPhoto - Core image generation nodes
|
||||
* @dunkeroni - Image generation backend
|
||||
* @SkunkWorxDark - Image generation backend
|
||||
* @keturn (Kevin Turner) - Diffusers
|
||||
* @millu (Millun Atluri) - Community Wizard, Documentation, Node-wrangler,
|
||||
* @glimmerleaf (Devon Hopkins) - Community Wizard
|
||||
* @gogurt enjoyer - Discord moderator and end user support
|
||||
* @whosawhatsis - Discord moderator and end user support
|
||||
* @dwinrger - Discord moderator and end user support
|
||||
* @526christian - Discord moderator and end user support
|
||||
* @harvester62 - Discord moderator and end user support
|
||||
|
||||
|
||||
## **Honored Team Alumni**
|
||||
|
||||
* @StAlKeR7779 (Sergey Borisov) - Torch stack, ONNX, model management, optimization
|
||||
* @damian0815 - Attention Systems and Compel Maintainer
|
||||
* @netsvetaev (Artur) - Localization support
|
||||
* @Kyle0654 (Kyle Schouviller) - Node Architect and General Backend Wizard
|
||||
* @tildebyte - Installation and configuration
|
||||
* @mauwii (Matthias Wilde) - Installation, release, continuous integration
|
||||
|
||||
|
||||
## **Full List of Contributors by Commit Name**
|
||||
|
||||
- 이승석
|
||||
- AbdBarho
|
||||
- ablattmann
|
||||
- AdamOStark
|
||||
- Adam Rice
|
||||
- Airton Silva
|
||||
- Aldo Hoeben
|
||||
- Alexander Eichhorn
|
||||
- Alexandre D. Roberge
|
||||
- Alexandre Macabies
|
||||
- Alfie John
|
||||
- Andreas Rozek
|
||||
- Andre LaBranche
|
||||
- Andy Bearman
|
||||
- Andy Luhrs
|
||||
- Andy Pilate
|
||||
- Anonymous
|
||||
- Anthony Monthe
|
||||
- Any-Winter-4079
|
||||
- apolinario
|
||||
- Ar7ific1al
|
||||
- ArDiouscuros
|
||||
- Armando C. Santisbon
|
||||
- Arnold Cordewiner
|
||||
- Arthur Holstvoogd
|
||||
- artmen1516
|
||||
- Artur
|
||||
@@ -64,13 +87,16 @@ We thank them for all of their time and hard work.
|
||||
- blhook
|
||||
- BlueAmulet
|
||||
- Bouncyknighter
|
||||
- Brandon
|
||||
- Brandon Rising
|
||||
- Brent Ozar
|
||||
- Brian Racer
|
||||
- bsilvereagle
|
||||
- c67e708d
|
||||
- camenduru
|
||||
- CapableWeb
|
||||
- Carson Katri
|
||||
- chainchompa
|
||||
- Chloe
|
||||
- Chris Dawson
|
||||
- Chris Hayes
|
||||
@@ -86,30 +112,45 @@ We thank them for all of their time and hard work.
|
||||
- cpacker
|
||||
- Cragin Godley
|
||||
- creachec
|
||||
- CrypticWit
|
||||
- d8ahazard
|
||||
- damian
|
||||
- damian0815
|
||||
- Damian at mba
|
||||
- Damian Stewart
|
||||
- Daniel Manzke
|
||||
- Danny Beer
|
||||
- Dan Sully
|
||||
- Darren Ringer
|
||||
- David Burnett
|
||||
- David Ford
|
||||
- David Regla
|
||||
- David Sisco
|
||||
- David Wager
|
||||
- Daya Adianto
|
||||
- db3000
|
||||
- DekitaRPG
|
||||
- Denis Olshin
|
||||
- Dennis
|
||||
- dependabot[bot]
|
||||
- Dmitry Parnas
|
||||
- Dobrynia100
|
||||
- Dominic Letz
|
||||
- DrGunnarMallon
|
||||
- Drun555
|
||||
- dunkeroni
|
||||
- Edward Johan
|
||||
- elliotsayes
|
||||
- Elrik
|
||||
- ElrikUnderlake
|
||||
- Eric Khun
|
||||
- Eric Wolf
|
||||
- Eugene
|
||||
- Eugene Brodsky
|
||||
- ExperimentalCyborg
|
||||
- Fabian Bahl
|
||||
- Fabio 'MrWHO' Torchetti
|
||||
- Fattire
|
||||
- fattire
|
||||
- Felipe Nogueira
|
||||
- Félix Sanz
|
||||
@@ -118,8 +159,12 @@ We thank them for all of their time and hard work.
|
||||
- gabrielrotbart
|
||||
- gallegonovato
|
||||
- Gérald LONLAS
|
||||
- Gille
|
||||
- GitHub Actions Bot
|
||||
- glibesyck
|
||||
- gogurtenjoyer
|
||||
- Gohsuke Shimada
|
||||
- greatwolf
|
||||
- greentext2
|
||||
- Gregg Helt
|
||||
- H4rk
|
||||
@@ -131,6 +176,7 @@ We thank them for all of their time and hard work.
|
||||
- Hosted Weblate
|
||||
- Iman Karim
|
||||
- ismail ihsan bülbül
|
||||
- ItzAttila
|
||||
- Ivan Efimov
|
||||
- jakehl
|
||||
- Jakub Kolčář
|
||||
@@ -141,6 +187,7 @@ We thank them for all of their time and hard work.
|
||||
- Jason Toffaletti
|
||||
- Jaulustus
|
||||
- Jeff Mahoney
|
||||
- Jennifer Player
|
||||
- jeremy
|
||||
- Jeremy Clark
|
||||
- JigenD
|
||||
@@ -148,19 +195,26 @@ We thank them for all of their time and hard work.
|
||||
- Johan Roxendal
|
||||
- Johnathon Selstad
|
||||
- Jonathan
|
||||
- Jordan Hewitt
|
||||
- Joseph Dries III
|
||||
- Josh Corbett
|
||||
- JPPhoto
|
||||
- jspraul
|
||||
- junzi
|
||||
- Justin Wong
|
||||
- Juuso V
|
||||
- Kaspar Emanuel
|
||||
- Katsuyuki-Karasawa
|
||||
- Keerigan45
|
||||
- Kent Keirsey
|
||||
- Kevin Brack
|
||||
- Kevin Coakley
|
||||
- Kevin Gibbons
|
||||
- Kevin Schaul
|
||||
- Kevin Turner
|
||||
- Kieran Klaassen
|
||||
- krummrey
|
||||
- Kyle
|
||||
- Kyle Lacy
|
||||
- Kyle Schouviller
|
||||
- Lawrence Norton
|
||||
@@ -171,10 +225,15 @@ We thank them for all of their time and hard work.
|
||||
- Lynne Whitehorn
|
||||
- majick
|
||||
- Marco Labarile
|
||||
- Marta Nahorniuk
|
||||
- Martin Kristiansen
|
||||
- Mary Hipp
|
||||
- maryhipp
|
||||
- Mary Hipp Rogers
|
||||
- mastercaster
|
||||
- mastercaster9000
|
||||
- Matthias Wild
|
||||
- mauwii
|
||||
- michaelk71
|
||||
- mickr777
|
||||
- Mihai
|
||||
@@ -182,11 +241,15 @@ We thank them for all of their time and hard work.
|
||||
- Mikhail Tishin
|
||||
- Millun Atluri
|
||||
- Minjune Song
|
||||
- Mitchell Allain
|
||||
- mitien
|
||||
- mofuzz
|
||||
- Muhammad Usama
|
||||
- Name
|
||||
- _nderscore
|
||||
- Neil Wang
|
||||
- nekowaiz
|
||||
- nemuruibai
|
||||
- Netzer R
|
||||
- Nicholas Koh
|
||||
- Nicholas Körfer
|
||||
@@ -197,9 +260,11 @@ We thank them for all of their time and hard work.
|
||||
- ofirkris
|
||||
- Olivier Louvignes
|
||||
- owenvincent
|
||||
- pand4z31
|
||||
- Patrick Esser
|
||||
- Patrick Tien
|
||||
- Patrick von Platen
|
||||
- Paul Curry
|
||||
- Paul Sajna
|
||||
- pejotr
|
||||
- Peter Baylies
|
||||
@@ -207,6 +272,7 @@ We thank them for all of their time and hard work.
|
||||
- plucked
|
||||
- prixt
|
||||
- psychedelicious
|
||||
- psychedelicious@windows
|
||||
- Rainer Bernhardt
|
||||
- Riccardo Giovanetti
|
||||
- Rich Jones
|
||||
@@ -215,17 +281,22 @@ We thank them for all of their time and hard work.
|
||||
- Robert Bolender
|
||||
- Robin Rombach
|
||||
- Rohan Barar
|
||||
- rohinish404
|
||||
- Rohinish
|
||||
- rpagliuca
|
||||
- rromb
|
||||
- Rupesh Sreeraman
|
||||
- Ryan
|
||||
- Ryan Cao
|
||||
- Ryan Dick
|
||||
- Saifeddine
|
||||
- Saifeddine ALOUI
|
||||
- Sam
|
||||
- SammCheese
|
||||
- Sam McLeod
|
||||
- Sammy
|
||||
- sammyf
|
||||
- Samuel Husso
|
||||
- Saurav Maheshkar
|
||||
- Scott Lahteine
|
||||
- Sean McLellan
|
||||
- Sebastian Aigner
|
||||
@@ -233,16 +304,21 @@ We thank them for all of their time and hard work.
|
||||
- Sergey Krashevich
|
||||
- Shapor Naghibzadeh
|
||||
- Shawn Zhong
|
||||
- Simona Liliac
|
||||
- Simon Vans-Colina
|
||||
- skunkworxdark
|
||||
- slashtechno
|
||||
- SoheilRezaei
|
||||
- Song, Pengcheng
|
||||
- spezialspezial
|
||||
- ssantos
|
||||
- StAlKeR7779
|
||||
- Stefan Tobler
|
||||
- Stephan Koglin-Fischer
|
||||
- SteveCaruso
|
||||
- Steve Martinelli
|
||||
- Steven Frank
|
||||
- Surisen
|
||||
- System X - Files
|
||||
- Taylor Kems
|
||||
- techicode
|
||||
@@ -261,26 +337,34 @@ We thank them for all of their time and hard work.
|
||||
- tyler
|
||||
- unknown
|
||||
- user1
|
||||
- vedant-3010
|
||||
- Vedant Madane
|
||||
- veprogames
|
||||
- wa.code
|
||||
- wfng92
|
||||
- whjms
|
||||
- whosawhatsis
|
||||
- Will
|
||||
- William Becher
|
||||
- William Chong
|
||||
- Wilson E. Alvarez
|
||||
- woweenie
|
||||
- Wubbbi
|
||||
- xra
|
||||
- Yeung Yiu Hung
|
||||
- ymgenesis
|
||||
- Yorzaren
|
||||
- Yosuke Shinya
|
||||
- yun saki
|
||||
- ZachNagengast
|
||||
- Zadagu
|
||||
- zeptofine
|
||||
- Zerdoumi
|
||||
- Васянатор
|
||||
- 冯不游
|
||||
- 唐澤 克幸
|
||||
|
||||
## **Original CompVis Authors**
|
||||
## **Original CompVis (Stable Diffusion) Authors**
|
||||
|
||||
- [Robin Rombach](https://github.com/rromb)
|
||||
- [Patrick von Platen](https://github.com/patrickvonplaten)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,11 +14,19 @@ function is_bin_in_path {
|
||||
}
|
||||
|
||||
function git_show {
|
||||
git show -s --format='%h %s' $1
|
||||
git show -s --format=oneline --abbrev-commit "$1" | cat
|
||||
}
|
||||
|
||||
if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
# we can't just call 'deactivate' because this function is not exported
|
||||
# to the environment of this script from the bash process that runs the script
|
||||
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
echo
|
||||
echo -e "${BYELLOW}This script must be run from the installer directory!${RESET}"
|
||||
echo "The current working directory is $(pwd)"
|
||||
read -p "If that looks right, press any key to proceed, or CTRL-C to exit..."
|
||||
@@ -32,13 +40,6 @@ if ! is_bin_in_path python && is_bin_in_path python3; then
|
||||
}
|
||||
fi
|
||||
|
||||
if [[ -v "VIRTUAL_ENV" ]]; then
|
||||
# we can't just call 'deactivate' because this function is not exported
|
||||
# to the environment of this script from the bash process that runs the script
|
||||
echo -e "${BRED}A virtual environment is activated. Please deactivate it before proceeding.${RESET}"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
VERSION=$(
|
||||
cd ..
|
||||
python -c "from invokeai.version import __version__ as version; print(version)"
|
||||
@@ -47,38 +48,9 @@ PATCH=""
|
||||
VERSION="v${VERSION}${PATCH}"
|
||||
|
||||
echo -e "${BGREEN}HEAD${RESET}:"
|
||||
git_show
|
||||
git_show HEAD
|
||||
echo
|
||||
|
||||
# ---------------------- FRONTEND ----------------------
|
||||
|
||||
pushd ../invokeai/frontend/web >/dev/null
|
||||
echo
|
||||
echo "Installing frontend dependencies..."
|
||||
echo
|
||||
pnpm i --frozen-lockfile
|
||||
echo
|
||||
echo "Building frontend..."
|
||||
echo
|
||||
pnpm build
|
||||
popd
|
||||
|
||||
# ---------------------- BACKEND ----------------------
|
||||
|
||||
echo
|
||||
echo "Building wheel..."
|
||||
echo
|
||||
|
||||
# install the 'build' package in the user site packages, if needed
|
||||
# could be improved by using a temporary venv, but it's tiny and harmless
|
||||
if [[ $(python -c 'from importlib.util import find_spec; print(find_spec("build") is None)') == "True" ]]; then
|
||||
pip install --user build
|
||||
fi
|
||||
|
||||
rm -rf ../build
|
||||
|
||||
python -m build --wheel --outdir dist/ ../.
|
||||
|
||||
# ----------------------
|
||||
|
||||
echo
|
||||
@@ -97,16 +69,13 @@ done
|
||||
mkdir InvokeAI-Installer/lib
|
||||
cp lib/*.py InvokeAI-Installer/lib
|
||||
|
||||
# Move the wheel
|
||||
mv dist/*.whl InvokeAI-Installer/lib/
|
||||
|
||||
# Install scripts
|
||||
# Mac/Linux
|
||||
cp install.sh.in InvokeAI-Installer/install.sh
|
||||
chmod a+x InvokeAI-Installer/install.sh
|
||||
|
||||
# Windows
|
||||
perl -p -e "s/^set INVOKEAI_VERSION=.*/set INVOKEAI_VERSION=$VERSION/" install.bat.in >InvokeAI-Installer/install.bat
|
||||
cp install.bat.in InvokeAI-Installer/install.bat
|
||||
cp WinLongPathsEnabled.reg InvokeAI-Installer/
|
||||
|
||||
# Zip everything up
|
||||
|
||||
@@ -15,7 +15,6 @@ if "%1" == "use-cache" (
|
||||
@rem Config
|
||||
@rem The version in the next line is replaced by an up to date release number
|
||||
@rem when create_installer.sh is run. Change the release number there.
|
||||
set INVOKEAI_VERSION=latest
|
||||
set INSTRUCTIONS=https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/
|
||||
set TROUBLESHOOTING=https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/#troubleshooting
|
||||
set PYTHON_URL=https://www.python.org/downloads/windows/
|
||||
|
||||
@@ -11,7 +11,7 @@ import sys
|
||||
import venv
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
SUPPORTED_PYTHON = ">=3.10.0,<=3.11.100"
|
||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||
@@ -21,40 +21,20 @@ OS = platform.uname().system
|
||||
ARCH = platform.uname().machine
|
||||
VERSION = "latest"
|
||||
|
||||
### Feature flags
|
||||
# Install the virtualenv into the runtime dir
|
||||
FF_VENV_IN_RUNTIME = True
|
||||
|
||||
# Install the wheel packaged with the installer
|
||||
FF_USE_LOCAL_WHEEL = True
|
||||
|
||||
|
||||
class Installer:
|
||||
"""
|
||||
Deploys an InvokeAI installation into a given path
|
||||
"""
|
||||
|
||||
reqs: list[str] = INSTALLER_REQS
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.reqs = INSTALLER_REQS
|
||||
self.preflight()
|
||||
if os.getenv("VIRTUAL_ENV") is not None:
|
||||
print("A virtual environment is already activated. Please 'deactivate' before installation.")
|
||||
sys.exit(-1)
|
||||
self.bootstrap()
|
||||
|
||||
def preflight(self) -> None:
|
||||
"""
|
||||
Preflight checks
|
||||
"""
|
||||
|
||||
# TODO
|
||||
# verify python version
|
||||
# on macOS verify XCode tools are present
|
||||
# verify libmesa, libglx on linux
|
||||
# check that the system arch is not i386 (?)
|
||||
# check that the system has a GPU, and the type of GPU
|
||||
|
||||
pass
|
||||
self.available_releases = get_github_releases()
|
||||
|
||||
def mktemp_venv(self) -> TemporaryDirectory:
|
||||
"""
|
||||
@@ -78,12 +58,9 @@ class Installer:
|
||||
|
||||
return venv_dir
|
||||
|
||||
def bootstrap(self, verbose: bool = False) -> TemporaryDirectory:
|
||||
def bootstrap(self, verbose: bool = False) -> TemporaryDirectory | None:
|
||||
"""
|
||||
Bootstrap the installer venv with packages required at install time
|
||||
|
||||
:return: path to the virtual environment directory that was bootstrapped
|
||||
:rtype: TemporaryDirectory
|
||||
"""
|
||||
|
||||
print("Initializing the installer. This may take a minute - please wait...")
|
||||
@@ -95,39 +72,27 @@ class Installer:
|
||||
cmd.extend(self.reqs)
|
||||
|
||||
try:
|
||||
res = subprocess.check_output(cmd).decode()
|
||||
# upgrade pip to the latest version to avoid a confusing message
|
||||
res = upgrade_pip(Path(venv_dir.name))
|
||||
if verbose:
|
||||
print(res)
|
||||
|
||||
# run the install prerequisites installation
|
||||
res = subprocess.check_output(cmd).decode()
|
||||
|
||||
if verbose:
|
||||
print(res)
|
||||
|
||||
return venv_dir
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
|
||||
def app_venv(self, path: str = None):
|
||||
def app_venv(self, venv_parent) -> Path:
|
||||
"""
|
||||
Create a virtualenv for the InvokeAI installation
|
||||
"""
|
||||
|
||||
# explicit venv location
|
||||
# currently unused in normal operation
|
||||
# useful for testing or special cases
|
||||
if path is not None:
|
||||
venv_dir = Path(path)
|
||||
|
||||
# experimental / testing
|
||||
elif not FF_VENV_IN_RUNTIME:
|
||||
if OS == "Windows":
|
||||
venv_dir_parent = os.getenv("APPDATA", "~/AppData/Roaming")
|
||||
elif OS == "Darwin":
|
||||
# there is no environment variable on macOS to find this
|
||||
# TODO: confirm this is working as expected
|
||||
venv_dir_parent = "~/Library/Application Support"
|
||||
elif OS == "Linux":
|
||||
venv_dir_parent = os.getenv("XDG_DATA_DIR", "~/.local/share")
|
||||
venv_dir = Path(venv_dir_parent).expanduser().resolve() / f"InvokeAI/{VERSION}/venv"
|
||||
|
||||
# stable / current
|
||||
else:
|
||||
venv_dir = self.dest / ".venv"
|
||||
venv_dir = venv_parent / ".venv"
|
||||
|
||||
# Prefer to copy python executables
|
||||
# so that updates to system python don't break InvokeAI
|
||||
@@ -141,7 +106,7 @@ class Installer:
|
||||
return venv_dir
|
||||
|
||||
def install(
|
||||
self, root: str = "~/invokeai", version: str = "latest", yes_to_all=False, find_links: Path = None
|
||||
self, version=None, root: str = "~/invokeai", yes_to_all=False, find_links: Optional[Path] = None
|
||||
) -> None:
|
||||
"""
|
||||
Install the InvokeAI application into the given runtime path
|
||||
@@ -158,15 +123,20 @@ class Installer:
|
||||
|
||||
import messages
|
||||
|
||||
messages.welcome()
|
||||
messages.welcome(self.available_releases)
|
||||
|
||||
default_path = os.environ.get("INVOKEAI_ROOT") or Path(root).expanduser().resolve()
|
||||
self.dest = default_path if yes_to_all else messages.dest_path(root)
|
||||
version = messages.choose_version(self.available_releases)
|
||||
|
||||
auto_dest = Path(os.environ.get("INVOKEAI_ROOT", root)).expanduser().resolve()
|
||||
destination = auto_dest if yes_to_all else messages.dest_path(root)
|
||||
if destination is None:
|
||||
print("Could not find or create the destination directory. Installation cancelled.")
|
||||
sys.exit(0)
|
||||
|
||||
# create the venv for the app
|
||||
self.venv = self.app_venv()
|
||||
self.venv = self.app_venv(venv_parent=destination)
|
||||
|
||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
||||
self.instance = InvokeAiInstance(runtime=destination, venv=self.venv, version=version)
|
||||
|
||||
# install dependencies and the InvokeAI application
|
||||
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
||||
@@ -190,7 +160,7 @@ class InvokeAiInstance:
|
||||
A single runtime directory *may* be shared by multiple virtual environments, though this isn't currently tested or supported.
|
||||
"""
|
||||
|
||||
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
||||
def __init__(self, runtime: Path, venv: Path, version: str = "stable") -> None:
|
||||
self.runtime = runtime
|
||||
self.venv = venv
|
||||
self.pip = get_pip_from_venv(venv)
|
||||
@@ -199,6 +169,7 @@ class InvokeAiInstance:
|
||||
set_sys_path(venv)
|
||||
os.environ["INVOKEAI_ROOT"] = str(self.runtime.expanduser().resolve())
|
||||
os.environ["VIRTUAL_ENV"] = str(self.venv.expanduser().resolve())
|
||||
upgrade_pip(venv)
|
||||
|
||||
def get(self) -> tuple[Path, Path]:
|
||||
"""
|
||||
@@ -212,54 +183,7 @@ class InvokeAiInstance:
|
||||
|
||||
def install(self, extra_index_url=None, optional_modules=None, find_links=None):
|
||||
"""
|
||||
Install this instance, including dependencies and the app itself
|
||||
|
||||
:param extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
||||
:type extra_index_url: str
|
||||
"""
|
||||
|
||||
import messages
|
||||
|
||||
# install torch first to ensure the correct version gets installed.
|
||||
# works with either source or wheel install with negligible impact on installation times.
|
||||
messages.simple_banner("Installing PyTorch :fire:")
|
||||
self.install_torch(extra_index_url, find_links)
|
||||
|
||||
messages.simple_banner("Installing the InvokeAI Application :art:")
|
||||
self.install_app(extra_index_url, optional_modules, find_links)
|
||||
|
||||
def install_torch(self, extra_index_url=None, find_links=None):
|
||||
"""
|
||||
Install PyTorch
|
||||
"""
|
||||
|
||||
from plumbum import FG, local
|
||||
|
||||
pip = local[self.pip]
|
||||
|
||||
(
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"numpy==1.26.3", # choose versions that won't be uninstalled during phase 2
|
||||
"urllib3~=1.26.0",
|
||||
"requests~=2.28.0",
|
||||
"torch==2.1.2",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchvision==0.16.2",
|
||||
"--force-reinstall",
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
extra_index_url,
|
||||
]
|
||||
& FG
|
||||
)
|
||||
|
||||
def install_app(self, extra_index_url=None, optional_modules=None, find_links=None):
|
||||
"""
|
||||
Install the application with pip.
|
||||
Supports installation from PyPi or from a local source directory.
|
||||
Install the package from PyPi.
|
||||
|
||||
:param extra_index_url: the "--extra-index-url ..." line for pip to look in extra indexes.
|
||||
:type extra_index_url: str
|
||||
@@ -271,53 +195,52 @@ class InvokeAiInstance:
|
||||
:type find_links: Path
|
||||
"""
|
||||
|
||||
## this only applies to pypi installs; TODO actually use this
|
||||
if self.version == "pre":
|
||||
import messages
|
||||
|
||||
# not currently used, but may be useful for "install most recent version" option
|
||||
if self.version == "prerelease":
|
||||
version = None
|
||||
pre = "--pre"
|
||||
pre_flag = "--pre"
|
||||
elif self.version == "stable":
|
||||
version = None
|
||||
pre_flag = None
|
||||
else:
|
||||
version = self.version
|
||||
pre = None
|
||||
pre_flag = None
|
||||
|
||||
## TODO: only local wheel will be installed as of now; support for --version arg is TODO
|
||||
if FF_USE_LOCAL_WHEEL:
|
||||
# if no wheel, try to do a source install before giving up
|
||||
try:
|
||||
src = str(next(Path(__file__).parent.glob("InvokeAI-*.whl")))
|
||||
except StopIteration:
|
||||
try:
|
||||
src = Path(__file__).parents[1].expanduser().resolve()
|
||||
# if the above directory contains one of these files, we'll do a source install
|
||||
next(src.glob("pyproject.toml"))
|
||||
next(src.glob("invokeai"))
|
||||
except StopIteration:
|
||||
print("Unable to find a wheel or perform a source install. Giving up.")
|
||||
src = "invokeai"
|
||||
if optional_modules:
|
||||
src += optional_modules
|
||||
if version:
|
||||
src += f"=={version}"
|
||||
|
||||
elif version == "source":
|
||||
# this makes an assumption about the location of the installer package in the source tree
|
||||
src = Path(__file__).parents[1].expanduser().resolve()
|
||||
else:
|
||||
# will install from PyPi
|
||||
src = f"invokeai=={version}" if version is not None else "invokeai"
|
||||
messages.simple_banner("Installing the InvokeAI Application :art:")
|
||||
|
||||
from plumbum import FG, local
|
||||
from plumbum import FG, ProcessExecutionError, local # type: ignore
|
||||
|
||||
pip = local[self.pip]
|
||||
|
||||
(
|
||||
pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--use-pep517",
|
||||
str(src) + (optional_modules if optional_modules else ""),
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
extra_index_url,
|
||||
pre,
|
||||
]
|
||||
& FG
|
||||
)
|
||||
pipeline = pip[
|
||||
"install",
|
||||
"--require-virtualenv",
|
||||
"--force-reinstall",
|
||||
"--use-pep517",
|
||||
str(src),
|
||||
"--find-links" if find_links is not None else None,
|
||||
find_links,
|
||||
"--extra-index-url" if extra_index_url is not None else None,
|
||||
extra_index_url,
|
||||
pre_flag,
|
||||
]
|
||||
|
||||
try:
|
||||
_ = pipeline & FG
|
||||
except ProcessExecutionError as e:
|
||||
print(f"Error: {e}")
|
||||
print(
|
||||
"Could not install InvokeAI. Please try downloading the latest version of the installer and install again."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
def configure(self):
|
||||
"""
|
||||
@@ -373,7 +296,6 @@ class InvokeAiInstance:
|
||||
|
||||
ext = "bat" if OS == "Windows" else "sh"
|
||||
|
||||
# scripts = ['invoke', 'update']
|
||||
scripts = ["invoke"]
|
||||
|
||||
for script in scripts:
|
||||
@@ -408,6 +330,23 @@ def get_pip_from_venv(venv_path: Path) -> str:
|
||||
return str(venv_path.expanduser().resolve() / pip)
|
||||
|
||||
|
||||
def upgrade_pip(venv_path: Path) -> str | None:
|
||||
"""
|
||||
Upgrade the pip executable in the given virtual environment
|
||||
"""
|
||||
|
||||
python = "Scripts\\python.exe" if OS == "Windows" else "bin/python"
|
||||
python = str(venv_path.expanduser().resolve() / python)
|
||||
|
||||
try:
|
||||
result = subprocess.check_output([python, "-m", "pip", "install", "--upgrade", "pip"]).decode()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_sys_path(venv_path: Path) -> None:
|
||||
"""
|
||||
Given a path to a virtual environment, set the sys.path, in a cross-platform fashion,
|
||||
@@ -431,7 +370,43 @@ def set_sys_path(venv_path: Path) -> None:
|
||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||
|
||||
|
||||
def get_torch_source() -> (Union[str, None], str):
|
||||
def get_github_releases() -> tuple[list, list] | None:
|
||||
"""
|
||||
Query Github for published (pre-)release versions.
|
||||
Return a tuple where the first element is a list of stable releases and the second element is a list of pre-releases.
|
||||
Return None if the query fails for any reason.
|
||||
"""
|
||||
|
||||
import requests
|
||||
|
||||
## get latest releases using github api
|
||||
url = "https://api.github.com/repos/invoke-ai/InvokeAI/releases"
|
||||
releases, pre_releases = [], []
|
||||
try:
|
||||
res = requests.get(url)
|
||||
res.raise_for_status()
|
||||
tag_info = res.json()
|
||||
for tag in tag_info:
|
||||
if not tag["prerelease"]:
|
||||
releases.append(tag["tag_name"].lstrip("v"))
|
||||
else:
|
||||
pre_releases.append(tag["tag_name"].lstrip("v"))
|
||||
except requests.HTTPError as e:
|
||||
print(f"Error: {e}")
|
||||
print("Could not fetch version information from GitHub. Please check your network connection and try again.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print("An unexpected error occurred while trying to fetch version information from GitHub. Please try again.")
|
||||
return
|
||||
|
||||
releases.sort(reverse=True)
|
||||
pre_releases.sort(reverse=True)
|
||||
|
||||
return releases, pre_releases
|
||||
|
||||
|
||||
def get_torch_source() -> Tuple[str | None, str | None]:
|
||||
"""
|
||||
Determine the extra index URL for pip to use for torch installation.
|
||||
This depends on the OS and the graphics accelerator in use.
|
||||
@@ -446,25 +421,26 @@ def get_torch_source() -> (Union[str, None], str):
|
||||
:rtype: list
|
||||
"""
|
||||
|
||||
from messages import graphical_accelerator
|
||||
from messages import select_gpu
|
||||
|
||||
# device can be one of: "cuda", "rocm", "cpu", "idk"
|
||||
device = graphical_accelerator()
|
||||
# device can be one of: "cuda", "rocm", "cpu", "cuda_and_dml, autodetect"
|
||||
device = select_gpu()
|
||||
|
||||
url = None
|
||||
optional_modules = "[onnx]"
|
||||
if OS == "Linux":
|
||||
if device == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||
elif device == "cpu":
|
||||
if device.value == "rocm":
|
||||
url = "https://download.pytorch.org/whl/rocm5.6"
|
||||
elif device.value == "cpu":
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
|
||||
if device == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
elif OS == "Windows":
|
||||
if device.value == "cuda":
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-cuda]"
|
||||
if device.value == "cuda_and_dml":
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
optional_modules = "[xformers,onnx-directml]"
|
||||
|
||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||
|
||||
|
||||
@@ -5,10 +5,11 @@ Installer user interaction
|
||||
|
||||
import os
|
||||
import platform
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit import HTML, prompt
|
||||
from prompt_toolkit.completion import PathCompleter
|
||||
from prompt_toolkit.completion import FuzzyWordCompleter, PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
@@ -35,16 +36,26 @@ else:
|
||||
console = Console(style=Style(color="grey74", bgcolor="grey19"))
|
||||
|
||||
|
||||
def welcome():
|
||||
def welcome(available_releases: tuple | None = None) -> None:
|
||||
@group()
|
||||
def text():
|
||||
if (platform_specific := _platform_specific_help()) != "":
|
||||
if (platform_specific := _platform_specific_help()) is not None:
|
||||
yield platform_specific
|
||||
yield ""
|
||||
yield Text.from_markup(
|
||||
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
||||
justify="center",
|
||||
)
|
||||
if available_releases is not None:
|
||||
latest_stable = available_releases[0][0]
|
||||
last_pre = available_releases[1][0]
|
||||
yield ""
|
||||
yield Text.from_markup(
|
||||
f"[red3]🠶[/] Latest stable release (recommended): [b bright_white]{latest_stable}", justify="center"
|
||||
)
|
||||
yield Text.from_markup(
|
||||
f"[red3]🠶[/] Last published pre-release version: [b bright_white]{last_pre}", justify="center"
|
||||
)
|
||||
|
||||
console.rule()
|
||||
print(
|
||||
@@ -61,19 +72,31 @@ def welcome():
|
||||
console.line()
|
||||
|
||||
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||
dest_confirmed = Confirm.ask(
|
||||
":stop_sign: (re)install in this location?",
|
||||
default=False,
|
||||
)
|
||||
else:
|
||||
print(f"InvokeAI will be installed in {dest}")
|
||||
dest_confirmed = Confirm.ask("Use this location?", default=True)
|
||||
def choose_version(available_releases: tuple | None = None) -> str:
|
||||
"""
|
||||
Prompt the user to choose an Invoke version to install
|
||||
"""
|
||||
|
||||
# short circuit if we couldn't get a version list
|
||||
# still try to install the latest stable version
|
||||
if available_releases is None:
|
||||
return "stable"
|
||||
|
||||
console.print(":grey_question: [orange3]Please choose an Invoke version to install.")
|
||||
|
||||
choices = available_releases[0] + available_releases[1]
|
||||
|
||||
response = prompt(
|
||||
message=f" <Enter> to install the recommended release ({choices[0]}). <Tab> or type to pick a version: ",
|
||||
complete_while_typing=True,
|
||||
completer=FuzzyWordCompleter(choices),
|
||||
)
|
||||
|
||||
console.print(f" Version {choices[0] if response == "" else response} will be installed.")
|
||||
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
return "stable" if response == "" else response
|
||||
|
||||
|
||||
def user_wants_auto_configuration() -> bool:
|
||||
@@ -109,7 +132,23 @@ def user_wants_auto_configuration() -> bool:
|
||||
return choice.lower().startswith("a")
|
||||
|
||||
|
||||
def dest_path(dest=None) -> Path:
|
||||
def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":stop_sign: Directory {dest} already exists!")
|
||||
print(" Is this location correct?")
|
||||
default = False
|
||||
else:
|
||||
print(f":file_folder: InvokeAI will be installed in {dest}")
|
||||
default = True
|
||||
|
||||
dest_confirmed = Confirm.ask(" Please confirm:", default=default)
|
||||
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
|
||||
|
||||
def dest_path(dest=None) -> Path | None:
|
||||
"""
|
||||
Prompt the user for the destination path and create the path
|
||||
|
||||
@@ -124,25 +163,21 @@ def dest_path(dest=None) -> Path:
|
||||
else:
|
||||
dest = Path.cwd().expanduser().resolve()
|
||||
prev_dest = init_path = dest
|
||||
|
||||
dest_confirmed = confirm_install(dest)
|
||||
dest_confirmed = False
|
||||
|
||||
while not dest_confirmed:
|
||||
# if the given destination already exists, the starting point for browsing is its parent directory.
|
||||
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
||||
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
||||
# since we can't read their mind, start browsing at Path.cwd().
|
||||
browse_start = (prev_dest.parent if prev_dest.exists() else Path.cwd()).expanduser().resolve()
|
||||
browse_start = (dest or Path.cwd()).expanduser().resolve()
|
||||
|
||||
path_completer = PathCompleter(
|
||||
only_directories=True,
|
||||
expanduser=True,
|
||||
get_paths=lambda: [browse_start], # noqa: B023
|
||||
get_paths=lambda: [str(browse_start)], # noqa: B023
|
||||
# get_paths=lambda: [".."].extend(list(browse_start.iterdir()))
|
||||
)
|
||||
|
||||
console.line()
|
||||
console.print(f"[orange3]Please select the destination directory for the installation:[/] \\[{browse_start}]: ")
|
||||
|
||||
console.print(f":grey_question: [orange3]Please select the install destination:[/] \\[{browse_start}]: ")
|
||||
selected = prompt(
|
||||
">>> ",
|
||||
complete_in_thread=True,
|
||||
@@ -155,6 +190,7 @@ def dest_path(dest=None) -> Path:
|
||||
)
|
||||
prev_dest = dest
|
||||
dest = Path(selected)
|
||||
|
||||
console.line()
|
||||
|
||||
dest_confirmed = confirm_install(dest.expanduser().resolve())
|
||||
@@ -182,41 +218,45 @@ def dest_path(dest=None) -> Path:
|
||||
console.rule("Goodbye!")
|
||||
|
||||
|
||||
def graphical_accelerator():
|
||||
class GpuType(Enum):
|
||||
CUDA = "cuda"
|
||||
CUDA_AND_DML = "cuda_and_dml"
|
||||
ROCM = "rocm"
|
||||
CPU = "cpu"
|
||||
AUTODETECT = "autodetect"
|
||||
|
||||
|
||||
def select_gpu() -> GpuType:
|
||||
"""
|
||||
Prompt the user to select the graphical accelerator in their system
|
||||
This does not validate user's choices (yet), but only offers choices
|
||||
valid for the platform.
|
||||
CUDA is the fallback.
|
||||
We may be able to detect the GPU driver by shelling out to `modprobe` or `lspci`,
|
||||
but this is not yet supported or reliable. Also, some users may have exotic preferences.
|
||||
Prompt the user to select the GPU driver
|
||||
"""
|
||||
|
||||
if ARCH == "arm64" and OS != "Darwin":
|
||||
print(f"Only CPU acceleration is available on {ARCH} architecture. Proceeding with that.")
|
||||
return "cpu"
|
||||
return GpuType.CPU
|
||||
|
||||
nvidia = (
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||
"cuda",
|
||||
GpuType.CUDA,
|
||||
)
|
||||
nvidia_with_dml = (
|
||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
||||
"cuda_and_dml",
|
||||
GpuType.CUDA_AND_DML,
|
||||
)
|
||||
amd = (
|
||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||
"rocm",
|
||||
GpuType.ROCM,
|
||||
)
|
||||
cpu = (
|
||||
"no compatible GPU, or specifically prefer to use the CPU",
|
||||
"cpu",
|
||||
"Do not install any GPU support, use CPU for generation (slow)",
|
||||
GpuType.CPU,
|
||||
)
|
||||
idk = (
|
||||
autodetect = (
|
||||
"I'm not sure what to choose",
|
||||
"idk",
|
||||
GpuType.AUTODETECT,
|
||||
)
|
||||
|
||||
options = []
|
||||
if OS == "Windows":
|
||||
options = [nvidia, nvidia_with_dml, cpu]
|
||||
if OS == "Linux":
|
||||
@@ -230,7 +270,7 @@ def graphical_accelerator():
|
||||
return options[0][1]
|
||||
|
||||
# "I don't know" is always added the last option
|
||||
options.append(idk)
|
||||
options.append(autodetect) # type: ignore
|
||||
|
||||
options = {str(i): opt for i, opt in enumerate(options, 1)}
|
||||
|
||||
@@ -265,9 +305,9 @@ def graphical_accelerator():
|
||||
),
|
||||
)
|
||||
|
||||
if options[choice][1] == "idk":
|
||||
if options[choice][1] is GpuType.AUTODETECT:
|
||||
console.print(
|
||||
"No problem. We will try to install a version that [i]should[/i] be compatible. :crossed_fingers:"
|
||||
"No problem. We will install CUDA support first :crossed_fingers: If Invoke does not detect a GPU, please re-run the installer and select one of the other GPU types."
|
||||
)
|
||||
|
||||
return options[choice][1]
|
||||
@@ -291,7 +331,7 @@ def windows_long_paths_registry() -> None:
|
||||
"""
|
||||
|
||||
with open(str(Path(__file__).parent / "WinLongPathsEnabled.reg"), "r", encoding="utf-16le") as code:
|
||||
syntax = Syntax(code.read(), line_numbers=True)
|
||||
syntax = Syntax(code.read(), line_numbers=True, lexer="regedit")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
@@ -301,7 +341,7 @@ def windows_long_paths_registry() -> None:
|
||||
"We will now apply a registry fix to enable long paths on Windows. InvokeAI needs this to function correctly. We are asking your permission to modify the Windows Registry on your behalf.",
|
||||
"",
|
||||
"This is the change that will be applied:",
|
||||
syntax,
|
||||
str(syntax),
|
||||
]
|
||||
)
|
||||
),
|
||||
@@ -340,7 +380,7 @@ def introduction() -> None:
|
||||
console.line(2)
|
||||
|
||||
|
||||
def _platform_specific_help() -> str:
|
||||
def _platform_specific_help() -> Text | None:
|
||||
if OS == "Darwin":
|
||||
text = Text.from_markup(
|
||||
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
||||
@@ -354,5 +394,5 @@ def _platform_specific_help() -> str:
|
||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
||||
)
|
||||
else:
|
||||
text = ""
|
||||
return
|
||||
return text
|
||||
|
||||
@@ -15,7 +15,7 @@ echo 4. Download and install models
|
||||
echo 5. Change InvokeAI startup options
|
||||
echo 6. Re-run the configure script to fix a broken install or to complete a major upgrade
|
||||
echo 7. Open the developer console
|
||||
echo 8. Update InvokeAI
|
||||
echo 8. Update InvokeAI (DEPRECATED - please use the installer)
|
||||
echo 9. Run the InvokeAI image database maintenance script
|
||||
echo 10. Command-line help
|
||||
echo Q - Quit
|
||||
@@ -52,8 +52,10 @@ IF /I "%choice%" == "1" (
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%choice%" == "8" (
|
||||
echo Running invokeai-update...
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
echo UPDATING FROM WITHIN THE APP IS BEING DEPRECATED.
|
||||
echo Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.
|
||||
timeout 4
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
) ELSE IF /I "%choice%" == "9" (
|
||||
echo Running the db maintenance script...
|
||||
python .venv\Scripts\invokeai-db-maintenance.exe
|
||||
@@ -77,4 +79,3 @@ pause
|
||||
|
||||
:ending
|
||||
exit /b
|
||||
|
||||
|
||||
@@ -90,7 +90,9 @@ do_choice() {
|
||||
;;
|
||||
8)
|
||||
clear
|
||||
printf "Update InvokeAI\n"
|
||||
printf "UPDATING FROM WITHIN THE APP IS BEING DEPRECATED\n"
|
||||
printf "Please download the installer from https://github.com/invoke-ai/InvokeAI/releases/latest and run it to update your installation.\n"
|
||||
sleep 4
|
||||
python -m invokeai.frontend.install.invokeai_update
|
||||
;;
|
||||
9)
|
||||
@@ -122,7 +124,7 @@ do_dialog() {
|
||||
5 "Change InvokeAI startup options"
|
||||
6 "Re-run the configure script to fix a broken install or to complete a major upgrade"
|
||||
7 "Open the developer console"
|
||||
8 "Update InvokeAI"
|
||||
8 "Update InvokeAI (DEPRECATED - please use the installer)"
|
||||
9 "Run the InvokeAI image database maintenance script"
|
||||
10 "Command-line help"
|
||||
)
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
@echo off
|
||||
setlocal EnableExtensions EnableDelayedExpansion
|
||||
|
||||
PUSHD "%~dp0"
|
||||
|
||||
set INVOKE_AI_VERSION=latest
|
||||
set arg=%1
|
||||
if "%arg%" neq "" (
|
||||
if "%arg:~0,2%" equ "/?" (
|
||||
echo Usage: update.bat ^<release name or branch^>
|
||||
echo Updates InvokeAI to use the indicated version of the code base.
|
||||
echo Find the version or branch for the release you want, and pass it as the argument.
|
||||
echo For example '.\update.bat v2.2.5' for release 2.2.5.
|
||||
echo '.\update.bat main' for the latest development version
|
||||
echo.
|
||||
echo If no argument provided then will install the most recent release, equivalent to
|
||||
echo '.\update.bat latest'
|
||||
exit /b
|
||||
) else (
|
||||
set INVOKE_AI_VERSION=%arg%
|
||||
)
|
||||
)
|
||||
|
||||
set INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive/!INVOKE_AI_VERSION!.zip"
|
||||
set INVOKE_AI_DEP=https://raw.githubusercontent.com/invoke-ai/InvokeAI/!INVOKE_AI_VERSION!/environments-and-requirements/requirements-base.txt
|
||||
set INVOKE_AI_MODELS=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/configs/INITIAL_MODELS.yaml
|
||||
|
||||
call curl -I "%INVOKE_AI_DEP%" -fs >.tmp.out
|
||||
if %errorlevel% neq 0 (
|
||||
echo '!INVOKE_AI_VERSION!' is not a known branch name or tag. Please check the version and try again.
|
||||
echo "Press any key to continue"
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
del .tmp.out
|
||||
|
||||
echo This script will update InvokeAI and all its dependencies to !INVOKE_AI_SRC!.
|
||||
echo If you do not want to do this, press control-C now!
|
||||
pause
|
||||
|
||||
call curl -L "%INVOKE_AI_DEP%" > environments-and-requirements/requirements-base.txt
|
||||
call curl -L "%INVOKE_AI_MODELS%" > configs/INITIAL_MODELS.yaml
|
||||
|
||||
|
||||
call .venv\Scripts\activate.bat
|
||||
call .venv\Scripts\python -mpip install -r requirements.txt
|
||||
if %errorlevel% neq 0 (
|
||||
echo Installation of requirements failed. See https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/#troubleshooting for suggestions.
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
|
||||
call .venv\Scripts\python -mpip install !INVOKE_AI_SRC!
|
||||
if %errorlevel% neq 0 (
|
||||
echo Installation of InvokeAI failed. See https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/#troubleshooting for suggestions.
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
|
||||
@rem call .venv\Scripts\invokeai-configure --root=.
|
||||
|
||||
@rem if %errorlevel% neq 0 (
|
||||
@rem echo Configuration InvokeAI failed. See https://invoke-ai.github.io/InvokeAI/installation/INSTALL_AUTOMATED/#troubleshooting for suggestions.
|
||||
@rem pause
|
||||
@rem exit /b
|
||||
@rem )
|
||||
|
||||
echo InvokeAI has been updated to '%INVOKE_AI_VERSION%'
|
||||
|
||||
echo "Press any key to continue"
|
||||
pause
|
||||
endlocal
|
||||
@@ -1,58 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -ge 1 ] && [ "${1:0:2}" == "-h" ]; then
|
||||
echo "Usage: update.sh <release>"
|
||||
echo "Updates InvokeAI to use the indicated version of the code base."
|
||||
echo "Find the version or branch for the release you want, and pass it as the argument."
|
||||
echo "For example: update.sh v2.2.5 for release 2.2.5."
|
||||
echo " update.sh main for the current development version."
|
||||
echo ""
|
||||
echo "If no argument provided then will install the version tagged with 'latest', equivalent to"
|
||||
echo "update.sh latest"
|
||||
exit -1
|
||||
fi
|
||||
|
||||
INVOKE_AI_VERSION=${1:-latest}
|
||||
|
||||
INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive/$INVOKE_AI_VERSION.zip"
|
||||
INVOKE_AI_DEP=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/environments-and-requirements/requirements-base.txt
|
||||
INVOKE_AI_MODELS=https://raw.githubusercontent.com/invoke-ai/InvokeAI/$INVOKE_AI_VERSION/configs/INITIAL_MODELS.yaml
|
||||
|
||||
# ensure we're in the correct folder in case user's CWD is somewhere else
|
||||
scriptdir=$(dirname "$0")
|
||||
cd "$scriptdir"
|
||||
|
||||
function _err_exit {
|
||||
if test "$1" -ne 0
|
||||
then
|
||||
echo "Something went wrong while installing InvokeAI and/or its requirements."
|
||||
echo "Update cannot continue. Please report this error to https://github.com/invoke-ai/InvokeAI/issues"
|
||||
echo -e "Error code $1; Error caught was '$2'"
|
||||
read -p "Press any key to exit..."
|
||||
exit
|
||||
fi
|
||||
}
|
||||
|
||||
if ! curl -I "$INVOKE_AI_DEP" -fs >/dev/null; then
|
||||
echo \'$INVOKE_AI_VERSION\' is not a known branch name or tag. Please check the version and try again.
|
||||
exit
|
||||
fi
|
||||
|
||||
echo This script will update InvokeAI and all its dependencies to version \'$INVOKE_AI_VERSION\'.
|
||||
echo If you do not want to do this, press control-C now!
|
||||
read -p "Press any key to continue, or CTRL-C to exit..."
|
||||
|
||||
curl -L "$INVOKE_AI_DEP" > environments-and-requirements/requirements-base.txt
|
||||
curl -L "$INVOKE_AI_MODELS" > configs/INITIAL_MODELS.yaml
|
||||
|
||||
. .venv/bin/activate
|
||||
|
||||
./.venv/bin/python -mpip install -r requirements.txt
|
||||
_err_exit $? "The pip program failed to install InvokeAI's requirements."
|
||||
|
||||
./.venv/bin/python -mpip install $INVOKE_AI_SRC
|
||||
_err_exit $? "The pip program failed to install InvokeAI."
|
||||
|
||||
echo InvokeAI updated to \'$INVOKE_AI_VERSION\'
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -22,11 +22,10 @@ from ..services.invocation_queue.invocation_queue_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_install import ModelInstallService
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
@@ -80,21 +79,18 @@ class ApiDependencies:
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
model_manager = ModelManagerService(config, logger)
|
||||
model_record_service = ModelRecordServiceSQL(db=db)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
metadata_store = ModelMetadataStore(db=db)
|
||||
model_install_service = ModelInstallService(
|
||||
app_config=config,
|
||||
record_store=model_record_service,
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db, metadata_store=model_metadata_service),
|
||||
download_queue=download_queue_service,
|
||||
metadata_store=metadata_store,
|
||||
event_bus=events,
|
||||
events=events,
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
@@ -120,9 +116,7 @@ class ApiDependencies:
|
||||
latents=latents,
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
model_records=model_record_service,
|
||||
download_queue=download_queue_service,
|
||||
model_install=model_install_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
|
||||
@@ -36,7 +36,7 @@ async def list_downloads() -> List[DownloadJob]:
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_downloads():
|
||||
async def prune_downloads() -> Response:
|
||||
"""Prune completed and errored jobs."""
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
queue.prune_jobs()
|
||||
@@ -55,7 +55,7 @@ async def download(
|
||||
) -> DownloadJob:
|
||||
"""Download the source URL to the file or directory indicted in dest."""
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
return queue.download(source, dest, priority, access_token)
|
||||
return queue.download(source, Path(dest), priority, access_token)
|
||||
|
||||
|
||||
@download_queue_router.get(
|
||||
@@ -87,7 +87,7 @@ async def get_download_job(
|
||||
)
|
||||
async def cancel_download_job(
|
||||
id: int = Path(description="ID of the download job to cancel."),
|
||||
):
|
||||
) -> Response:
|
||||
"""Cancel a download job using its ID."""
|
||||
try:
|
||||
queue = ApiDependencies.invoker.services.download_queue
|
||||
@@ -105,7 +105,7 @@ async def cancel_download_job(
|
||||
204: {"description": "Download jobs have been cancelled"},
|
||||
},
|
||||
)
|
||||
async def cancel_all_download_jobs():
|
||||
async def cancel_all_download_jobs() -> Response:
|
||||
"""Cancel all download jobs."""
|
||||
ApiDependencies.invoker.services.download_queue.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
759
invokeai/app/api/routers/model_manager.py
Normal file
759
invokeai/app/api/routers/model_manager.py
Normal file
@@ -0,0 +1,759 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import pathlib
|
||||
import shutil
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
MainCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
|
||||
models: List[AnyModelConfig]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
##############################################################################
|
||||
# These are example inputs and outputs that are used in places where Swagger
|
||||
# is unable to generate a correct example.
|
||||
##############################################################################
|
||||
example_model_config = {
|
||||
"path": "string",
|
||||
"name": "string",
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "string",
|
||||
"key": "string",
|
||||
"original_hash": "string",
|
||||
"current_hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"last_modified": 0,
|
||||
"vae": "string",
|
||||
"variant": "normal",
|
||||
"prediction_type": "epsilon",
|
||||
"repo_variant": "fp16",
|
||||
"upcast_attention": False,
|
||||
"ztsnr_training": False,
|
||||
}
|
||||
|
||||
example_model_input = {
|
||||
"path": "/path/to/model",
|
||||
"name": "model_name",
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "checkpoint",
|
||||
"config": "configs/stable-diffusion/v1-inference.yaml",
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
}
|
||||
|
||||
example_model_metadata = {
|
||||
"name": "ip_adapter_sd_image_encoder",
|
||||
"author": "InvokeAI",
|
||||
"tags": [
|
||||
"transformers",
|
||||
"safetensors",
|
||||
"clip_vision_model",
|
||||
"endpoints_compatible",
|
||||
"region:us",
|
||||
"has_space",
|
||||
"license:apache-2.0",
|
||||
],
|
||||
"files": [
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/README.md",
|
||||
"path": "ip_adapter_sd_image_encoder/README.md",
|
||||
"size": 628,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/config.json",
|
||||
"path": "ip_adapter_sd_image_encoder/config.json",
|
||||
"size": 560,
|
||||
"sha256": None,
|
||||
},
|
||||
{
|
||||
"url": "https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder/resolve/main/model.safetensors",
|
||||
"path": "ip_adapter_sd_image_encoder/model.safetensors",
|
||||
"size": 2528373448,
|
||||
"sha256": "6ca9667da1ca9e0b0f75e46bb030f7e011f44f86cbfb8d5a36590fcd7507b030",
|
||||
},
|
||||
],
|
||||
"type": "huggingface",
|
||||
"id": "InvokeAI/ip_adapter_sd_image_encoder",
|
||||
"tag_dict": {"license": "apache-2.0"},
|
||||
"last_modified": "2023-09-23T17:33:25Z",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
# ROUTES
|
||||
##############################################################################
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||
model_format: Optional[ModelFormat] = Query(
|
||||
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(
|
||||
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
|
||||
)
|
||||
)
|
||||
else:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model configuration was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
)
|
||||
async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
config: AnyModelConfig = record_store.get_model(key)
|
||||
return config
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.get("/summary", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
return results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model metadata was retrieved successfully",
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "No metadata available"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Set[str] = record_store.list_tags()
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model was updated successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return model_response
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
Delete model record from database.
|
||||
|
||||
The configuration record will be removed. The corresponding weights files will be
|
||||
deleted as well if they reside within the InvokeAI "models" directory.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {
|
||||
"description": "The model added successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[
|
||||
AnyModelConfig, Body(description="Model config", discriminator="type", example=example_model_input)
|
||||
],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
result: AnyModelConfig = record_store.get_model(config.key)
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/heuristic_import",
|
||||
operation_id="heuristic_import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def heuristic_import(
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
example={"name": "modelT", "description": "antique cars"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using a string identifier.
|
||||
|
||||
`source` can be any of the following.
|
||||
|
||||
1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors')
|
||||
2. A Url pointing to a single downloadable model file
|
||||
3. A HuggingFace repo_id with any of the following formats:
|
||||
- model/name
|
||||
- model/name:fp16:vae
|
||||
- model/name::vae -- use default precision
|
||||
- model/name:fp16:path/to/model.safetensors
|
||||
- model/name::path/to/model.safetensors
|
||||
|
||||
`config` is an optional dict containing model configuration values that will override
|
||||
the ones that are probed automatically.
|
||||
|
||||
`access_token` is an optional access token for use with Urls that require
|
||||
authentication.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
See the documentation for `import_model_record` for more information on
|
||||
interpreting the job information returned by this route.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.heuristic_import(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install",
|
||||
operation_id="import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The source object is a discriminated Union of LocalModelSource,
|
||||
HFModelSource and URLModelSource. Set the "type" field to the
|
||||
appropriate value:
|
||||
|
||||
* To install a local path using LocalModelSource, pass a source of form:
|
||||
```
|
||||
{
|
||||
"type": "local",
|
||||
"path": "/path/to/model",
|
||||
"inplace": false
|
||||
}
|
||||
```
|
||||
The "inplace" flag, if true, will register the model in place in its
|
||||
current filesystem location. Otherwise, the model will be copied
|
||||
into the InvokeAI models directory.
|
||||
|
||||
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||
```
|
||||
{
|
||||
"type": "hf",
|
||||
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||
"variant": "fp16",
|
||||
"subfolder": "vae",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}
|
||||
```
|
||||
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||
|
||||
* To install a remote model using an arbitrary URL, pass:
|
||||
```
|
||||
{
|
||||
"type": "url",
|
||||
"url": "http://www.civitai.com/models/123456",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}
|
||||
```
|
||||
The `access_token` field is optonal
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
|
||||
Installation occurs in the background. Either use list_model_install_jobs()
|
||||
to poll for completion, or listen on the event bus for the following events:
|
||||
|
||||
* "model_install_running"
|
||||
* "model_install_completed"
|
||||
* "model_install_error"
|
||||
|
||||
On successful completion, the event's payload will contain the field "key"
|
||||
containing the installed ID of the model. On an error, the event's payload
|
||||
will contain the fields "error_type" and "error" describing the nature of the
|
||||
error and its traceback, respectively.
|
||||
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return the list of model install jobs.
|
||||
|
||||
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
|
||||
the nature of the job and its progress. The `status` is one of:
|
||||
|
||||
* "waiting" -- Job is waiting in the queue to run
|
||||
* "downloading" -- Model file(s) are downloading
|
||||
* "running" -- Model has downloaded and the model probing and registration process is running
|
||||
* "completed" -- Installation completed successfully
|
||||
* "error" -- An error occurred. Details will be in the "error_type" and "error" fields.
|
||||
* "cancelled" -- Job was cancelled before completion.
|
||||
|
||||
Once completed, information about the model such as its size, base
|
||||
model, type, and metadata can be retrieved from the `config_out`
|
||||
field. For multi-file models such as diffusers, information on individual files
|
||||
can be retrieved from `download_parts`.
|
||||
|
||||
See the example and schema below for more information.
|
||||
"""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_manager.install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "No such job"},
|
||||
},
|
||||
)
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""
|
||||
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
|
||||
for information on the format of the return value.
|
||||
"""
|
||||
try:
|
||||
result: ModelInstallJob = ApiDependencies.invoker.services.model_manager.install.get_job_by_id(id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_manager.install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/sync",
|
||||
operation_id="sync_models_to_config",
|
||||
responses={
|
||||
204: {"description": "Model config record database resynced with files on disk"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def sync_models_to_config() -> Response:
|
||||
"""
|
||||
Traverse the models and autoimport directories.
|
||||
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_manager.install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/convert/{key}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def convert_model(
|
||||
key: str = Path(description="Unique key of the safetensors main model to convert to diffusers format."),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Permanently convert a model into diffusers format, replacing the safetensors version.
|
||||
Note that during the conversion process the key and model hash will change.
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
|
||||
try:
|
||||
model_config = store.get_model(key)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
|
||||
if not isinstance(model_config, MainCheckpointConfig):
|
||||
logger.error(f"The model with key {key} is not a main checkpoint model.")
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
assert cache_path.exists()
|
||||
|
||||
# temporarily rename the original safetensors file so that there is no naming conflict
|
||||
original_name = model_config.name
|
||||
model_config.name = f"{original_name}.DELETE"
|
||||
store.update_model(key, config=model_config)
|
||||
|
||||
# install the diffusers
|
||||
try:
|
||||
new_key = installer.install_path(
|
||||
cache_path,
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"original_hash": model_config.original_hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
# get the original metadata
|
||||
if orig_metadata := store.get_metadata(key):
|
||||
store.metadata_store.add_metadata(new_key, orig_metadata)
|
||||
|
||||
# delete the original safetensors file
|
||||
installer.delete(key)
|
||||
|
||||
# delete the cached version
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
# return the config record for the new diffusers directory
|
||||
new_config: AnyModelConfig = store.get_model(new_key)
|
||||
return new_config
|
||||
|
||||
|
||||
@model_manager_router.put(
|
||||
"/merge",
|
||||
operation_id="merge",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Model converted successfully",
|
||||
"content": {"application/json": {"example": example_model_config}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
409: {"description": "There is already a model registered at this location"},
|
||||
},
|
||||
)
|
||||
async def merge(
|
||||
keys: List[str] = Body(description="Keys for two to three models to merge", min_length=2, max_length=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model", default=None),
|
||||
alpha: float = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
force: bool = Body(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method", default=None),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Merge diffusers models. The process is controlled by a set parameters provided in the body of the request.
|
||||
```
|
||||
Argument Description [default]
|
||||
-------- ----------------------
|
||||
keys List of 2-3 model keys to merge together. All models must use the same base type.
|
||||
merged_model_name Name for the merged model [Concat model names]
|
||||
alpha Alpha value (0.0-1.0). Higher values give more weight to the second model [0.5]
|
||||
force If true, force the merge even if the models were generated by different versions of the diffusers library [False]
|
||||
interp Interpolation method. One of "weighted_sum", "sigmoid", "inv_sigmoid" or "add_difference" [weighted_sum]
|
||||
merge_dest_directory Specify a directory to store the merged model in [models directory]
|
||||
```
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {keys} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
merger = ModelMerger(installer)
|
||||
model_names = [installer.record_store.get_model(x).name for x in keys]
|
||||
response = merger.merge_diffusion_models_and_save(
|
||||
model_keys=keys,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
except UnknownModelException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{keys}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
@@ -1,417 +0,0 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
model_records_router = APIRouter(prefix="/v1/model/record", tags=["model_manager_v2_unstable"])
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
|
||||
models: List[AnyModelConfig]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class ModelTagSet(BaseModel):
|
||||
"""Return tags for a set of models."""
|
||||
|
||||
key: str
|
||||
name: str
|
||||
author: str
|
||||
tags: Set[str]
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/",
|
||||
operation_id="list_model_records",
|
||||
)
|
||||
async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||
model_format: Optional[ModelFormat] = Query(
|
||||
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
|
||||
),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
found_models: list[AnyModelConfig] = []
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(
|
||||
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
|
||||
)
|
||||
)
|
||||
else:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
)
|
||||
async def get_model_record(
|
||||
key: str = Path(description="Key of the model record to fetch."),
|
||||
) -> AnyModelConfig:
|
||||
"""Get a model record"""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
try:
|
||||
return record_store.get_model(key)
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.get("/meta", operation_id="list_model_summary")
|
||||
async def list_model_summary(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of models per page"),
|
||||
order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
"""Gets a page of model summary data."""
|
||||
return ApiDependencies.invoker.services.model_records.list_models(page=page, per_page=per_page, order_by=order_by)
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "No metadata available"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
key: str = Path(description="Key of the model repo metadata to fetch."),
|
||||
) -> Optional[AnyModelRepoMetadata]:
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
result = record_store.get_metadata(key)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/tags",
|
||||
operation_id="list_tags",
|
||||
)
|
||||
async def list_tags() -> Set[str]:
|
||||
"""Get a unique set of all the model tags."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
return record_store.list_tags()
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
)
|
||||
async def search_by_metadata_tags(
|
||||
tags: Set[str] = Query(default=None, description="Tags to search for"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
results = record_store.search_by_metadata_tag(tags)
|
||||
return ModelsList(models=results)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
try:
|
||||
model_response = record_store.update_model(key, config=info)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return model_response
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="del_model_record",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
)
|
||||
async def del_model_record(
|
||||
key: str = Path(description="Unique key of model to remove from model registry."),
|
||||
) -> Response:
|
||||
"""
|
||||
Delete model record from database.
|
||||
|
||||
The configuration record will be removed. The corresponding weights files will be
|
||||
deleted as well if they reside within the InvokeAI "models" directory.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
installer.delete(key)
|
||||
logger.info(f"Deleted model: {key}")
|
||||
return Response(status_code=204)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
"/i/",
|
||||
operation_id="add_model_record",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""Add a model using the configuration information appropriate for its type."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
if config.key == "<NOKEY>":
|
||||
config.key = sha1(randbytes(100)).hexdigest()
|
||||
logger.info(f"Created model {config.key} for {config.name}")
|
||||
try:
|
||||
record_store.add_model(config.key, config)
|
||||
except DuplicateModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
# now fetch it out
|
||||
return record_store.get_model(config.key)
|
||||
|
||||
|
||||
@model_records_router.post(
|
||||
"/import",
|
||||
operation_id="import_model_record",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The source object is a discriminated Union of LocalModelSource,
|
||||
HFModelSource and URLModelSource. Set the "type" field to the
|
||||
appropriate value:
|
||||
|
||||
* To install a local path using LocalModelSource, pass a source of form:
|
||||
`{
|
||||
"type": "local",
|
||||
"path": "/path/to/model",
|
||||
"inplace": false
|
||||
}`
|
||||
The "inplace" flag, if true, will register the model in place in its
|
||||
current filesystem location. Otherwise, the model will be copied
|
||||
into the InvokeAI models directory.
|
||||
|
||||
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||
`{
|
||||
"type": "hf",
|
||||
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||
"variant": "fp16",
|
||||
"subfolder": "vae",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}`
|
||||
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||
|
||||
* To install a remote model using an arbitrary URL, pass:
|
||||
`{
|
||||
"type": "url",
|
||||
"url": "http://www.civitai.com/models/123456",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}`
|
||||
The `access_token` field is optonal
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
|
||||
Installation occurs in the background. Either use list_model_install_jobs()
|
||||
to poll for completion, or listen on the event bus for the following events:
|
||||
|
||||
"model_install_running"
|
||||
"model_install_completed"
|
||||
"model_install_error"
|
||||
|
||||
On successful completion, the event's payload will contain the field "key"
|
||||
containing the installed ID of the model. On an error, the event's payload
|
||||
will contain the fields "error_type" and "error" describing the nature of the
|
||||
error and its traceback, respectively.
|
||||
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/import",
|
||||
operation_id="list_model_install_jobs",
|
||||
)
|
||||
async def list_model_install_jobs() -> List[ModelInstallJob]:
|
||||
"""Return list of model install jobs."""
|
||||
jobs: List[ModelInstallJob] = ApiDependencies.invoker.services.model_install.list_jobs()
|
||||
return jobs
|
||||
|
||||
|
||||
@model_records_router.get(
|
||||
"/import/{id}",
|
||||
operation_id="get_model_install_job",
|
||||
responses={
|
||||
200: {"description": "Success"},
|
||||
404: {"description": "No such job"},
|
||||
},
|
||||
)
|
||||
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
|
||||
"""Return model install job corresponding to the given source."""
|
||||
try:
|
||||
return ApiDependencies.invoker.services.model_install.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@model_records_router.delete(
|
||||
"/import/{id}",
|
||||
operation_id="cancel_model_install_job",
|
||||
responses={
|
||||
201: {"description": "The job was cancelled successfully"},
|
||||
415: {"description": "No such job"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def cancel_model_install_job(id: int = Path(description="Model install job ID")) -> None:
|
||||
"""Cancel the model install job(s) corresponding to the given job ID."""
|
||||
installer = ApiDependencies.invoker.services.model_install
|
||||
try:
|
||||
job = installer.get_job_by_id(id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=415, detail=str(e))
|
||||
installer.cancel_job(job)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/import",
|
||||
operation_id="prune_model_install_jobs",
|
||||
responses={
|
||||
204: {"description": "All completed and errored jobs have been pruned"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def prune_model_install_jobs() -> Response:
|
||||
"""Prune all completed and errored jobs from the install job list."""
|
||||
ApiDependencies.invoker.services.model_install.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@model_records_router.patch(
|
||||
"/sync",
|
||||
operation_id="sync_models_to_config",
|
||||
responses={
|
||||
204: {"description": "Model config record database resynced with files on disk"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
)
|
||||
async def sync_models_to_config() -> Response:
|
||||
"""
|
||||
Traverse the models and autoimport directories.
|
||||
|
||||
Model files without a corresponding
|
||||
record in the database are added. Orphan records without a models file are deleted.
|
||||
"""
|
||||
ApiDependencies.invoker.services.model_install.sync_to_config()
|
||||
return Response(status_code=204)
|
||||
@@ -1,427 +0,0 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
|
||||
|
||||
import pathlib
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
UpdateModelResponseValidator = TypeAdapter(UpdateModelResponse)
|
||||
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponseValidator = TypeAdapter(ImportModelResponse)
|
||||
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponseValidator = TypeAdapter(ConvertModelResponse)
|
||||
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
ModelsListValidator = TypeAdapter(ModelsList)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList}},
|
||||
)
|
||||
async def list_models(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = []
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = ModelsListValidator.validate_python({"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
# rename operation requested
|
||||
if info.model_name != model_name or info.base_model != base_model:
|
||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=info.model_name,
|
||||
new_base=info.base_model,
|
||||
)
|
||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
new_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
if new_info.get("path") != previous_info.get(
|
||||
"path"
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
# replace empty string values with None/null to avoid phenomenon of vae: ''
|
||||
info_dict = info.model_dump()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info_dict,
|
||||
)
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_response = UpdateModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||
description="Prediction type for SDv2 checkpoints and rare SDv1 checkpoints",
|
||||
default=None,
|
||||
),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
|
||||
location = location.strip("\"' ")
|
||||
items_to_import = {location}
|
||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import=items_to_import,
|
||||
prediction_type_helper=lambda x: prediction_types.get(prediction_type),
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
logger.info(f"Successfully imported {location}, got {info}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return ImportModelResponseValidator.validate_python(model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
operation_id="add_model",
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes=info.model_dump(),
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type,
|
||||
)
|
||||
return ImportModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: {"description": "Model deleted successfully"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ConvertModelResponse,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(
|
||||
default=None, description="Save the converted model to the designated directory"
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||
model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
convert_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/search",
|
||||
operation_id="search_for_models",
|
||||
responses={
|
||||
200: {"description": "Directory searched successfully"},
|
||||
404: {"description": "Invalid directory path"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[pathlib.Path],
|
||||
)
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||
) -> List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The search path '{search_path}' does not exist or is not directory",
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
operation_id="list_ckpt_configs",
|
||||
responses={
|
||||
200: {"description": "paths retrieved successfully"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[pathlib.Path],
|
||||
)
|
||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/sync",
|
||||
operation_id="sync_to_config",
|
||||
responses={
|
||||
201: {"description": "synchronization successful"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=bool,
|
||||
)
|
||||
async def sync_to_config() -> bool:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
return True
|
||||
|
||||
|
||||
# There's some weird pydantic-fastapi behaviour that requires this to be a separate class
|
||||
# TODO: After a few updates, see if it works inside the route operation handler?
|
||||
class MergeModelsBody(BaseModel):
|
||||
model_names: List[str] = Field(description="model name", min_length=2, max_length=3)
|
||||
merged_model_name: Optional[str] = Field(description="Name of destination model")
|
||||
alpha: Optional[float] = Field(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5)
|
||||
interp: Optional[MergeInterpolationMethod] = Field(description="Interpolation method")
|
||||
force: Optional[bool] = Field(
|
||||
description="Force merging of models created with different versions of diffusers",
|
||||
default=False,
|
||||
)
|
||||
|
||||
merge_dest_directory: Optional[str] = Field(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Incompatible models"},
|
||||
404: {"description": "One or more models not found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
body: Annotated[MergeModelsBody, Body(description="Model configuration", embed=True)],
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(
|
||||
f"Merging models: {body.model_names} into {body.merge_dest_directory or '<MODELS>'}/{body.merged_model_name}"
|
||||
)
|
||||
dest = pathlib.Path(body.merge_dest_directory) if body.merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names=body.model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=body.merged_model_name or "+".join(body.model_names),
|
||||
alpha=body.alpha,
|
||||
interp=body.interp,
|
||||
force=body.force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
result.name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = ConvertModelResponseValidator.validate_python(model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{body.model_names}' not found",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
@@ -14,7 +14,7 @@ class SocketIO:
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="socket.io")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
|
||||
@@ -47,8 +47,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
boards,
|
||||
download_queue,
|
||||
images,
|
||||
model_records,
|
||||
models,
|
||||
model_manager,
|
||||
session_queue,
|
||||
sessions,
|
||||
utilities,
|
||||
@@ -115,8 +114,7 @@ async def shutdown_event() -> None:
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(models.models_router, prefix="/api")
|
||||
app.include_router(model_records.model_records_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
app.include_router(images.images_router, prefix="/api")
|
||||
app.include_router(boards.boards_router, prefix="/api")
|
||||
@@ -178,21 +176,23 @@ def custom_openapi() -> dict[str, Any]:
|
||||
invoker_schema["class"] = "invocation"
|
||||
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||
|
||||
from invokeai.backend.model_management.models import get_model_config_enums
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
#
|
||||
# from invokeai.backend.model_manager import get_model_config_formats
|
||||
# formats = get_model_config_formats()
|
||||
# for model_config_name, enum_set in formats.items():
|
||||
|
||||
for model_config_format_enum in set(get_model_config_enums()):
|
||||
name = model_config_format_enum.__qualname__
|
||||
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||
# # print(f"Config with name {name} already defined")
|
||||
# continue
|
||||
|
||||
if name in openapi_schema["components"]["schemas"]:
|
||||
# print(f"Config with name {name} already defined")
|
||||
continue
|
||||
|
||||
openapi_schema["components"]["schemas"][name] = {
|
||||
"title": name,
|
||||
"description": "An enumeration.",
|
||||
"type": "string",
|
||||
"enum": [v.value for v in model_config_format_enum],
|
||||
}
|
||||
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||
# "title": model_config_name,
|
||||
# "description": "An enumeration.",
|
||||
# "type": "string",
|
||||
# "enum": [v.value for v in enum_set],
|
||||
# }
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -66,21 +71,22 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader():
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
@@ -90,25 +96,20 @@ class CompelInvocation(BaseInvocation):
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
loaded_model = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
).model
|
||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, loaded_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@@ -116,7 +117,7 @@ class CompelInvocation(BaseInvocation):
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
@@ -150,7 +151,7 @@ class CompelInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
context.services.latents.save(conditioning_name, conditioning_data) # TODO: fix type mismatch here
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
@@ -160,6 +161,8 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
"""Prompt processor for SDXL models."""
|
||||
|
||||
def run_clip_compel(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -168,26 +171,27 @@ class SDXLPromptInvocationBase:
|
||||
get_pooled: bool,
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
):
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
cpu_text_encoder = text_encoder_info.context.model
|
||||
cpu_text_encoder = text_encoder_info.model
|
||||
assert isinstance(cpu_text_encoder, torch.nn.Module)
|
||||
c = torch.zeros(
|
||||
(
|
||||
1,
|
||||
cpu_text_encoder.config.max_position_embeddings,
|
||||
cpu_text_encoder.config.hidden_size,
|
||||
),
|
||||
dtype=text_encoder_info.context.cache.precision,
|
||||
dtype=cpu_text_encoder.dtype,
|
||||
)
|
||||
if get_pooled:
|
||||
c_pooled = torch.zeros(
|
||||
@@ -198,12 +202,14 @@ class SDXLPromptInvocationBase:
|
||||
c_pooled = None
|
||||
return c, c_pooled, None
|
||||
|
||||
def _lora_loader():
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
@@ -213,25 +219,24 @@ class SDXLPromptInvocationBase:
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=clip_field.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).context.model,
|
||||
)
|
||||
)
|
||||
except ModelNotFoundException:
|
||||
ti_model = context.services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).model
|
||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, ti_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
logger.warning(f'trigger: "{trigger}" not found')
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@@ -239,7 +244,7 @@ class SDXLPromptInvocationBase:
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
@@ -357,6 +362,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
dim=1,
|
||||
)
|
||||
|
||||
assert c2_pooled is not None
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
@@ -410,6 +416,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
|
||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
||||
|
||||
assert c2_pooled is not None
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[
|
||||
SDXLConditioningInfo(
|
||||
@@ -459,9 +466,9 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_max_token_count(
|
||||
tokenizer,
|
||||
tokenizer: CLIPTokenizer,
|
||||
prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||
truncate_if_too_long=False,
|
||||
truncate_if_too_long: bool = False,
|
||||
) -> int:
|
||||
if type(prompt) is Blend:
|
||||
blend: Blend = prompt
|
||||
@@ -473,7 +480,9 @@ def get_max_token_count(
|
||||
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
||||
|
||||
|
||||
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
|
||||
def get_tokens_for_prompt_object(
|
||||
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
|
||||
) -> List[str]:
|
||||
if type(parsed_prompt) is Blend:
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
|
||||
@@ -486,24 +495,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
tokens = tokenizer.tokenize(text)
|
||||
tokens: List[str] = tokenizer.tokenize(text)
|
||||
if truncate_if_too_long:
|
||||
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||
tokens = tokens[0:max_tokens_length]
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
|
||||
def log_tokenization_for_conjunction(
|
||||
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
|
||||
) -> None:
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
for i, p in enumerate(c.prompts):
|
||||
if len(c.prompts) > 1:
|
||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||
else:
|
||||
assert display_label_prefix is not None
|
||||
this_display_label_prefix = display_label_prefix
|
||||
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
|
||||
) -> None:
|
||||
display_label_prefix = display_label_prefix or ""
|
||||
if type(p) is Blend:
|
||||
blend: Blend = p
|
||||
@@ -543,7 +557,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
|
||||
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||
|
||||
|
||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||
def log_tokenization_for_text(
|
||||
text: str,
|
||||
tokenizer: CLIPTokenizer,
|
||||
display_label: Optional[str] = None,
|
||||
truncate_if_too_long: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
|
||||
@@ -24,7 +24,7 @@ from controlnet_aux import (
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@@ -32,7 +32,6 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -57,10 +56,7 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model config record key for the ControlNet model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from builtins import float
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -17,22 +17,16 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
class IPAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Key to the IP-Adapter model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
||||
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
@@ -49,12 +43,12 @@ class IPAdapterField(BaseModel):
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
def validate_ip_adapter_weight(cls, v):
|
||||
def validate_ip_adapter_weight(cls, v: float) -> float:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
def validate_begin_end_step_percent(self) -> Self:
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
@@ -87,33 +81,25 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
def validate_ip_adapter_weight(cls, v):
|
||||
def validate_ip_adapter_weight(cls, v: float) -> float:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
def validate_begin_end_step_percent(self) -> Self:
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.services.model_manager.model_info(
|
||||
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
||||
)
|
||||
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
||||
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
|
||||
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
|
||||
# is currently messy due to differences between how the model info is generated when installing a model from
|
||||
# disk vs. downloading the model.
|
||||
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
||||
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
|
||||
)
|
||||
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_model = CLIPVisionModelField(
|
||||
model_name=image_encoder_model_name,
|
||||
base_model=BaseModelType.Any,
|
||||
image_encoder_models = context.services.model_manager.store.search_by_attr(
|
||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
|
||||
@@ -3,13 +3,15 @@
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.adapter import T2IAdapter
|
||||
from diffusers.models.attention_processor import (
|
||||
@@ -18,8 +20,10 @@ from diffusers.models.attention_processor import (
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
@@ -39,13 +43,13 @@ from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.seamless import set_seamless
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
IPAdapterData,
|
||||
@@ -77,7 +81,9 @@ if choose_torch_device() == torch.device("mps"):
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(SCHEDULER_MAP.keys())
|
||||
] # FIXME: "Invalid type alias". This defeats static type checking.
|
||||
|
||||
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||
@@ -131,10 +137,10 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
ui_order=4,
|
||||
)
|
||||
|
||||
def prep_mask_tensor(self, mask_image):
|
||||
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
if mask_tensor.dim() == 3:
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
# if shape is not None:
|
||||
@@ -145,24 +151,24 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
else:
|
||||
image = None
|
||||
image_tensor = None
|
||||
|
||||
mask = self.prep_mask_tensor(
|
||||
context.services.images.get_pil_image(self.mask.image_name),
|
||||
)
|
||||
|
||||
if image is not None:
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
if image_tensor is not None:
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
@@ -189,7 +195,7 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
orig_scheduler_info = context.services.model_manager.load_model_by_key(
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@@ -200,7 +206,7 @@ def get_scheduler(
|
||||
scheduler_config = scheduler_config["_backup"]
|
||||
scheduler_config = {
|
||||
**scheduler_config,
|
||||
**scheduler_extra_config,
|
||||
**scheduler_extra_config, # FIXME
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
@@ -213,6 +219,7 @@ def get_scheduler(
|
||||
# hack copied over from generate.py
|
||||
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
assert isinstance(scheduler, Scheduler)
|
||||
return scheduler
|
||||
|
||||
|
||||
@@ -296,7 +303,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
def ge_one(cls, v: Union[List[float], float]) -> Union[List[float], float]:
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
@@ -326,9 +333,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
seed,
|
||||
scheduler: Scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
seed: int,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -351,7 +358,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
scheduler,
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
@@ -363,8 +370,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet,
|
||||
scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
# TODO:
|
||||
# configure_model_padding(
|
||||
@@ -375,10 +382,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return StableDiffusionGeneratorPipeline(
|
||||
@@ -395,11 +402,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
control_input: Optional[Union[ControlField, List[ControlField]]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
) -> Optional[List[ControlNetData]]:
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||
@@ -422,10 +429,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context.services.model_manager.load_model_by_key(
|
||||
key=control_info.control_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
@@ -490,27 +495,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||
context.services.model_manager.load_model_by_key(
|
||||
key=single_ip_adapter.ip_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.get_model(
|
||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||
image_encoder_model_info = context.services.model_manager.load_model_by_key(
|
||||
key=single_ip_adapter.image_encoder_model.key,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_images = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_images, list):
|
||||
single_ipa_images = [single_ipa_images]
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]
|
||||
single_ipa_images = [
|
||||
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
|
||||
]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
@@ -554,23 +557,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_info = context.services.model_manager.get_model(
|
||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||
model_type=ModelType.T2IAdapter,
|
||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||
t2i_adapter_model_info = context.services.model_manager.load_model_by_key(
|
||||
key=t2i_adapter_field.t2i_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||
if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1:
|
||||
max_unet_downscale = 8
|
||||
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
|
||||
elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL:
|
||||
max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
|
||||
)
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.")
|
||||
|
||||
t2i_adapter_model: T2IAdapter
|
||||
with t2i_adapter_model_info as t2i_adapter_model:
|
||||
@@ -593,7 +592,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=False,
|
||||
width=t2i_input_width,
|
||||
height=t2i_input_height,
|
||||
num_channels=t2i_adapter_model.config.in_channels,
|
||||
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
|
||||
device=t2i_adapter_model.device,
|
||||
dtype=t2i_adapter_model.dtype,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
@@ -618,7 +617,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
||||
def init_scheduler(
|
||||
self,
|
||||
scheduler: Union[Scheduler, ConfigMixin],
|
||||
device: torch.device,
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
) -> Tuple[int, List[int], int]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(steps, device="cpu")
|
||||
timesteps = scheduler.timesteps.to(device=device)
|
||||
@@ -630,11 +637,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
_timesteps = timesteps[:: scheduler.order]
|
||||
|
||||
# get start timestep index
|
||||
t_start_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_start)))
|
||||
t_start_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_start)))
|
||||
t_start_idx = len(list(filter(lambda ts: ts >= t_start_val, _timesteps)))
|
||||
|
||||
# get end timestep index
|
||||
t_end_val = int(round(scheduler.config.num_train_timesteps * (1 - denoising_end)))
|
||||
t_end_val = int(round(scheduler.config["num_train_timesteps"] * (1 - denoising_end)))
|
||||
t_end_idx = len(list(filter(lambda ts: ts >= t_end_val, _timesteps[t_start_idx:])))
|
||||
|
||||
# apply order to indexes
|
||||
@@ -647,7 +654,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(self, context, latents):
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
@@ -700,31 +709,36 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
def step_callback(state: PipelineIntermediateState):
|
||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
|
||||
|
||||
def _lora_loader():
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.services.model_manager.get_model(
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.context.model, lora.weight)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.services.model_manager.get_model(
|
||||
unet_info = context.services.model_manager.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
|
||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||
unet_info as unet,
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
@@ -822,12 +836,13 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
@@ -1016,8 +1031,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info, upcast, tiled, image_tensor):
|
||||
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
orig_dtype = vae.dtype
|
||||
if upcast:
|
||||
vae.to(dtype=torch.float32)
|
||||
@@ -1063,7 +1079,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
@@ -1082,14 +1098,19 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
def _encode_to_tensor(vae: AutoencoderKL, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(
|
||||
dtype=vae.dtype
|
||||
) # FIXME: uses torch.randn. make reproducible!
|
||||
return latents
|
||||
|
||||
@_encode_to_tensor.register
|
||||
@staticmethod
|
||||
def _(vae: AutoencoderTiny, image_tensor: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(image_tensor).latents
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents: torch.FloatTensor = vae.encode(image_tensor).latents
|
||||
return latents
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -1122,7 +1143,12 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
|
||||
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
|
||||
def slerp(
|
||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||
v0: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
v1: Union[torch.Tensor, npt.NDArray[Any]],
|
||||
DOT_THRESHOLD: float = 0.9995,
|
||||
) -> Union[torch.Tensor, npt.NDArray[Any]]:
|
||||
"""
|
||||
Spherical linear interpolation
|
||||
Args:
|
||||
@@ -1155,12 +1181,16 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(device)
|
||||
|
||||
return v2
|
||||
v2_torch: torch.Tensor = torch.from_numpy(v2).to(device)
|
||||
return v2_torch
|
||||
else:
|
||||
assert isinstance(v2, np.ndarray)
|
||||
return v2
|
||||
|
||||
# blend
|
||||
blended_latents = slerp(self.alpha, latents_a, latents_b)
|
||||
bl = slerp(self.alpha, latents_a, latents_b)
|
||||
assert isinstance(bl, torch.Tensor)
|
||||
blended_latents: torch.Tensor = bl # for type checking convenience
|
||||
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
@@ -1256,15 +1286,19 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in initial generation artifacts if too large)",
|
||||
)
|
||||
|
||||
def trim_to_multiple_of(self, *args, multiple_of=LATENT_SCALE_FACTOR):
|
||||
def trim_to_multiple_of(self, *args: int, multiple_of: int = LATENT_SCALE_FACTOR) -> Tuple[int, ...]:
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.services.model_manager.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
aspect = self.width / self.height
|
||||
dimension = 512
|
||||
if self.unet.unet.base_model == BaseModelType.StableDiffusion2:
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
dimension = 768
|
||||
elif self.unet.unet.base_model == BaseModelType.StableDiffusionXL:
|
||||
elif unet_config.base == BaseModelType.StableDiffusionXL:
|
||||
dimension = 1024
|
||||
dimension = dimension * self.multiplier
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -20,12 +20,8 @@ from .baseinvocation import (
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
model_name: str = Field(description="Info to load submodel")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Info to load submodel")
|
||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
||||
|
||||
|
||||
class LoraInfo(ModelInfo):
|
||||
@@ -55,7 +51,7 @@ class VaeField(BaseModel):
|
||||
|
||||
@invocation_output("unet_output")
|
||||
class UNetOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a UNet field"""
|
||||
"""Base class for invocations that output a UNet field."""
|
||||
|
||||
unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet")
|
||||
|
||||
@@ -84,20 +80,13 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
||||
class MainModelField(BaseModel):
|
||||
"""Main model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model key")
|
||||
|
||||
|
||||
class LoRAModelField(BaseModel):
|
||||
"""LoRA model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the LoRA model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="LoRA model key")
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -114,85 +103,40 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
key=key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
key=key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
key=key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
key=key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
key=key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -229,21 +173,16 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
output = LoraLoaderOutput()
|
||||
|
||||
@@ -251,10 +190,8 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -263,10 +200,8 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -318,24 +253,19 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
if self.lora is None:
|
||||
raise Exception("No LoRA provided")
|
||||
|
||||
base_model = self.lora.base_model
|
||||
lora_name = self.lora.model_name
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
):
|
||||
raise Exception(f"Unknown lora name: {lora_name}!")
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||
|
||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||
|
||||
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_name}" already applied to clip2')
|
||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||
|
||||
output = SDXLLoraLoaderOutput()
|
||||
|
||||
@@ -343,10 +273,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.unet = copy.deepcopy(self.unet)
|
||||
output.unet.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -355,10 +283,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip = copy.deepcopy(self.clip)
|
||||
output.clip.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -367,10 +293,8 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
output.clip2 = copy.deepcopy(self.clip2)
|
||||
output.clip2.loras.append(
|
||||
LoraInfo(
|
||||
base_model=base_model,
|
||||
model_name=lora_name,
|
||||
model_type=ModelType.Lora,
|
||||
submodel=None,
|
||||
key=lora_key,
|
||||
submodel_type=None,
|
||||
weight=self.weight,
|
||||
)
|
||||
)
|
||||
@@ -381,10 +305,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
class VAEModelField(BaseModel):
|
||||
"""Vae model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
@@ -398,25 +319,12 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
base_model = self.vae_model.base_model
|
||||
model_name = self.vae_model.model_name
|
||||
model_type = ModelType.Vae
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unkown vae name: {model_name}!")
|
||||
return VAEOutput(
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
)
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
||||
|
||||
@invocation_output("seamless_output")
|
||||
|
||||
@@ -8,16 +8,16 @@ from typing import List, Literal, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import ModelType, SubModelType
|
||||
from invokeai.backend.model_patcher import ONNXModelPatcher
|
||||
|
||||
from ...backend.model_management import ONNXModelPatcher
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
@@ -62,16 +62,16 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.get_model(
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.get_model(
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.clip.loras
|
||||
@@ -84,11 +84,11 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.get_model(
|
||||
context.services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model,
|
||||
).model,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
@@ -257,13 +257,13 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump())
|
||||
unet_info = context.services.model_manager.load_model_by_key(**self.unet.unet.model_dump())
|
||||
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(
|
||||
context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model,
|
||||
context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.unet.loras
|
||||
@@ -344,9 +344,9 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}")
|
||||
|
||||
vae_info = context.services.model_manager.get_model(
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
)
|
||||
|
||||
@@ -400,11 +400,7 @@ class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
model_type: ModelType = Field(description="Model Type")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model ID")
|
||||
|
||||
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||
@@ -416,93 +412,46 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.ONNX
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeDecoder,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.VaeEncoder,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -368,7 +368,7 @@ class LatentsCollectionInvocation(BaseInvocation):
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> LatentsOutput:
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
from ...backend.model_management import ModelType, SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -44,72 +44,52 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -133,56 +113,40 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
# TODO: precision?
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||
base_model = self.model.base_model
|
||||
model_name = self.model.model_name
|
||||
model_type = ModelType.Main
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
):
|
||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.UNet,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Scheduler,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip2=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Tokenizer2,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer2,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.TextEncoder2,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder2,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=SubModelType.Vae,
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Vae,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@@ -16,14 +16,10 @@ from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESI
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.model_management.models.base import BaseModelType
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the T2I-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
key: str = Field(description="Model record key for the T2I-Adapter model")
|
||||
|
||||
|
||||
class T2IAdapterField(BaseModel):
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Literal
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
|
||||
@@ -27,11 +27,11 @@ class InvokeAISettings(BaseSettings):
|
||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
argparse_groups: ClassVar[Dict[str, Any]] = {}
|
||||
|
||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
def parse_args(self, argv: Optional[List[str]] = sys.argv[1:]) -> None:
|
||||
"""Call to parse command-line arguments."""
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
@@ -68,7 +68,7 @@ class InvokeAISettings(BaseSettings):
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
def add_parser_arguments(cls, parser: ArgumentParser) -> None:
|
||||
"""Dynamically create arguments for a settings parser."""
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
@@ -117,7 +117,8 @@ class InvokeAISettings(BaseSettings):
|
||||
"""Return the category of a setting."""
|
||||
hints = get_type_hints(cls)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
result: str = get_args(hints[command_field])[0]
|
||||
return result
|
||||
else:
|
||||
return "Uncategorized"
|
||||
|
||||
@@ -158,7 +159,7 @@ class InvokeAISettings(BaseSettings):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None) -> None:
|
||||
"""Add the argparse arguments for a setting parser."""
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
|
||||
@@ -21,7 +21,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
||||
It also supports reading defaults from an init file.
|
||||
"""
|
||||
|
||||
def print_help(self, file=None):
|
||||
def print_help(self, file=None) -> None:
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
@@ -173,10 +173,10 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_type_hints
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic import Field
|
||||
from pydantic.config import JsonDict
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
@@ -185,7 +185,9 @@ from .config_base import InvokeAISettings
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_MAX_VRAM = 0.5
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
|
||||
|
||||
class Categories(object):
|
||||
@@ -237,6 +239,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
@@ -251,13 +254,19 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
log_level : Literal["debug", "info", "warning", "error", "critical"] = Field(default="info", description="Emit logging messages at this level or higher", json_schema_extra=Categories.Logging)
|
||||
log_sql : bool = Field(default=False, description="Log SQL queries", json_schema_extra=Categories.Logging)
|
||||
|
||||
# Development
|
||||
dev_reload : bool = Field(default=False, description="Automatically reload when Python sources are changed.", json_schema_extra=Categories.Development)
|
||||
profile_graphs : bool = Field(default=False, description="Enable graph profiling", json_schema_extra=Categories.Development)
|
||||
profile_prefix : Optional[str] = Field(default=None, description="An optional prefix for profile output files.", json_schema_extra=Categories.Development)
|
||||
profiles_dir : Path = Field(default=Path('profiles'), description="Directory for graph profiles", json_schema_extra=Categories.Development)
|
||||
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", json_schema_extra=Categories.Other)
|
||||
|
||||
# CACHE
|
||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
ram : float = Field(default=DEFAULT_RAM_CACHE, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
convert_cache : float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB)", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
@@ -270,7 +279,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
attention_type : Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] = Field(default="auto", description="Attention type", json_schema_extra=Categories.Generation)
|
||||
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', json_schema_extra=Categories.Generation)
|
||||
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.Generation)
|
||||
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||
png_compress_level : int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", json_schema_extra=Categories.Generation)
|
||||
|
||||
# QUEUE
|
||||
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", json_schema_extra=Categories.Queue)
|
||||
@@ -280,6 +289,9 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||
|
||||
# MODEL IMPORT
|
||||
civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.Other)
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||
@@ -289,6 +301,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
|
||||
# this is not referred to in the source code and can be removed entirely
|
||||
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
||||
@@ -328,13 +341,9 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
super().parse_args(argv)
|
||||
|
||||
if self.singleton_init and not clobber:
|
||||
hints = get_type_hints(self.__class__)
|
||||
for k in self.singleton_init:
|
||||
setattr(
|
||||
self,
|
||||
k,
|
||||
TypeAdapter(hints[k]).validate_python(self.singleton_init[k]),
|
||||
)
|
||||
# When setting values in this way, set validate_assignment to true if you want to validate the value.
|
||||
for k, v in self.singleton_init.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs: Any) -> InvokeAIAppConfig:
|
||||
@@ -400,6 +409,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Path to the models directory."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def models_convert_cache_path(self) -> Path:
|
||||
"""Path to the converted cache models directory."""
|
||||
return self._resolve(self.convert_cache_dir)
|
||||
|
||||
@property
|
||||
def custom_nodes_path(self) -> Path:
|
||||
"""Path to the custom nodes directory."""
|
||||
@@ -429,15 +443,20 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
"""Return the ram cache size using the legacy or modern setting."""
|
||||
def ram_cache_size(self) -> float:
|
||||
"""Return the ram cache size using the legacy or modern setting (GB)."""
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
"""Return the vram cache size using the legacy or modern setting."""
|
||||
def vram_cache_size(self) -> float:
|
||||
"""Return the vram cache size using the legacy or modern setting (GB)."""
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
def convert_cache_size(self) -> float:
|
||||
"""Return the convert cache size on disk (GB)."""
|
||||
return self.convert_cache
|
||||
|
||||
@property
|
||||
def use_cpu(self) -> bool:
|
||||
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
|
||||
@@ -449,6 +468,11 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
disabled_in_config = not self.xformers_enabled
|
||||
return disabled_in_config and self.attention_type != "xformers"
|
||||
|
||||
@property
|
||||
def profiles_path(self) -> Path:
|
||||
"""Path to the graph profiles directory."""
|
||||
return self._resolve(self.profiles_dir)
|
||||
|
||||
@staticmethod
|
||||
def find_root() -> Path:
|
||||
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||
|
||||
@@ -260,3 +260,16 @@ class DownloadQueueServiceBase(ABC):
|
||||
def join(self) -> None:
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
"""Wait until the indicated download job has reached a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
been cancelled, or errored out.
|
||||
|
||||
:param job: The job to wait on.
|
||||
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
||||
the job hasn't completed within the indicated time.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -4,10 +4,11 @@
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@@ -48,11 +49,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
"""
|
||||
self._jobs = {}
|
||||
self._jobs: Dict[int, DownloadJob] = {}
|
||||
self._next_job_id = 0
|
||||
self._queue = PriorityQueue()
|
||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||
self._stop_event = threading.Event()
|
||||
self._worker_pool = set()
|
||||
self._job_completed_event = threading.Event()
|
||||
self._worker_pool: Set[threading.Thread] = set()
|
||||
self._lock = threading.Lock()
|
||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||
self._event_bus = event_bus
|
||||
@@ -188,6 +190,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if not job.in_terminal_state:
|
||||
self.cancel_job(job)
|
||||
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._job_completed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
|
||||
def _start_workers(self, max_workers: int) -> None:
|
||||
"""Start the requested number of worker threads."""
|
||||
self._stop_event.clear()
|
||||
@@ -208,7 +220,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
job = self._queue.get(timeout=1)
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
try:
|
||||
job.job_started = get_iso_timestamp()
|
||||
self._do_download(job)
|
||||
@@ -224,6 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
self._job_completed_event.set() # signal a change to terminal state
|
||||
self._queue.task_done()
|
||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||
|
||||
@@ -408,11 +420,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
# Example on_progress event handler to display a TQDM status bar
|
||||
# Activate with:
|
||||
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
|
||||
# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update))
|
||||
class TqdmProgress(object):
|
||||
"""TQDM-based progress bar object to use in on_progress handlers."""
|
||||
|
||||
_bars: Dict[int, tqdm] # the tqdm object
|
||||
_bars: Dict[int, tqdm] # type: ignore
|
||||
_last: Dict[int, int] # last bytes downloaded
|
||||
|
||||
def __init__(self) -> None: # noqa D107
|
||||
|
||||
@@ -11,8 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_management.model_manager import ModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
@@ -171,10 +170,7 @@ class EventServiceBase:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_config: AnyModelConfig,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
@@ -184,10 +180,7 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_name": model_name,
|
||||
"base_model": base_model,
|
||||
"model_type": model_type,
|
||||
"submodel": submodel,
|
||||
"model_config": model_config.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -197,11 +190,7 @@ class EventServiceBase:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
model_info: ModelInfo,
|
||||
model_config: AnyModelConfig,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
@@ -211,13 +200,7 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_name": model_name,
|
||||
"base_model": base_model,
|
||||
"model_type": model_type,
|
||||
"submodel": submodel,
|
||||
"hash": model_info.hash,
|
||||
"location": str(model_info.location),
|
||||
"precision": str(model_info.precision),
|
||||
"model_config": model_config.model_dump(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
GESStatsNotFoundError,
|
||||
)
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .invocation_processor_base import InvocationProcessorABC
|
||||
@@ -18,7 +23,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker) -> None:
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
@@ -39,6 +44,27 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__threadLimit.acquire()
|
||||
queue_item: Optional[InvocationQueueItem] = None
|
||||
|
||||
profiler = (
|
||||
Profiler(
|
||||
logger=self.__invoker.services.logger,
|
||||
output_dir=self.__invoker.services.configuration.profiles_path,
|
||||
prefix=self.__invoker.services.configuration.profile_prefix,
|
||||
)
|
||||
if self.__invoker.services.configuration.profile_graphs
|
||||
else None
|
||||
)
|
||||
|
||||
def stats_cleanup(graph_execution_state_id: str) -> None:
|
||||
if profiler:
|
||||
profile_path = profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self.__invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=graph_execution_state_id, output_path=stats_path
|
||||
)
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self.__invoker.services.performance_statistics.log_stats(graph_execution_state_id)
|
||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state_id)
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item = self.__invoker.services.queue.get()
|
||||
@@ -49,6 +75,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
if profiler and profiler.profile_id != queue_item.graph_execution_state_id:
|
||||
profiler.start(profile_id=queue_item.graph_execution_state_id)
|
||||
|
||||
try:
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
@@ -137,7 +167,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||
stats_cleanup(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
@@ -162,7 +192,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
@@ -194,13 +223,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.performance_statistics.log_stats(graph_execution_state.id)
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
stats_cleanup(graph_execution_state.id)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
|
||||
@@ -22,9 +22,7 @@ if TYPE_CHECKING:
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_install import ModelInstallServiceBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .model_records import ModelRecordServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
@@ -50,9 +48,7 @@ class InvocationServices:
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
model_records: "ModelRecordServiceBase"
|
||||
download_queue: "DownloadQueueServiceBase"
|
||||
model_install: "ModelInstallServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
@@ -78,9 +74,7 @@ class InvocationServices:
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
model_records: "ModelRecordServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
model_install: "ModelInstallServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
@@ -104,9 +98,7 @@ class InvocationServices:
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.model_records = model_records
|
||||
self.download_queue = download_queue
|
||||
self.model_install = model_install
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
|
||||
@@ -29,27 +29,28 @@ writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import AbstractContextManager
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
|
||||
|
||||
|
||||
class InvocationStatsServiceBase(ABC):
|
||||
"Abstract base class for recording node memory/time performance statistics"
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the InvocationStatsService and reset counters to zero
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_stats(
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> AbstractContextManager:
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
@@ -59,16 +60,38 @@ class InvocationStatsServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
def reset_stats(self, graph_execution_state_id: str) -> None:
|
||||
"""
|
||||
Reset all statistics for the indicated graph
|
||||
:param graph_execution_state_id
|
||||
Reset all statistics for the indicated graph.
|
||||
:param graph_execution_state_id: The id of the session whose stats to reset.
|
||||
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def log_stats(self, graph_execution_state_id: str):
|
||||
def log_stats(self, graph_execution_state_id: str) -> None:
|
||||
"""
|
||||
Write out the accumulated statistics to the log or somewhere else.
|
||||
:param graph_execution_state_id: The id of the session whose stats to log.
|
||||
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
"""
|
||||
Gets the accumulated statistics for the indicated graph.
|
||||
:param graph_execution_state_id: The id of the session whose stats to get.
|
||||
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None:
|
||||
"""
|
||||
Write out the accumulated statistics to the indicated path as JSON.
|
||||
:param graph_execution_state_id: The id of the session whose stats to dump.
|
||||
:param output_path: The file to write the stats to.
|
||||
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,5 +1,91 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class GESStatsNotFoundError(Exception):
|
||||
"""Raised when execution stats are not found for a given Graph Execution State."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeExecutionStatsSummary:
|
||||
"""The stats for a specific type of node."""
|
||||
|
||||
node_type: str
|
||||
num_calls: int
|
||||
time_used_seconds: float
|
||||
peak_vram_gb: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelCacheStatsSummary:
|
||||
"""The stats for the model cache."""
|
||||
|
||||
high_water_mark_gb: float
|
||||
cache_size_gb: float
|
||||
total_usage_gb: float
|
||||
cache_hits: int
|
||||
cache_misses: int
|
||||
models_cached: int
|
||||
models_cleared: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphExecutionStatsSummary:
|
||||
"""The stats for the graph execution state."""
|
||||
|
||||
graph_execution_state_id: str
|
||||
execution_time_seconds: float
|
||||
# `wall_time_seconds`, `ram_usage_gb` and `ram_change_gb` are derived from the node execution stats.
|
||||
# In some situations, there are no node stats, so these values are optional.
|
||||
wall_time_seconds: Optional[float]
|
||||
ram_usage_gb: Optional[float]
|
||||
ram_change_gb: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvocationStatsSummary:
|
||||
"""
|
||||
The accumulated stats for a graph execution.
|
||||
Its `__str__` method returns a human-readable stats summary.
|
||||
"""
|
||||
|
||||
vram_usage_gb: Optional[float]
|
||||
graph_stats: GraphExecutionStatsSummary
|
||||
model_cache_stats: ModelCacheStatsSummary
|
||||
node_stats: list[NodeExecutionStatsSummary]
|
||||
|
||||
def __str__(self) -> str:
|
||||
_str = ""
|
||||
_str = f"Graph stats: {self.graph_stats.graph_execution_state_id}\n"
|
||||
_str += f"{'Node':>30} {'Calls':>7} {'Seconds':>9} {'VRAM Used':>10}\n"
|
||||
|
||||
for summary in self.node_stats:
|
||||
_str += f"{summary.node_type:>30} {summary.num_calls:>7} {summary.time_used_seconds:>8.3f}s {summary.peak_vram_gb:>9.3f}G\n"
|
||||
|
||||
_str += f"TOTAL GRAPH EXECUTION TIME: {self.graph_stats.execution_time_seconds:7.3f}s\n"
|
||||
|
||||
if self.graph_stats.wall_time_seconds is not None:
|
||||
_str += f"TOTAL GRAPH WALL TIME: {self.graph_stats.wall_time_seconds:7.3f}s\n"
|
||||
|
||||
if self.graph_stats.ram_usage_gb is not None and self.graph_stats.ram_change_gb is not None:
|
||||
_str += f"RAM used by InvokeAI process: {self.graph_stats.ram_usage_gb:4.2f}G ({self.graph_stats.ram_change_gb:+5.3f}G)\n"
|
||||
|
||||
_str += f"RAM used to load models: {self.model_cache_stats.total_usage_gb:4.2f}G\n"
|
||||
if self.vram_usage_gb:
|
||||
_str += f"VRAM in use: {self.vram_usage_gb:4.3f}G\n"
|
||||
_str += "RAM cache statistics:\n"
|
||||
_str += f" Model cache hits: {self.model_cache_stats.cache_hits}\n"
|
||||
_str += f" Model cache misses: {self.model_cache_stats.cache_misses}\n"
|
||||
_str += f" Models cached: {self.model_cache_stats.models_cached}\n"
|
||||
_str += f" Models cleared from cache: {self.model_cache_stats.models_cleared}\n"
|
||||
_str += f" Cache high water mark: {self.model_cache_stats.high_water_mark_gb:4.2f}/{self.model_cache_stats.cache_size_gb:4.2f}G\n"
|
||||
|
||||
return _str
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""Returns the stats as a dictionary."""
|
||||
return asdict(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -55,12 +141,33 @@ class GraphExecutionStats:
|
||||
|
||||
return last_node
|
||||
|
||||
def get_pretty_log(self, graph_execution_state_id: str) -> str:
|
||||
log = f"Graph stats: {graph_execution_state_id}\n"
|
||||
log += f"{'Node':>30} {'Calls':>7}{'Seconds':>9} {'VRAM Used':>10}\n"
|
||||
def get_graph_stats_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary:
|
||||
"""Get a summary of the graph stats."""
|
||||
first_node = self.get_first_node_stats()
|
||||
last_node = self.get_last_node_stats()
|
||||
|
||||
# Log stats aggregated by node type.
|
||||
wall_time_seconds: Optional[float] = None
|
||||
ram_usage_gb: Optional[float] = None
|
||||
ram_change_gb: Optional[float] = None
|
||||
|
||||
if last_node and first_node:
|
||||
wall_time_seconds = last_node.end_time - first_node.start_time
|
||||
ram_usage_gb = last_node.end_ram_gb
|
||||
ram_change_gb = last_node.end_ram_gb - first_node.start_ram_gb
|
||||
|
||||
return GraphExecutionStatsSummary(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
execution_time_seconds=self.get_total_run_time(),
|
||||
wall_time_seconds=wall_time_seconds,
|
||||
ram_usage_gb=ram_usage_gb,
|
||||
ram_change_gb=ram_change_gb,
|
||||
)
|
||||
|
||||
def get_node_stats_summaries(self) -> list[NodeExecutionStatsSummary]:
|
||||
"""Get a summary of the node stats."""
|
||||
summaries: list[NodeExecutionStatsSummary] = []
|
||||
node_stats_by_type: dict[str, list[NodeExecutionStats]] = defaultdict(list)
|
||||
|
||||
for node_stats in self._node_stats_list:
|
||||
node_stats_by_type[node_stats.invocation_type].append(node_stats)
|
||||
|
||||
@@ -68,17 +175,9 @@ class GraphExecutionStats:
|
||||
num_calls = len(node_type_stats_list)
|
||||
time_used = sum([n.total_time() for n in node_type_stats_list])
|
||||
peak_vram = max([n.peak_vram_gb for n in node_type_stats_list])
|
||||
log += f"{node_type:>30} {num_calls:>4} {time_used:7.3f}s {peak_vram:4.3f}G\n"
|
||||
summary = NodeExecutionStatsSummary(
|
||||
node_type=node_type, num_calls=num_calls, time_used_seconds=time_used, peak_vram_gb=peak_vram
|
||||
)
|
||||
summaries.append(summary)
|
||||
|
||||
# Log stats for the entire graph.
|
||||
log += f"TOTAL GRAPH EXECUTION TIME: {self.get_total_run_time():7.3f}s\n"
|
||||
|
||||
first_node = self.get_first_node_stats()
|
||||
last_node = self.get_last_node_stats()
|
||||
if first_node is not None and last_node is not None:
|
||||
total_wall_time = last_node.end_time - first_node.start_time
|
||||
ram_change = last_node.end_ram_gb - first_node.start_ram_gb
|
||||
log += f"TOTAL GRAPH WALL TIME: {total_wall_time:7.3f}s\n"
|
||||
log += f"RAM used by InvokeAI process: {last_node.end_ram_gb:4.2f}G ({ram_change:+5.3f}G)\n"
|
||||
|
||||
return log
|
||||
return summaries
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import json
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@@ -7,10 +10,19 @@ import torch
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
|
||||
from .invocation_stats_base import InvocationStatsServiceBase
|
||||
from .invocation_stats_common import GraphExecutionStats, NodeExecutionStats
|
||||
from .invocation_stats_common import (
|
||||
GESStatsNotFoundError,
|
||||
GraphExecutionStats,
|
||||
GraphExecutionStatsSummary,
|
||||
InvocationStatsSummary,
|
||||
ModelCacheStatsSummary,
|
||||
NodeExecutionStats,
|
||||
NodeExecutionStatsSummary,
|
||||
)
|
||||
|
||||
# Size of 1GB in bytes.
|
||||
GB = 2**30
|
||||
@@ -30,7 +42,10 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@contextmanager
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str):
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
|
||||
# This is to handle case of the model manager not being initialized, which happens
|
||||
# during some tests.
|
||||
services = self._invoker.services
|
||||
if not self._stats.get(graph_execution_state_id):
|
||||
# First time we're seeing this graph_execution_state_id.
|
||||
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
||||
@@ -44,8 +59,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
start_ram = psutil.Process().memory_info().rss
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
if self._invoker.services.model_manager:
|
||||
self._invoker.services.model_manager.collect_cache_stats(self._cache_stats[graph_execution_state_id])
|
||||
|
||||
assert services.model_manager.load is not None
|
||||
services.model_manager.load.ram_cache.stats = self._cache_stats[graph_execution_state_id]
|
||||
|
||||
try:
|
||||
# Let the invocation run.
|
||||
@@ -53,7 +69,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
finally:
|
||||
# Record state after the invocation.
|
||||
node_stats = NodeExecutionStats(
|
||||
invocation_type=invocation.type,
|
||||
invocation_type=invocation.get_type(),
|
||||
start_time=start_time,
|
||||
end_time=time.time(),
|
||||
start_ram_gb=start_ram / GB,
|
||||
@@ -62,17 +78,17 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def _prune_stale_stats(self):
|
||||
def _prune_stale_stats(self) -> None:
|
||||
"""Check all graphs being tracked and prune any that have completed/errored.
|
||||
|
||||
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
|
||||
for now we call this function periodically to prevent them from accumulating.
|
||||
"""
|
||||
to_prune = []
|
||||
to_prune: list[str] = []
|
||||
for graph_execution_state_id in self._stats:
|
||||
try:
|
||||
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
|
||||
except Exception:
|
||||
except ItemNotFoundError:
|
||||
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
|
||||
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
|
||||
continue
|
||||
@@ -95,31 +111,66 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
del self._stats[graph_execution_state_id]
|
||||
del self._cache_stats[graph_execution_state_id]
|
||||
except KeyError as e:
|
||||
logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}.")
|
||||
raise GESStatsNotFoundError(
|
||||
f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||
) from e
|
||||
|
||||
def log_stats(self, graph_execution_state_id: str):
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
node_stats_summaries = self._get_node_summaries(graph_execution_state_id)
|
||||
model_cache_stats_summary = self._get_model_cache_summary(graph_execution_state_id)
|
||||
vram_usage_gb = torch.cuda.memory_allocated() / GB if torch.cuda.is_available() else None
|
||||
|
||||
return InvocationStatsSummary(
|
||||
graph_stats=graph_stats_summary,
|
||||
model_cache_stats=model_cache_stats_summary,
|
||||
node_stats=node_stats_summaries,
|
||||
vram_usage_gb=vram_usage_gb,
|
||||
)
|
||||
|
||||
def log_stats(self, graph_execution_state_id: str) -> None:
|
||||
stats = self.get_stats(graph_execution_state_id)
|
||||
logger.info(str(stats))
|
||||
|
||||
def dump_stats(self, graph_execution_state_id: str, output_path: Path) -> None:
|
||||
stats = self.get_stats(graph_execution_state_id)
|
||||
with open(output_path, "w") as f:
|
||||
f.write(json.dumps(stats.as_dict(), indent=2))
|
||||
|
||||
def _get_model_cache_summary(self, graph_execution_state_id: str) -> ModelCacheStatsSummary:
|
||||
try:
|
||||
graph_stats = self._stats[graph_execution_state_id]
|
||||
cache_stats = self._cache_stats[graph_execution_state_id]
|
||||
except KeyError as e:
|
||||
logger.warning(f"Attempted to log statistics for unknown graph {graph_execution_state_id}: {e}.")
|
||||
return
|
||||
raise GESStatsNotFoundError(
|
||||
f"Attempted to get model cache statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||
) from e
|
||||
|
||||
log = graph_stats.get_pretty_log(graph_execution_state_id)
|
||||
return ModelCacheStatsSummary(
|
||||
cache_hits=cache_stats.hits,
|
||||
cache_misses=cache_stats.misses,
|
||||
high_water_mark_gb=cache_stats.high_watermark / GB,
|
||||
cache_size_gb=cache_stats.cache_size / GB,
|
||||
total_usage_gb=sum(list(cache_stats.loaded_model_sizes.values())) / GB,
|
||||
models_cached=cache_stats.in_cache,
|
||||
models_cleared=cache_stats.cleared,
|
||||
)
|
||||
|
||||
hwm = cache_stats.high_watermark / GB
|
||||
tot = cache_stats.cache_size / GB
|
||||
loaded = sum(list(cache_stats.loaded_model_sizes.values())) / GB
|
||||
log += f"RAM used to load models: {loaded:4.2f}G\n"
|
||||
if torch.cuda.is_available():
|
||||
log += f"VRAM in use: {(torch.cuda.memory_allocated() / GB):4.3f}G\n"
|
||||
log += "RAM cache statistics:\n"
|
||||
log += f" Model cache hits: {cache_stats.hits}\n"
|
||||
log += f" Model cache misses: {cache_stats.misses}\n"
|
||||
log += f" Models cached: {cache_stats.in_cache}\n"
|
||||
log += f" Models cleared from cache: {cache_stats.cleared}\n"
|
||||
log += f" Cache high water mark: {hwm:4.2f}/{tot:4.2f}G\n"
|
||||
logger.info(log)
|
||||
def _get_graph_summary(self, graph_execution_state_id: str) -> GraphExecutionStatsSummary:
|
||||
try:
|
||||
graph_stats = self._stats[graph_execution_state_id]
|
||||
except KeyError as e:
|
||||
raise GESStatsNotFoundError(
|
||||
f"Attempted to get graph statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||
) from e
|
||||
|
||||
del self._stats[graph_execution_state_id]
|
||||
del self._cache_stats[graph_execution_state_id]
|
||||
return graph_stats.get_graph_stats_summary(graph_execution_state_id)
|
||||
|
||||
def _get_node_summaries(self, graph_execution_state_id: str) -> list[NodeExecutionStatsSummary]:
|
||||
try:
|
||||
graph_stats = self._stats[graph_execution_state_id]
|
||||
except KeyError as e:
|
||||
raise GESStatsNotFoundError(
|
||||
f"Attempted to get node statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||
) from e
|
||||
|
||||
return graph_stats.get_node_stats_summaries()
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, Optional, TypeVar
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@@ -22,26 +20,26 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def get(self, item_id: str) -> T:
|
||||
"""Gets the item, parsing it into a Pydantic model"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_raw(self, item_id: str) -> Optional[str]:
|
||||
"""Gets the raw item as a string, skipping Pydantic parsing"""
|
||||
"""
|
||||
Gets the item.
|
||||
:param item_id: the id of the item to get
|
||||
:raises ItemNotFoundError: if the item is not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> None:
|
||||
"""Sets the item"""
|
||||
"""
|
||||
Sets the item. The id will be extracted based on id_field.
|
||||
:param item: the item to set
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
"""Gets a paginated list of items"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
def delete(self, item_id: str) -> None:
|
||||
"""
|
||||
Deletes the item, if it exists.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
class ItemNotFoundError(KeyError):
|
||||
"""Raised when an item is not found in storage"""
|
||||
|
||||
def __init__(self, item_id: str) -> None:
|
||||
super().__init__(f"Item with id {item_id} not found")
|
||||
52
invokeai/app/services/item_storage/item_storage_memory.py
Normal file
52
invokeai/app/services/item_storage/item_storage_memory.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class ItemStorageMemory(ItemStorageABC[T], Generic[T]):
|
||||
"""
|
||||
Provides a simple in-memory storage for items, with a maximum number of items to store.
|
||||
The storage uses the LRU strategy to evict items from storage when the max has been reached.
|
||||
"""
|
||||
|
||||
def __init__(self, id_field: str = "id", max_items: int = 10) -> None:
|
||||
super().__init__()
|
||||
if max_items < 1:
|
||||
raise ValueError("max_items must be at least 1")
|
||||
if not id_field:
|
||||
raise ValueError("id_field must not be empty")
|
||||
self._id_field = id_field
|
||||
self._items: OrderedDict[str, T] = OrderedDict()
|
||||
self._max_items = max_items
|
||||
|
||||
def get(self, item_id: str) -> T:
|
||||
# If the item exists, move it to the end of the OrderedDict.
|
||||
item = self._items.pop(item_id, None)
|
||||
if item is None:
|
||||
raise ItemNotFoundError(item_id)
|
||||
self._items[item_id] = item
|
||||
return item
|
||||
|
||||
def set(self, item: T) -> None:
|
||||
item_id = getattr(item, self._id_field)
|
||||
if item_id in self._items:
|
||||
# If item already exists, remove it and add it to the end
|
||||
self._items.pop(item_id)
|
||||
elif len(self._items) >= self._max_items:
|
||||
# If cache is full, evict the least recently used item
|
||||
self._items.popitem(last=False)
|
||||
self._items[item_id] = item
|
||||
self._on_changed(item)
|
||||
|
||||
def delete(self, item_id: str) -> None:
|
||||
# This is a no-op if the item doesn't exist.
|
||||
with suppress(KeyError):
|
||||
del self._items[item_id]
|
||||
self._on_deleted(item_id)
|
||||
@@ -1,147 +0,0 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Generic, Optional, TypeVar, get_args
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .item_storage_base import ItemStorageABC
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||
_table_name: str
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_id_field: str
|
||||
_lock: threading.RLock
|
||||
_validator: Optional[TypeAdapter[T]]
|
||||
|
||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
||||
super().__init__()
|
||||
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._table_name = table_name
|
||||
self._id_field = id_field # TODO: validate that T has this field
|
||||
self._cursor = self._conn.cursor()
|
||||
self._validator: Optional[TypeAdapter[T]] = None
|
||||
|
||||
self._create_table()
|
||||
|
||||
def _create_table(self):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
|
||||
item TEXT,
|
||||
id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
|
||||
)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _parse_item(self, item: str) -> T:
|
||||
if self._validator is None:
|
||||
"""
|
||||
We don't get access to `__orig_class__` in `__init__()`, and we need this before start(), so
|
||||
we can create it when it is first needed instead.
|
||||
__orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||
"""
|
||||
self._validator = TypeAdapter(get_args(self.__orig_class__)[0]) # type: ignore [attr-defined]
|
||||
return self._validator.validate_json(item)
|
||||
|
||||
def set(self, item: T):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
|
||||
(item.model_dump_json(warnings=False, exclude_none=True),),
|
||||
)
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_changed(item)
|
||||
|
||||
def get(self, id: str) -> Optional[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return self._parse_item(result[0])
|
||||
|
||||
def get_raw(self, id: str) -> Optional[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
result = self._cursor.fetchone()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
return result[0]
|
||||
|
||||
def delete(self, id: str):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
self._on_deleted(id)
|
||||
|
||||
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
|
||||
(per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = [self._parse_item(r[0]) for r in result]
|
||||
|
||||
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
|
||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
|
||||
(f"%{query}%", per_page, page * per_page),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
items = [self._parse_item(r[0]) for r in result]
|
||||
|
||||
self._cursor.execute(
|
||||
f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
|
||||
(f"%{query}%",),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
||||
@@ -1,10 +1,12 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
@@ -20,8 +22,10 @@ class LatentsStorageBase(ABC):
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
# (LS) Added a Union with ConditioningFieldData to fix type mismatch errors in compel.py
|
||||
# Not 100% sure this isn't an existing bug.
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
@@ -27,7 +28,7 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
@@ -46,7 +47,9 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
def save(self, name: str, data: torch.Tensor) -> None:
|
||||
# TODO: (LS) ConditioningFieldData added as Union because of type-checking errors
|
||||
# in compel.py. Unclear whether this is a long-standing bug, but seems to run.
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
@@ -14,11 +14,13 @@ from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
@@ -127,8 +129,8 @@ class HFModelSource(StringLikeSource):
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
base += f":{self.variant or ''}"
|
||||
base += f":{self.subfolder}" if self.subfolder else ""
|
||||
base += f" ({self.variant})" if self.variant else ""
|
||||
return base
|
||||
|
||||
|
||||
@@ -165,8 +167,8 @@ class ModelInstallJob(BaseModel):
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: Optional[int] = Field(
|
||||
default=None, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
@@ -243,7 +245,7 @@ class ModelInstallServiceBase(ABC):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: ModelMetadataStore,
|
||||
metadata_store: ModelMetadataStoreBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
):
|
||||
"""
|
||||
@@ -324,6 +326,43 @@ class ModelInstallServiceBase(ABC):
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
r"""Install the indicated model using heuristics to interpret user intentions.
|
||||
|
||||
:param source: String source
|
||||
:param config: Optional dict. Any fields in this dict
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record as described in `import_model()`.
|
||||
:param access_token: Optional access token for remote sources.
|
||||
|
||||
The source can be:
|
||||
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
|
||||
2. An http or https URL (`https://foo.bar/foo`)
|
||||
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
|
||||
|
||||
We extend the HuggingFace repo_id syntax to include the variant and the
|
||||
subfolder or path. The following are acceptable alternatives:
|
||||
stabilityai/stable-diffusion-v4
|
||||
stabilityai/stable-diffusion-v4:fp16
|
||||
stabilityai/stable-diffusion-v4:fp16:vae
|
||||
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
stabilityai/stable-diffusion-v4:onnx:vae
|
||||
|
||||
Because a local file path can look like a huggingface repo_id, the logic
|
||||
first checks whether the path exists on disk, and if not, it is treated as
|
||||
a parseable huggingface repo.
|
||||
|
||||
The previous support for recursing into a local folder and loading all model-like files
|
||||
has been removed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def import_model(
|
||||
self,
|
||||
@@ -385,6 +424,18 @@ class ModelInstallServiceBase(ABC):
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
||||
"""Wait for the indicated job to reach a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
been cancelled, or errored out.
|
||||
|
||||
:param job: The job to wait on.
|
||||
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
||||
the job hasn't completed within the indicated time.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
|
||||
"""
|
||||
@@ -394,7 +445,8 @@ class ModelInstallServiceBase(ABC):
|
||||
completed, been cancelled, or errored out.
|
||||
|
||||
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
|
||||
installs do not complete within the indicated time.
|
||||
installs do not complete within the indicated time. A timeout of zero (the default)
|
||||
will block indefinitely until the installs complete.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -410,3 +462,22 @@ class ModelInstallServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the model record database."""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
:param source: A Url or a string that can be converted into one.
|
||||
:param access_token: Optional access token to access restricted resources.
|
||||
|
||||
The model file will be downloaded into the system-wide model cache
|
||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||
is periodically cleared of infrequently-used entries when the model
|
||||
converter runs.
|
||||
|
||||
Note that this doesn't automaticallly install or register the model, but is
|
||||
intended for use by nodes that need access to models that aren't directly
|
||||
supported by InvokeAI. The downloading process takes advantage of the download queue
|
||||
to avoid interrupting other operations.
|
||||
"""
|
||||
|
||||
@@ -17,10 +17,10 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
@@ -33,7 +33,6 @@ from invokeai.backend.model_manager.metadata import (
|
||||
AnyModelRepoMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
@@ -50,6 +49,7 @@ from .model_install_base import (
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
StringLikeSource,
|
||||
URLModelSource,
|
||||
)
|
||||
|
||||
@@ -64,7 +64,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
metadata_store: Optional[ModelMetadataStore] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
@@ -86,19 +85,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._lock = threading.Lock()
|
||||
self._stop_event = threading.Event()
|
||||
self._downloads_changed_event = threading.Event()
|
||||
self._install_completed_event = threading.Event()
|
||||
self._download_queue = download_queue
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._next_job_id = 0
|
||||
# There may not necessarily be a metadata store initialized
|
||||
# so we create one and initialize it with the same sql database
|
||||
# used by the record store service.
|
||||
if metadata_store:
|
||||
self._metadata_store = metadata_store
|
||||
else:
|
||||
assert isinstance(record_store, ModelRecordServiceSQL)
|
||||
self._metadata_store = ModelMetadataStore(record_store.db)
|
||||
self._metadata_store = record_store.metadata_store # for convenience
|
||||
|
||||
@property
|
||||
def app_config(self) -> InvokeAIAppConfig: # noqa D102
|
||||
@@ -145,7 +138,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if config.get("source") is None:
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
return self._register(model_path, config)
|
||||
|
||||
@@ -156,12 +149,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
config = config or {}
|
||||
if config.get("source") is None:
|
||||
if not config.get("source"):
|
||||
config["source"] = model_path.resolve().as_posix()
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.original_hash
|
||||
dest_path = self.app_config.models_path / info.base.value / info.type.value / model_path.name
|
||||
old_hash = info.current_hash
|
||||
dest_path = (
|
||||
self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name)
|
||||
)
|
||||
try:
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
except FileExistsError as excp:
|
||||
@@ -177,7 +172,40 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
info,
|
||||
)
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
|
||||
if similar_jobs:
|
||||
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
|
||||
return similar_jobs[0]
|
||||
|
||||
if isinstance(source, LocalModelSource):
|
||||
install_job = self._import_local_model(source, config)
|
||||
self._install_queue.put(install_job) # synchronously install
|
||||
@@ -207,14 +235,25 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
assert isinstance(jobs[0], ModelInstallJob)
|
||||
return jobs[0]
|
||||
|
||||
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._install_completed_event.wait(timeout=5): # in case we miss an event
|
||||
self._install_completed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
|
||||
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||
"""Block until all installation jobs are done."""
|
||||
start = time.time()
|
||||
while len(self._download_cache) > 0:
|
||||
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
|
||||
if self._downloads_changed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._downloads_changed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise Exception("Timeout exceeded")
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
self._install_queue.join()
|
||||
return self._install_jobs
|
||||
|
||||
@@ -268,6 +307,38 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def download_and_cache(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
||||
model_path = self._app_config.models_convert_cache_path / model_hash
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file.
|
||||
# We don't know the file's name in advance, as it is set by the download
|
||||
# content-disposition header.
|
||||
if model_path.exists():
|
||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
||||
if len(contents) > 0:
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
job = self._download_queue.download(
|
||||
source=AnyHttpUrl(str(source)),
|
||||
dest=model_path,
|
||||
access_token=access_token,
|
||||
on_progress=TqdmProgress().update,
|
||||
)
|
||||
self._download_queue.wait_for_job(job, timeout)
|
||||
if job.complete:
|
||||
assert job.download_path is not None
|
||||
return job.download_path
|
||||
else:
|
||||
raise Exception(job.error)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the installer threads
|
||||
# --------------------------------------------------------------------------------------------
|
||||
@@ -300,6 +371,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.total_bytes = self._stat_size(job.local_path)
|
||||
job.bytes = job.total_bytes
|
||||
self._signal_job_running(job)
|
||||
job.config_in["source"] = str(job.source)
|
||||
if job.inplace:
|
||||
key = self.register_path(job.local_path, job.config_in)
|
||||
else:
|
||||
@@ -330,6 +402,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# if this is an install of a remote file, then clean up the temporary directory
|
||||
if job._install_tmpdir is not None:
|
||||
rmtree(job._install_tmpdir)
|
||||
self._install_completed_event.set()
|
||||
self._install_queue.task_done()
|
||||
|
||||
self._logger.info("Install thread exiting")
|
||||
@@ -489,10 +562,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return id
|
||||
|
||||
@staticmethod
|
||||
def _guess_variant() -> ModelRepoVariant:
|
||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
||||
"""Guess the best HuggingFace variant type to download."""
|
||||
precision = choose_precision(choose_torch_device())
|
||||
return ModelRepoVariant.FP16 if precision == "float16" else ModelRepoVariant.DEFAULT
|
||||
return ModelRepoVariant.FP16 if precision == "float16" else None
|
||||
|
||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
return ModelInstallJob(
|
||||
@@ -517,7 +590,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if not source.access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id)
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
@@ -535,19 +608,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from Civitai or HuggingFace will be handled specially
|
||||
url_patterns = {
|
||||
r"https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"https?://huggingface.co/": HuggingFaceMetadataFetch,
|
||||
r"^https?://civitai.com/": CivitaiMetadataFetch,
|
||||
r"^https?://huggingface.co/[^/]+/[^/]+$": HuggingFaceMetadataFetch,
|
||||
}
|
||||
metadata = None
|
||||
for pattern, fetcher in url_patterns.items():
|
||||
if re.match(pattern, str(source.url), re.IGNORECASE):
|
||||
metadata = fetcher(self._session).from_url(source.url)
|
||||
break
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
else:
|
||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@@ -565,6 +638,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||
# being held in a daemon thread.
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
tmpdir = Path(
|
||||
mkdtemp(
|
||||
dir=self._app_config.models_path,
|
||||
@@ -580,15 +655,26 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
bytes=0,
|
||||
total_bytes=0,
|
||||
)
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if hasattr(source, "subfolder") and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
root = Path(".")
|
||||
subfolder = Path(".")
|
||||
|
||||
# we remember the path up to the top of the tmpdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = tmpdir
|
||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||
|
||||
self._logger.info(f"Queuing {source} for downloading")
|
||||
self._logger.debug(f"remote_files={remote_files}")
|
||||
for model_file in remote_files:
|
||||
url = model_file.url
|
||||
path = model_file.path
|
||||
path = root / model_file.path.relative_to(subfolder)
|
||||
self._logger.info(f"Downloading {url} => {path}")
|
||||
install_job.total_bytes += model_file.size
|
||||
assert hasattr(source, "access_token")
|
||||
|
||||
6
invokeai/app/services/model_load/__init__.py
Normal file
6
invokeai/app/services/model_load/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Initialization file for model load service module."""
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
from .model_load_default import ModelLoadService
|
||||
|
||||
__all__ = ["ModelLoadServiceBase", "ModelLoadService"]
|
||||
40
invokeai/app/services/model_load/model_load_base.py
Normal file
40
invokeai/app/services/model_load/model_load_base.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Base class for model loader."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
|
||||
class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
106
invokeai/app/services/model_load/model_load_default.py
Normal file
106
invokeai/app/services/model_load/model_load_default.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel, ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
|
||||
|
||||
class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Wrapper around ModelLoaderRegistry."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
self._logger = logger
|
||||
self._app_config = app_config
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._registry = registry
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
) -> None:
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
@@ -1 +1,17 @@
|
||||
from .model_manager_default import ModelManagerService # noqa F401
|
||||
"""Initialization file for model manager service."""
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
|
||||
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
|
||||
|
||||
__all__ = [
|
||||
"ModelManagerServiceBase",
|
||||
"ModelManagerService",
|
||||
"AnyModel",
|
||||
"AnyModelConfig",
|
||||
"BaseModelType",
|
||||
"ModelType",
|
||||
"SubModelType",
|
||||
"LoadedModel",
|
||||
]
|
||||
|
||||
@@ -1,286 +1,67 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from pydantic import Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
"""Abstract base class for the model manager service."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
pass
|
||||
# attributes:
|
||||
# store: ModelRecordServiceBase = Field(description="An instance of the model record configuration service.")
|
||||
# install: ModelInstallServiceBase = Field(description="An instance of the model install service.")
|
||||
# load: ModelLoadServiceBase = Field(description="An instance of the model load service.")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
node: Optional[BaseInvocation] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""Retrieve the indicated model with name and type.
|
||||
submodel can be used to get a part (such as the vae)
|
||||
of a diffusers pipeline."""
|
||||
def build_model_manager(
|
||||
cls,
|
||||
app_config: InvokeAIAppConfig,
|
||||
db: SqliteDatabase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
||||
Use it rather than the __init__ constructor. This class
|
||||
method simplifies the construction considerably.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self):
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
"""Return the ModelRecordServiceBase used to store and retrieve configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
"""Return the ModelLoadServiceBase used to load models from their configuration records."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
"""Return the ModelInstallServiceBase used to download and manipulate model files."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
Uses the exact format as the omegaconf stanza.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||
"""
|
||||
Return a dict of models in the format:
|
||||
{ model_type1:
|
||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||
'model_name' : name,
|
||||
'model_type' : SDModelType,
|
||||
'description': description,
|
||||
'format': 'folder'|'safetensors'|'ckpt'
|
||||
},
|
||||
model_name2: { etc }
|
||||
},
|
||||
model_type2:
|
||||
{ model_name_n: etc
|
||||
}
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException if the name does not already exist.
|
||||
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well. Call commit() to write to disk.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: str,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: Optional[float] = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: Optional[bool] = False,
|
||||
merge_dest_directory: Optional[Path] = None,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@@ -1,413 +1,149 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||
from typing_extensions import Self
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
MergeInterpolationMethod,
|
||||
ModelInfo,
|
||||
ModelManager,
|
||||
ModelMerger,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_cache import CacheStats
|
||||
from invokeai.backend.model_management.model_search import FindModels
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase, UnknownModelException
|
||||
from .model_manager_base import ModelManagerServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
# simple implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory"""
|
||||
"""
|
||||
The ModelManagerService handles various aspects of model installation, maintenance and loading.
|
||||
|
||||
It bundles three distinct services:
|
||||
model_manager.store -- Routines to manage the database of model configuration records.
|
||||
model_manager.install -- Routines to install, move and delete models.
|
||||
model_manager.load -- Routines to load models into memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
store: ModelRecordServiceBase,
|
||||
install: ModelInstallServiceBase,
|
||||
load: ModelLoadServiceBase,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
if config.model_conf_path and config.model_conf_path.exists():
|
||||
config_file = config.model_conf_path
|
||||
else:
|
||||
config_file = config.root_dir / "configs/models.yaml"
|
||||
self._store = store
|
||||
self._install = install
|
||||
self._load = load
|
||||
|
||||
logger.debug(f"Config file={config_file}")
|
||||
@property
|
||||
def store(self) -> ModelRecordServiceBase:
|
||||
return self._store
|
||||
|
||||
device = torch.device(choose_torch_device())
|
||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
||||
logger.info(f"GPU device = {device} {device_name}")
|
||||
@property
|
||||
def install(self) -> ModelInstallServiceBase:
|
||||
return self._install
|
||||
|
||||
precision = config.precision
|
||||
if precision == "auto":
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
||||
@property
|
||||
def load(self) -> ModelLoadServiceBase:
|
||||
return self._load
|
||||
|
||||
# this is transitional backward compatibility
|
||||
# support for the deprecated `max_loaded_models`
|
||||
# configuration value. If present, then the
|
||||
# cache size is set to 2.5 GB times
|
||||
# the number of max_loaded_models. Otherwise
|
||||
# use new `ram_cache_size` config setting
|
||||
max_cache_size = config.ram_cache_size
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "start"):
|
||||
service.start(invoker)
|
||||
|
||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
for service in [self._store, self._install, self._load]:
|
||||
if hasattr(service, "stop"):
|
||||
service.stop(invoker)
|
||||
|
||||
sequential_offload = config.sequential_guidance
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
return self.load.load_model(model_config, submodel_type, context)
|
||||
|
||||
self.mgr = ModelManager(
|
||||
config=config_file,
|
||||
device_type=device,
|
||||
precision=dtype,
|
||||
max_cache_size=max_cache_size,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info("Model manager service initialized")
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
config = self.store.get_model(key)
|
||||
return self.load.load_model(config, submodel_type, context)
|
||||
|
||||
def get_model(
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> ModelInfo:
|
||||
"""
|
||||
Retrieve the indicated model. submodel can be used to get a
|
||||
part (such as the vae) of a diffusers mode.
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
|
||||
# we can emit model loading events if we are executing with access to the invocation context
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
This is provided for API compatability with the get_model() method
|
||||
in the original model manager. However, note that LoadedModel is
|
||||
not the same as the original ModelInfo that ws returned.
|
||||
|
||||
model_info = self.mgr.get_model(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
submodel,
|
||||
)
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
:param context: The invocation context.
|
||||
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
return model_info
|
||||
|
||||
def model_exists(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
) -> bool:
|
||||
Exceptions: UnknownModelException -- model with this key not known
|
||||
NotImplementedException -- a model loader was not provided at initialization time
|
||||
ValueError -- more than one model matches this combination
|
||||
"""
|
||||
Given a model name, returns True if it is a valid
|
||||
identifier.
|
||||
"""
|
||||
return self.mgr.model_exists(
|
||||
model_name,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||
"""
|
||||
return self.mgr.model_info(model_name, base_model, model_type)
|
||||
|
||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||
"""
|
||||
Returns a list of all the model names known.
|
||||
"""
|
||||
return self.mgr.model_names()
|
||||
|
||||
def list_models(
|
||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Return a list of models.
|
||||
"""
|
||||
return self.mgr.list_models(base_model, model_type)
|
||||
|
||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> Union[dict, None]:
|
||||
"""
|
||||
Return information about the model using the same format as list_models()
|
||||
"""
|
||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
||||
|
||||
def add_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
clobber: bool = False,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with an
|
||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {model_name}")
|
||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
model_attributes: dict,
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Update the named model with a dictionary of attributes. Will fail with a
|
||||
ModelNotFoundException exception if the name does not already exist.
|
||||
On a successful update, the config will be changed in memory. Will fail
|
||||
with an assertion error if provided attributes are incorrect or
|
||||
the model name is missing. Call commit() to write changes to disk.
|
||||
"""
|
||||
self.logger.debug(f"update model {model_name}")
|
||||
if not self.model_exists(model_name, base_model, model_type):
|
||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
"""
|
||||
Delete the named model from configuration. If delete_files is true,
|
||||
then the underlying weight file or diffusers directory will be deleted
|
||||
as well.
|
||||
"""
|
||||
self.logger.debug(f"delete model {model_name}")
|
||||
self.mgr.del_model(model_name, base_model, model_type)
|
||||
self.mgr.commit()
|
||||
|
||||
def convert_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||
convert_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
:param model_name: Name of the model to convert
|
||||
:param base_model: Base model type
|
||||
:param model_type: Type of model ['vae' or 'main']
|
||||
:param convert_dest_directory: Save the converted model to the designated directory (`models/etc/etc` by default)
|
||||
|
||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||
directory already in place.
|
||||
"""
|
||||
self.logger.debug(f"convert model {model_name}")
|
||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||
|
||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||
"""
|
||||
Reset model cache statistics for graph with graph_id.
|
||||
"""
|
||||
self.mgr.cache.stats = cache_stats
|
||||
|
||||
def commit(self, conf_file: Optional[Path] = None):
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
If no conf_file is provided, then replaces the
|
||||
original file/database used to initialize the object.
|
||||
"""
|
||||
return self.mgr.commit(conf_file)
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
model_info=model_info,
|
||||
)
|
||||
configs = self.store.search_by_attr(model_name, base_model, model_type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"{base_model}/{model_type}/{model_name}: Unknown model")
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
)
|
||||
return self.load.load_model(configs[0], submodel, context)
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
return self.mgr.logger
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
cls,
|
||||
app_config: InvokeAIAppConfig,
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> dict[str, AddModelResult]:
|
||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||
For simplicity, use this class method rather than the __init__ constructor.
|
||||
"""
|
||||
logger = InvokeAILogger.get_logger(cls.__name__)
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
The prediction type helper is necessary to distinguish between
|
||||
models based on Stable Diffusion 2 Base (requiring
|
||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||
(requiring SchedulerPredictionType.VPrediction). It is
|
||||
generally impossible to do this programmatically, so the
|
||||
prediction_type_helper usually asks the user to choose.
|
||||
|
||||
The result is a set of successfully installed models. Each element
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
"""
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
model_names: List[str] = Field(
|
||||
default=None, min_length=2, max_length=3, description="List of model names to merge"
|
||||
),
|
||||
base_model: Union[BaseModelType, str] = Field(
|
||||
default=None, description="Base model shared by all models to be merged"
|
||||
),
|
||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||
alpha: float = 0.5,
|
||||
interp: Optional[MergeInterpolationMethod] = None,
|
||||
force: bool = False,
|
||||
merge_dest_directory: Optional[Path] = Field(
|
||||
default=None, description="Optional directory location for merged model"
|
||||
),
|
||||
) -> AddModelResult:
|
||||
"""
|
||||
Merge two to three diffusrs pipeline models and save as a new model.
|
||||
:param model_names: List of 2-3 models to merge
|
||||
:param base_model: Base model to use for all models
|
||||
:param merged_model_name: Name of destination merged model
|
||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||
:param interp: Interpolation method. None (default)
|
||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||
"""
|
||||
merger = ModelMerger(self.mgr)
|
||||
try:
|
||||
result = merger.merge_diffusion_models_and_save(
|
||||
model_names=model_names,
|
||||
base_model=base_model,
|
||||
merged_model_name=merged_model_name,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=merge_dest_directory,
|
||||
)
|
||||
except AssertionError as e:
|
||||
raise ValueError(e)
|
||||
return result
|
||||
|
||||
def search_for_models(self, directory: Path) -> List[Path]:
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels([directory], self.logger)
|
||||
return search.list_models()
|
||||
|
||||
def sync_to_config(self):
|
||||
"""
|
||||
Re-read models.yaml, rescan the models directory, and reimport models
|
||||
in the autoimport directories. Call after making changes outside the
|
||||
model manager API.
|
||||
"""
|
||||
return self.mgr.sync_to_config()
|
||||
|
||||
def list_checkpoint_configs(self) -> List[Path]:
|
||||
"""
|
||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||
"""
|
||||
config = self.mgr.app_config
|
||||
conf_path = config.legacy_conf_path
|
||||
root_path = config.root_path
|
||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
||||
|
||||
def rename_model(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
new_name: Optional[str] = None,
|
||||
new_base: Optional[BaseModelType] = None,
|
||||
):
|
||||
"""
|
||||
Rename the indicated model. Can provide a new name and/or a new base.
|
||||
:param model_name: Current name of the model
|
||||
:param base_model: Current base of the model
|
||||
:param model_type: Model type (can't be changed)
|
||||
:param new_name: New name for the model
|
||||
:param new_base: New base for the model
|
||||
"""
|
||||
self.mgr.rename_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=new_name,
|
||||
new_base=new_base,
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
||||
)
|
||||
loader = ModelLoadService(
|
||||
app_config=app_config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
registry=ModelLoaderRegistry,
|
||||
)
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=model_record_service,
|
||||
download_queue=download_queue,
|
||||
event_bus=events,
|
||||
)
|
||||
return cls(store=model_record_service, install=installer, load=loader)
|
||||
|
||||
9
invokeai/app/services/model_metadata/__init__.py
Normal file
9
invokeai/app/services/model_metadata/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Init file for ModelMetadataStoreService module."""
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
from .metadata_store_sql import ModelMetadataStoreSQL
|
||||
|
||||
__all__ = [
|
||||
"ModelMetadataStoreBase",
|
||||
"ModelMetadataStoreSQL",
|
||||
]
|
||||
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
65
invokeai/app/services/model_metadata/metadata_store_base.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Storage for Model Metadata
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class ModelMetadataStoreBase(ABC):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
@abstractmethod
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
|
||||
@abstractmethod
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
|
||||
@abstractmethod
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
|
||||
@abstractmethod
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
222
invokeai/app/services/model_metadata/metadata_store_sql.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
SQL Storage for Model Metadata
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.fetch import ModelMetadataFetchBase
|
||||
|
||||
from .metadata_store_base import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class ModelMetadataStoreSQL(ModelMetadataStoreBase):
|
||||
"""Store, search and fetch model metadata retrieved from remote repositories."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
def add_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> None:
|
||||
"""
|
||||
Add a block of repo metadata to a model record.
|
||||
|
||||
The model record config must already exist in the database with the
|
||||
same key. Otherwise a FOREIGN KEY constraint exception will be raised.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to store
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json()
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_metadata(
|
||||
id,
|
||||
metadata
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
model_key,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
|
||||
self._db.conn.rollback()
|
||||
raise UnknownMetadataException from excp
|
||||
except sqlite3.Error as excp:
|
||||
self._db.conn.rollback()
|
||||
raise excp
|
||||
|
||||
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
|
||||
"""Retrieve the ModelRepoMetadata corresponding to model key."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM model_metadata
|
||||
WHERE id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
return ModelMetadataFetchBase.from_json(rows[0])
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]: # key, metadata
|
||||
"""Dump out all the metadata."""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id,metadata FROM model_metadata;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
return [(x[0], ModelMetadataFetchBase.from_json(x[1])) for x in rows]
|
||||
|
||||
def update_metadata(self, model_key: str, metadata: AnyModelRepoMetadata) -> AnyModelRepoMetadata:
|
||||
"""
|
||||
Update metadata corresponding to the model with the indicated key.
|
||||
|
||||
:param model_key: Existing model key in the `model_config` table
|
||||
:param metadata: ModelRepoMetadata object to update
|
||||
"""
|
||||
json_serialized = metadata.model_dump_json() # turn it into a json string.
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_metadata
|
||||
SET
|
||||
metadata=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, model_key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownMetadataException("model metadata not found")
|
||||
self._update_tags(model_key, metadata.tags)
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_metadata(model_key)
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return all tags in the tags table."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select tag_text from tags;
|
||||
"""
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> Set[str]:
|
||||
"""Return the keys of models containing all of the listed tags."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
matches: Optional[Set[str]] = None
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT a.model_id FROM model_tags AS a,
|
||||
tags AS b
|
||||
WHERE a.tag_id=b.tag_id
|
||||
AND b.tag_text=?;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
model_keys = {x[0] for x in self._cursor.fetchall()}
|
||||
if matches is None:
|
||||
matches = model_keys
|
||||
matches = matches.intersection(model_keys)
|
||||
except sqlite3.Error as e:
|
||||
raise e
|
||||
return matches if matches else set()
|
||||
|
||||
def search_by_author(self, author: str) -> Set[str]:
|
||||
"""Return the keys of models authored by the indicated author."""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE author=?;
|
||||
""",
|
||||
(author,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def search_by_name(self, name: str) -> Set[str]:
|
||||
"""
|
||||
Return the keys of models with the indicated name.
|
||||
|
||||
Note that this is the name of the model given to it by
|
||||
the remote source. The user may have changed the local
|
||||
name. The local name will be located in the model config
|
||||
record object.
|
||||
"""
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT id FROM model_metadata
|
||||
WHERE name=?;
|
||||
""",
|
||||
(name,),
|
||||
)
|
||||
return {x[0] for x in self._cursor.fetchall()}
|
||||
|
||||
def _update_tags(self, model_key: str, tags: Set[str]) -> None:
|
||||
"""Update tags for the model referenced by model_key."""
|
||||
# remove previous tags from this model
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_tags
|
||||
WHERE model_id=?;
|
||||
""",
|
||||
(model_key,),
|
||||
)
|
||||
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO tags (
|
||||
tag_text
|
||||
)
|
||||
VALUES (?);
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT tag_id
|
||||
FROM tags
|
||||
WHERE tag_text = ?
|
||||
LIMIT 1;
|
||||
""",
|
||||
(tag,),
|
||||
)
|
||||
tag_id = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO model_tags (
|
||||
model_id,
|
||||
tag_id
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(model_key, tag_id),
|
||||
)
|
||||
@@ -11,8 +11,15 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@@ -104,7 +111,7 @@ class ModelRecordServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
pass
|
||||
|
||||
@@ -146,7 +153,7 @@ class ModelRecordServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Return True if a model with the indicated key exists in the databse.
|
||||
Return True if a model with the indicated key exists in the database.
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
|
||||
@@ -54,8 +54,9 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, ModelMetadataStore, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata, UnknownMetadataException
|
||||
|
||||
from ..model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
DuplicateModelException,
|
||||
@@ -69,16 +70,16 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
def __init__(self, db: SqliteDatabase, metadata_store: ModelMetadataStoreBase):
|
||||
"""
|
||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||
|
||||
:param conn: sqlite3 connection object
|
||||
:param lock: threading Lock object
|
||||
:param db: Sqlite connection object
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = self._db.conn.cursor()
|
||||
self._cursor = db.conn.cursor()
|
||||
self._metadata_store = metadata_store
|
||||
|
||||
@property
|
||||
def db(self) -> SqliteDatabase:
|
||||
@@ -158,7 +159,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[Dict[str, Any], AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
@@ -199,7 +200,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@@ -207,7 +208,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]))
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
@@ -265,12 +266,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config FROM model_config
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
@@ -279,12 +282,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
@@ -293,18 +298,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
return results
|
||||
|
||||
@property
|
||||
def metadata_store(self) -> ModelMetadataStore:
|
||||
def metadata_store(self) -> ModelMetadataStoreBase:
|
||||
"""Return a ModelMetadataStore initialized on the same database."""
|
||||
return ModelMetadataStore(self._db)
|
||||
return self._metadata_store
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
@@ -325,18 +332,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param tags: Set of tags to search for. All tags must be present.
|
||||
"""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
keys = store.search_by_tag(tags)
|
||||
return [self.get_model(x) for x in keys]
|
||||
|
||||
def list_tags(self) -> Set[str]:
|
||||
"""Return a unique set of all the model tags in the metadata database."""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_tags()
|
||||
|
||||
def list_all_metadata(self) -> List[Tuple[str, AnyModelRepoMetadata]]:
|
||||
"""List metadata for all models that have it."""
|
||||
store = ModelMetadataStore(self._db)
|
||||
store = ModelMetadataStoreSQL(self._db)
|
||||
return store.list_all_metadata()
|
||||
|
||||
def list_models(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||
@@ -141,6 +141,16 @@ def are_connections_compatible(
|
||||
return are_connection_types_compatible(from_node_field, to_node_field)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def copydeep(obj: T) -> T:
|
||||
"""Deep-copies an object. If it is a pydantic model, use the model's copy method."""
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.model_copy(deep=True)
|
||||
return copy.deepcopy(obj)
|
||||
|
||||
|
||||
class NodeAlreadyInGraphError(ValueError):
|
||||
pass
|
||||
|
||||
@@ -1118,17 +1128,22 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _prepare_inputs(self, node: BaseInvocation):
|
||||
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
|
||||
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
|
||||
# will see the mutation.
|
||||
if isinstance(node, CollectInvocation):
|
||||
output_collection = [
|
||||
getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
|
||||
for edge in input_edges
|
||||
if edge.destination.field == "item"
|
||||
]
|
||||
node.collection = output_collection
|
||||
else:
|
||||
for edge in input_edges:
|
||||
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||
setattr(node, edge.destination.field, output_value)
|
||||
setattr(
|
||||
node,
|
||||
edge.destination.field,
|
||||
copydeep(getattr(self.results[edge.source.node_id], edge.source.field)),
|
||||
)
|
||||
|
||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||
|
||||
@@ -7,6 +7,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_2 import build_migration_2
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import build_migration_3
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@@ -31,6 +33,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_2(image_files=image_files, logger=logger))
|
||||
migrator.register_migration(build_migration_3(app_config=config, logger=logger))
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration5Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._drop_graph_executions(cursor)
|
||||
|
||||
def _drop_graph_executions(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Drops the `graph_executions` table."""
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DROP TABLE IF EXISTS graph_executions;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_5() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 4 to 5.
|
||||
|
||||
Introduced in v3.6.3, this migration:
|
||||
- Drops the `graph_executions` table. We are able to do this because we are moving the graph storage
|
||||
to be purely in-memory.
|
||||
"""
|
||||
migration_5 = Migration(
|
||||
from_version=4,
|
||||
to_version=5,
|
||||
callback=Migration5Callback(),
|
||||
)
|
||||
|
||||
return migration_5
|
||||
@@ -0,0 +1,62 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration6Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._recreate_model_triggers(cursor)
|
||||
self._delete_ip_adapters(cursor)
|
||||
|
||||
def _recreate_model_triggers(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Adds the timestamp trigger to the model_config table.
|
||||
|
||||
This trigger was inadvertently dropped in earlier migration scripts.
|
||||
"""
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS model_config_updated_at
|
||||
AFTER UPDATE
|
||||
ON model_config FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE model_config SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def _delete_ip_adapters(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Delete all the IP adapters.
|
||||
|
||||
The model manager will automatically find and re-add them after the migration
|
||||
is done. This allows the manager to add the correct image encoder to their
|
||||
configuration records.
|
||||
"""
|
||||
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
WHERE type='ip_adapter';
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def build_migration_6() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 5 to 6.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the model_config_updated_at trigger if it does not exist
|
||||
- Delete all ip_adapter models so that the model prober can find and
|
||||
update with the correct image processor model.
|
||||
"""
|
||||
migration_6 = Migration(
|
||||
from_version=5,
|
||||
to_version=6,
|
||||
callback=Migration6Callback(),
|
||||
)
|
||||
|
||||
return migration_6
|
||||
@@ -72,7 +72,12 @@ class MigrateModelYamlToDb1:
|
||||
continue
|
||||
|
||||
base_type, model_type, model_name = str(model_key).split("/")
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
try:
|
||||
hash = FastModelHash.hash(self.config.models_path / stanza.path)
|
||||
except OSError:
|
||||
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
|
||||
continue
|
||||
|
||||
assert isinstance(model_key, str)
|
||||
new_key = sha1(model_key.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_timestamp():
|
||||
def get_timestamp() -> int:
|
||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
|
||||
@@ -20,16 +20,16 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
SEED_MAX = np.iinfo(np.uint32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
def get_random_seed() -> int:
|
||||
rng = np.random.default_rng(seed=None)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
||||
|
||||
def uuid_string():
|
||||
def uuid_string() -> str:
|
||||
res = uuid.uuid4()
|
||||
return str(res)
|
||||
|
||||
|
||||
def is_optional(value: typing.Any):
|
||||
def is_optional(value: typing.Any) -> bool:
|
||||
"""Checks if a value is typed as Optional. Note that Optional is sugar for Union[x, None]."""
|
||||
return typing.get_origin(value) is typing.Union and type(None) in typing.get_args(value)
|
||||
|
||||
67
invokeai/app/util/profiler.py
Normal file
67
invokeai/app/util/profiler.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import cProfile
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Profiler:
|
||||
"""
|
||||
Simple wrapper around cProfile.
|
||||
|
||||
Usage
|
||||
```
|
||||
# Create a profiler
|
||||
profiler = Profiler(logger, output_dir, "sql_query_perf")
|
||||
# Start a new profile
|
||||
profiler.start("my_profile")
|
||||
# Do stuff
|
||||
profiler.stop()
|
||||
```
|
||||
|
||||
Visualize a profile as a flamegraph with [snakeviz](https://jiffyclub.github.io/snakeviz/)
|
||||
```sh
|
||||
snakeviz my_profile.prof
|
||||
```
|
||||
|
||||
Visualize a profile as directed graph with [graphviz](https://graphviz.org/download/) & [gprof2dot](https://github.com/jrfonseca/gprof2dot)
|
||||
```sh
|
||||
gprof2dot -f pstats my_profile.prof | dot -Tpng -o my_profile.png
|
||||
# SVG or PDF may be nicer - you can search for function names
|
||||
gprof2dot -f pstats my_profile.prof | dot -Tsvg -o my_profile.svg
|
||||
gprof2dot -f pstats my_profile.prof | dot -Tpdf -o my_profile.pdf
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Logger, output_dir: Path, prefix: Optional[str] = None) -> None:
|
||||
self._logger = logger.getChild(f"profiler.{prefix}" if prefix else "profiler")
|
||||
self._output_dir = output_dir
|
||||
self._output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._profiler: Optional[cProfile.Profile] = None
|
||||
self._prefix = prefix
|
||||
|
||||
self.profile_id: Optional[str] = None
|
||||
|
||||
def start(self, profile_id: str) -> None:
|
||||
if self._profiler:
|
||||
self.stop()
|
||||
|
||||
self.profile_id = profile_id
|
||||
|
||||
self._profiler = cProfile.Profile()
|
||||
self._profiler.enable()
|
||||
self._logger.info(f"Started profiling {self.profile_id}.")
|
||||
|
||||
def stop(self) -> Path:
|
||||
if not self._profiler:
|
||||
raise RuntimeError("Profiler not initialized. Call start() first.")
|
||||
self._profiler.disable()
|
||||
|
||||
filename = f"{self._prefix}_{self.profile_id}.prof" if self._prefix else f"{self.profile_id}.prof"
|
||||
path = Path(self._output_dir, filename)
|
||||
|
||||
self._profiler.dump_stats(path)
|
||||
self._logger.info(f"Stopped profiling, profile dumped to {path}.")
|
||||
self._profiler = None
|
||||
self.profile_id = None
|
||||
|
||||
return path
|
||||
@@ -3,7 +3,7 @@ from PIL import Image
|
||||
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_manager import BaseModelType
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import BaseModelType, ModelCache, ModelInfo, ModelManager, ModelType, SubModelType # noqa: F401
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
|
||||
4
invokeai/backend/embeddings/__init__.py
Normal file
4
invokeai/backend/embeddings/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Initialization file for invokeai.backend.embeddings modules."""
|
||||
|
||||
# from .model_patcher import ModelPatcher
|
||||
# __all__ = ["ModelPatcher"]
|
||||
12
invokeai/backend/embeddings/embedding_base.py
Normal file
12
invokeai/backend/embeddings/embedding_base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Base class for LoRA and Textual Inversion models.
|
||||
|
||||
The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher.
|
||||
|
||||
The use of "Raw" here is a historical artifact, and carried forward in
|
||||
order to avoid confusion.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingModelRaw:
|
||||
"""Base class for LoRA and Textual Inversion models."""
|
||||
201
invokeai/backend/image_util/basicsr/LICENSE
Normal file
201
invokeai/backend/image_util/basicsr/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2018-2022 BasicSR Authors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
18
invokeai/backend/image_util/basicsr/__init__.py
Normal file
18
invokeai/backend/image_util/basicsr/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Adapted from https://github.com/XPixelGroup/BasicSR
|
||||
License: Apache-2.0
|
||||
|
||||
As of Feb 2024, `basicsr` appears to be unmaintained. It imports a function from `torchvision` that is removed in
|
||||
`torchvision` 0.17. Here is the deprecation warning:
|
||||
|
||||
UserWarning: The torchvision.transforms.functional_tensor module is deprecated in 0.15 and will be **removed in
|
||||
0.17**. Please don't rely on it. You probably just need to use APIs in torchvision.transforms.functional or in
|
||||
torchvision.transforms.v2.functional.
|
||||
|
||||
As a result, a dependency on `basicsr` means we cannot keep our `torchvision` dependency up to date.
|
||||
|
||||
Because we only rely on a single class `RRDBNet` from `basicsr`, we've copied the relevant code here and removed the
|
||||
dependency on `basicsr`.
|
||||
|
||||
The code is almost unchanged, only a few type annotations have been added. The license is also copied.
|
||||
"""
|
||||
75
invokeai/backend/image_util/basicsr/arch_util.py
Normal file
75
invokeai/backend/image_util/basicsr/arch_util.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import init as init
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def default_init_weights(
|
||||
module_list: list[nn.Module] | nn.Module, scale: float = 1, bias_fill: float = 0, **kwargs
|
||||
) -> None:
|
||||
"""Initialize network weights.
|
||||
|
||||
Args:
|
||||
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
||||
scale (float): Scale initialized weights, especially for residual
|
||||
blocks. Default: 1.
|
||||
bias_fill (float): The value to fill bias. Default: 0
|
||||
kwargs (dict): Other arguments for initialization function.
|
||||
"""
|
||||
if not isinstance(module_list, list):
|
||||
module_list = [module_list]
|
||||
for module in module_list:
|
||||
for m in module.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, **kwargs)
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(bias_fill)
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, **kwargs)
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(bias_fill)
|
||||
elif isinstance(m, _BatchNorm):
|
||||
init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
m.bias.data.fill_(bias_fill)
|
||||
|
||||
|
||||
def make_layer(basic_block: Type[nn.Module], num_basic_block: int, **kwarg) -> nn.Sequential:
|
||||
"""Make layers by stacking the same blocks.
|
||||
|
||||
Args:
|
||||
basic_block (Type[nn.Module]): nn.Module class for basic block.
|
||||
num_basic_block (int): number of blocks.
|
||||
|
||||
Returns:
|
||||
nn.Sequential: Stacked blocks in nn.Sequential.
|
||||
"""
|
||||
layers = []
|
||||
for _ in range(num_basic_block):
|
||||
layers.append(basic_block(**kwarg))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
# TODO: may write a cpp file
|
||||
def pixel_unshuffle(x: torch.Tensor, scale: int) -> torch.Tensor:
|
||||
"""Pixel unshuffle.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input feature with shape (b, c, hh, hw).
|
||||
scale (int): Downsample ratio.
|
||||
|
||||
Returns:
|
||||
Tensor: the pixel unshuffled feature.
|
||||
"""
|
||||
b, c, hh, hw = x.size()
|
||||
out_channel = c * (scale**2)
|
||||
assert hh % scale == 0 and hw % scale == 0
|
||||
h = hh // scale
|
||||
w = hw // scale
|
||||
x_view = x.view(b, c, h, scale, w, scale)
|
||||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
||||
125
invokeai/backend/image_util/basicsr/rrdbnet_arch.py
Normal file
125
invokeai/backend/image_util/basicsr/rrdbnet_arch.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
||||
|
||||
|
||||
class ResidualDenseBlock(nn.Module):
|
||||
"""Residual Dense Block.
|
||||
|
||||
Used in RRDB block in ESRGAN.
|
||||
|
||||
Args:
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
num_grow_ch (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, num_feat: int = 64, num_grow_ch: int = 32) -> None:
|
||||
super(ResidualDenseBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
||||
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
# Empirically, we use 0.2 to scale the residual for better performance
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""Residual in Residual Dense Block.
|
||||
|
||||
Used in RRDB-Net in ESRGAN.
|
||||
|
||||
Args:
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
num_grow_ch (int): Channels for each growth.
|
||||
"""
|
||||
|
||||
def __init__(self, num_feat: int, num_grow_ch: int = 32) -> None:
|
||||
super(RRDB, self).__init__()
|
||||
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.rdb1(x)
|
||||
out = self.rdb2(out)
|
||||
out = self.rdb3(out)
|
||||
# Empirically, we use 0.2 to scale the residual for better performance
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
"""Networks consisting of Residual in Residual Dense Block, which is used
|
||||
in ESRGAN.
|
||||
|
||||
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
||||
|
||||
We extend ESRGAN for scale x2 and scale x1.
|
||||
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
||||
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
||||
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
||||
|
||||
Args:
|
||||
num_in_ch (int): Channel number of inputs.
|
||||
num_out_ch (int): Channel number of outputs.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
Default: 64
|
||||
num_block (int): Block number in the trunk network. Defaults: 23
|
||||
num_grow_ch (int): Channels for each growth. Default: 32.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_in_ch: int,
|
||||
num_out_ch: int,
|
||||
scale: int = 4,
|
||||
num_feat: int = 64,
|
||||
num_block: int = 23,
|
||||
num_grow_ch: int = 32,
|
||||
) -> None:
|
||||
super(RRDBNet, self).__init__()
|
||||
self.scale = scale
|
||||
if scale == 2:
|
||||
num_in_ch = num_in_ch * 4
|
||||
elif scale == 1:
|
||||
num_in_ch = num_in_ch * 16
|
||||
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
||||
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
||||
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
# upsample
|
||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.scale == 2:
|
||||
feat = pixel_unshuffle(x, scale=2)
|
||||
elif self.scale == 1:
|
||||
feat = pixel_unshuffle(x, scale=4)
|
||||
else:
|
||||
feat = x
|
||||
feat = self.conv_first(feat)
|
||||
body_feat = self.conv_body(self.body(feat))
|
||||
feat = feat + body_feat
|
||||
# upsample
|
||||
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")))
|
||||
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
return out
|
||||
@@ -7,10 +7,10 @@ import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
"""
|
||||
|
||||
@@ -8,8 +8,8 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend import SilenceWarnings
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
299
invokeai/backend/install/install_helper.py
Normal file
299
invokeai/backend/install/install_helper.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""Utility (backend) functions used by model_install.py"""
|
||||
import re
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import omegaconf
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import (
|
||||
HFModelSource,
|
||||
LocalModelSource,
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
|
||||
|
||||
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
"""Return an initialized ModelConfigRecordServiceBase object."""
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
image_files = DiskImageFileStorage(f"{app_config.output_path}/images")
|
||||
db = init_db(config=app_config, logger=logger, image_files=image_files)
|
||||
obj: ModelRecordServiceBase = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
return obj
|
||||
|
||||
|
||||
def initialize_installer(
|
||||
app_config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None
|
||||
) -> ModelInstallServiceBase:
|
||||
"""Return an initialized ModelInstallService object."""
|
||||
record_store = initialize_record_store(app_config)
|
||||
download_queue = DownloadQueueService()
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=record_store,
|
||||
download_queue=download_queue,
|
||||
event_bus=event_bus,
|
||||
)
|
||||
download_queue.start()
|
||||
installer.start()
|
||||
return installer
|
||||
|
||||
|
||||
class UnifiedModelInfo(BaseModel):
|
||||
"""Catchall class for information in INITIAL_MODELS2.yaml."""
|
||||
|
||||
name: Optional[str] = None
|
||||
base: Optional[BaseModelType] = None
|
||||
type: Optional[ModelType] = None
|
||||
source: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
recommended: bool = False
|
||||
installed: bool = False
|
||||
default: bool = False
|
||||
requires: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
"""Lists of models to install and remove."""
|
||||
|
||||
install_models: List[UnifiedModelInfo] = Field(default_factory=list)
|
||||
remove_models: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TqdmEventService(EventServiceBase):
|
||||
"""An event service to track downloads."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Create a new TqdmEventService object."""
|
||||
super().__init__()
|
||||
self._bars: Dict[str, tqdm] = {}
|
||||
self._last: Dict[str, int] = {}
|
||||
self._logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
data = payload["data"]
|
||||
source = data["source"]
|
||||
if payload["event"] == "model_install_downloading":
|
||||
dest = data["local_path"]
|
||||
total_bytes = data["total_bytes"]
|
||||
bytes = data["bytes"]
|
||||
if dest not in self._bars:
|
||||
self._bars[dest] = tqdm(desc=Path(dest).name, initial=0, total=total_bytes, unit="iB", unit_scale=True)
|
||||
self._last[dest] = 0
|
||||
self._bars[dest].update(bytes - self._last[dest])
|
||||
self._last[dest] = bytes
|
||||
elif payload["event"] == "model_install_completed":
|
||||
self._logger.info(f"{source}: installed successfully.")
|
||||
elif payload["event"] == "model_install_error":
|
||||
self._logger.warning(f"{source}: installation failed with error {data['error']}")
|
||||
elif payload["event"] == "model_install_cancelled":
|
||||
self._logger.warning(f"{source}: installation cancelled")
|
||||
|
||||
|
||||
class InstallHelper(object):
|
||||
"""Capture information stored jointly in INITIAL_MODELS.yaml and the installed models db."""
|
||||
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger):
|
||||
"""Create new InstallHelper object."""
|
||||
self._app_config = app_config
|
||||
self.all_models: Dict[str, UnifiedModelInfo] = {}
|
||||
|
||||
omega = omegaconf.OmegaConf.load(Path(configs.__path__[0]) / INITIAL_MODELS)
|
||||
assert isinstance(omega, omegaconf.dictconfig.DictConfig)
|
||||
|
||||
self._installer = initialize_installer(app_config, TqdmEventService())
|
||||
self._initial_models = omega
|
||||
self._installed_models: List[str] = []
|
||||
self._starter_models: List[str] = []
|
||||
self._default_model: Optional[str] = None
|
||||
self._logger = logger
|
||||
self._initialize_model_lists()
|
||||
|
||||
@property
|
||||
def installer(self) -> ModelInstallServiceBase:
|
||||
"""Return the installer object used internally."""
|
||||
return self._installer
|
||||
|
||||
def _initialize_model_lists(self) -> None:
|
||||
"""
|
||||
Initialize our model slots.
|
||||
|
||||
Set up the following:
|
||||
installed_models -- list of installed model keys
|
||||
starter_models -- list of starter model keys from INITIAL_MODELS
|
||||
all_models -- dict of key => UnifiedModelInfo
|
||||
default_model -- key to default model
|
||||
"""
|
||||
# previously-installed models
|
||||
for model in self._installer.record_store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
||||
self.all_models[model_key] = info
|
||||
self._installed_models.append(model_key)
|
||||
|
||||
for key in self._initial_models.keys():
|
||||
assert isinstance(key, str)
|
||||
if key in self.all_models:
|
||||
# we want to preserve the description
|
||||
description = self.all_models[key].description or self._initial_models[key].get("description")
|
||||
self.all_models[key].description = description
|
||||
else:
|
||||
base_model, model_type, model_name = key.split("/")
|
||||
info = UnifiedModelInfo(
|
||||
name=model_name,
|
||||
type=ModelType(model_type),
|
||||
base=BaseModelType(base_model),
|
||||
source=self._initial_models[key].source,
|
||||
description=self._initial_models[key].get("description"),
|
||||
recommended=self._initial_models[key].get("recommended", False),
|
||||
default=self._initial_models[key].get("default", False),
|
||||
subfolder=self._initial_models[key].get("subfolder"),
|
||||
requires=list(self._initial_models[key].get("requires", [])),
|
||||
)
|
||||
self.all_models[key] = info
|
||||
if not self.default_model():
|
||||
self._default_model = key
|
||||
elif self._initial_models[key].get("default", False):
|
||||
self._default_model = key
|
||||
self._starter_models.append(key)
|
||||
|
||||
# previously-installed models
|
||||
for model in self._installer.record_store.all_models():
|
||||
info = UnifiedModelInfo.parse_obj(model.dict())
|
||||
info.installed = True
|
||||
model_key = f"{model.base.value}/{model.type.value}/{model.name}"
|
||||
self.all_models[model_key] = info
|
||||
self._installed_models.append(model_key)
|
||||
|
||||
def recommended_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of the models recommended in INITIAL_MODELS.yaml."""
|
||||
return [self._to_model(x) for x in self._starter_models if self._to_model(x).recommended]
|
||||
|
||||
def installed_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of models already installed."""
|
||||
return [self._to_model(x) for x in self._installed_models]
|
||||
|
||||
def starter_models(self) -> List[UnifiedModelInfo]:
|
||||
"""List of starter models."""
|
||||
return [self._to_model(x) for x in self._starter_models]
|
||||
|
||||
def default_model(self) -> Optional[UnifiedModelInfo]:
|
||||
"""Return the default model."""
|
||||
return self._to_model(self._default_model) if self._default_model else None
|
||||
|
||||
def _to_model(self, key: str) -> UnifiedModelInfo:
|
||||
return self.all_models[key]
|
||||
|
||||
def _add_required_models(self, model_list: List[UnifiedModelInfo]) -> None:
|
||||
installed = {x.source for x in self.installed_models()}
|
||||
reverse_source = {x.source: x for x in self.all_models.values()}
|
||||
additional_models: List[UnifiedModelInfo] = []
|
||||
for model_info in model_list:
|
||||
for requirement in model_info.requires:
|
||||
if requirement not in installed and reverse_source.get(requirement):
|
||||
additional_models.append(reverse_source[requirement])
|
||||
model_list.extend(additional_models)
|
||||
|
||||
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
|
||||
assert model_info.source
|
||||
model_path_id_or_url = model_info.source.strip("\"' ")
|
||||
model_path = Path(model_path_id_or_url)
|
||||
|
||||
if model_path.exists(): # local file on disk
|
||||
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
||||
|
||||
# parsing huggingface repo ids
|
||||
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
|
||||
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
|
||||
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
|
||||
repo_id = match.group(1)
|
||||
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
|
||||
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
|
||||
return HFModelSource(
|
||||
repo_id=repo_id,
|
||||
access_token=HfFolder.get_token(),
|
||||
subfolder=subfolder,
|
||||
variant=repo_variant,
|
||||
)
|
||||
if re.match(r"^(http|https):", model_path_id_or_url):
|
||||
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
||||
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
|
||||
|
||||
def add_or_delete(self, selections: InstallSelections) -> None:
|
||||
"""Add or delete selected models."""
|
||||
installer = self._installer
|
||||
self._add_required_models(selections.install_models)
|
||||
for model in selections.install_models:
|
||||
source = self._make_install_source(model)
|
||||
config = (
|
||||
{
|
||||
"description": model.description,
|
||||
"name": model.name,
|
||||
}
|
||||
if model.name
|
||||
else None
|
||||
)
|
||||
|
||||
try:
|
||||
installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
||||
self._logger.warning(f"{source}: {e}")
|
||||
|
||||
for model_to_remove in selections.remove_models:
|
||||
parts = model_to_remove.split("/")
|
||||
if len(parts) == 1:
|
||||
base_model, model_type, model_name = (None, None, model_to_remove)
|
||||
else:
|
||||
base_model, model_type, model_name = parts
|
||||
matches = installer.record_store.search_by_attr(
|
||||
base_model=BaseModelType(base_model) if base_model else None,
|
||||
model_type=ModelType(model_type) if model_type else None,
|
||||
model_name=model_name,
|
||||
)
|
||||
if len(matches) > 1:
|
||||
self._logger.error(
|
||||
"{model_to_remove} is ambiguous. Please use model_base/model_type/model_name (e.g. sd-1/main/my_model) to disambiguate"
|
||||
)
|
||||
elif not matches:
|
||||
self._logger.error(f"{model_to_remove}: unknown model")
|
||||
else:
|
||||
for m in matches:
|
||||
self._logger.info(f"Deleting {m.type}:{m.name}")
|
||||
installer.delete(m.key)
|
||||
|
||||
installer.wait_for_installs()
|
||||
@@ -18,31 +18,30 @@ from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import Any, get_args, get_type_hints
|
||||
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import omegaconf
|
||||
import psutil
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import AutoencoderKL, ModelMixin
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub import login as hf_hub_login
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import ValidationError
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.model_install import addModelsForm
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
@@ -61,7 +60,7 @@ warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
def get_literal_fields(field) -> list[Any]:
|
||||
def get_literal_fields(field: str) -> Tuple[Any]:
|
||||
return get_args(get_type_hints(InvokeAIAppConfig).get(field))
|
||||
|
||||
|
||||
@@ -80,8 +79,7 @@ ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size")
|
||||
GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"]
|
||||
GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0)
|
||||
|
||||
MAX_VRAM /= GB
|
||||
MAX_RAM = psutil.virtual_memory().total / GB
|
||||
@@ -96,13 +94,15 @@ logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
"""Dummy widget values."""
|
||||
|
||||
zero = 0
|
||||
true = True
|
||||
false = False
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
def postscript(errors: Set[str]) -> None:
|
||||
if not any(errors):
|
||||
message = f"""
|
||||
** INVOKEAI INSTALLATION SUCCESSFUL **
|
||||
@@ -143,7 +143,7 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def HfLogin(access_token) -> str:
|
||||
def HfLogin(access_token) -> None:
|
||||
"""
|
||||
Helper for logging in to Huggingface
|
||||
The stdout capture is needed to hide the irrelevant "git credential helper" warning
|
||||
@@ -162,7 +162,7 @@ def HfLogin(access_token) -> str:
|
||||
|
||||
# -------------------------------------
|
||||
class ProgressBar:
|
||||
def __init__(self, model_name="file"):
|
||||
def __init__(self, model_name: str = "file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
@@ -179,6 +179,22 @@ class ProgressBar:
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any):
|
||||
filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731
|
||||
logger.addFilter(filter)
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
finally:
|
||||
logger.removeFilter(filter)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||
try:
|
||||
@@ -249,6 +265,7 @@ def download_conversion_models():
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# TO DO: use the download queue here.
|
||||
def download_realesrgan():
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
@@ -288,18 +305,19 @@ def download_lama():
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_support_models():
|
||||
def download_support_models() -> None:
|
||||
download_realesrgan()
|
||||
download_lama()
|
||||
download_conversion_models()
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
def get_root(root: Optional[str] = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
elif root := os.environ.get("INVOKEAI_ROOT"):
|
||||
assert root is not None
|
||||
return root
|
||||
else:
|
||||
return str(config.root_path)
|
||||
|
||||
@@ -455,6 +473,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
max_width=110,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.disk = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.convert_cache, range=(0, 100), step=0.5),
|
||||
out_of=100,
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||
@@ -495,6 +532,14 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
else:
|
||||
self.vram = DummyWidgetValue.zero
|
||||
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Location of the database used to store model path and configuration information:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@@ -506,19 +551,21 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@@ -555,6 +602,10 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
self.attention_slice_label.hidden = not show
|
||||
self.attention_slice_size.hidden = not show
|
||||
|
||||
def show_hide_model_conf_override(self, value):
|
||||
self.model_conf_override.hidden = value
|
||||
self.model_conf_override.display()
|
||||
|
||||
def on_ok(self):
|
||||
options = self.marshall_arguments()
|
||||
if self.validate_field_values(options):
|
||||
@@ -584,18 +635,21 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
else:
|
||||
return True
|
||||
|
||||
def marshall_arguments(self):
|
||||
def marshall_arguments(self) -> Namespace:
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"ram",
|
||||
"vram",
|
||||
"convert_cache",
|
||||
"outdir",
|
||||
]:
|
||||
if hasattr(self, attr):
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
for attr in self.autoimport_dirs:
|
||||
if not self.autoimport_dirs[attr].value:
|
||||
continue
|
||||
directory = Path(self.autoimport_dirs[attr].value)
|
||||
if directory.is_relative_to(config.root_path):
|
||||
directory = directory.relative_to(config.root_path)
|
||||
@@ -615,13 +669,14 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
|
||||
|
||||
class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = program_opts
|
||||
self.invokeai_opts = invokeai_opts
|
||||
self.user_cancelled = False
|
||||
self.autoload_pending = True
|
||||
self.install_selections = default_user_selections(program_opts)
|
||||
self.install_helper = install_helper
|
||||
self.install_selections = default_user_selections(program_opts, install_helper)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
@@ -640,16 +695,10 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
cycle_widgets=False,
|
||||
)
|
||||
|
||||
def new_opts(self):
|
||||
def new_opts(self) -> Namespace:
|
||||
return self.options.marshall_arguments()
|
||||
|
||||
|
||||
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_ramcache() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
@@ -660,27 +709,18 @@ def default_ramcache() -> float:
|
||||
) # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
def default_startup_options(init_file: Path) -> InvokeAIAppConfig:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
opts.ram = opts.ram or default_ramcache()
|
||||
opts.ram = default_ramcache()
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
try:
|
||||
installer = ModelInstall(config)
|
||||
except omegaconf.errors.ConfigKeyError:
|
||||
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
||||
initialize_rootdir(config.root_path, True)
|
||||
installer = ModelInstall(config)
|
||||
|
||||
models = installer.all_models()
|
||||
def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
|
||||
default_model = install_helper.default_model()
|
||||
assert default_model is not None
|
||||
default_models = [default_model] if program_opts.default_only else install_helper.recommended_models()
|
||||
return InstallSelections(
|
||||
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||
if program_opts.default_only
|
||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||
if program_opts.yes_to_all
|
||||
else [],
|
||||
install_models=default_models if program_opts.yes_to_all else [],
|
||||
)
|
||||
|
||||
|
||||
@@ -716,21 +756,10 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def maybe_create_models_yaml(root: Path):
|
||||
models_yaml = root / "configs" / "models.yaml"
|
||||
if models_yaml.exists():
|
||||
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
|
||||
return
|
||||
else:
|
||||
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
|
||||
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
|
||||
|
||||
with open(models_yaml, "w") as yaml_file:
|
||||
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
def run_console_ui(
|
||||
program_opts: Namespace, initfile: Path, install_helper: InstallHelper
|
||||
) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
@@ -739,22 +768,16 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts, install_helper)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
return (None, None)
|
||||
else:
|
||||
return (editApp.new_opts, editApp.install_selections)
|
||||
return (editApp.new_opts(), editApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None:
|
||||
"""
|
||||
Update the invokeai.yaml file with values from current settings.
|
||||
"""
|
||||
@@ -762,7 +785,7 @@ def write_opts(opts: Namespace, init_file: Path):
|
||||
new_config = InvokeAIAppConfig.get_config()
|
||||
new_config.root = config.root
|
||||
|
||||
for key, value in opts.__dict__.items():
|
||||
for key, value in opts.model_dump().items():
|
||||
if hasattr(new_config, key):
|
||||
setattr(new_config, key, value)
|
||||
|
||||
@@ -779,7 +802,7 @@ def default_output_dir() -> Path:
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
def write_default_options(program_opts: Namespace, initfile: Path) -> None:
|
||||
opt = default_startup_options(initfile)
|
||||
write_opts(opt, initfile)
|
||||
|
||||
@@ -789,16 +812,11 @@ def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format: Path):
|
||||
def migrate_init_file(legacy_format: Path) -> None:
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = [
|
||||
x
|
||||
for x, y in InvokeAIAppConfig.model_fields.items()
|
||||
if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED"
|
||||
]
|
||||
for attr in fields:
|
||||
for attr in InvokeAIAppConfig.model_fields.keys():
|
||||
if hasattr(old, attr):
|
||||
try:
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
@@ -819,7 +837,7 @@ def migrate_init_file(legacy_format: Path):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def migrate_models(root: Path):
|
||||
def migrate_models(root: Path) -> None:
|
||||
from invokeai.backend.install.migrate_to_3 import do_migrate
|
||||
|
||||
do_migrate(root, root)
|
||||
@@ -838,7 +856,9 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||
):
|
||||
logger.info("** Migrating invokeai.init to invokeai.yaml")
|
||||
migrate_init_file(old_init_file)
|
||||
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
|
||||
omegaconf = OmegaConf.load(new_init_file)
|
||||
assert isinstance(omegaconf, DictConfig)
|
||||
config.parse_args(argv=[], conf=omegaconf)
|
||||
|
||||
if old_hub.exists():
|
||||
migrate_models(config.root_path)
|
||||
@@ -908,6 +928,7 @@ def main():
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
errors = set()
|
||||
@@ -921,14 +942,18 @@ def main():
|
||||
# run this unconditionally in case new directories need to be added
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
# this will initialize the models.yaml file if not present
|
||||
install_helper = InstallHelper(config, logger)
|
||||
|
||||
models_to_download = default_user_selections(opt, install_helper)
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper)
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
@@ -943,10 +968,12 @@ def main():
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
logger.warning("Skipping diffusion weights download per user request")
|
||||
|
||||
elif models_to_download:
|
||||
process_and_execute(opt, models_to_download)
|
||||
install_helper.add_or_delete(models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
|
||||
if not opt.yes_to_all:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
|
||||
@@ -1,591 +0,0 @@
|
||||
"""
|
||||
Migrate the models directory and models.yaml file from an existing
|
||||
InvokeAI 2.3 installation to 3.0.0.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import diffusers
|
||||
import transformers
|
||||
import yaml
|
||||
from diffusers import AutoencoderKL, StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelProbe, ModelProbeInfo, ModelType
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
diffusers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# holder for paths that we will migrate
|
||||
@dataclass
|
||||
class ModelPaths:
|
||||
models: Path
|
||||
embeddings: Path
|
||||
loras: Path
|
||||
controlnets: Path
|
||||
|
||||
|
||||
class MigrateTo3(object):
|
||||
def __init__(
|
||||
self,
|
||||
from_root: Path,
|
||||
to_models: Path,
|
||||
model_manager: ModelManager,
|
||||
src_paths: ModelPaths,
|
||||
):
|
||||
self.root_directory = from_root
|
||||
self.dest_models = to_models
|
||||
self.mgr = model_manager
|
||||
self.src_paths = src_paths
|
||||
|
||||
@classmethod
|
||||
def initialize_yaml(cls, yaml_file: Path):
|
||||
with open(yaml_file, "w") as file:
|
||||
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
def create_directory_structure(self):
|
||||
"""
|
||||
Create the basic directory structure for the models folder.
|
||||
"""
|
||||
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
for model_type in [
|
||||
ModelType.Main,
|
||||
ModelType.Vae,
|
||||
ModelType.Lora,
|
||||
ModelType.ControlNet,
|
||||
ModelType.TextualInversion,
|
||||
]:
|
||||
path = self.dest_models / model_base.value / model_type.value
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
path = self.dest_models / "core"
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def copy_file(src: Path, dest: Path):
|
||||
"""
|
||||
copy a single file with logging
|
||||
"""
|
||||
if dest.exists():
|
||||
logger.info(f"Skipping existing {str(dest)}")
|
||||
return
|
||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
||||
try:
|
||||
shutil.copy(src, dest)
|
||||
except Exception as e:
|
||||
logger.error(f"COPY FAILED: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def copy_dir(src: Path, dest: Path):
|
||||
"""
|
||||
Recursively copy a directory with logging
|
||||
"""
|
||||
if dest.exists():
|
||||
logger.info(f"Skipping existing {str(dest)}")
|
||||
return
|
||||
|
||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
||||
try:
|
||||
shutil.copytree(src, dest)
|
||||
except Exception as e:
|
||||
logger.error(f"COPY FAILED: {str(e)}")
|
||||
|
||||
def migrate_models(self, src_dir: Path):
|
||||
"""
|
||||
Recursively walk through src directory, probe anything
|
||||
that looks like a model, and copy the model into the
|
||||
appropriate location within the destination models directory.
|
||||
"""
|
||||
directories_scanned = set()
|
||||
for root, dirs, files in os.walk(src_dir, followlinks=True):
|
||||
for d in dirs:
|
||||
try:
|
||||
model = Path(root, d)
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / model.name
|
||||
self.copy_dir(model, dest)
|
||||
directories_scanned.add(model)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
for f in files:
|
||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||
# let them be copied as part of a tree copy operation
|
||||
try:
|
||||
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
||||
continue
|
||||
model = Path(root, f)
|
||||
if model.parent in directories_scanned:
|
||||
continue
|
||||
info = ModelProbe().heuristic_probe(model)
|
||||
if not info:
|
||||
continue
|
||||
dest = self._model_probe_to_path(info) / f
|
||||
self.copy_file(model, dest)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
|
||||
def migrate_support_models(self):
|
||||
"""
|
||||
Copy the clipseg, upscaler, and restoration models to their new
|
||||
locations.
|
||||
"""
|
||||
dest_directory = self.dest_models
|
||||
if (self.root_directory / "models/clipseg").exists():
|
||||
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
|
||||
if (self.root_directory / "models/realesrgan").exists():
|
||||
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
|
||||
for d in ["codeformer", "gfpgan"]:
|
||||
path = self.root_directory / "models" / d
|
||||
if path.exists():
|
||||
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
|
||||
|
||||
def migrate_tuning_models(self):
|
||||
"""
|
||||
Migrate the embeddings, loras and controlnets directories to their new homes.
|
||||
"""
|
||||
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
||||
if not src:
|
||||
continue
|
||||
if src.is_dir():
|
||||
logger.info(f"Scanning {src}")
|
||||
self.migrate_models(src)
|
||||
else:
|
||||
logger.info(f"{src} directory not found; skipping")
|
||||
continue
|
||||
|
||||
def migrate_conversion_models(self):
|
||||
"""
|
||||
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
||||
script.
|
||||
"""
|
||||
|
||||
dest_directory = self.dest_models
|
||||
kwargs = {
|
||||
"cache_dir": self.root_directory / "models/hub",
|
||||
# local_files_only = True
|
||||
}
|
||||
try:
|
||||
logger.info("Migrating core tokenizers and text encoders")
|
||||
target_dir = dest_directory / "core" / "convert"
|
||||
|
||||
self._migrate_pretrained(
|
||||
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
|
||||
)
|
||||
|
||||
# sd-1
|
||||
repo_id = "openai/clip-vit-large-patch14"
|
||||
self._migrate_pretrained(
|
||||
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
|
||||
)
|
||||
|
||||
# sd-2
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
self._migrate_pretrained(
|
||||
CLIPTokenizer,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
|
||||
**{"subfolder": "tokenizer", **kwargs},
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
CLIPTextModel,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
|
||||
**{"subfolder": "text_encoder", **kwargs},
|
||||
)
|
||||
|
||||
# VAE
|
||||
logger.info("Migrating stable diffusion VAE")
|
||||
self._migrate_pretrained(
|
||||
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
|
||||
)
|
||||
|
||||
# safety checking
|
||||
logger.info("Migrating safety checker")
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
self._migrate_pretrained(
|
||||
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
|
||||
)
|
||||
self._migrate_pretrained(
|
||||
StableDiffusionSafetyChecker,
|
||||
repo_id=repo_id,
|
||||
dest=target_dir / "stable-diffusion-safety-checker",
|
||||
**kwargs,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
|
||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||
|
||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
|
||||
if dest.exists() and not force:
|
||||
logger.info(f"Skipping existing {dest}")
|
||||
return
|
||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||
self._save_pretrained(model, dest, overwrite=force)
|
||||
|
||||
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
|
||||
model_name = dest.name
|
||||
if overwrite:
|
||||
model.save_pretrained(dest, safe_serialization=True)
|
||||
else:
|
||||
download_path = dest.with_name(f"{model_name}.downloading")
|
||||
model.save_pretrained(download_path, safe_serialization=True)
|
||||
download_path.replace(dest)
|
||||
|
||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
||||
info = ModelProbe().heuristic_probe(vae)
|
||||
_, model_name = repo_id.split("/")
|
||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||
vae.save_pretrained(dest, safe_serialization=True)
|
||||
return dest
|
||||
|
||||
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
||||
"""
|
||||
Convert 2.3 VAE stanza to a straight path.
|
||||
"""
|
||||
vae_path = None
|
||||
|
||||
# First get a path
|
||||
if isinstance(vae, str):
|
||||
vae_path = vae
|
||||
|
||||
elif isinstance(vae, DictConfig):
|
||||
if p := vae.get("path"):
|
||||
vae_path = p
|
||||
elif repo_id := vae.get("repo_id"):
|
||||
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
||||
vae_path = "models/core/convert/sd-vae-ft-mse"
|
||||
return vae_path
|
||||
else:
|
||||
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
||||
|
||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||
|
||||
# if the VAE is in the old models directory, then we must move it into the new
|
||||
# one. VAEs outside of this directory can stay where they are.
|
||||
vae_path = Path(vae_path)
|
||||
if vae_path.is_relative_to(self.src_paths.models):
|
||||
info = ModelProbe().heuristic_probe(vae_path)
|
||||
dest = self._model_probe_to_path(info) / vae_path.name
|
||||
if not dest.exists():
|
||||
if vae_path.is_dir():
|
||||
self.copy_dir(vae_path, dest)
|
||||
else:
|
||||
self.copy_file(vae_path, dest)
|
||||
vae_path = dest
|
||||
|
||||
if vae_path.is_relative_to(self.dest_models):
|
||||
rel_path = vae_path.relative_to(self.dest_models)
|
||||
return Path("models", rel_path)
|
||||
else:
|
||||
return vae_path
|
||||
|
||||
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
|
||||
"""
|
||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||
"""
|
||||
dest_dir = self.dest_models
|
||||
|
||||
cache = self.root_directory / "models/hub"
|
||||
kwargs = {
|
||||
"cache_dir": cache,
|
||||
"safety_checker": None,
|
||||
# local_files_only = True,
|
||||
}
|
||||
|
||||
owner, repo_name = repo_id.split("/")
|
||||
model_name = model_name or repo_name
|
||||
model = cache / "--".join(["models", owner, repo_name])
|
||||
|
||||
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
|
||||
return
|
||||
revisions = [x.name for x in model.glob("refs/*")]
|
||||
|
||||
# if an fp16 is available we use that
|
||||
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
|
||||
|
||||
info = ModelProbe().heuristic_probe(pipeline)
|
||||
if not info:
|
||||
return
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
||||
return
|
||||
|
||||
dest = self._model_probe_to_path(info) / model_name
|
||||
self._save_pretrained(pipeline, dest)
|
||||
|
||||
rel_path = Path("models", dest.relative_to(dest_dir))
|
||||
self._add_model(model_name, info, rel_path, **extra_config)
|
||||
|
||||
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
|
||||
"""
|
||||
Migrate a model referred to using 'weights' or 'path'
|
||||
"""
|
||||
|
||||
# handle relative paths
|
||||
dest_dir = self.dest_models
|
||||
location = self.root_directory / location
|
||||
model_name = model_name or location.stem
|
||||
|
||||
info = ModelProbe().heuristic_probe(location)
|
||||
if not info:
|
||||
return
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
||||
return
|
||||
|
||||
# uh oh, weights is in the old models directory - move it into the new one
|
||||
if Path(location).is_relative_to(self.src_paths.models):
|
||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
||||
if location.is_dir():
|
||||
self.copy_dir(location, dest)
|
||||
else:
|
||||
self.copy_file(location, dest)
|
||||
location = Path("models", info.base_type.value, info.model_type.value, location.name)
|
||||
|
||||
self._add_model(model_name, info, location, **extra_config)
|
||||
|
||||
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
|
||||
if info.model_type != ModelType.Main:
|
||||
return
|
||||
|
||||
self.mgr.add_model(
|
||||
model_name=model_name,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
clobber=True,
|
||||
model_attributes={
|
||||
"path": str(location),
|
||||
"description": f"A {info.base_type.value} {info.model_type.value} model",
|
||||
"model_format": info.format,
|
||||
"variant": info.variant_type.value,
|
||||
**extra_config,
|
||||
},
|
||||
)
|
||||
|
||||
def migrate_defined_models(self):
|
||||
"""
|
||||
Migrate models defined in models.yaml
|
||||
"""
|
||||
# find any models referred to in old models.yaml
|
||||
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
|
||||
|
||||
for model_name, stanza in conf.items():
|
||||
try:
|
||||
passthru_args = {}
|
||||
|
||||
if vae := stanza.get("vae"):
|
||||
try:
|
||||
passthru_args["vae"] = str(self._vae_path(vae))
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
||||
logger.warning(str(e))
|
||||
|
||||
if config := stanza.get("config"):
|
||||
passthru_args["config"] = config
|
||||
|
||||
if description := stanza.get("description"):
|
||||
passthru_args["description"] = description
|
||||
|
||||
if repo_id := stanza.get("repo_id"):
|
||||
logger.info(f"Migrating diffusers model {model_name}")
|
||||
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
||||
|
||||
elif location := stanza.get("weights"):
|
||||
logger.info(f"Migrating checkpoint model {model_name}")
|
||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||
|
||||
elif location := stanza.get("path"):
|
||||
logger.info(f"Migrating diffusers model {model_name}")
|
||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
|
||||
def migrate(self):
|
||||
self.create_directory_structure()
|
||||
# the configure script is doing this
|
||||
self.migrate_support_models()
|
||||
self.migrate_conversion_models()
|
||||
self.migrate_tuning_models()
|
||||
self.migrate_defined_models()
|
||||
|
||||
|
||||
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
|
||||
"""
|
||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
|
||||
parser.add_argument(
|
||||
"--embedding_directory",
|
||||
"--embedding_path",
|
||||
type=Path,
|
||||
dest="embedding_path",
|
||||
default=Path("embeddings"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_directory",
|
||||
dest="lora_path",
|
||||
type=Path,
|
||||
default=Path("loras"),
|
||||
)
|
||||
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
|
||||
return ModelPaths(
|
||||
models=root / "models",
|
||||
embeddings=root / str(opt.embedding_path).strip('"'),
|
||||
loras=root / str(opt.lora_path).strip('"'),
|
||||
controlnets=root / "controlnets",
|
||||
)
|
||||
|
||||
|
||||
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
|
||||
"""
|
||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||
"""
|
||||
# Don't use the config object because it is unforgiving of version updates
|
||||
# Just use omegaconf directly
|
||||
opt = OmegaConf.load(initfile)
|
||||
paths = opt.InvokeAI.Paths
|
||||
models = paths.get("models_dir", "models")
|
||||
embeddings = paths.get("embedding_dir", "embeddings")
|
||||
loras = paths.get("lora_dir", "loras")
|
||||
controlnets = paths.get("controlnet_dir", "controlnets")
|
||||
return ModelPaths(
|
||||
models=root / models if models else None,
|
||||
embeddings=root / embeddings if embeddings else None,
|
||||
loras=root / loras if loras else None,
|
||||
controlnets=root / controlnets if controlnets else None,
|
||||
)
|
||||
|
||||
|
||||
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||
path = root / "invokeai.init"
|
||||
if path.exists():
|
||||
return _parse_legacy_initfile(root, path)
|
||||
path = root / "invokeai.yaml"
|
||||
if path.exists():
|
||||
return _parse_legacy_yamlfile(root, path)
|
||||
|
||||
|
||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||
"""
|
||||
Migrate models from src to dest InvokeAI root directories
|
||||
"""
|
||||
config_file = dest_directory / "configs" / "models.yaml.3"
|
||||
dest_models = dest_directory / "models.3"
|
||||
|
||||
version_3 = (dest_directory / "models" / "core").exists()
|
||||
|
||||
# Here we create the destination models.yaml file.
|
||||
# If we are writing into a version 3 directory and the
|
||||
# file already exists, then we write into a copy of it to
|
||||
# avoid deleting its previous customizations. Otherwise we
|
||||
# create a new empty one.
|
||||
if version_3: # write into the dest directory
|
||||
try:
|
||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
||||
except Exception:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||
(dest_directory / "models").replace(dest_models)
|
||||
else:
|
||||
MigrateTo3.initialize_yaml(config_file)
|
||||
mgr = ModelManager(config_file)
|
||||
|
||||
paths = get_legacy_embeddings(src_directory)
|
||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
||||
migrator.migrate()
|
||||
print("Migration successful.")
|
||||
|
||||
if not version_3:
|
||||
(dest_directory / "models").replace(src_directory / "models.orig")
|
||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
||||
|
||||
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
||||
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
||||
|
||||
config_file.replace(config_file.with_suffix(""))
|
||||
dest_models.replace(dest_models.with_suffix(""))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="invokeai-migrate3",
|
||||
description="""
|
||||
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
||||
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
||||
|
||||
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
||||
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
||||
script, which will perform a full upgrade in place.""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from-directory",
|
||||
dest="src_root",
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to-directory",
|
||||
dest="dest_root",
|
||||
type=Path,
|
||||
required=True,
|
||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
src_root = args.src_root
|
||||
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
||||
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
||||
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
|
||||
assert (src_root / "invokeai.init").exists() or (
|
||||
src_root / "invokeai.yaml"
|
||||
).exists(), f"{src_root} does not contain an InvokeAI init file."
|
||||
|
||||
dest_root = args.dest_root
|
||||
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args(["--root", str(dest_root)])
|
||||
|
||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
|
||||
if not dest_is_setup:
|
||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||
|
||||
initialize_rootdir(dest_root, True)
|
||||
|
||||
do_migrate(src_root, dest_root)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,631 +0,0 @@
|
||||
"""
|
||||
Utility (backend) functions used by model_install.py
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import HfApi, HfFolder, hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# --------------------------globals-----------------------
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger(name="InvokeAI")
|
||||
|
||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||
|
||||
Config_preamble = """
|
||||
# This file describes the alternative machine learning models
|
||||
# available to InvokeAI script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
"""
|
||||
|
||||
LEGACY_CONFIGS = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelVariantType.Normal: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
||||
},
|
||||
ModelVariantType.Inpaint: {
|
||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
||||
},
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[str] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelLoadInfo:
|
||||
name: str
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
path: Optional[Path] = None
|
||||
repo_id: Optional[str] = None
|
||||
subfolder: Optional[str] = None
|
||||
description: str = ""
|
||||
installed: bool = False
|
||||
recommended: bool = False
|
||||
default: bool = False
|
||||
requires: Optional[List[str]] = field(default_factory=list)
|
||||
|
||||
|
||||
class ModelInstall(object):
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
model_manager: Optional[ModelManager] = None,
|
||||
access_token: Optional[str] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||
self.datasets = OmegaConf.load(Dataset_path)
|
||||
self.prediction_helper = prediction_type_helper
|
||||
self.access_token = access_token or HfFolder.get_token()
|
||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||
|
||||
def all_models(self) -> Dict[str, ModelLoadInfo]:
|
||||
"""
|
||||
Return dict of model_key=>ModelLoadInfo objects.
|
||||
This method consolidates and simplifies the entries in both
|
||||
models.yaml and INITIAL_MODELS.yaml so that they can
|
||||
be treated uniformly. It also sorts the models alphabetically
|
||||
by their name, to improve the display somewhat.
|
||||
"""
|
||||
model_dict = {}
|
||||
|
||||
# first populate with the entries in INITIAL_MODELS.yaml
|
||||
for key, value in self.datasets.items():
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
value["name"] = name
|
||||
value["base_type"] = base
|
||||
value["model_type"] = model_type
|
||||
model_info = ModelLoadInfo(**value)
|
||||
if model_info.subfolder and model_info.repo_id:
|
||||
model_info.repo_id += f":{model_info.subfolder}"
|
||||
model_dict[key] = model_info
|
||||
|
||||
# supplement with entries in models.yaml
|
||||
installed_models = list(self.mgr.list_models())
|
||||
|
||||
for md in installed_models:
|
||||
base = md["base_model"]
|
||||
model_type = md["model_type"]
|
||||
name = md["model_name"]
|
||||
key = ModelManager.create_key(name, base, model_type)
|
||||
if key in model_dict:
|
||||
model_dict[key].installed = True
|
||||
else:
|
||||
model_dict[key] = ModelLoadInfo(
|
||||
name=name,
|
||||
base_type=base,
|
||||
model_type=model_type,
|
||||
path=value.get("path"),
|
||||
installed=True,
|
||||
)
|
||||
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
|
||||
|
||||
def _is_autoloaded(self, model_info: dict) -> bool:
|
||||
path = model_info.get("path")
|
||||
if not path:
|
||||
return False
|
||||
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
|
||||
if autodir_path := getattr(self.config, autodir):
|
||||
autodir_path = self.config.root_path / autodir_path
|
||||
if Path(path).is_relative_to(autodir_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_models(self, model_type):
|
||||
installed = self.mgr.list_models(model_type=model_type)
|
||||
print()
|
||||
print(f"Installed models of type `{model_type}`:")
|
||||
print(f"{'Model Key':50} Model Path")
|
||||
for i in installed:
|
||||
print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}")
|
||||
print()
|
||||
|
||||
# logic here a little reversed to maintain backward compatibility
|
||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
||||
models = set()
|
||||
for key, _value in self.datasets.items():
|
||||
name, base, model_type = ModelManager.parse_key(key)
|
||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||
models.add(key)
|
||||
return models
|
||||
|
||||
def recommended_models(self) -> Set[str]:
|
||||
starters = self.starter_models(all_models=True)
|
||||
return {x for x in starters if self.datasets[x].get("recommended", False)}
|
||||
|
||||
def default_model(self) -> str:
|
||||
starters = self.starter_models()
|
||||
defaults = [x for x in starters if self.datasets[x].get("default", False)]
|
||||
return defaults[0]
|
||||
|
||||
def install(self, selections: InstallSelections):
|
||||
verbosity = dlogging.get_verbosity() # quench NSFW nags
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
job = 1
|
||||
jobs = len(selections.remove_models) + len(selections.install_models)
|
||||
|
||||
# remove requested models
|
||||
for key in selections.remove_models:
|
||||
name, base, mtype = self.mgr.parse_key(key)
|
||||
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
|
||||
try:
|
||||
self.mgr.del_model(name, base, mtype)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(e)
|
||||
job += 1
|
||||
|
||||
# add requested models
|
||||
self._remove_installed(selections.install_models)
|
||||
self._add_required_models(selections.install_models)
|
||||
for path in selections.install_models:
|
||||
logger.info(f"Installing {path} [{job}/{jobs}]")
|
||||
try:
|
||||
self.heuristic_import(path)
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(str(e))
|
||||
job += 1
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
self.mgr.commit()
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
models_installed: Set[Path] = None,
|
||||
) -> Dict[str, AddModelResult]:
|
||||
"""
|
||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||
:param models_installed: Set of installed models, used for recursive invocation
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
"""
|
||||
|
||||
if not models_installed:
|
||||
models_installed = {}
|
||||
|
||||
model_path_id_or_url = str(model_path_id_or_url).strip("\"' ")
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
|
||||
# fix relative paths
|
||||
if path.exists() and not path.is_absolute():
|
||||
path = path.absolute() # make relative to current WD
|
||||
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path): self._install_path(path)})
|
||||
|
||||
# folders style or similar
|
||||
elif path.is_dir() and any(
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"pytorch_lora_weights.safetensors",
|
||||
}
|
||||
):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||
|
||||
# recursive scan
|
||||
elif path.is_dir():
|
||||
for child in path.iterdir():
|
||||
self.heuristic_import(child, models_installed=models_installed)
|
||||
|
||||
# huggingface repo
|
||||
elif len(str(model_path_id_or_url).split("/")) == 2:
|
||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||
|
||||
# a URL
|
||||
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
else:
|
||||
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
|
||||
|
||||
return models_installed
|
||||
|
||||
def _remove_installed(self, model_list: List[str]):
|
||||
all_models = self.all_models()
|
||||
models_to_remove = []
|
||||
|
||||
for path in model_list:
|
||||
key = self.reverse_paths.get(path)
|
||||
if key and all_models[key].installed:
|
||||
models_to_remove.append(path)
|
||||
|
||||
for path in models_to_remove:
|
||||
logger.warning(f"{path} already installed. Skipping")
|
||||
model_list.remove(path)
|
||||
|
||||
def _add_required_models(self, model_list: List[str]):
|
||||
additional_models = []
|
||||
all_models = self.all_models()
|
||||
for path in model_list:
|
||||
if not (key := self.reverse_paths.get(path)):
|
||||
continue
|
||||
for requirement in all_models[key].requires:
|
||||
requirement_key = self.reverse_paths.get(requirement)
|
||||
if not all_models[requirement_key].installed:
|
||||
additional_models.append(requirement)
|
||||
model_list.extend(additional_models)
|
||||
|
||||
# install a model from a local path. The optional info parameter is there to prevent
|
||||
# the model from being probed twice in the event that it has already been probed.
|
||||
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
||||
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f"Unable to parse format of {path}")
|
||||
return None
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
attributes = self._make_attributes(path, info)
|
||||
return self.mgr.add_model(
|
||||
model_name=model_name,
|
||||
base_model=info.base_type,
|
||||
model_type=info.model_type,
|
||||
model_attributes=attributes,
|
||||
)
|
||||
|
||||
def _install_url(self, url: str) -> AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url, Path(staging))
|
||||
if not location:
|
||||
logger.error(f"Unable to download {url}. Skipping.")
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
models_path = shutil.move(location, dest)
|
||||
|
||||
# staged version will be garbage-collected at this time
|
||||
return self._install_path(Path(models_path), info)
|
||||
|
||||
def _install_repo(self, repo_id: str) -> AddModelResult:
|
||||
# hack to recover models stored in subfolders --
|
||||
# Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster
|
||||
subfolder = None
|
||||
if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id):
|
||||
repo_id = match.group(1)
|
||||
subfolder = match.group(2)
|
||||
|
||||
hinfo = HfApi().model_info(repo_id)
|
||||
|
||||
# we try to figure out how to download this most economically
|
||||
# list all the files in the repo
|
||||
files = [x.rfilename for x in hinfo.siblings]
|
||||
if subfolder:
|
||||
files = [x for x in files if x.startswith(f"{subfolder}/")]
|
||||
prefix = f"{subfolder}/" if subfolder else ""
|
||||
|
||||
location = None
|
||||
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
staging = Path(staging)
|
||||
if f"{prefix}model_index.json" in files:
|
||||
location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline
|
||||
elif f"{prefix}unet/model.onnx" in files:
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
else:
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
if f"{prefix}pytorch_lora_weights.{suffix}" in files:
|
||||
location = self._download_hf_model(
|
||||
repo_id, [f"pytorch_lora_weights.{suffix}"], staging, subfolder=subfolder
|
||||
) # LoRA
|
||||
break
|
||||
elif (
|
||||
self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files
|
||||
): # vae, controlnet or some other standalone
|
||||
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}diffusion_pytorch_model.{suffix}" in files:
|
||||
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}learned_embeds.{suffix}" in files:
|
||||
location = self._download_hf_model(
|
||||
repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder
|
||||
)
|
||||
break
|
||||
elif (
|
||||
f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files
|
||||
): # IP-Adapter
|
||||
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files:
|
||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
||||
# by InvokeAI for use with IP-Adapters.
|
||||
files = ["config.json", f"model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||
return {}
|
||||
|
||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f"Could not probe {location}. Skipping install.")
|
||||
return {}
|
||||
dest = (
|
||||
self.config.models_path
|
||||
/ info.base_type.value
|
||||
/ info.model_type.value
|
||||
/ self._get_model_name(repo_id, location)
|
||||
)
|
||||
if dest.exists():
|
||||
shutil.rmtree(dest)
|
||||
shutil.copytree(location, dest)
|
||||
return self._install_path(dest, info)
|
||||
|
||||
def _get_model_name(self, path_name: str, location: Path) -> str:
|
||||
"""
|
||||
Calculate a name for the model - primitive implementation.
|
||||
"""
|
||||
if key := self.reverse_paths.get(path_name):
|
||||
(name, base, mtype) = ModelManager.parse_key(key)
|
||||
return name
|
||||
elif location.is_dir():
|
||||
return location.name
|
||||
else:
|
||||
return location.stem
|
||||
|
||||
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
|
||||
model_name = path.name if path.is_dir() else path.stem
|
||||
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
|
||||
if key := self.reverse_paths.get(self.current_id):
|
||||
if key in self.datasets:
|
||||
description = self.datasets[key].get("description") or description
|
||||
|
||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
||||
|
||||
attributes = {
|
||||
"path": str(rel_path),
|
||||
"description": str(description),
|
||||
"model_format": info.format,
|
||||
}
|
||||
legacy_conf = None
|
||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||
attributes.update(
|
||||
{
|
||||
"variant": info.variant_type,
|
||||
}
|
||||
)
|
||||
if info.format == "checkpoint":
|
||||
try:
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir,
|
||||
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
||||
)
|
||||
else:
|
||||
legacy_conf = Path(
|
||||
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
|
||||
)
|
||||
except KeyError:
|
||||
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
|
||||
|
||||
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
|
||||
possible_conf = path.with_suffix(".yaml")
|
||||
if possible_conf.exists():
|
||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||
else:
|
||||
legacy_conf = Path(
|
||||
self.config.root_path,
|
||||
"configs/controlnet",
|
||||
("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"),
|
||||
)
|
||||
|
||||
if legacy_conf:
|
||||
attributes.update({"config": str(legacy_conf)})
|
||||
return attributes
|
||||
|
||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
||||
root = root or self.config.root_path
|
||||
if path.is_relative_to(root):
|
||||
return path.relative_to(root)
|
||||
else:
|
||||
return path
|
||||
|
||||
def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path:
|
||||
"""
|
||||
Retrieve a StableDiffusion model from cache or remote and then
|
||||
does a save_pretrained() to the indicated staging area.
|
||||
"""
|
||||
_, name = repo_id.split("/")
|
||||
precision = torch_dtype(choose_torch_device())
|
||||
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
||||
|
||||
model = None
|
||||
for variant in variants:
|
||||
try:
|
||||
model = DiffusionPipeline.from_pretrained(
|
||||
repo_id,
|
||||
variant=variant,
|
||||
torch_dtype=precision,
|
||||
safety_checker=None,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||
if "fp16" not in str(e):
|
||||
print(e)
|
||||
|
||||
if model:
|
||||
break
|
||||
|
||||
if not model:
|
||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
||||
return None
|
||||
model.save_pretrained(staging / name, safe_serialization=True)
|
||||
return staging / name
|
||||
|
||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path:
|
||||
_, name = repo_id.split("/")
|
||||
location = staging / name
|
||||
paths = []
|
||||
for filename in files:
|
||||
filePath = Path(filename)
|
||||
p = hf_download_with_resume(
|
||||
repo_id,
|
||||
model_dir=location / filePath.parent,
|
||||
model_name=filePath.name,
|
||||
access_token=self.access_token,
|
||||
subfolder=filePath.parent / subfolder if subfolder else filePath.parent,
|
||||
)
|
||||
if p:
|
||||
paths.append(p)
|
||||
else:
|
||||
logger.warning(f"Could not download {filename} from {repo_id}.")
|
||||
|
||||
return location if len(paths) > 0 else None
|
||||
|
||||
@classmethod
|
||||
def _reverse_paths(cls, datasets) -> dict:
|
||||
"""
|
||||
Reverse mapping from repo_id/path to destination name.
|
||||
"""
|
||||
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
||||
logger = InvokeAILogger.get_logger("InvokeAI")
|
||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
subfolder: str = None,
|
||||
) -> Path:
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
logger.warning("File not found")
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
logger.info(f"{model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
with (
|
||||
open(model_dest, open_mode) as file,
|
||||
tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar,
|
||||
):
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
@@ -8,8 +8,8 @@ from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||
from invokeai.backend.model_management.models.base import calc_model_size_by_data
|
||||
|
||||
from ..raw_model import RawModel
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ class MLPProjModel(torch.nn.Module):
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter:
|
||||
class IPAdapter(RawModel):
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
def __init__(
|
||||
@@ -124,6 +124,9 @@ class IPAdapter:
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self):
|
||||
# workaround for circular import
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
|
||||
@@ -1,98 +1,17 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import bisect
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from typing_extensions import Self
|
||||
|
||||
from .base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
|
||||
class LoRAModelFormat(str, Enum):
|
||||
LyCORIS = "lycoris"
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class LoRAModel(ModelBase):
|
||||
# model_size: int
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: LoRAModelFormat # TODO:
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.Lora
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in lora")
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
):
|
||||
if child_type is not None:
|
||||
raise Exception("There is no child models in lora")
|
||||
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=self.model_path,
|
||||
dtype=torch_dtype,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
|
||||
self.model_size = model.calc_size()
|
||||
return model
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str):
|
||||
if not os.path.exists(path):
|
||||
raise ModelNotFoundException()
|
||||
|
||||
if os.path.isdir(path):
|
||||
for ext in ["safetensors", "bin"]:
|
||||
if os.path.exists(os.path.join(path, f"pytorch_lora_weights.{ext}")):
|
||||
return LoRAModelFormat.Diffusers
|
||||
|
||||
if os.path.isfile(path):
|
||||
if any(path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]):
|
||||
return LoRAModelFormat.LyCORIS
|
||||
|
||||
raise InvalidModelException(f"Not a valid model: {path}")
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == LoRAModelFormat.Diffusers:
|
||||
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
|
||||
path = Path(model_path, f"pytorch_lora_weights.{ext}")
|
||||
if path.exists():
|
||||
return path
|
||||
else:
|
||||
return model_path
|
||||
from .raw_model import RawModel
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
@@ -108,7 +27,7 @@ class LoRALayerBase:
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
@@ -116,7 +35,7 @@ class LoRALayerBase:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias = torch.sparse_coo_tensor(
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
@@ -128,7 +47,7 @@ class LoRALayerBase:
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@@ -142,7 +61,7 @@ class LoRALayerBase:
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -156,20 +75,20 @@ class LoRALayer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid = values["lora_mid.weight"]
|
||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
@@ -190,7 +109,7 @@ class LoRALayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
@@ -208,11 +127,7 @@ class LoHALayer(LoRALayerBase):
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
@@ -221,20 +136,20 @@ class LoHALayer(LoRALayerBase):
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1 = values["hada_t1"]
|
||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2 = values["hada_t2"]
|
||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
@@ -254,7 +169,7 @@ class LoHALayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
@@ -280,12 +195,12 @@ class LoKRLayer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1 = values["lokr_w1"]
|
||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
@@ -294,7 +209,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2 = values["lokr_w2"]
|
||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
@@ -303,7 +218,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2 = values["lokr_t2"]
|
||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
@@ -314,14 +229,18 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
w1 = self.w1
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
@@ -329,6 +248,8 @@ class LoKRLayer(LoRALayerBase):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
@@ -344,18 +265,22 @@ class LoKRLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -369,7 +294,7 @@ class FullLayer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
@@ -382,7 +307,7 @@ class FullLayer(LoRALayerBase):
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@@ -394,7 +319,7 @@ class FullLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
@@ -407,7 +332,7 @@ class IA3Layer(LoRALayerBase):
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
@@ -416,10 +341,11 @@ class IA3Layer(LoRALayerBase):
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@@ -439,28 +365,30 @@ class IA3Layer(LoRALayerBase):
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, LoRALayer]
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, LoRALayer],
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
@@ -472,7 +400,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
@@ -536,7 +464,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
):
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
@@ -544,16 +472,16 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem, # TODO:
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
@@ -561,7 +489,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer = LoRALayer(layer_key, values)
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
@@ -592,8 +520,8 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: dict):
|
||||
state_dict_groupped = {}
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
@@ -606,7 +534,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map():
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Model Cache
|
||||
|
||||
## `glibc` Memory Allocator Fragmentation
|
||||
|
||||
Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM).
|
||||
|
||||
This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`.
|
||||
|
||||
To better understand how the `glibc` memory allocator works, see these references:
|
||||
- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html
|
||||
- Details: https://sourceware.org/glibc/wiki/MallocInternals
|
||||
|
||||
Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation.
|
||||
|
||||
We can work around this memory fragmentation issue by setting the following env var:
|
||||
|
||||
```bash
|
||||
# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed.
|
||||
MALLOC_MMAP_THRESHOLD_=1048576
|
||||
```
|
||||
|
||||
See the following references for more information about the `malloc` tunable parameters:
|
||||
- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html
|
||||
- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html
|
||||
- https://man7.org/linux/man-pages/man3/mallopt.3.html
|
||||
|
||||
The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected.
|
||||
@@ -1,20 +0,0 @@
|
||||
# ruff: noqa: I001, F401
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
# This import must be first
|
||||
from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .model_cache import ModelCache
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
# This import must be last
|
||||
from .model_merge import MergeInterpolationMethod, ModelMerger
|
||||
@@ -1,31 +0,0 @@
|
||||
# Copyright (c) 2024 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
This module exports the function has_baked_in_sdxl_vae().
|
||||
It returns True if an SDXL checkpoint model has the original SDXL 1.0 VAE,
|
||||
which doesn't work properly in fp16 mode.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
SDXL_1_0_VAE_HASH = "bc40b16c3a0fa4625abdfc01c04ffc21bf3cefa6af6c7768ec61eb1f1ac0da51"
|
||||
|
||||
|
||||
def has_baked_in_sdxl_vae(checkpoint_path: Path) -> bool:
|
||||
"""Return true if the checkpoint contains a custom (non SDXL-1.0) VAE."""
|
||||
hash = _vae_hash(checkpoint_path)
|
||||
return hash != SDXL_1_0_VAE_HASH
|
||||
|
||||
|
||||
def _vae_hash(checkpoint_path: Path) -> str:
|
||||
checkpoint = load_file(checkpoint_path, device="cpu")
|
||||
vae_keys = [x for x in checkpoint.keys() if x.startswith("first_stage_model.")]
|
||||
hash = hashlib.new("sha256")
|
||||
for key in vae_keys:
|
||||
value = checkpoint[key]
|
||||
hash.update(bytes(key, "UTF-8"))
|
||||
hash.update(bytes(str(value), "UTF-8"))
|
||||
|
||||
return hash.hexdigest()
|
||||
@@ -1,553 +0,0 @@
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
|
||||
The cache returns context manager generators designed to load the
|
||||
model into the GPU within the context, and unload outside the
|
||||
context. Use like this:
|
||||
|
||||
cache = ModelCache(max_cache_size=7.5)
|
||||
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_in_GPU(SD1,SD2)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import gc
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, Union, types
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
|
||||
|
||||
from ..util.devices import choose_torch_device
|
||||
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
# Size of a MB in bytes.
|
||||
MB = 2**20
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
# {submodel_key => size}
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelLocker(object):
|
||||
"Forward declaration"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
"Forward declaration"
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _CacheRecord:
|
||||
size: int
|
||||
model: Any
|
||||
cache: ModelCache
|
||||
_locks: int
|
||||
|
||||
def __init__(self, cache, model: Any, size: int):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.cache = cache
|
||||
self._locks = 0
|
||||
|
||||
def lock(self):
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self):
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self):
|
||||
return self._locks > 0
|
||||
|
||||
@property
|
||||
def loaded(self):
|
||||
if self.model is not None and hasattr(self.model, "device"):
|
||||
return self.model.device != self.cache.storage_device
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
logger: types.ModuleType = logger,
|
||||
log_memory_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
self.model_infos: Dict[str, ModelBase] = {}
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self.precision: torch.dtype = precision
|
||||
self.max_cache_size: float = max_cache_size
|
||||
self.max_vram_cache_size: float = max_vram_cache_size
|
||||
self.execution_device: torch.device = execution_device
|
||||
self.storage_device: torch.device = storage_device
|
||||
self.sha_chunksize = sha_chunksize
|
||||
self.logger = logger
|
||||
self._log_memory_usage = log_memory_usage
|
||||
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
|
||||
self._cached_models = {}
|
||||
self._cache_stack = []
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def get_key(
|
||||
self,
|
||||
model_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
):
|
||||
key = f"{model_path}:{base_model}:{model_type}"
|
||||
if submodel_type:
|
||||
key += f":{submodel_type}"
|
||||
return key
|
||||
|
||||
def _get_model_info(
|
||||
self,
|
||||
model_path: str,
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
):
|
||||
model_info_key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=None,
|
||||
)
|
||||
|
||||
if model_info_key not in self.model_infos:
|
||||
self.model_infos[model_info_key] = model_class(
|
||||
model_path,
|
||||
base_model,
|
||||
model_type,
|
||||
)
|
||||
|
||||
return self.model_infos[model_info_key]
|
||||
|
||||
# TODO: args
|
||||
def get_model(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
model_class: Type[ModelBase],
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
gpu_load: bool = True,
|
||||
) -> Any:
|
||||
if not isinstance(model_path, Path):
|
||||
model_path = Path(model_path)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
model_info = self._get_model_info(
|
||||
model_path=model_path,
|
||||
model_class=model_class,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
key = self.get_key(
|
||||
model_path=model_path,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel_type=submodel,
|
||||
)
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(
|
||||
f"Loading model {model_path}, type"
|
||||
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
||||
)
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
|
||||
self_reported_model_size_before_load = model_info.get_size(submodel)
|
||||
# Remove old models from the cache to make room for the new model.
|
||||
self._make_cache_room(self_reported_model_size_before_load)
|
||||
|
||||
# Load the model from disk and capture a memory snapshot before/after.
|
||||
start_load_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
with skip_torch_weight_init():
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_load_time = time.time()
|
||||
|
||||
self_reported_model_size_after_load = model_info.get_size(submodel)
|
||||
|
||||
self.logger.debug(
|
||||
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n"
|
||||
f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /"
|
||||
f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if abs(self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB:
|
||||
self.logger.debug(
|
||||
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
|
||||
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
|
||||
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
|
||||
)
|
||||
|
||||
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
|
||||
self._cached_models[key] = cache_entry
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
|
||||
if self.stats:
|
||||
self.stats.cache_size = self.max_cache_size * GIG
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[key] = max(
|
||||
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
|
||||
)
|
||||
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
||||
|
||||
def _move_model_to_device(self, key: str, target_device: torch.device):
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
source_device = cache_entry.model.device
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
|
||||
# multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
cache_entry.model.to(target_device)
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||
"""
|
||||
:param cache: The model_cache object
|
||||
:param key: The key of the model to lock in GPU
|
||||
:param model: The model to lock
|
||||
:param gpu_load: True if load into gpu
|
||||
:param size_needed: Size of the model to load
|
||||
"""
|
||||
self.gpu_load = gpu_load
|
||||
self.cache = cache
|
||||
self.key = key
|
||||
self.model = model
|
||||
self.size_needed = size_needed
|
||||
self.cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
def __enter__(self) -> Any:
|
||||
if not hasattr(self.model, "to"):
|
||||
return self.model
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this
|
||||
# code to move it into GPU!
|
||||
if self.gpu_load:
|
||||
self.cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models(self.size_needed)
|
||||
|
||||
self.cache._move_model_to_device(self.key, self.cache.execution_device)
|
||||
|
||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
except Exception:
|
||||
self.cache_entry.unlock()
|
||||
raise
|
||||
|
||||
# TODO: not fully understand
|
||||
# in the event that the caller wants the model in RAM, we
|
||||
# move it into CPU if it is in GPU and not locked
|
||||
elif self.cache_entry.loaded and not self.cache_entry.locked:
|
||||
self.cache._move_model_to_device(self.key, self.cache.storage_device)
|
||||
|
||||
return self.model
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
if not hasattr(self.model, "to"):
|
||||
return
|
||||
|
||||
self.cache_entry.unlock()
|
||||
if not self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
# TODO: should it be called untrack_model?
|
||||
def uncache_model(self, cache_id: str):
|
||||
with suppress(ValueError):
|
||||
self._cache_stack.remove(cache_id)
|
||||
self._cached_models.pop(cache_id, None)
|
||||
|
||||
def model_hash(
|
||||
self,
|
||||
model_path: Union[str, Path],
|
||||
) -> str:
|
||||
"""
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
|
||||
:param model_path: Path to model file/directory on disk.
|
||||
"""
|
||||
return self._local_model_hash(model_path)
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"""Return the current size of the cache, in GB."""
|
||||
return self._cache_size() / GIG
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.execution_device.type == "cuda"
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % self.cache_size()
|
||||
|
||||
cached_models = 0
|
||||
loaded_models = 0
|
||||
locked_models = 0
|
||||
for model_info in self._cached_models.values():
|
||||
cached_models += 1
|
||||
if model_info.loaded:
|
||||
loaded_models += 1
|
||||
if model_info.locked:
|
||||
locked_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
|
||||
f" {cached_models}/{loaded_models}/{locked_models}"
|
||||
)
|
||||
|
||||
def _cache_size(self) -> int:
|
||||
return sum([m.size for m in self._cached_models.values()])
|
||||
|
||||
def _make_cache_room(self, model_size):
|
||||
# calculate how much memory this model will require
|
||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = self._cache_size()
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(
|
||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
|
||||
f" {(bytes_needed/GIG):.2f} GB"
|
||||
)
|
||||
|
||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
for referrer in gc.get_referrers(cache_entry.model):
|
||||
if type(referrer).__name__ == "frame":
|
||||
# RuntimeError: cannot clear an executing frame
|
||||
with suppress(RuntimeError):
|
||||
referrer.clear()
|
||||
cleared = True
|
||||
# break
|
||||
|
||||
# repeat if referrers changes(due to frame clear), else exit loop
|
||||
if cleared:
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
||||
f" refs: {refs}"
|
||||
)
|
||||
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
if self.stats:
|
||||
self.stats.cleared += 1
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _offload_unlocked_models(self, size_needed: int = 0):
|
||||
reserved = self.max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated()
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
self._move_model_to_device(model_key, self.storage_device)
|
||||
|
||||
vram_in_use = torch.cuda.memory_allocated()
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
path = Path(model_path)
|
||||
|
||||
hashpath = path / "checksum.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
self.logger.debug(f"computing hash of model {path.name}")
|
||||
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
|
||||
with open(file, "rb") as f:
|
||||
while chunk := f.read(self.sha_chunksize):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,664 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
)
|
||||
from .models.base import read_checkpoint_meta
|
||||
from .util import lora_token_vector_length
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelProbeInfo(object):
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
|
||||
image_size: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
"""forward declaration"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
PROBES = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
"onnx": {},
|
||||
}
|
||||
|
||||
CLASS2TYPE = {
|
||||
"StableDiffusionPipeline": ModelType.Main,
|
||||
"StableDiffusionInpaintPipeline": ModelType.Main,
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
"T2IAdapter": ModelType.T2IAdapter,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_probe(
|
||||
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
||||
):
|
||||
cls.PROBES[format][model_type] = probe_class
|
||||
|
||||
@classmethod
|
||||
def heuristic_probe(
|
||||
cls,
|
||||
model: Union[Dict, ModelMixin, Path],
|
||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
||||
) -> ModelProbeInfo:
|
||||
if isinstance(model, Path):
|
||||
return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
|
||||
elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(
|
||||
cls,
|
||||
model_path: Path,
|
||||
model: Optional[Union[Dict, ModelMixin]] = None,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> ModelProbeInfo:
|
||||
"""
|
||||
Probe the model at model_path and return sufficient information about it
|
||||
to place it somewhere in the models directory hierarchy. If the model is
|
||||
already loaded into memory, you may provide it as model in order to avoid
|
||||
opening it a second time. The prediction_type_helper callable is a function that receives
|
||||
the path to the model and returns the SchedulerPredictionType.
|
||||
"""
|
||||
if model_path:
|
||||
format_type = "diffusers" if model_path.is_dir() else "checkpoint"
|
||||
else:
|
||||
format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
|
||||
model_info = None
|
||||
try:
|
||||
model_type = (
|
||||
cls.get_model_type_from_folder(model_path, model)
|
||||
if format_type == "diffusers"
|
||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||
)
|
||||
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
if not probe_class:
|
||||
return None
|
||||
probe = probe_class(model_path, model, prediction_type_helper)
|
||||
base_type = probe.get_base_type()
|
||||
variant_type = probe.get_variant_type()
|
||||
prediction_type = probe.get_scheduler_prediction_type()
|
||||
name = cls.get_model_name(model_path)
|
||||
description = f"{base_type.value} {model_type.value} model {name}"
|
||||
format = probe.get_format()
|
||||
model_info = ModelProbeInfo(
|
||||
model_type=model_type,
|
||||
base_type=base_type,
|
||||
variant_type=variant_type,
|
||||
prediction_type=prediction_type,
|
||||
name=name,
|
||||
description=description,
|
||||
upcast_attention=(
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
),
|
||||
format=format,
|
||||
image_size=(
|
||||
1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else (
|
||||
768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512
|
||||
)
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
return model_info
|
||||
|
||||
@classmethod
|
||||
def get_model_name(cls, model_path: Path) -> str:
|
||||
if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
|
||||
return model_path.stem
|
||||
else:
|
||||
return model_path.name
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
|
||||
if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
|
||||
return None
|
||||
|
||||
if model_path.name == "learned_embeds.bin":
|
||||
return ModelType.TextualInversion
|
||||
|
||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in ckpt.keys():
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
return ModelType.Vae
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
return ModelType.Lora
|
||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||
return ModelType.ControlNet
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
|
||||
"""
|
||||
Get the model type of a hugging-face style folder.
|
||||
"""
|
||||
class_name = None
|
||||
error_hint = None
|
||||
if model:
|
||||
class_name = model.__class__.__name__
|
||||
else:
|
||||
for suffix in ["bin", "safetensors"]:
|
||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||
return ModelType.TextualInversion
|
||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "unet/model.onnx").exists():
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
config_path = i if i.exists() else c if c.exists() else None
|
||||
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
else:
|
||||
error_hint = f"No model_index.json or config.json found in {folder_path}."
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
else:
|
||||
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
|
||||
|
||||
# give up
|
||||
raise InvalidModelException(
|
||||
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path, model_path)
|
||||
return torch.load(model_path, map_location="cpu")
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@classmethod
|
||||
def _scan_model(cls, model_name, checkpoint):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
class ProbeBase(object):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
pass
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class CheckpointProbeBase(ProbeBase):
|
||||
def __init__(
|
||||
self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
|
||||
) -> BaseModelType:
|
||||
self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
|
||||
self.checkpoint_path = checkpoint_path
|
||||
self.helper = helper
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
pass
|
||||
|
||||
def get_format(self) -> str:
|
||||
return "checkpoint"
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise InvalidModelException(
|
||||
f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
|
||||
)
|
||||
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
else:
|
||||
raise InvalidModelException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
|
||||
"""Return model prediction type."""
|
||||
# if there is a .yaml associated with this checkpoint, then we do not need
|
||||
# to probe for the prediction type as it will be ignored.
|
||||
if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
|
||||
return None
|
||||
|
||||
type = self.get_base_type()
|
||||
if type == BaseModelType.StableDiffusion2:
|
||||
checkpoint = self.checkpoint
|
||||
state_dict = self.checkpoint.get("state_dict") or checkpoint
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||
if "global_step" in checkpoint:
|
||||
if checkpoint["global_step"] == 220000:
|
||||
return SchedulerPredictionType.Epsilon
|
||||
elif checkpoint["global_step"] == 110000:
|
||||
return SchedulerPredictionType.VPrediction
|
||||
if self.helper and self.checkpoint_path:
|
||||
if helper_guess := self.helper(self.checkpoint_path):
|
||||
return helper_guess
|
||||
return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts
|
||||
|
||||
elif type == BaseModelType.StableDiffusion1:
|
||||
if self.helper and self.checkpoint_path:
|
||||
if helper_guess := self.helper(self.checkpoint_path):
|
||||
return helper_guess
|
||||
return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return "lycoris"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
token_vector_length = lora_token_vector_length(checkpoint)
|
||||
|
||||
if token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_vector_length == 1280:
|
||||
return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return None
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
if "string_to_token" in checkpoint:
|
||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in checkpoint:
|
||||
token_dim = checkpoint["emb_params"].shape[-1]
|
||||
elif "clip_g" in checkpoint:
|
||||
token_dim = checkpoint["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[-1]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_dim == 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
):
|
||||
if key_name not in checkpoint:
|
||||
continue
|
||||
if checkpoint[key_name].shape[-1] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif checkpoint[key_name].shape[-1] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
class FolderProbeBase(ProbeBase):
|
||||
def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used
|
||||
self.model = model
|
||||
self.folder_path = folder_path
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
def get_format(self) -> str:
|
||||
return "diffusers"
|
||||
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self.model:
|
||||
unet_conf = self.model.unet.config
|
||||
else:
|
||||
with open(self.folder_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif unet_conf["cross_attention_dim"] == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif unet_conf["cross_attention_dim"] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
elif unet_conf["cross_attention_dim"] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown base model for {self.folder_path}")
|
||||
|
||||
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
|
||||
if self.model:
|
||||
scheduler_conf = self.model.scheduler.config
|
||||
else:
|
||||
with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
# This only works for pipelines! Any kind of
|
||||
# exception results in our returning the
|
||||
# "normal" variant type
|
||||
try:
|
||||
if self.model:
|
||||
conf = self.model.unet.config
|
||||
else:
|
||||
config_file = self.folder_path / "unet" / "config.json"
|
||||
with open(config_file, "r") as file:
|
||||
conf = json.load(file)
|
||||
|
||||
in_channels = conf["in_channels"]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
return ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
except Exception:
|
||||
pass
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def _config_looks_like_sdxl(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.folder_path.name
|
||||
if name == "vae":
|
||||
name = self.folder_path.parent.name
|
||||
return name
|
||||
|
||||
|
||||
class TextualInversionFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return None
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
path = self.folder_path / "learned_embeds.bin"
|
||||
if not path.exists():
|
||||
return None
|
||||
checkpoint = ModelProbe._scan_and_load_checkpoint(path)
|
||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
||||
|
||||
|
||||
class ONNXFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return "onnx"
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.StableDiffusion1
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||
return base_model
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = None
|
||||
for suffix in ["safetensors", "bin"]:
|
||||
base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
|
||||
if base_file.exists():
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise InvalidModelException("Unknown LoRA format encountered")
|
||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = self.folder_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.folder_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
|
||||
adapter_type = config.get("adapter_type", None)
|
||||
if adapter_type == "full_adapter_xl":
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif adapter_type == "full_adapter" or "light_adapter":
|
||||
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
|
||||
return BaseModelType.StableDiffusion1
|
||||
else:
|
||||
raise InvalidModelException(
|
||||
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
|
||||
)
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
@@ -1,112 +0,0 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Set, types
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
|
||||
class ModelSearch(ABC):
|
||||
def __init__(self, directories: List[Path], logger: types.ModuleType = logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.directories = directories
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self):
|
||||
self.on_search_started()
|
||||
for dir in self.directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
|
||||
def walk_directory(self, path: Path):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any(Path(root).is_relative_to(x) for x in self._pruned_paths):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
(path / x).exists()
|
||||
for x in {
|
||||
"config.json",
|
||||
"model_index.json",
|
||||
"learned_embeds.bin",
|
||||
"pytorch_lora_weights.bin",
|
||||
"image_encoder.txt",
|
||||
}
|
||||
):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
def on_search_started(self):
|
||||
self.models_found: Set[Path] = set()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
self.models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
pass
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return list(self.models_found)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user