mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 14:58:00 -05:00
Compare commits
1 Commits
fix/diffus
...
ebr/fix-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a784a8e4b8 |
@@ -20,13 +20,13 @@ def calc_images_mean_L1(image1_path, image2_path):
|
|||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("image1_path")
|
parser.add_argument('image1_path')
|
||||||
parser.add_argument("image2_path")
|
parser.add_argument('image2_path')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
|
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
|
||||||
print(mean_L1)
|
print(mean_L1)
|
||||||
|
|||||||
@@ -1,2 +1 @@
|
|||||||
b3dccfaeb636599c02effc377cdd8a87d658256c
|
b3dccfaeb636599c02effc377cdd8a87d658256c
|
||||||
218b6d0546b990fc449c876fb99f44b50c4daa35
|
|
||||||
|
|||||||
4
.github/workflows/lint-frontend.yml
vendored
4
.github/workflows/lint-frontend.yml
vendored
@@ -2,6 +2,8 @@ name: Lint frontend
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'invokeai/frontend/web/**'
|
||||||
types:
|
types:
|
||||||
- 'ready_for_review'
|
- 'ready_for_review'
|
||||||
- 'opened'
|
- 'opened'
|
||||||
@@ -9,6 +11,8 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
|
paths:
|
||||||
|
- 'invokeai/frontend/web/**'
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
|||||||
27
.github/workflows/style-checks.yml
vendored
27
.github/workflows/style-checks.yml
vendored
@@ -1,27 +0,0 @@
|
|||||||
name: Black # TODO: add isort and flake8 later
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request: {}
|
|
||||||
push:
|
|
||||||
branches: master
|
|
||||||
tags: "*"
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v3
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Install dependencies with pip
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pip wheel
|
|
||||||
pip install .[test]
|
|
||||||
|
|
||||||
# - run: isort --check-only .
|
|
||||||
- run: black --check .
|
|
||||||
# - run: flake8
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -38,6 +38,7 @@ develop-eggs/
|
|||||||
downloads/
|
downloads/
|
||||||
eggs/
|
eggs/
|
||||||
.eggs/
|
.eggs/
|
||||||
|
lib/
|
||||||
lib64/
|
lib64/
|
||||||
parts/
|
parts/
|
||||||
sdist/
|
sdist/
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
# See https://pre-commit.com/ for usage and config
|
|
||||||
repos:
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: black
|
|
||||||
name: black
|
|
||||||
stages: [commit]
|
|
||||||
language: system
|
|
||||||
entry: black
|
|
||||||
types: [python]
|
|
||||||
@@ -123,7 +123,7 @@ and go to http://localhost:9090.
|
|||||||
|
|
||||||
### Command-Line Installation (for developers and users familiar with Terminals)
|
### Command-Line Installation (for developers and users familiar with Terminals)
|
||||||
|
|
||||||
You must have Python 3.9 through 3.11 installed on your machine. Earlier or
|
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
|
||||||
later versions are not supported.
|
later versions are not supported.
|
||||||
Node.js also needs to be installed along with yarn (can be installed with
|
Node.js also needs to be installed along with yarn (can be installed with
|
||||||
the command `npm install -g yarn` if needed)
|
the command `npm install -g yarn` if needed)
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 131 KiB |
@@ -16,7 +16,7 @@ If you don't feel ready to make a code contribution yet, no problem! You can als
|
|||||||
There are two paths to making a development contribution:
|
There are two paths to making a development contribution:
|
||||||
|
|
||||||
1. Choosing an open issue to address. Open issues can be found in the [Issues](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen) section of the InvokeAI repository. These are tagged by the issue type (bug, enhancement, etc.) along with the “good first issues” tag denoting if they are suitable for first time contributors.
|
1. Choosing an open issue to address. Open issues can be found in the [Issues](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen) section of the InvokeAI repository. These are tagged by the issue type (bug, enhancement, etc.) along with the “good first issues” tag denoting if they are suitable for first time contributors.
|
||||||
1. Additional items can be found on our [roadmap](https://github.com/orgs/invoke-ai/projects/7). The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item you’d like to help with, reach out to the contributor assigned to the item to see how you can help.
|
1. Additional items can be found on our roadmap <******************************link to roadmap>******************************. The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item you’d like to help with, reach out to the contributor assigned to the item to see how you can help.
|
||||||
2. Opening a new issue or feature to add. **Please make sure you have searched through existing issues before creating new ones.**
|
2. Opening a new issue or feature to add. **Please make sure you have searched through existing issues before creating new ones.**
|
||||||
|
|
||||||
*Regardless of what you choose, please post in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord before you start development in order to confirm that the issue or feature is aligned with the current direction of the project. We value our contributors time and effort and want to ensure that no one’s time is being misspent.*
|
*Regardless of what you choose, please post in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord before you start development in order to confirm that the issue or feature is aligned with the current direction of the project. We value our contributors time and effort and want to ensure that no one’s time is being misspent.*
|
||||||
|
|||||||
@@ -4,9 +4,6 @@ title: Overview
|
|||||||
|
|
||||||
Here you can find the documentation for InvokeAI's various features.
|
Here you can find the documentation for InvokeAI's various features.
|
||||||
|
|
||||||
## The [Getting Started Guide](../help/gettingStartedWithAI)
|
|
||||||
A getting started guide for those new to AI image generation.
|
|
||||||
|
|
||||||
## The Basics
|
## The Basics
|
||||||
### * The [Web User Interface](WEB.md)
|
### * The [Web User Interface](WEB.md)
|
||||||
Guide to the Web interface. Also see the [WebUI Hotkeys Reference Guide](WEBUIHOTKEYS.md)
|
Guide to the Web interface. Also see the [WebUI Hotkeys Reference Guide](WEBUIHOTKEYS.md)
|
||||||
@@ -49,7 +46,7 @@ Personalize models by adding your own style or subjects.
|
|||||||
|
|
||||||
## Other Features
|
## Other Features
|
||||||
|
|
||||||
### * [The NSFW Checker](WATERMARK+NSFW.md)
|
### * [The NSFW Checker](NSFW.md)
|
||||||
Prevent InvokeAI from displaying unwanted racy images.
|
Prevent InvokeAI from displaying unwanted racy images.
|
||||||
|
|
||||||
### * [Controlling Logging](LOGGING.md)
|
### * [Controlling Logging](LOGGING.md)
|
||||||
|
|||||||
@@ -1,95 +0,0 @@
|
|||||||
# Getting Started with AI Image Generation
|
|
||||||
|
|
||||||
New to image generation with AI? You’re in the right place!
|
|
||||||
|
|
||||||
This is a high level walkthrough of some of the concepts and terms you’ll see as you start using InvokeAI. Please note, this is not an exhaustive guide and may be out of date due to the rapidly changing nature of the space.
|
|
||||||
|
|
||||||
## Using InvokeAI
|
|
||||||
|
|
||||||
### **Prompt Crafting**
|
|
||||||
|
|
||||||
- Prompts are the basis of using InvokeAI, providing the models directions on what to generate. As a general rule of thumb, the more detailed your prompt is, the better your result will be.
|
|
||||||
|
|
||||||
*To get started, here’s an easy template to use for structuring your prompts:*
|
|
||||||
|
|
||||||
- Subject, Style, Quality, Aesthetic
|
|
||||||
- **Subject:** What your image will be about. E.g. “a futuristic city with trains”, “penguins floating on icebergs”, “friends sharing beers”
|
|
||||||
- **Style:** The style or medium in which your image will be in. E.g. “photograph”, “pencil sketch”, “oil paints”, or “pop art”, “cubism”, “abstract”
|
|
||||||
- **Quality:** A particular aspect or trait that you would like to see emphasized in your image. E.g. "award-winning", "featured in {relevant set of high quality works}", "professionally acclaimed". Many people often use "masterpiece".
|
|
||||||
- **Aesthetics:** The visual impact and design of the artwork. This can be colors, mood, lighting, setting, etc.
|
|
||||||
- There are two prompt boxes: *Positive Prompt* & *Negative Prompt*.
|
|
||||||
- A **Positive** Prompt includes words you want the model to reference when creating an image.
|
|
||||||
- Negative Prompt is for anything you want the model to eliminate when creating an image. It doesn’t always interpret things exactly the way you would, but helps control the generation process. Always try to include a few terms - you can typically use lower quality image terms like “blurry” or “distorted” with good success.
|
|
||||||
- Some examples prompts you can try on your own:
|
|
||||||
- A detailed oil painting of a tranquil forest at sunset with vibrant+ colors and soft, golden light filtering through the trees
|
|
||||||
- friends sharing beers in a busy city, realistic colored pencil sketch, twilight, masterpiece, bright, lively
|
|
||||||
|
|
||||||
### Generation Workflows
|
|
||||||
|
|
||||||
- Invoke offers a number of different workflows for interacting with models to produce images. Each is extremely powerful on its own, but together provide you an unparalleled way of producing high quality creative outputs that align with your vision.
|
|
||||||
- **Text to Image:** The text to image tab focuses on the key workflow of using a prompt to generate a new image. It includes other features that help control the generation process as well.
|
|
||||||
- **Image to Image:** With image to image, you provide an image as a reference (called the “initial image”), which provides more guidance around color and structure to the AI as it generates a new image. This is provided alongside the same features as Text to Image.
|
|
||||||
- **Unified Canvas:** The Unified Canvas is an advanced AI-first image editing tool that is easy to use, but hard to master. Drag an image onto the canvas from your gallery in order to regenerate certain elements, edit content or colors (known as inpainting), or extend the image with an exceptional degree of consistency and clarity (called outpainting).
|
|
||||||
|
|
||||||
### Improving Image Quality
|
|
||||||
|
|
||||||
- Fine tuning your prompt - the more specific you are, the closer the image will turn out to what is in your head! Adding more details in the Positive Prompt or Negative Prompt can help add / remove pieces of your image to improve it - You can also use advanced techniques like upweighting and downweighting to control the influence of certain words. [Learn more here](https://invoke-ai.github.io/InvokeAI/features/PROMPTS/#prompt-syntax-features).
|
|
||||||
- **Tip: If you’re seeing poor results, try adding the things you don’t like about the image to your negative prompt may help. E.g. distorted, low quality, unrealistic, etc.**
|
|
||||||
- Explore different models - Other models can produce different results due to the data they’ve been trained on. Each model has specific language and settings it works best with; a model’s documentation is your friend here. Play around with some and see what works best for you!
|
|
||||||
- Increasing Steps - The number of steps used controls how much time the model is given to produce an image, and depends on the “Scheduler” used. The schedule controls how each step is processed by the model. More steps tends to mean better results, but will take longer - We recommend at least 30 steps for most
|
|
||||||
- Tweak and Iterate - Remember, it’s best to change one thing at a time so you know what is working and what isn't. Sometimes you just need to try a new image, and other times using a new prompt might be the ticket. For testing, consider turning off the “random” Seed - Using the same seed with the same settings will produce the same image, which makes it the perfect way to learn exactly what your changes are doing.
|
|
||||||
- Explore Advanced Settings - InvokeAI has a full suite of tools available to allow you complete control over your image creation process - Check out our [docs if you want to learn more](https://invoke-ai.github.io/InvokeAI/features/).
|
|
||||||
|
|
||||||
|
|
||||||
## Terms & Concepts
|
|
||||||
|
|
||||||
If you're interested in learning more, check out [this presentation](https://docs.google.com/presentation/d/1IO78i8oEXFTZ5peuHHYkVF-Y3e2M6iM5tCnc-YBfcCM/edit?usp=sharing) from one of our maintainers (@lstein).
|
|
||||||
|
|
||||||
### Stable Diffusion
|
|
||||||
|
|
||||||
Stable Diffusion is deep learning, text-to-image model that is the foundation of the capabilities found in InvokeAI. Since the release of Stable Diffusion, there have been many subsequent models created based on Stable Diffusion that are designed to generate specific types of images.
|
|
||||||
|
|
||||||
### Prompts
|
|
||||||
|
|
||||||
Prompts provide the models directions on what to generate. As a general rule of thumb, the more detailed your prompt is, the better your result will be.
|
|
||||||
|
|
||||||
### Models
|
|
||||||
|
|
||||||
Models are the magic that power InvokeAI. These files represent the output of training a machine on understanding massive amounts of images - providing them with the capability to generate new images using just a text description of what you’d like to see. (Like Stable Diffusion!)
|
|
||||||
|
|
||||||
Invoke offers a simple way to download several different models upon installation, but many more can be discovered online, including at ****. Each model can produce a unique style of output, based on the images it was trained on - Try out different models to see which best fits your creative vision!
|
|
||||||
|
|
||||||
- *Models that contain “inpainting” in the name are designed for use with the inpainting feature of the Unified Canvas*
|
|
||||||
|
|
||||||
### Scheduler
|
|
||||||
|
|
||||||
Schedulers guide the process of removing noise (de-noising) from data. They determine:
|
|
||||||
|
|
||||||
1. The number of steps to take to remove the noise.
|
|
||||||
2. Whether the steps are random (stochastic) or predictable (deterministic).
|
|
||||||
3. The specific method (algorithm) used for de-noising.
|
|
||||||
|
|
||||||
Experimenting with different schedulers is recommended as each will produce different outputs!
|
|
||||||
|
|
||||||
### Steps
|
|
||||||
|
|
||||||
The number of de-noising steps each generation through.
|
|
||||||
|
|
||||||
Schedulers can be intricate and there's often a balance to strike between how quickly they can de-noise data and how well they can do it. It's typically advised to experiment with different schedulers to see which one gives the best results. There has been a lot written on the internet about different schedulers, as well as exploring what the right level of "steps" are for each. You can save generation time by reducing the number of steps used, but you'll want to make sure that you are satisfied with the quality of images produced!
|
|
||||||
|
|
||||||
### Low-Rank Adaptations / LoRAs
|
|
||||||
|
|
||||||
Low-Rank Adaptations (LoRAs) are like a smaller, more focused version of models, intended to focus on training a better understanding of how a specific character, style, or concept looks.
|
|
||||||
|
|
||||||
### Textual Inversion Embeddings
|
|
||||||
|
|
||||||
Textual Inversion Embeddings, like LoRAs, assist with more easily prompting for certain characters, styles, or concepts. However, embeddings are trained to update the relationship between a specific word (known as the “trigger”) and the intended output.
|
|
||||||
|
|
||||||
### ControlNet
|
|
||||||
|
|
||||||
ControlNets are neural network models that are able to extract key features from an existing image and use these features to guide the output of the image generation model.
|
|
||||||
|
|
||||||
### VAE
|
|
||||||
|
|
||||||
Variational auto-encoder (VAE) is a encode/decode model that translates the "latents" image produced during the image generation procees to the large pixel images that we see.
|
|
||||||
|
|
||||||
@@ -11,33 +11,6 @@ title: Home
|
|||||||
```
|
```
|
||||||
-->
|
-->
|
||||||
|
|
||||||
<!-- CSS styling -->
|
|
||||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@fortawesome/fontawesome-free@6.2.1/css/fontawesome.min.css">
|
|
||||||
<style>
|
|
||||||
.button {
|
|
||||||
width: 300px;
|
|
||||||
height: 50px;
|
|
||||||
background-color: #448AFF;
|
|
||||||
color: #fff;
|
|
||||||
font-size: 16px;
|
|
||||||
border: none;
|
|
||||||
cursor: pointer;
|
|
||||||
border-radius: 0.2rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.button-container {
|
|
||||||
display: grid;
|
|
||||||
grid-template-columns: repeat(3, 300px);
|
|
||||||
gap: 20px;
|
|
||||||
}
|
|
||||||
|
|
||||||
.button:hover {
|
|
||||||
background-color: #526CFE;
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<div align="center" markdown>
|
<div align="center" markdown>
|
||||||
|
|
||||||
|
|
||||||
@@ -97,23 +70,61 @@ image-to-image generator. It provides a streamlined process with various new
|
|||||||
features and options to aid the image generation process. It runs on Windows,
|
features and options to aid the image generation process. It runs on Windows,
|
||||||
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
|
||||||
|
|
||||||
|
**Quick links**: [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>]
|
||||||
|
[<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a
|
||||||
|
href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a
|
||||||
|
href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas &
|
||||||
|
Q&A</a>]
|
||||||
|
|
||||||
<div align="center"><img src="assets/invoke-web-server-1.png" width=640></div>
|
<div align="center"><img src="assets/invoke-web-server-1.png" width=640></div>
|
||||||
|
|
||||||
!!! Note
|
!!! note
|
||||||
|
|
||||||
This project is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates as it will help aid response time.
|
This fork is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates. They will help aid diagnose issues faster.
|
||||||
|
|
||||||
## :octicons-link-24: Quick Links
|
## :octicons-package-dependencies-24: Installation
|
||||||
|
|
||||||
<div class="button-container">
|
This fork is supported across Linux, Windows and Macintosh. Linux users can use
|
||||||
<a href="installation/INSTALLATION"> <button class="button">Installation</button> </a>
|
either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm
|
||||||
<a href="features/"> <button class="button">Features</button> </a>
|
driver).
|
||||||
<a href="help/gettingStartedWithAI/"> <button class="button">Getting Started</button> </a>
|
|
||||||
<a href="contributing/CONTRIBUTING/"> <button class="button">Contributing</button> </a>
|
### [Installation Getting Started Guide](installation)
|
||||||
<a href="https://github.com/invoke-ai/InvokeAI/"> <button class="button">Code and Downloads</button> </a>
|
#### **[Automated Installer](installation/010_INSTALL_AUTOMATED.md)**
|
||||||
<a href="https://github.com/invoke-ai/InvokeAI/issues"> <button class="button">Bug Reports </button> </a>
|
✅ This is the recommended installation method for first-time users.
|
||||||
<a href="https://discord.gg/ZmtBAhwWhy"> <button class="button"> Join the Discord Server!</button> </a>
|
#### [Manual Installation](installation/020_INSTALL_MANUAL.md)
|
||||||
</div>
|
This method is recommended for experienced users and developers
|
||||||
|
#### [Docker Installation](installation/040_INSTALL_DOCKER.md)
|
||||||
|
This method is recommended for those familiar with running Docker containers
|
||||||
|
### Other Installation Guides
|
||||||
|
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
|
||||||
|
- [XFormers](installation/070_INSTALL_XFORMERS.md)
|
||||||
|
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
|
||||||
|
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
|
||||||
|
|
||||||
|
## :fontawesome-solid-computer: Hardware Requirements
|
||||||
|
|
||||||
|
### :octicons-cpu-24: System
|
||||||
|
|
||||||
|
You wil need one of the following:
|
||||||
|
|
||||||
|
- :simple-nvidia: An NVIDIA-based graphics card with 4 GB or more VRAM memory.
|
||||||
|
- :simple-amd: An AMD-based graphics card with 4 GB or more VRAM memory (Linux
|
||||||
|
only)
|
||||||
|
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
|
||||||
|
|
||||||
|
We do **not recommend** the following video cards due to issues with their
|
||||||
|
running in half-precision mode and having insufficient VRAM to render 512x512
|
||||||
|
images in full-precision mode:
|
||||||
|
|
||||||
|
- NVIDIA 10xx series cards such as the 1080ti
|
||||||
|
- GTX 1650 series cards
|
||||||
|
- GTX 1660 series cards
|
||||||
|
|
||||||
|
### :fontawesome-solid-memory: Memory and Disk
|
||||||
|
|
||||||
|
- At least 12 GB Main Memory RAM.
|
||||||
|
- At least 18 GB of free disk space for the machine learning model, Python, and
|
||||||
|
all its dependencies.
|
||||||
|
|
||||||
|
|
||||||
## :octicons-gift-24: InvokeAI Features
|
## :octicons-gift-24: InvokeAI Features
|
||||||
@@ -219,7 +230,7 @@ encouraged to do so.
|
|||||||
|
|
||||||
## :octicons-person-24: Contributors
|
## :octicons-person-24: Contributors
|
||||||
|
|
||||||
This software is a combined effort of various people from across the world.
|
This fork is a combined effort of various people from across the world.
|
||||||
[Check out the list of all these amazing people](other/CONTRIBUTORS.md). We
|
[Check out the list of all these amazing people](other/CONTRIBUTORS.md). We
|
||||||
thank them for their time, hard work and effort.
|
thank them for their time, hard work and effort.
|
||||||
|
|
||||||
|
|||||||
@@ -40,8 +40,10 @@ experimental versions later.
|
|||||||
this, open up a command-line window ("Terminal" on Linux and
|
this, open up a command-line window ("Terminal" on Linux and
|
||||||
Macintosh, "Command" or "Powershell" on Windows) and type `python
|
Macintosh, "Command" or "Powershell" on Windows) and type `python
|
||||||
--version`. If Python is installed, it will print out the version
|
--version`. If Python is installed, it will print out the version
|
||||||
number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
|
number. If it is version `3.9.*` or `3.10.*`, you meet
|
||||||
requirements.
|
requirements. We do not recommend using Python 3.11 or higher,
|
||||||
|
as not all the libraries that InvokeAI depends on work properly
|
||||||
|
with this version.
|
||||||
|
|
||||||
!!! warning "What to do if you have an unsupported version"
|
!!! warning "What to do if you have an unsupported version"
|
||||||
|
|
||||||
@@ -372,71 +374,8 @@ experimental versions later.
|
|||||||
Once InvokeAI is installed, do not move or remove this directory."
|
Once InvokeAI is installed, do not move or remove this directory."
|
||||||
|
|
||||||
|
|
||||||
<a name="troubleshooting"></a>
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
### _OSErrors on Windows while installing dependencies_
|
|
||||||
|
|
||||||
During a zip file installation or an online update, installation stops
|
|
||||||
with an error like this:
|
|
||||||
|
|
||||||
{:width="800px"}
|
|
||||||
|
|
||||||
This seems to happen particularly often with the `pydantic` and
|
|
||||||
`numpy` packages. The most reliable solution requires several manual
|
|
||||||
steps to complete installation.
|
|
||||||
|
|
||||||
Open up a Powershell window and navigate to the `invokeai` directory
|
|
||||||
created by the installer. Then give the following series of commands:
|
|
||||||
|
|
||||||
```cmd
|
|
||||||
rm .\.venv -r -force
|
|
||||||
python -mvenv .venv
|
|
||||||
.\.venv\Scripts\activate
|
|
||||||
pip install invokeai
|
|
||||||
invokeai-configure --yes --root .
|
|
||||||
```
|
|
||||||
|
|
||||||
If you see anything marked as an error during this process please stop
|
|
||||||
and seek help on the Discord [installation support
|
|
||||||
channel](https://discord.com/channels/1020123559063990373/1041391462190956654). A
|
|
||||||
few warning messages are OK.
|
|
||||||
|
|
||||||
If you are updating from a previous version, this should restore your
|
|
||||||
system to a working state. If you are installing from scratch, there
|
|
||||||
is one additional command to give:
|
|
||||||
|
|
||||||
```cmd
|
|
||||||
wget -O invoke.bat https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/installer/templates/invoke.bat.in
|
|
||||||
```
|
|
||||||
|
|
||||||
This will create the `invoke.bat` script needed to launch InvokeAI and
|
|
||||||
its related programs.
|
|
||||||
|
|
||||||
|
|
||||||
### _Stable Diffusion XL Generation Fails after Trying to Load unet_
|
|
||||||
|
|
||||||
InvokeAI is working in other respects, but when trying to generate
|
|
||||||
images with Stable Diffusion XL you get a "Server Error". The text log
|
|
||||||
in the launch window contains this log line above several more lines of
|
|
||||||
error messages:
|
|
||||||
|
|
||||||
```INFO --> Loading model:D:\LONG\PATH\TO\MODEL, type sdxl:main:unet```
|
|
||||||
|
|
||||||
This failure mode occurs when there is a network glitch during
|
|
||||||
downloading the very large SDXL model.
|
|
||||||
|
|
||||||
To address this, first go to the Web Model Manager and delete the
|
|
||||||
Stable-Diffusion-XL-base-1.X model. Then navigate to HuggingFace and
|
|
||||||
manually download the .safetensors version of the model. The 1.0
|
|
||||||
version is located at
|
|
||||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main
|
|
||||||
and the file is named `sd_xl_base_1.0.safetensors`.
|
|
||||||
|
|
||||||
Save this file to disk and then reenter the Model Manager. Navigate to
|
|
||||||
Import Models->Add Model, then type (or drag-and-drop) the path to the
|
|
||||||
.safetensors file. Press "Add Model".
|
|
||||||
|
|
||||||
### _Package dependency conflicts_
|
### _Package dependency conflicts_
|
||||||
|
|
||||||
If you have previously installed InvokeAI or another Stable Diffusion
|
If you have previously installed InvokeAI or another Stable Diffusion
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ gaming):
|
|||||||
|
|
||||||
* **Python**
|
* **Python**
|
||||||
|
|
||||||
version 3.9 through 3.11
|
version 3.9 or 3.10 (3.11 is not recommended).
|
||||||
|
|
||||||
* **CUDA Tools**
|
* **CUDA Tools**
|
||||||
|
|
||||||
@@ -65,7 +65,7 @@ gaming):
|
|||||||
To install InvokeAI with virtual environments and the PIP package
|
To install InvokeAI with virtual environments and the PIP package
|
||||||
manager, please follow these steps:
|
manager, please follow these steps:
|
||||||
|
|
||||||
1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
|
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install
|
||||||
procedure depends on this and will not work with other versions:
|
procedure depends on this and will not work with other versions:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
# Overview
|
---
|
||||||
|
title: Overview
|
||||||
|
---
|
||||||
|
|
||||||
We offer several ways to install InvokeAI, each one suited to your
|
We offer several ways to install InvokeAI, each one suited to your
|
||||||
experience and preferences. We suggest that everyone start by
|
experience and preferences. We suggest that everyone start by
|
||||||
@@ -13,56 +15,6 @@ See the [troubleshooting
|
|||||||
section](010_INSTALL_AUTOMATED.md#troubleshooting) of the automated
|
section](010_INSTALL_AUTOMATED.md#troubleshooting) of the automated
|
||||||
install guide for frequently-encountered installation issues.
|
install guide for frequently-encountered installation issues.
|
||||||
|
|
||||||
This fork is supported across Linux, Windows and Macintosh. Linux users can use
|
|
||||||
either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm
|
|
||||||
driver).
|
|
||||||
|
|
||||||
### [Installation Getting Started Guide](installation)
|
|
||||||
#### **[Automated Installer](010_INSTALL_AUTOMATED.md)**
|
|
||||||
✅ This is the recommended installation method for first-time users.
|
|
||||||
#### [Manual Installation](020_INSTALL_MANUAL.md)
|
|
||||||
This method is recommended for experienced users and developers
|
|
||||||
#### [Docker Installation](040_INSTALL_DOCKER.md)
|
|
||||||
This method is recommended for those familiar with running Docker containers
|
|
||||||
### Other Installation Guides
|
|
||||||
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
|
|
||||||
- [XFormers](installation/070_INSTALL_XFORMERS.md)
|
|
||||||
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
|
|
||||||
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
|
|
||||||
|
|
||||||
## :fontawesome-solid-computer: Hardware Requirements
|
|
||||||
|
|
||||||
### :octicons-cpu-24: System
|
|
||||||
|
|
||||||
You wil need one of the following:
|
|
||||||
|
|
||||||
- :simple-nvidia: An NVIDIA-based graphics card with 4 GB or more VRAM memory.
|
|
||||||
- :simple-amd: An AMD-based graphics card with 4 GB or more VRAM memory (Linux
|
|
||||||
only)
|
|
||||||
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
|
|
||||||
|
|
||||||
** SDXL 1.0 Requirements*
|
|
||||||
To use SDXL, user must have one of the following:
|
|
||||||
- :simple-nvidia: An NVIDIA-based graphics card with 8 GB or more VRAM memory.
|
|
||||||
- :simple-amd: An AMD-based graphics card with 16 GB or more VRAM memory (Linux
|
|
||||||
only)
|
|
||||||
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
|
|
||||||
|
|
||||||
|
|
||||||
### :fontawesome-solid-memory: Memory and Disk
|
|
||||||
|
|
||||||
- At least 12 GB Main Memory RAM.
|
|
||||||
- At least 18 GB of free disk space for the machine learning model, Python, and
|
|
||||||
all its dependencies.
|
|
||||||
|
|
||||||
We do **not recommend** the following video cards due to issues with their
|
|
||||||
running in half-precision mode and having insufficient VRAM to render 512x512
|
|
||||||
images in full-precision mode:
|
|
||||||
|
|
||||||
- NVIDIA 10xx series cards such as the 1080ti
|
|
||||||
- GTX 1650 series cards
|
|
||||||
- GTX 1660 series cards
|
|
||||||
|
|
||||||
## Installation options
|
## Installation options
|
||||||
|
|
||||||
1. [Automated Installer](010_INSTALL_AUTOMATED.md)
|
1. [Automated Installer](010_INSTALL_AUTOMATED.md)
|
||||||
@@ -14,28 +14,23 @@ The nodes linked below have been developed and contributed by members of the Inv
|
|||||||
|
|
||||||
## List of Nodes
|
## List of Nodes
|
||||||
|
|
||||||
### FaceTools
|
### Face Mask
|
||||||
|
|
||||||
**Description:** FaceTools is a collection of nodes created to manipulate faces as you would in Unified Canvas. It includes FaceMask, FaceOff, and FacePlace. FaceMask autodetects a face in the image using MediaPipe and creates a mask from it. FaceOff similarly detects a face, then takes the face off of the image by adding a square bounding box around it and cropping/scaling it. FacePlace puts the bounded face image from FaceOff back onto the original image. Using these nodes with other inpainting node(s), you can put new faces on existing things, put new things around existing faces, and work closer with a face as a bounded image. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control on FaceMask and FaceOff. See GitHub repository below for usage examples.
|
**Description:** This node autodetects a face in the image using MediaPipe and masks it by making it transparent. Via outpainting you can swap faces with other faces, or invert the mask and swap things around the face with other things. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control. The node also outputs an all-white mask in the same dimensions as the input image. This is needed by the inpaint node (and unified canvas) for outpainting.
|
||||||
|
|
||||||
**Node Link:** https://github.com/ymgenesis/FaceTools/
|
**Node Link:** https://github.com/ymgenesis/InvokeAI/blob/facemaskmediapipe/invokeai/app/invocations/facemask.py
|
||||||
|
|
||||||
**FaceMask Output Examples**
|
**Example Node Graph:** https://www.mediafire.com/file/gohn5sb1bfp8use/21-July_2023-FaceMask.json/file
|
||||||
|
|
||||||

|
**Output Examples**
|
||||||

|
|
||||||

|
|
||||||
|
|
||||||
<hr>
|

|
||||||
|

|
||||||
### Ideal Size
|

|
||||||
|

|
||||||
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
|
|
||||||
|
|
||||||
**Node Link:** https://github.com/JPPhoto/ideal-size-node
|
|
||||||
|
|
||||||
--------------------------------
|
--------------------------------
|
||||||
### Example Node Template
|
### Super Cool Node Template
|
||||||
|
|
||||||
**Description:** This node allows you to do super cool things with InvokeAI.
|
**Description:** This node allows you to do super cool things with InvokeAI.
|
||||||
|
|
||||||
@@ -45,9 +40,13 @@ The nodes linked below have been developed and contributed by members of the Inv
|
|||||||
|
|
||||||
**Output Examples**
|
**Output Examples**
|
||||||
|
|
||||||
{: style="height:115px;width:240px"}
|

|
||||||
|
|
||||||
|
### Ideal Size
|
||||||
|
|
||||||
|
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
|
||||||
|
|
||||||
|
**Node Link:** https://github.com/JPPhoto/ideal-size-node
|
||||||
|
|
||||||
## Help
|
## Help
|
||||||
If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).
|
If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
25
flake.lock
generated
25
flake.lock
generated
@@ -1,25 +0,0 @@
|
|||||||
{
|
|
||||||
"nodes": {
|
|
||||||
"nixpkgs": {
|
|
||||||
"locked": {
|
|
||||||
"lastModified": 1690630721,
|
|
||||||
"narHash": "sha256-Y04onHyBQT4Erfr2fc82dbJTfXGYrf4V0ysLUYnPOP8=",
|
|
||||||
"owner": "NixOS",
|
|
||||||
"repo": "nixpkgs",
|
|
||||||
"rev": "d2b52322f35597c62abf56de91b0236746b2a03d",
|
|
||||||
"type": "github"
|
|
||||||
},
|
|
||||||
"original": {
|
|
||||||
"id": "nixpkgs",
|
|
||||||
"type": "indirect"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root": {
|
|
||||||
"inputs": {
|
|
||||||
"nixpkgs": "nixpkgs"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"root": "root",
|
|
||||||
"version": 7
|
|
||||||
}
|
|
||||||
81
flake.nix
81
flake.nix
@@ -1,81 +0,0 @@
|
|||||||
# Important note: this flake does not attempt to create a fully isolated, 'pure'
|
|
||||||
# Python environment for InvokeAI. Instead, it depends on local invocations of
|
|
||||||
# virtualenv/pip to install the required (binary) packages, most importantly the
|
|
||||||
# prebuilt binary pytorch packages with CUDA support.
|
|
||||||
# ML Python packages with CUDA support, like pytorch, are notoriously expensive
|
|
||||||
# to compile so it's purposefuly not what this flake does.
|
|
||||||
|
|
||||||
{
|
|
||||||
description = "An (impure) flake to develop on InvokeAI.";
|
|
||||||
|
|
||||||
outputs = { self, nixpkgs }:
|
|
||||||
let
|
|
||||||
system = "x86_64-linux";
|
|
||||||
pkgs = import nixpkgs {
|
|
||||||
inherit system;
|
|
||||||
config.allowUnfree = true;
|
|
||||||
};
|
|
||||||
|
|
||||||
python = pkgs.python310;
|
|
||||||
|
|
||||||
mkShell = { dir, install }:
|
|
||||||
let
|
|
||||||
setupScript = pkgs.writeScript "setup-invokai" ''
|
|
||||||
# This must be sourced using 'source', not executed.
|
|
||||||
${python}/bin/python -m venv ${dir}
|
|
||||||
${dir}/bin/python -m pip install ${install}
|
|
||||||
# ${dir}/bin/python -c 'import torch; assert(torch.cuda.is_available())'
|
|
||||||
source ${dir}/bin/activate
|
|
||||||
'';
|
|
||||||
in
|
|
||||||
pkgs.mkShell rec {
|
|
||||||
buildInputs = with pkgs; [
|
|
||||||
# Backend: graphics, CUDA.
|
|
||||||
cudaPackages.cudnn
|
|
||||||
cudaPackages.cuda_nvrtc
|
|
||||||
cudatoolkit
|
|
||||||
freeglut
|
|
||||||
glib
|
|
||||||
gperf
|
|
||||||
procps
|
|
||||||
libGL
|
|
||||||
libGLU
|
|
||||||
linuxPackages.nvidia_x11
|
|
||||||
python
|
|
||||||
stdenv.cc
|
|
||||||
stdenv.cc.cc.lib
|
|
||||||
xorg.libX11
|
|
||||||
xorg.libXext
|
|
||||||
xorg.libXi
|
|
||||||
xorg.libXmu
|
|
||||||
xorg.libXrandr
|
|
||||||
xorg.libXv
|
|
||||||
zlib
|
|
||||||
|
|
||||||
# Pre-commit hooks.
|
|
||||||
black
|
|
||||||
|
|
||||||
# Frontend.
|
|
||||||
yarn
|
|
||||||
nodejs
|
|
||||||
];
|
|
||||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
|
|
||||||
CUDA_PATH = pkgs.cudatoolkit;
|
|
||||||
EXTRA_LDFLAGS = "-L${pkgs.linuxPackages.nvidia_x11}/lib";
|
|
||||||
shellHook = ''
|
|
||||||
if [[ -f "${dir}/bin/activate" ]]; then
|
|
||||||
source "${dir}/bin/activate"
|
|
||||||
echo "Using Python: $(which python)"
|
|
||||||
else
|
|
||||||
echo "Use 'source ${setupScript}' to set up the environment."
|
|
||||||
fi
|
|
||||||
'';
|
|
||||||
};
|
|
||||||
in
|
|
||||||
{
|
|
||||||
devShells.${system} = rec {
|
|
||||||
develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
|
|
||||||
default = develop;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -9,20 +9,16 @@ cd $scriptdir
|
|||||||
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
|
||||||
|
|
||||||
MINIMUM_PYTHON_VERSION=3.9.0
|
MINIMUM_PYTHON_VERSION=3.9.0
|
||||||
MAXIMUM_PYTHON_VERSION=3.11.100
|
MAXIMUM_PYTHON_VERSION=3.11.0
|
||||||
PYTHON=""
|
PYTHON=""
|
||||||
for candidate in python3.11 python3.10 python3.9 python3 python ; do
|
for candidate in python3.10 python3.9 python3 python ; do
|
||||||
if ppath=`which $candidate`; then
|
if ppath=`which $candidate`; then
|
||||||
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
|
|
||||||
# we check that this found executable can actually run
|
|
||||||
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
|
|
||||||
|
|
||||||
python_version=$($ppath -V | awk '{ print $2 }')
|
python_version=$($ppath -V | awk '{ print $2 }')
|
||||||
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
|
||||||
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then
|
||||||
PYTHON=$ppath
|
PYTHON=$ppath
|
||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from pathlib import Path
|
|||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
SUPPORTED_PYTHON = ">=3.9.0,<=3.11.100"
|
SUPPORTED_PYTHON = ">=3.9.0,<3.11"
|
||||||
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
|
||||||
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
|
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
|
||||||
|
|
||||||
@@ -141,16 +141,15 @@ class Installer:
|
|||||||
|
|
||||||
# upgrade pip in Python 3.9 environments
|
# upgrade pip in Python 3.9 environments
|
||||||
if int(platform.python_version_tuple()[1]) == 9:
|
if int(platform.python_version_tuple()[1]) == 9:
|
||||||
|
|
||||||
from plumbum import FG, local
|
from plumbum import FG, local
|
||||||
|
|
||||||
pip = local[get_pip_from_venv(venv_dir)]
|
pip = local[get_pip_from_venv(venv_dir)]
|
||||||
pip["install", "--upgrade", "pip"] & FG
|
pip[ "install", "--upgrade", "pip"] & FG
|
||||||
|
|
||||||
return venv_dir
|
return venv_dir
|
||||||
|
|
||||||
def install(
|
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
|
||||||
self, root: str = "~/invokeai", version: str = "latest", yes_to_all=False, find_links: Path = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Install the InvokeAI application into the given runtime path
|
Install the InvokeAI application into the given runtime path
|
||||||
|
|
||||||
@@ -168,8 +167,7 @@ class Installer:
|
|||||||
|
|
||||||
messages.welcome()
|
messages.welcome()
|
||||||
|
|
||||||
default_path = os.environ.get("INVOKEAI_ROOT") or Path(root).expanduser().resolve()
|
self.dest = Path(root).expanduser().resolve() if yes_to_all else messages.dest_path(root)
|
||||||
self.dest = default_path if yes_to_all else messages.dest_path(root)
|
|
||||||
|
|
||||||
# create the venv for the app
|
# create the venv for the app
|
||||||
self.venv = self.app_venv()
|
self.venv = self.app_venv()
|
||||||
@@ -177,7 +175,7 @@ class Installer:
|
|||||||
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
|
||||||
|
|
||||||
# install dependencies and the InvokeAI application
|
# install dependencies and the InvokeAI application
|
||||||
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
|
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
|
||||||
self.instance.install(
|
self.instance.install(
|
||||||
extra_index_url,
|
extra_index_url,
|
||||||
optional_modules,
|
optional_modules,
|
||||||
@@ -190,7 +188,6 @@ class Installer:
|
|||||||
# run through the configuration flow
|
# run through the configuration flow
|
||||||
self.instance.configure()
|
self.instance.configure()
|
||||||
|
|
||||||
|
|
||||||
class InvokeAiInstance:
|
class InvokeAiInstance:
|
||||||
"""
|
"""
|
||||||
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
|
||||||
@@ -199,6 +196,7 @@ class InvokeAiInstance:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
|
||||||
|
|
||||||
self.runtime = runtime
|
self.runtime = runtime
|
||||||
self.venv = venv
|
self.venv = venv
|
||||||
self.pip = get_pip_from_venv(venv)
|
self.pip = get_pip_from_venv(venv)
|
||||||
@@ -249,9 +247,6 @@ class InvokeAiInstance:
|
|||||||
pip[
|
pip[
|
||||||
"install",
|
"install",
|
||||||
"--require-virtualenv",
|
"--require-virtualenv",
|
||||||
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
|
|
||||||
"urllib3~=1.26.0",
|
|
||||||
"requests~=2.28.0",
|
|
||||||
"torch~=2.0.0",
|
"torch~=2.0.0",
|
||||||
"torchmetrics==0.11.4",
|
"torchmetrics==0.11.4",
|
||||||
"torchvision>=0.14.1",
|
"torchvision>=0.14.1",
|
||||||
@@ -317,7 +312,7 @@ class InvokeAiInstance:
|
|||||||
"install",
|
"install",
|
||||||
"--require-virtualenv",
|
"--require-virtualenv",
|
||||||
"--use-pep517",
|
"--use-pep517",
|
||||||
str(src) + (optional_modules if optional_modules else ""),
|
str(src)+(optional_modules if optional_modules else ''),
|
||||||
"--find-links" if find_links is not None else None,
|
"--find-links" if find_links is not None else None,
|
||||||
find_links,
|
find_links,
|
||||||
"--extra-index-url" if extra_index_url is not None else None,
|
"--extra-index-url" if extra_index_url is not None else None,
|
||||||
@@ -334,12 +329,12 @@ class InvokeAiInstance:
|
|||||||
|
|
||||||
# set sys.argv to a consistent state
|
# set sys.argv to a consistent state
|
||||||
new_argv = [sys.argv[0]]
|
new_argv = [sys.argv[0]]
|
||||||
for i in range(1, len(sys.argv)):
|
for i in range(1,len(sys.argv)):
|
||||||
el = sys.argv[i]
|
el = sys.argv[i]
|
||||||
if el in ["-r", "--root"]:
|
if el in ['-r','--root']:
|
||||||
new_argv.append(el)
|
new_argv.append(el)
|
||||||
new_argv.append(sys.argv[i + 1])
|
new_argv.append(sys.argv[i+1])
|
||||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
elif el in ['-y','--yes','--yes-to-all']:
|
||||||
new_argv.append(el)
|
new_argv.append(el)
|
||||||
sys.argv = new_argv
|
sys.argv = new_argv
|
||||||
|
|
||||||
@@ -358,16 +353,16 @@ class InvokeAiInstance:
|
|||||||
invokeai_configure()
|
invokeai_configure()
|
||||||
succeeded = True
|
succeeded = True
|
||||||
except requests.exceptions.ConnectionError as e:
|
except requests.exceptions.ConnectionError as e:
|
||||||
print(f"\nA network error was encountered during configuration and download: {str(e)}")
|
print(f'\nA network error was encountered during configuration and download: {str(e)}')
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
|
print(f'\nAn OS error was encountered during configuration and download: {str(e)}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
|
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}')
|
||||||
finally:
|
finally:
|
||||||
if not succeeded:
|
if not succeeded:
|
||||||
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
|
||||||
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
|
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.')
|
||||||
print("Alternatively you can relaunch the installer.")
|
print('Alternatively you can relaunch the installer.')
|
||||||
|
|
||||||
def install_user_scripts(self):
|
def install_user_scripts(self):
|
||||||
"""
|
"""
|
||||||
@@ -376,11 +371,11 @@ class InvokeAiInstance:
|
|||||||
|
|
||||||
ext = "bat" if OS == "Windows" else "sh"
|
ext = "bat" if OS == "Windows" else "sh"
|
||||||
|
|
||||||
# scripts = ['invoke', 'update']
|
#scripts = ['invoke', 'update']
|
||||||
scripts = ["invoke"]
|
scripts = ['invoke']
|
||||||
|
|
||||||
for script in scripts:
|
for script in scripts:
|
||||||
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
|
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in"
|
||||||
dest = self.runtime / f"{script}.{ext}"
|
dest = self.runtime / f"{script}.{ext}"
|
||||||
shutil.copy(src, dest)
|
shutil.copy(src, dest)
|
||||||
os.chmod(dest, 0o0755)
|
os.chmod(dest, 0o0755)
|
||||||
@@ -425,7 +420,11 @@ def set_sys_path(venv_path: Path) -> None:
|
|||||||
# filter out any paths in sys.path that may be system- or user-wide
|
# filter out any paths in sys.path that may be system- or user-wide
|
||||||
# but leave the temporary bootstrap virtualenv as it contains packages we
|
# but leave the temporary bootstrap virtualenv as it contains packages we
|
||||||
# temporarily need at install time
|
# temporarily need at install time
|
||||||
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
|
sys.path = list(filter(
|
||||||
|
lambda p: not p.endswith("-packages")
|
||||||
|
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
|
||||||
|
sys.path
|
||||||
|
))
|
||||||
|
|
||||||
# determine site-packages/lib directory location for the venv
|
# determine site-packages/lib directory location for the venv
|
||||||
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
|
||||||
@@ -434,7 +433,7 @@ def set_sys_path(venv_path: Path) -> None:
|
|||||||
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
|
||||||
|
|
||||||
|
|
||||||
def get_torch_source() -> (Union[str, None], str):
|
def get_torch_source() -> (Union[str, None],str):
|
||||||
"""
|
"""
|
||||||
Determine the extra index URL for pip to use for torch installation.
|
Determine the extra index URL for pip to use for torch installation.
|
||||||
This depends on the OS and the graphics accelerator in use.
|
This depends on the OS and the graphics accelerator in use.
|
||||||
@@ -455,19 +454,16 @@ def get_torch_source() -> (Union[str, None], str):
|
|||||||
device = graphical_accelerator()
|
device = graphical_accelerator()
|
||||||
|
|
||||||
url = None
|
url = None
|
||||||
optional_modules = "[onnx]"
|
optional_modules = None
|
||||||
if OS == "Linux":
|
if OS == "Linux":
|
||||||
if device == "rocm":
|
if device == "rocm":
|
||||||
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
url = "https://download.pytorch.org/whl/rocm5.4.2"
|
||||||
elif device == "cpu":
|
elif device == "cpu":
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
|
||||||
if device == "cuda":
|
if device == 'cuda':
|
||||||
url = "https://download.pytorch.org/whl/cu117"
|
url = 'https://download.pytorch.org/whl/cu117'
|
||||||
optional_modules = "[xformers,onnx-cuda]"
|
optional_modules = '[xformers]'
|
||||||
if device == "cuda_and_dml":
|
|
||||||
url = "https://download.pytorch.org/whl/cu117"
|
|
||||||
optional_modules = "[xformers,onnx-directml]"
|
|
||||||
|
|
||||||
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ InvokeAI Installer
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from installer import Installer
|
from installer import Installer
|
||||||
|
|
||||||
@@ -16,7 +15,7 @@ if __name__ == "__main__":
|
|||||||
dest="root",
|
dest="root",
|
||||||
type=str,
|
type=str,
|
||||||
help="Destination path for installation",
|
help="Destination path for installation",
|
||||||
default=os.environ.get("INVOKEAI_ROOT") or "~/invokeai",
|
default="~/invokeai",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-y",
|
"-y",
|
||||||
|
|||||||
@@ -36,15 +36,13 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
def welcome():
|
def welcome():
|
||||||
|
|
||||||
@group()
|
@group()
|
||||||
def text():
|
def text():
|
||||||
if (platform_specific := _platform_specific_help()) != "":
|
if (platform_specific := _platform_specific_help()) != "":
|
||||||
yield platform_specific
|
yield platform_specific
|
||||||
yield ""
|
yield ""
|
||||||
yield Text.from_markup(
|
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center")
|
||||||
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
|
|
||||||
justify="center",
|
|
||||||
)
|
|
||||||
|
|
||||||
console.rule()
|
console.rule()
|
||||||
print(
|
print(
|
||||||
@@ -60,7 +58,6 @@ def welcome():
|
|||||||
)
|
)
|
||||||
console.line()
|
console.line()
|
||||||
|
|
||||||
|
|
||||||
def confirm_install(dest: Path) -> bool:
|
def confirm_install(dest: Path) -> bool:
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||||
@@ -95,6 +92,7 @@ def dest_path(dest=None) -> Path:
|
|||||||
dest_confirmed = confirm_install(dest)
|
dest_confirmed = confirm_install(dest)
|
||||||
|
|
||||||
while not dest_confirmed:
|
while not dest_confirmed:
|
||||||
|
|
||||||
# if the given destination already exists, the starting point for browsing is its parent directory.
|
# if the given destination already exists, the starting point for browsing is its parent directory.
|
||||||
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
|
||||||
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
|
||||||
@@ -167,10 +165,6 @@ def graphical_accelerator():
|
|||||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
|
||||||
"cuda",
|
"cuda",
|
||||||
)
|
)
|
||||||
nvidia_with_dml = (
|
|
||||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
|
||||||
"cuda_and_dml",
|
|
||||||
)
|
|
||||||
amd = (
|
amd = (
|
||||||
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
"an [gold1 b]AMD[/] GPU (using ROCm™)",
|
||||||
"rocm",
|
"rocm",
|
||||||
@@ -185,7 +179,7 @@ def graphical_accelerator():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if OS == "Windows":
|
if OS == "Windows":
|
||||||
options = [nvidia, nvidia_with_dml, cpu]
|
options = [nvidia, cpu]
|
||||||
if OS == "Linux":
|
if OS == "Linux":
|
||||||
options = [nvidia, amd, cpu]
|
options = [nvidia, amd, cpu]
|
||||||
elif OS == "Darwin":
|
elif OS == "Darwin":
|
||||||
@@ -306,20 +300,15 @@ def introduction() -> None:
|
|||||||
)
|
)
|
||||||
console.line(2)
|
console.line(2)
|
||||||
|
|
||||||
|
def _platform_specific_help()->str:
|
||||||
def _platform_specific_help() -> str:
|
|
||||||
if OS == "Darwin":
|
if OS == "Darwin":
|
||||||
text = Text.from_markup(
|
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""")
|
||||||
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
|
|
||||||
)
|
|
||||||
elif OS == "Windows":
|
elif OS == "Windows":
|
||||||
text = Text.from_markup(
|
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
||||||
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
|
|
||||||
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
|
||||||
enable long path support on your system.
|
enable long path support on your system.
|
||||||
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
|
||||||
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
|
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""")
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
text = ""
|
text = ""
|
||||||
return text
|
return text
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ IF /I "%choice%" == "1" (
|
|||||||
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
|
||||||
) ELSE IF /I "%choice%" == "7" (
|
) ELSE IF /I "%choice%" == "7" (
|
||||||
echo Running invokeai-configure...
|
echo Running invokeai-configure...
|
||||||
python .venv\Scripts\invokeai-configure.exe --yes --skip-sd-weight
|
python .venv\Scripts\invokeai-configure.exe --yes --default_only
|
||||||
) ELSE IF /I "%choice%" == "8" (
|
) ELSE IF /I "%choice%" == "8" (
|
||||||
echo Developer Console
|
echo Developer Console
|
||||||
echo Python command is:
|
echo Python command is:
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ do_choice() {
|
|||||||
7)
|
7)
|
||||||
clear
|
clear
|
||||||
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
|
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
|
||||||
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only --skip-sd-weights
|
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
|
||||||
;;
|
;;
|
||||||
8)
|
8)
|
||||||
clear
|
clear
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
import os
|
import os
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
@@ -55,7 +54,7 @@ logger = InvokeAILogger.getLogger()
|
|||||||
class ApiDependencies:
|
class ApiDependencies:
|
||||||
"""Contains and initializes all dependencies for the API"""
|
"""Contains and initializes all dependencies for the API"""
|
||||||
|
|
||||||
invoker: Optional[Invoker] = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||||
@@ -79,7 +78,9 @@ class ApiDependencies:
|
|||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
latents = ForwardCacheLatentsStorage(
|
||||||
|
DiskLatentsStorage(f"{output_folder}/latents")
|
||||||
|
)
|
||||||
|
|
||||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||||
@@ -124,7 +125,9 @@ class ApiDependencies:
|
|||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=db_location, table_name="graphs"
|
||||||
|
),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ from invokeai.version import __version__
|
|||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend.util.logging import logging
|
from invokeai.backend.util.logging import logging
|
||||||
|
|
||||||
|
|
||||||
class LogLevel(int, Enum):
|
class LogLevel(int, Enum):
|
||||||
NotSet = logging.NOTSET
|
NotSet = logging.NOTSET
|
||||||
Debug = logging.DEBUG
|
Debug = logging.DEBUG
|
||||||
@@ -24,12 +23,10 @@ class LogLevel(int, Enum):
|
|||||||
Error = logging.ERROR
|
Error = logging.ERROR
|
||||||
Critical = logging.CRITICAL
|
Critical = logging.CRITICAL
|
||||||
|
|
||||||
|
|
||||||
class Upscaler(BaseModel):
|
class Upscaler(BaseModel):
|
||||||
upscaling_method: str = Field(description="Name of upscaling method")
|
upscaling_method: str = Field(description="Name of upscaling method")
|
||||||
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
|
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
|
||||||
|
|
||||||
|
|
||||||
app_router = APIRouter(prefix="/v1/app", tags=["app"])
|
app_router = APIRouter(prefix="/v1/app", tags=["app"])
|
||||||
|
|
||||||
|
|
||||||
@@ -48,29 +45,37 @@ class AppConfig(BaseModel):
|
|||||||
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
|
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
|
||||||
|
|
||||||
|
|
||||||
@app_router.get("/version", operation_id="app_version", status_code=200, response_model=AppVersion)
|
@app_router.get(
|
||||||
|
"/version", operation_id="app_version", status_code=200, response_model=AppVersion
|
||||||
|
)
|
||||||
async def get_version() -> AppVersion:
|
async def get_version() -> AppVersion:
|
||||||
return AppVersion(version=__version__)
|
return AppVersion(version=__version__)
|
||||||
|
|
||||||
|
|
||||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
@app_router.get(
|
||||||
|
"/config", operation_id="get_config", status_code=200, response_model=AppConfig
|
||||||
|
)
|
||||||
async def get_config() -> AppConfig:
|
async def get_config() -> AppConfig:
|
||||||
infill_methods = ["tile"]
|
infill_methods = ['tile']
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infill_methods.append("patchmatch")
|
infill_methods.append('patchmatch')
|
||||||
|
|
||||||
|
|
||||||
upscaling_models = []
|
upscaling_models = []
|
||||||
for model in typing.get_args(ESRGAN_MODELS):
|
for model in typing.get_args(ESRGAN_MODELS):
|
||||||
upscaling_models.append(str(Path(model).stem))
|
upscaling_models.append(str(Path(model).stem))
|
||||||
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
|
upscaler = Upscaler(
|
||||||
|
upscaling_method = 'esrgan',
|
||||||
|
upscaling_models = upscaling_models
|
||||||
|
)
|
||||||
|
|
||||||
nsfw_methods = []
|
nsfw_methods = []
|
||||||
if SafetyChecker.safety_checker_available():
|
if SafetyChecker.safety_checker_available():
|
||||||
nsfw_methods.append("nsfw_checker")
|
nsfw_methods.append('nsfw_checker')
|
||||||
|
|
||||||
watermarking_methods = []
|
watermarking_methods = []
|
||||||
if InvisibleWatermark.invisible_watermark_available():
|
if InvisibleWatermark.invisible_watermark_available():
|
||||||
watermarking_methods.append("invisible_watermark")
|
watermarking_methods.append('invisible_watermark')
|
||||||
|
|
||||||
return AppConfig(
|
return AppConfig(
|
||||||
infill_methods=infill_methods,
|
infill_methods=infill_methods,
|
||||||
@@ -79,26 +84,25 @@ async def get_config() -> AppConfig:
|
|||||||
watermarking_methods=watermarking_methods,
|
watermarking_methods=watermarking_methods,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app_router.get(
|
@app_router.get(
|
||||||
"/logging",
|
"/logging",
|
||||||
operation_id="get_log_level",
|
operation_id="get_log_level",
|
||||||
responses={200: {"description": "The operation was successful"}},
|
responses={200: {"description" : "The operation was successful"}},
|
||||||
response_model=LogLevel,
|
response_model = LogLevel,
|
||||||
)
|
)
|
||||||
async def get_log_level() -> LogLevel:
|
async def get_log_level(
|
||||||
|
) -> LogLevel:
|
||||||
"""Returns the log level"""
|
"""Returns the log level"""
|
||||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||||
|
|
||||||
|
|
||||||
@app_router.post(
|
@app_router.post(
|
||||||
"/logging",
|
"/logging",
|
||||||
operation_id="set_log_level",
|
operation_id="set_log_level",
|
||||||
responses={200: {"description": "The operation was successful"}},
|
responses={200: {"description" : "The operation was successful"}},
|
||||||
response_model=LogLevel,
|
response_model = LogLevel,
|
||||||
)
|
)
|
||||||
async def set_log_level(
|
async def set_log_level(
|
||||||
level: LogLevel = Body(description="New log verbosity level"),
|
level: LogLevel = Body(description="New log verbosity level"),
|
||||||
) -> LogLevel:
|
) -> LogLevel:
|
||||||
"""Sets the log verbosity level"""
|
"""Sets the log verbosity level"""
|
||||||
ApiDependencies.invoker.services.logger.setLevel(level)
|
ApiDependencies.invoker.services.logger.setLevel(level)
|
||||||
|
|||||||
@@ -52,3 +52,4 @@ async def remove_board_image(
|
|||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ class DeleteBoardResult(BaseModel):
|
|||||||
deleted_board_images: list[str] = Field(
|
deleted_board_images: list[str] = Field(
|
||||||
description="The image names of the board-images relationships that were deleted."
|
description="The image names of the board-images relationships that were deleted."
|
||||||
)
|
)
|
||||||
deleted_images: list[str] = Field(description="The names of the images that were deleted.")
|
deleted_images: list[str] = Field(
|
||||||
|
description="The names of the images that were deleted."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@boards_router.post(
|
@boards_router.post(
|
||||||
@@ -71,16 +73,22 @@ async def update_board(
|
|||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
"""Updates a board"""
|
"""Updates a board"""
|
||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
result = ApiDependencies.invoker.services.boards.update(
|
||||||
|
board_id=board_id, changes=changes
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||||
|
|
||||||
|
|
||||||
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
|
@boards_router.delete(
|
||||||
|
"/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult
|
||||||
|
)
|
||||||
async def delete_board(
|
async def delete_board(
|
||||||
board_id: str = Path(description="The id of board to delete"),
|
board_id: str = Path(description="The id of board to delete"),
|
||||||
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
|
include_images: Optional[bool] = Query(
|
||||||
|
description="Permanently delete all images on the board", default=False
|
||||||
|
),
|
||||||
) -> DeleteBoardResult:
|
) -> DeleteBoardResult:
|
||||||
"""Deletes a board"""
|
"""Deletes a board"""
|
||||||
try:
|
try:
|
||||||
@@ -88,7 +96,9 @@ async def delete_board(
|
|||||||
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||||
board_id=board_id
|
board_id=board_id
|
||||||
)
|
)
|
||||||
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
|
ApiDependencies.invoker.services.images.delete_images_on_board(
|
||||||
|
board_id=board_id
|
||||||
|
)
|
||||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||||
return DeleteBoardResult(
|
return DeleteBoardResult(
|
||||||
board_id=board_id,
|
board_id=board_id,
|
||||||
@@ -117,7 +127,9 @@ async def delete_board(
|
|||||||
async def list_boards(
|
async def list_boards(
|
||||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||||
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
limit: Optional[int] = Query(
|
||||||
|
default=None, description="The number of boards per page"
|
||||||
|
),
|
||||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||||
"""Gets a list of boards"""
|
"""Gets a list of boards"""
|
||||||
if all:
|
if all:
|
||||||
|
|||||||
@@ -40,9 +40,15 @@ async def upload_image(
|
|||||||
response: Response,
|
response: Response,
|
||||||
image_category: ImageCategory = Query(description="The category of the image"),
|
image_category: ImageCategory = Query(description="The category of the image"),
|
||||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
board_id: Optional[str] = Query(
|
||||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
default=None, description="The board to add this image to, if any"
|
||||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
),
|
||||||
|
session_id: Optional[str] = Query(
|
||||||
|
default=None, description="The session ID associated with this upload, if any"
|
||||||
|
),
|
||||||
|
crop_visible: Optional[bool] = Query(
|
||||||
|
default=False, description="Whether to crop the image"
|
||||||
|
),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Uploads an image"""
|
"""Uploads an image"""
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
@@ -109,7 +115,9 @@ async def clear_intermediates() -> int:
|
|||||||
)
|
)
|
||||||
async def update_image(
|
async def update_image(
|
||||||
image_name: str = Path(description="The name of the image to update"),
|
image_name: str = Path(description="The name of the image to update"),
|
||||||
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
|
image_changes: ImageRecordChanges = Body(
|
||||||
|
description="The changes to apply to the image"
|
||||||
|
),
|
||||||
) -> ImageDTO:
|
) -> ImageDTO:
|
||||||
"""Updates an image"""
|
"""Updates an image"""
|
||||||
|
|
||||||
@@ -204,11 +212,15 @@ async def get_image_thumbnail(
|
|||||||
"""Gets a thumbnail image file"""
|
"""Gets a thumbnail image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
|
image_name, thumbnail=True
|
||||||
|
)
|
||||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
response = FileResponse(
|
||||||
|
path, media_type="image/webp", content_disposition_type="inline"
|
||||||
|
)
|
||||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -227,7 +239,9 @@ async def get_image_urls(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(image_name, thumbnail=True)
|
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||||
|
image_name, thumbnail=True
|
||||||
|
)
|
||||||
return ImageUrlsDTO(
|
return ImageUrlsDTO(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
image_url=image_url,
|
image_url=image_url,
|
||||||
@@ -243,9 +257,15 @@ async def get_image_urls(
|
|||||||
response_model=OffsetPaginatedResults[ImageDTO],
|
response_model=OffsetPaginatedResults[ImageDTO],
|
||||||
)
|
)
|
||||||
async def list_image_dtos(
|
async def list_image_dtos(
|
||||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
image_origin: Optional[ResourceOrigin] = Query(
|
||||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
default=None, description="The origin of images to list."
|
||||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
),
|
||||||
|
categories: Optional[list[ImageCategory]] = Query(
|
||||||
|
default=None, description="The categories of image to include."
|
||||||
|
),
|
||||||
|
is_intermediate: Optional[bool] = Query(
|
||||||
|
default=None, description="Whether to list intermediate images."
|
||||||
|
),
|
||||||
board_id: Optional[str] = Query(
|
board_id: Optional[str] = Query(
|
||||||
default=None,
|
default=None,
|
||||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||||
|
|||||||
@@ -28,52 +28,49 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_models",
|
operation_id="list_models",
|
||||||
responses={200: {"model": ModelsList}},
|
responses={200: {"model": ModelsList }},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
if base_models and len(base_models) > 0:
|
if base_models and len(base_models)>0:
|
||||||
models_raw = list()
|
models_raw = list()
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@models_router.patch(
|
@models_router.patch(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={
|
responses={200: {"description" : "The model was updated successfully"},
|
||||||
200: {"description": "The model was updated successfully"},
|
400: {"description" : "Bad request"},
|
||||||
400: {"description": "Bad request"},
|
404: {"description" : "The model could not be found"},
|
||||||
404: {"description": "The model could not be found"},
|
409: {"description" : "There is already a model corresponding to the new name"},
|
||||||
409: {"description": "There is already a model corresponding to the new name"},
|
},
|
||||||
},
|
status_code = 200,
|
||||||
status_code=200,
|
response_model = UpdateModelResponse,
|
||||||
response_model=UpdateModelResponse,
|
|
||||||
)
|
)
|
||||||
async def update_model(
|
async def update_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> UpdateModelResponse:
|
) -> UpdateModelResponse:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -84,13 +81,13 @@ async def update_model(
|
|||||||
# rename operation requested
|
# rename operation requested
|
||||||
if info.model_name != model_name or info.base_model != base_model:
|
if info.model_name != model_name or info.base_model != base_model:
|
||||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
base_model=base_model,
|
base_model = base_model,
|
||||||
model_type=model_type,
|
model_type = model_type,
|
||||||
model_name=model_name,
|
model_name = model_name,
|
||||||
new_name=info.model_name,
|
new_name = info.model_name,
|
||||||
new_base=info.base_model,
|
new_base = info.base_model,
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
|
||||||
# update information to support an update of attributes
|
# update information to support an update of attributes
|
||||||
model_name = info.model_name
|
model_name = info.model_name
|
||||||
base_model = info.base_model
|
base_model = info.base_model
|
||||||
@@ -99,13 +96,14 @@ async def update_model(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
if new_info.get("path") != previous_info.get(
|
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
|
||||||
"path"
|
info.path = new_info.get('path')
|
||||||
): # model manager moved model path during rename - don't overwrite it
|
|
||||||
info.path = new_info.get("path")
|
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.update_model(
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
model_attributes=info.dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
@@ -125,35 +123,34 @@ async def update_model(
|
|||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses={
|
responses= {
|
||||||
201: {"description": "The model imported successfully"},
|
201: {"description" : "The model imported successfully"},
|
||||||
404: {"description": "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
415: {"description": "Unrecognized file/folder format"},
|
415: {"description" : "Unrecognized file/folder format"},
|
||||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse,
|
response_model=ImportModelResponse
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||||
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
),
|
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||||
|
|
||||||
items_to_import = {location}
|
items_to_import = {location}
|
||||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
items_to_import = items_to_import,
|
||||||
|
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||||
)
|
)
|
||||||
info = installed_models.get(location)
|
info = installed_models.get(location)
|
||||||
|
|
||||||
@@ -161,9 +158,11 @@ async def import_model(
|
|||||||
logger.error("Import failed")
|
logger.error("Import failed")
|
||||||
raise HTTPException(status_code=415)
|
raise HTTPException(status_code=415)
|
||||||
|
|
||||||
logger.info(f"Successfully imported {location}, got {info}")
|
logger.info(f'Successfully imported {location}, got {info}')
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
model_name=info.name,
|
||||||
|
base_model=info.base_model,
|
||||||
|
model_type=info.model_type
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
|
||||||
@@ -177,33 +176,37 @@ async def import_model(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/add",
|
"/add",
|
||||||
operation_id="add_model",
|
operation_id="add_model",
|
||||||
responses={
|
responses= {
|
||||||
201: {"description": "The model added successfully"},
|
201: {"description" : "The model added successfully"},
|
||||||
404: {"description": "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
||||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse,
|
response_model=ImportModelResponse
|
||||||
)
|
)
|
||||||
async def add_model(
|
async def add_model(
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
ApiDependencies.invoker.services.model_manager.add_model(
|
||||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
info.model_name,
|
||||||
|
info.base_model,
|
||||||
|
info.model_type,
|
||||||
|
model_attributes = info.dict()
|
||||||
)
|
)
|
||||||
logger.info(f"Successfully added {info.model_name}")
|
logger.info(f'Successfully added {info.model_name}')
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
model_name=info.model_name,
|
||||||
|
base_model=info.base_model,
|
||||||
|
model_type=info.model_type
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
@@ -217,62 +220,62 @@ async def add_model(
|
|||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
responses={
|
||||||
status_code=204,
|
204: { "description": "Model deleted successfully" },
|
||||||
response_model=None,
|
404: { "description": "Model not found" }
|
||||||
|
},
|
||||||
|
status_code = 204,
|
||||||
|
response_model = None,
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def delete_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Delete Model"""
|
"""Delete Model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(
|
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
||||||
model_name, base_model=base_model, model_type=model_type
|
base_model = base_model,
|
||||||
)
|
model_type = model_type
|
||||||
|
)
|
||||||
logger.info(f"Deleted model: {model_name}")
|
logger.info(f"Deleted model: {model_name}")
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/convert/{base_model}/{model_type}/{model_name}",
|
"/convert/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="convert_model",
|
operation_id="convert_model",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "Model converted successfully"},
|
200: { "description": "Model converted successfully" },
|
||||||
400: {"description": "Bad request"},
|
400: {"description" : "Bad request" },
|
||||||
404: {"description": "Model not found"},
|
404: { "description": "Model not found" },
|
||||||
},
|
},
|
||||||
status_code=200,
|
status_code = 200,
|
||||||
response_model=ConvertModelResponse,
|
response_model = ConvertModelResponse,
|
||||||
)
|
)
|
||||||
async def convert_model(
|
async def convert_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
convert_dest_directory: Optional[str] = Query(
|
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
||||||
default=None, description="Save the converted model to the designated directory"
|
|
||||||
),
|
|
||||||
) -> ConvertModelResponse:
|
) -> ConvertModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Converting model: {model_name}")
|
logger.info(f"Converting model: {model_name}")
|
||||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||||
model_name,
|
base_model = base_model,
|
||||||
base_model=base_model,
|
model_type = model_type,
|
||||||
model_type=model_type,
|
convert_dest_directory = dest,
|
||||||
convert_dest_directory=dest,
|
)
|
||||||
)
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
base_model = base_model,
|
||||||
model_name, base_model=base_model, model_type=model_type
|
model_type = model_type)
|
||||||
)
|
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
@@ -280,37 +283,34 @@ async def convert_model(
|
|||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/search",
|
"/search",
|
||||||
operation_id="search_for_models",
|
operation_id="search_for_models",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "Directory searched successfully"},
|
200: { "description": "Directory searched successfully" },
|
||||||
404: {"description": "Invalid directory path"},
|
404: { "description": "Invalid directory path" },
|
||||||
},
|
},
|
||||||
status_code=200,
|
status_code = 200,
|
||||||
response_model=List[pathlib.Path],
|
response_model = List[pathlib.Path]
|
||||||
)
|
)
|
||||||
async def search_for_models(
|
async def search_for_models(
|
||||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||||
) -> List[pathlib.Path]:
|
)->List[pathlib.Path]:
|
||||||
if not search_path.is_dir():
|
if not search_path.is_dir():
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
|
||||||
)
|
|
||||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/ckpt_confs",
|
"/ckpt_confs",
|
||||||
operation_id="list_ckpt_configs",
|
operation_id="list_ckpt_configs",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "paths retrieved successfully"},
|
200: { "description" : "paths retrieved successfully" },
|
||||||
},
|
},
|
||||||
status_code=200,
|
status_code = 200,
|
||||||
response_model=List[pathlib.Path],
|
response_model = List[pathlib.Path]
|
||||||
)
|
)
|
||||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
async def list_ckpt_configs(
|
||||||
|
)->List[pathlib.Path]:
|
||||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||||
|
|
||||||
@@ -319,62 +319,55 @@ async def list_ckpt_configs() -> List[pathlib.Path]:
|
|||||||
"/sync",
|
"/sync",
|
||||||
operation_id="sync_to_config",
|
operation_id="sync_to_config",
|
||||||
responses={
|
responses={
|
||||||
201: {"description": "synchronization successful"},
|
201: { "description": "synchronization successful" },
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code = 201,
|
||||||
response_model=bool,
|
response_model = bool
|
||||||
)
|
)
|
||||||
async def sync_to_config() -> bool:
|
async def sync_to_config(
|
||||||
|
)->bool:
|
||||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
in-memory data structures with disk data structures."""
|
in-memory data structures with disk data structures."""
|
||||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
responses={
|
responses={
|
||||||
200: {"description": "Model converted successfully"},
|
200: { "description": "Model converted successfully" },
|
||||||
400: {"description": "Incompatible models"},
|
400: { "description": "Incompatible models" },
|
||||||
404: {"description": "One or more models not found"},
|
404: { "description": "One or more models not found" },
|
||||||
},
|
},
|
||||||
status_code=200,
|
status_code = 200,
|
||||||
response_model=MergeModelResponse,
|
response_model = MergeModelResponse,
|
||||||
)
|
)
|
||||||
async def merge_models(
|
async def merge_models(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||||
force: Optional[bool] = Body(
|
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||||
description="Force merging of models created with different versions of diffusers", default=False
|
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
||||||
),
|
|
||||||
merge_dest_directory: Optional[str] = Body(
|
|
||||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
|
||||||
default=None,
|
|
||||||
),
|
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||||
model_names,
|
base_model,
|
||||||
base_model,
|
merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
merged_model_name=merged_model_name or "+".join(model_names),
|
alpha=alpha,
|
||||||
alpha=alpha,
|
interp=interp,
|
||||||
interp=interp,
|
force=force,
|
||||||
force=force,
|
merge_dest_directory = dest
|
||||||
merge_dest_directory=dest,
|
)
|
||||||
)
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
base_model = base_model,
|
||||||
result.name,
|
model_type = ModelType.Main,
|
||||||
base_model=base_model,
|
)
|
||||||
model_type=ModelType.Main,
|
|
||||||
)
|
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||||
|
|||||||
@@ -30,7 +30,9 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def create_session(
|
async def create_session(
|
||||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
graph: Optional[Graph] = Body(
|
||||||
|
default=None, description="The graph to initialize the session with"
|
||||||
|
)
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||||
@@ -49,9 +51,13 @@ async def list_sessions(
|
|||||||
) -> PaginatedResults[GraphExecutionState]:
|
) -> PaginatedResults[GraphExecutionState]:
|
||||||
"""Gets a list of sessions, optionally searching"""
|
"""Gets a list of sessions, optionally searching"""
|
||||||
if query == "":
|
if query == "":
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||||
|
page, per_page
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
result = ApiDependencies.invoker.services.graph_execution_manager.search(
|
||||||
|
query, page, per_page
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -85,9 +91,9 @@ async def get_session(
|
|||||||
)
|
)
|
||||||
async def add_node(
|
async def add_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
node: Annotated[
|
||||||
description="The node to add"
|
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||||
),
|
] = Body(description="The node to add"),
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Adds a node to the graph"""
|
"""Adds a node to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
@@ -118,9 +124,9 @@ async def add_node(
|
|||||||
async def update_node(
|
async def update_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
node_path: str = Path(description="The path to the node in the graph"),
|
node_path: str = Path(description="The path to the node in the graph"),
|
||||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
node: Annotated[
|
||||||
description="The new node"
|
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||||
),
|
] = Body(description="The new node"),
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Updates a node in the graph and removes all linked edges"""
|
"""Updates a node in the graph and removes all linked edges"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
@@ -224,7 +230,7 @@ async def delete_edge(
|
|||||||
try:
|
try:
|
||||||
edge = Edge(
|
edge = Edge(
|
||||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||||
destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||||
)
|
)
|
||||||
session.delete_edge(edge)
|
session.delete_edge(edge)
|
||||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||||
@@ -249,7 +255,9 @@ async def delete_edge(
|
|||||||
)
|
)
|
||||||
async def invoke_session(
|
async def invoke_session(
|
||||||
session_id: str = Path(description="The id of the session to invoke"),
|
session_id: str = Path(description="The id of the session to invoke"),
|
||||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
all: bool = Query(
|
||||||
|
default=False, description="Whether or not to invoke all remaining invocations"
|
||||||
|
),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Invokes a session"""
|
"""Invokes a session"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
@@ -266,7 +274,9 @@ async def invoke_session(
|
|||||||
@session_router.delete(
|
@session_router.delete(
|
||||||
"/{session_id}/invoke",
|
"/{session_id}/invoke",
|
||||||
operation_id="cancel_session_invoke",
|
operation_id="cancel_session_invoke",
|
||||||
responses={202: {"description": "The invocation is canceled"}},
|
responses={
|
||||||
|
202: {"description": "The invocation is canceled"}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
async def cancel_session_invoke(
|
async def cancel_session_invoke(
|
||||||
session_id: str = Path(description="The id of the session to cancel"),
|
session_id: str = Path(description="The id of the session to cancel"),
|
||||||
|
|||||||
@@ -16,7 +16,9 @@ class SocketIO:
|
|||||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||||
|
|
||||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
local_handler.register(
|
||||||
|
event_name=EventServiceBase.session_event, _func=self._handle_session_event
|
||||||
|
)
|
||||||
|
|
||||||
async def _handle_session_event(self, event: Event):
|
async def _handle_session_event(self, event: Event):
|
||||||
await self.__sio.emit(
|
await self.__sio.emit(
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
import logging
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
@@ -17,10 +16,9 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
# This should come early so that modules can log their initialization properly
|
#This should come early so that modules can log their initialization properly
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
app_config = InvokeAIAppConfig.get_config()
|
app_config = InvokeAIAppConfig.get_config()
|
||||||
app_config.parse_args()
|
app_config.parse_args()
|
||||||
logger = InvokeAILogger.getLogger(config=app_config)
|
logger = InvokeAILogger.getLogger(config=app_config)
|
||||||
@@ -29,7 +27,7 @@ from invokeai.version.invokeai_version import __version__
|
|||||||
# we call this early so that the message appears before
|
# we call this early so that the message appears before
|
||||||
# other invokeai initialization messages
|
# other invokeai initialization messages
|
||||||
if app_config.version:
|
if app_config.version:
|
||||||
print(f"InvokeAI version {__version__}")
|
print(f'InvokeAI version {__version__}')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
@@ -43,14 +41,13 @@ from .invocations.baseinvocation import BaseInvocation
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import invokeai.backend.util.hotfixes
|
import invokeai.backend.util.hotfixes
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
# fix for windows mimetypes registry entries being borked
|
# fix for windows mimetypes registry entries being borked
|
||||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type('application/javascript', '.js')
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type('text/css', '.css')
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
@@ -60,13 +57,14 @@ app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
|||||||
event_handler_id: int = id(app)
|
event_handler_id: int = id(app)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
EventHandlerASGIMiddleware,
|
EventHandlerASGIMiddleware,
|
||||||
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
|
handlers=[
|
||||||
|
local_handler
|
||||||
|
], # TODO: consider doing this in services to support different configurations
|
||||||
middleware_id=event_handler_id,
|
middleware_id=event_handler_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
socket_io = SocketIO(app)
|
socket_io = SocketIO(app)
|
||||||
|
|
||||||
|
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
@@ -78,7 +76,9 @@ async def startup_event():
|
|||||||
allow_headers=app_config.allow_headers,
|
allow_headers=app_config.allow_headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
ApiDependencies.initialize(
|
||||||
|
config=app_config, event_handler_id=event_handler_id, logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Shut down threads
|
# Shut down threads
|
||||||
@@ -103,8 +103,7 @@ app.include_router(boards.boards_router, prefix="/api")
|
|||||||
|
|
||||||
app.include_router(board_images.board_images_router, prefix="/api")
|
app.include_router(board_images.board_images_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(app_info.app_router, prefix="/api")
|
app.include_router(app_info.app_router, prefix='/api')
|
||||||
|
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
@@ -145,7 +144,6 @@ def custom_openapi():
|
|||||||
invoker_schema["output"] = outputs_ref
|
invoker_schema["output"] = outputs_ref
|
||||||
|
|
||||||
from invokeai.backend.model_management.models import get_model_config_enums
|
from invokeai.backend.model_management.models import get_model_config_enums
|
||||||
|
|
||||||
for model_config_format_enum in set(get_model_config_enums()):
|
for model_config_format_enum in set(get_model_config_enums()):
|
||||||
name = model_config_format_enum.__qualname__
|
name = model_config_format_enum.__qualname__
|
||||||
|
|
||||||
@@ -168,8 +166,7 @@ def custom_openapi():
|
|||||||
app.openapi = custom_openapi
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
# Override API doc favicons
|
# Override API doc favicons
|
||||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
|
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
def overridden_swagger():
|
def overridden_swagger():
|
||||||
@@ -190,8 +187,11 @@ def overridden_redoc():
|
|||||||
|
|
||||||
|
|
||||||
# Must mount *after* the other routes else it borks em
|
# Must mount *after* the other routes else it borks em
|
||||||
app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=True), name="ui")
|
app.mount("/",
|
||||||
|
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
|
||||||
|
html=True
|
||||||
|
), name="ui"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke_api():
|
def invoke_api():
|
||||||
def find_port(port: int):
|
def find_port(port: int):
|
||||||
@@ -205,33 +205,17 @@ def invoke_api():
|
|||||||
return port
|
return port
|
||||||
|
|
||||||
from invokeai.backend.install.check_root import check_invokeai_root
|
from invokeai.backend.install.check_root import check_invokeai_root
|
||||||
|
|
||||||
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
|
||||||
|
|
||||||
port = find_port(app_config.port)
|
port = find_port(app_config.port)
|
||||||
if port != app_config.port:
|
if port != app_config.port:
|
||||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||||
|
|
||||||
# Start our own event loop for eventing usage
|
# Start our own event loop for eventing usage
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
config = uvicorn.Config(
|
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop)
|
||||||
app=app,
|
# Use access_log to turn off logging
|
||||||
host=app_config.host,
|
|
||||||
port=port,
|
|
||||||
loop=loop,
|
|
||||||
log_level=app_config.log_level,
|
|
||||||
)
|
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
|
|
||||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
|
||||||
for logname in ["uvicorn.access", "uvicorn"]:
|
|
||||||
l = logging.getLogger(logname)
|
|
||||||
l.handlers.clear()
|
|
||||||
for ch in logger.handlers:
|
|
||||||
l.addHandler(ch)
|
|
||||||
|
|
||||||
loop.run_until_complete(server.serve())
|
loop.run_until_complete(server.serve())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
invoke_api()
|
invoke_api()
|
||||||
|
|||||||
@@ -14,14 +14,8 @@ from ..services.graph import GraphExecutionState, LibraryGraph, Edge
|
|||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
|
|
||||||
|
|
||||||
def add_field_argument(command_parser, name: str, field, default_override=None):
|
def add_field_argument(command_parser, name: str, field, default_override = None):
|
||||||
default = (
|
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||||
default_override
|
|
||||||
if default_override is not None
|
|
||||||
else field.default
|
|
||||||
if field.default_factory is None
|
|
||||||
else field.default_factory()
|
|
||||||
)
|
|
||||||
if get_origin(field.type_) == Literal:
|
if get_origin(field.type_) == Literal:
|
||||||
allowed_values = get_args(field.type_)
|
allowed_values = get_args(field.type_)
|
||||||
allowed_types = set()
|
allowed_types = set()
|
||||||
@@ -53,8 +47,8 @@ def add_parsers(
|
|||||||
commands: list[type],
|
commands: list[type],
|
||||||
command_field: str = "type",
|
command_field: str = "type",
|
||||||
exclude_fields: list[str] = ["id", "type"],
|
exclude_fields: list[str] = ["id", "type"],
|
||||||
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
|
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
|
||||||
):
|
):
|
||||||
"""Adds parsers for each command to the subparsers"""
|
"""Adds parsers for each command to the subparsers"""
|
||||||
|
|
||||||
# Create subparsers for each command
|
# Create subparsers for each command
|
||||||
@@ -67,7 +61,7 @@ def add_parsers(
|
|||||||
add_arguments(command_parser)
|
add_arguments(command_parser)
|
||||||
|
|
||||||
# Convert all fields to arguments
|
# Convert all fields to arguments
|
||||||
fields = command.__fields__ # type: ignore
|
fields = command.__fields__ # type: ignore
|
||||||
for name, field in fields.items():
|
for name, field in fields.items():
|
||||||
if name in exclude_fields:
|
if name in exclude_fields:
|
||||||
continue
|
continue
|
||||||
@@ -76,7 +70,9 @@ def add_parsers(
|
|||||||
|
|
||||||
|
|
||||||
def add_graph_parsers(
|
def add_graph_parsers(
|
||||||
subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
subparsers,
|
||||||
|
graphs: list[LibraryGraph],
|
||||||
|
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
|
||||||
):
|
):
|
||||||
for graph in graphs:
|
for graph in graphs:
|
||||||
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
command_parser = subparsers.add_parser(graph.name, help=graph.description)
|
||||||
@@ -132,7 +128,6 @@ class CliContext:
|
|||||||
|
|
||||||
class ExitCli(Exception):
|
class ExitCli(Exception):
|
||||||
"""Exception to exit the CLI"""
|
"""Exception to exit the CLI"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -160,7 +155,7 @@ class BaseCommand(ABC, BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_commands_map(cls):
|
def get_commands_map(cls):
|
||||||
# Get the type strings out of the literals and into a dictionary
|
# Get the type strings out of the literals and into a dictionary
|
||||||
return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
|
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
@@ -170,8 +165,7 @@ class BaseCommand(ABC, BaseModel):
|
|||||||
|
|
||||||
class ExitCommand(BaseCommand):
|
class ExitCommand(BaseCommand):
|
||||||
"""Exits the CLI"""
|
"""Exits the CLI"""
|
||||||
|
type: Literal['exit'] = 'exit'
|
||||||
type: Literal["exit"] = "exit"
|
|
||||||
|
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
raise ExitCli()
|
raise ExitCli()
|
||||||
@@ -179,8 +173,7 @@ class ExitCommand(BaseCommand):
|
|||||||
|
|
||||||
class HelpCommand(BaseCommand):
|
class HelpCommand(BaseCommand):
|
||||||
"""Shows help"""
|
"""Shows help"""
|
||||||
|
type: Literal['help'] = 'help'
|
||||||
type: Literal["help"] = "help"
|
|
||||||
|
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
context.parser.print_help()
|
context.parser.print_help()
|
||||||
@@ -190,7 +183,11 @@ def get_graph_execution_history(
|
|||||||
graph_execution_state: GraphExecutionState,
|
graph_execution_state: GraphExecutionState,
|
||||||
) -> Iterable[str]:
|
) -> Iterable[str]:
|
||||||
"""Gets the history of fully-executed invocations for a graph execution"""
|
"""Gets the history of fully-executed invocations for a graph execution"""
|
||||||
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
|
return (
|
||||||
|
n
|
||||||
|
for n in reversed(graph_execution_state.executed_history)
|
||||||
|
if n in graph_execution_state.graph.nodes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_invocation_command(invocation) -> str:
|
def get_invocation_command(invocation) -> str:
|
||||||
@@ -221,8 +218,7 @@ def get_invocation_command(invocation) -> str:
|
|||||||
|
|
||||||
class HistoryCommand(BaseCommand):
|
class HistoryCommand(BaseCommand):
|
||||||
"""Shows the invocation history"""
|
"""Shows the invocation history"""
|
||||||
|
type: Literal['history'] = 'history'
|
||||||
type: Literal["history"] = "history"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@@ -239,8 +235,7 @@ class HistoryCommand(BaseCommand):
|
|||||||
|
|
||||||
class SetDefaultCommand(BaseCommand):
|
class SetDefaultCommand(BaseCommand):
|
||||||
"""Sets a default value for a field"""
|
"""Sets a default value for a field"""
|
||||||
|
type: Literal['default'] = 'default'
|
||||||
type: Literal["default"] = "default"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@@ -258,8 +253,7 @@ class SetDefaultCommand(BaseCommand):
|
|||||||
|
|
||||||
class DrawGraphCommand(BaseCommand):
|
class DrawGraphCommand(BaseCommand):
|
||||||
"""Debugs a graph"""
|
"""Debugs a graph"""
|
||||||
|
type: Literal['draw_graph'] = 'draw_graph'
|
||||||
type: Literal["draw_graph"] = "draw_graph"
|
|
||||||
|
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||||
@@ -277,8 +271,7 @@ class DrawGraphCommand(BaseCommand):
|
|||||||
|
|
||||||
class DrawExecutionGraphCommand(BaseCommand):
|
class DrawExecutionGraphCommand(BaseCommand):
|
||||||
"""Debugs an execution graph"""
|
"""Debugs an execution graph"""
|
||||||
|
type: Literal['draw_xgraph'] = 'draw_xgraph'
|
||||||
type: Literal["draw_xgraph"] = "draw_xgraph"
|
|
||||||
|
|
||||||
def run(self, context: CliContext) -> None:
|
def run(self, context: CliContext) -> None:
|
||||||
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
|
||||||
@@ -293,7 +286,6 @@ class DrawExecutionGraphCommand(BaseCommand):
|
|||||||
plt.axis("off")
|
plt.axis("off")
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
class SortedHelpFormatter(argparse.HelpFormatter):
|
class SortedHelpFormatter(argparse.HelpFormatter):
|
||||||
def _iter_indented_subactions(self, action):
|
def _iter_indented_subactions(self, action):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ from ..services.invocation_services import InvocationServices
|
|||||||
# singleton object, class variable
|
# singleton object, class variable
|
||||||
completer = None
|
completer = None
|
||||||
|
|
||||||
|
|
||||||
class Completer(object):
|
class Completer(object):
|
||||||
|
|
||||||
def __init__(self, model_manager: ModelManager):
|
def __init__(self, model_manager: ModelManager):
|
||||||
self.commands = self.get_commands()
|
self.commands = self.get_commands()
|
||||||
self.matches = None
|
self.matches = None
|
||||||
@@ -56,17 +56,17 @@ class Completer(object):
|
|||||||
return match
|
return match
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_commands(self) -> List[object]:
|
def get_commands(self)->List[object]:
|
||||||
"""
|
"""
|
||||||
Return a list of all the client commands and invocations.
|
Return a list of all the client commands and invocations.
|
||||||
"""
|
"""
|
||||||
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
|
||||||
|
|
||||||
def get_current_command(self, buffer: str) -> tuple[str, str]:
|
def get_current_command(self, buffer: str)->tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Parse the readline buffer to find the most recent command and its switch.
|
Parse the readline buffer to find the most recent command and its switch.
|
||||||
"""
|
"""
|
||||||
if len(buffer) == 0:
|
if len(buffer)==0:
|
||||||
return None, None
|
return None, None
|
||||||
tokens = shlex.split(buffer)
|
tokens = shlex.split(buffer)
|
||||||
command = None
|
command = None
|
||||||
@@ -78,11 +78,11 @@ class Completer(object):
|
|||||||
else:
|
else:
|
||||||
switch = t
|
switch = t
|
||||||
# don't try to autocomplete switches that are already complete
|
# don't try to autocomplete switches that are already complete
|
||||||
if switch and buffer.endswith(" "):
|
if switch and buffer.endswith(' '):
|
||||||
switch = None
|
switch=None
|
||||||
return command or "", switch or ""
|
return command or '', switch or ''
|
||||||
|
|
||||||
def parse_commands(self) -> Dict[str, List[str]]:
|
def parse_commands(self)->Dict[str, List[str]]:
|
||||||
"""
|
"""
|
||||||
Return a dict in which the keys are the command name
|
Return a dict in which the keys are the command name
|
||||||
and the values are the parameters the command takes.
|
and the values are the parameters the command takes.
|
||||||
@@ -90,11 +90,11 @@ class Completer(object):
|
|||||||
result = dict()
|
result = dict()
|
||||||
for command in self.commands:
|
for command in self.commands:
|
||||||
hints = get_type_hints(command)
|
hints = get_type_hints(command)
|
||||||
name = get_args(hints["type"])[0]
|
name = get_args(hints['type'])[0]
|
||||||
result.update({name: hints})
|
result.update({name:hints})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_command_options(self, command: str, switch: str) -> List[str]:
|
def get_command_options(self, command: str, switch: str)->List[str]:
|
||||||
"""
|
"""
|
||||||
Return all the parameters that can be passed to the command as
|
Return all the parameters that can be passed to the command as
|
||||||
command-line switches. Returns None if the command is unrecognized.
|
command-line switches. Returns None if the command is unrecognized.
|
||||||
@@ -105,28 +105,25 @@ class Completer(object):
|
|||||||
|
|
||||||
# handle switches in the format "-foo=bar"
|
# handle switches in the format "-foo=bar"
|
||||||
argument = None
|
argument = None
|
||||||
if switch and "=" in switch:
|
if switch and '=' in switch:
|
||||||
switch, argument = switch.split("=")
|
switch, argument = switch.split('=')
|
||||||
|
|
||||||
parameter = switch.strip("-")
|
parameter = switch.strip('-')
|
||||||
if parameter in parsed_commands[command]:
|
if parameter in parsed_commands[command]:
|
||||||
if argument is None:
|
if argument is None:
|
||||||
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
||||||
else:
|
else:
|
||||||
return [
|
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
|
||||||
f"--{parameter}={x}"
|
|
||||||
for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
return [f"--{x}" for x in parsed_commands[command].keys()]
|
return [f"--{x}" for x in parsed_commands[command].keys()]
|
||||||
|
|
||||||
def get_parameter_options(self, parameter: str, typehint) -> List[str]:
|
def get_parameter_options(self, parameter: str, typehint)->List[str]:
|
||||||
"""
|
"""
|
||||||
Given a parameter type (such as Literal), offers autocompletions.
|
Given a parameter type (such as Literal), offers autocompletions.
|
||||||
"""
|
"""
|
||||||
if get_origin(typehint) == Literal:
|
if get_origin(typehint) == Literal:
|
||||||
return get_args(typehint)
|
return get_args(typehint)
|
||||||
if parameter == "model":
|
if parameter == 'model':
|
||||||
return self.manager.model_names()
|
return self.manager.model_names()
|
||||||
|
|
||||||
def _pre_input_hook(self):
|
def _pre_input_hook(self):
|
||||||
@@ -135,7 +132,6 @@ class Completer(object):
|
|||||||
readline.redisplay()
|
readline.redisplay()
|
||||||
self.linebuffer = None
|
self.linebuffer = None
|
||||||
|
|
||||||
|
|
||||||
def set_autocompleter(services: InvocationServices) -> Completer:
|
def set_autocompleter(services: InvocationServices) -> Completer:
|
||||||
global completer
|
global completer
|
||||||
|
|
||||||
@@ -166,6 +162,8 @@ def set_autocompleter(services: InvocationServices) -> Completer:
|
|||||||
pass
|
pass
|
||||||
except OSError: # file likely corrupted
|
except OSError: # file likely corrupted
|
||||||
newname = f"{histfile}.old"
|
newname = f"{histfile}.old"
|
||||||
logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
|
logger.error(
|
||||||
|
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
|
||||||
|
)
|
||||||
histfile.replace(Path(newname))
|
histfile.replace(Path(newname))
|
||||||
atexit.register(readline.write_history_file, histfile)
|
atexit.register(readline.write_history_file, histfile)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from pydantic.fields import Field
|
|||||||
# This should come early so that the logger can pick up its configuration options
|
# This should come early so that the logger can pick up its configuration options
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
config.parse_args()
|
config.parse_args()
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
@@ -21,7 +20,7 @@ from invokeai.version.invokeai_version import __version__
|
|||||||
|
|
||||||
# we call this early so that the message appears before other invokeai initialization messages
|
# we call this early so that the message appears before other invokeai initialization messages
|
||||||
if config.version:
|
if config.version:
|
||||||
print(f"InvokeAI version {__version__}")
|
print(f'InvokeAI version {__version__}')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import (
|
from invokeai.app.services.board_image_record_storage import (
|
||||||
@@ -37,21 +36,18 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|||||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
|
from .services.default_graphs import (default_text_to_image_graph_id,
|
||||||
|
create_system_graphs)
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
|
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
||||||
|
SortedHelpFormatter, add_graph_parsers, add_parsers)
|
||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.graph import (
|
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
|
||||||
Edge,
|
GraphInvocation, LibraryGraph,
|
||||||
EdgeConnection,
|
are_connection_types_compatible)
|
||||||
GraphExecutionState,
|
|
||||||
GraphInvocation,
|
|
||||||
LibraryGraph,
|
|
||||||
are_connection_types_compatible,
|
|
||||||
)
|
|
||||||
from .services.image_file_storage import DiskImageFileStorage
|
from .services.image_file_storage import DiskImageFileStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
@@ -62,7 +58,6 @@ from .services.sqlite import SqliteItemStorage
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import invokeai.backend.util.hotfixes
|
import invokeai.backend.util.hotfixes
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
import invokeai.backend.util.mps_fixes
|
import invokeai.backend.util.mps_fixes
|
||||||
|
|
||||||
@@ -74,7 +69,6 @@ class CliCommand(BaseModel):
|
|||||||
class InvalidArgs(Exception):
|
class InvalidArgs(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def add_invocation_args(command_parser):
|
def add_invocation_args(command_parser):
|
||||||
# Add linking capability
|
# Add linking capability
|
||||||
command_parser.add_argument(
|
command_parser.add_argument(
|
||||||
@@ -119,7 +113,7 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class NodeField:
|
class NodeField():
|
||||||
alias: str
|
alias: str
|
||||||
node_path: str
|
node_path: str
|
||||||
field: str
|
field: str
|
||||||
@@ -132,20 +126,15 @@ class NodeField:
|
|||||||
self.field_type = field_type
|
self.field_type = field_type
|
||||||
|
|
||||||
|
|
||||||
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str, NodeField]:
|
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
|
||||||
return {k: NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
|
||||||
|
|
||||||
|
|
||||||
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||||
"""Gets the node field for the specified field alias"""
|
"""Gets the node field for the specified field alias"""
|
||||||
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
|
||||||
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
node_type = type(graph.graph.get_node(exposed_input.node_path))
|
||||||
return NodeField(
|
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
|
||||||
alias=exposed_input.alias,
|
|
||||||
node_path=f"{node_id}.{exposed_input.node_path}",
|
|
||||||
field=exposed_input.field,
|
|
||||||
field_type=get_type_hints(node_type)[exposed_input.field],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
|
||||||
@@ -153,12 +142,7 @@ def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -
|
|||||||
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
|
||||||
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
node_type = type(graph.graph.get_node(exposed_output.node_path))
|
||||||
node_output_type = node_type.get_output_type()
|
node_output_type = node_type.get_output_type()
|
||||||
return NodeField(
|
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
|
||||||
alias=exposed_output.alias,
|
|
||||||
node_path=f"{node_id}.{exposed_output.node_path}",
|
|
||||||
field=exposed_output.field,
|
|
||||||
field_type=get_type_hints(node_output_type)[exposed_output.field],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
|
||||||
@@ -181,7 +165,9 @@ def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[st
|
|||||||
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
|
||||||
|
|
||||||
|
|
||||||
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]:
|
def generate_matching_edges(
|
||||||
|
a: BaseInvocation, b: BaseInvocation, context: CliContext
|
||||||
|
) -> list[Edge]:
|
||||||
"""Generates all possible edges between two invocations"""
|
"""Generates all possible edges between two invocations"""
|
||||||
afields = get_node_outputs(a, context)
|
afields = get_node_outputs(a, context)
|
||||||
bfields = get_node_inputs(b, context)
|
bfields = get_node_inputs(b, context)
|
||||||
@@ -193,14 +179,12 @@ def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliCo
|
|||||||
matching_fields = matching_fields.difference(invalid_fields)
|
matching_fields = matching_fields.difference(invalid_fields)
|
||||||
|
|
||||||
# Validate types
|
# Validate types
|
||||||
matching_fields = [
|
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
|
||||||
f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)
|
|
||||||
]
|
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
|
||||||
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field),
|
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
|
||||||
)
|
)
|
||||||
for alias in matching_fields
|
for alias in matching_fields
|
||||||
]
|
]
|
||||||
@@ -209,7 +193,6 @@ def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliCo
|
|||||||
|
|
||||||
class SessionError(Exception):
|
class SessionError(Exception):
|
||||||
"""Raised when a session error has occurred"""
|
"""Raised when a session error has occurred"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -229,20 +212,19 @@ def invoke_all(context: CliContext):
|
|||||||
|
|
||||||
raise SessionError()
|
raise SessionError()
|
||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
logger.info(f"InvokeAI version {__version__}")
|
logger.info(f'InvokeAI version {__version__}')
|
||||||
# get the optional list of invocations to execute on the command line
|
# get the optional list of invocations to execute on the command line
|
||||||
parser = config.get_parser()
|
parser = config.get_parser()
|
||||||
parser.add_argument("commands", nargs="*")
|
parser.add_argument('commands',nargs='*')
|
||||||
invocation_commands = parser.parse_args().commands
|
invocation_commands = parser.parse_args().commands
|
||||||
|
|
||||||
# get the optional file to read commands from.
|
# get the optional file to read commands from.
|
||||||
# Simplest is to use it for STDIN
|
# Simplest is to use it for STDIN
|
||||||
if infile := config.from_file:
|
if infile := config.from_file:
|
||||||
sys.stdin = open(infile, "r")
|
sys.stdin = open(infile,"r")
|
||||||
|
|
||||||
model_manager = ModelManagerService(config, logger)
|
model_manager = ModelManagerService(config,logger)
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
@@ -252,13 +234,13 @@ def invoke_cli():
|
|||||||
db_location = ":memory:"
|
db_location = ":memory:"
|
||||||
else:
|
else:
|
||||||
db_location = config.db_path
|
db_location = config.db_path
|
||||||
db_location.parent.mkdir(parents=True, exist_ok=True)
|
db_location.parent.mkdir(parents=True,exist_ok=True)
|
||||||
|
|
||||||
logger.info(f'InvokeAI database location is "{db_location}"')
|
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||||
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
filename=db_location, table_name="graph_executions"
|
filename=db_location, table_name="graph_executions"
|
||||||
)
|
)
|
||||||
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
@@ -303,18 +285,21 @@ def invoke_cli():
|
|||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||||
images=images,
|
images=images,
|
||||||
boards=boards,
|
boards=boards,
|
||||||
board_images=board_images,
|
board_images=board_images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
|
filename=db_location, table_name="graphs"
|
||||||
|
),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
configuration=config,
|
configuration=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
system_graphs = create_system_graphs(services.graph_library)
|
system_graphs = create_system_graphs(services.graph_library)
|
||||||
system_graph_names = set([g.name for g in system_graphs])
|
system_graph_names = set([g.name for g in system_graphs])
|
||||||
set_autocompleter(services)
|
set_autocompleter(services)
|
||||||
@@ -323,7 +308,7 @@ def invoke_cli():
|
|||||||
session: GraphExecutionState = invoker.create_execution_state()
|
session: GraphExecutionState = invoker.create_execution_state()
|
||||||
parser = get_command_parser(services)
|
parser = get_command_parser(services)
|
||||||
|
|
||||||
re_negid = re.compile("^-[0-9]+$")
|
re_negid = re.compile('^-[0-9]+$')
|
||||||
|
|
||||||
# Uncomment to print out previous sessions at startup
|
# Uncomment to print out previous sessions at startup
|
||||||
# print(services.session_manager.list())
|
# print(services.session_manager.list())
|
||||||
@@ -347,7 +332,7 @@ def invoke_cli():
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Refresh the state of the session
|
# Refresh the state of the session
|
||||||
# history = list(get_graph_execution_history(context.session))
|
#history = list(get_graph_execution_history(context.session))
|
||||||
history = list(reversed(context.nodes_added))
|
history = list(reversed(context.nodes_added))
|
||||||
|
|
||||||
# Split the command for piping
|
# Split the command for piping
|
||||||
@@ -368,17 +353,17 @@ def invoke_cli():
|
|||||||
args[field_name] = field_default
|
args[field_name] = field_default
|
||||||
|
|
||||||
# Parse invocation
|
# Parse invocation
|
||||||
command: CliCommand = None # type:ignore
|
command: CliCommand = None # type:ignore
|
||||||
system_graph: Optional[LibraryGraph] = None
|
system_graph: Optional[LibraryGraph] = None
|
||||||
if args["type"] in system_graph_names:
|
if args['type'] in system_graph_names:
|
||||||
system_graph = next(filter(lambda g: g.name == args["type"], system_graphs))
|
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
|
||||||
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
|
||||||
for exposed_input in system_graph.exposed_inputs:
|
for exposed_input in system_graph.exposed_inputs:
|
||||||
if exposed_input.alias in args:
|
if exposed_input.alias in args:
|
||||||
node = invocation.graph.get_node(exposed_input.node_path)
|
node = invocation.graph.get_node(exposed_input.node_path)
|
||||||
field = exposed_input.field
|
field = exposed_input.field
|
||||||
setattr(node, field, args[exposed_input.alias])
|
setattr(node, field, args[exposed_input.alias])
|
||||||
command = CliCommand(command=invocation)
|
command = CliCommand(command = invocation)
|
||||||
context.graph_nodes[invocation.id] = system_graph.id
|
context.graph_nodes[invocation.id] = system_graph.id
|
||||||
else:
|
else:
|
||||||
args["id"] = current_id
|
args["id"] = current_id
|
||||||
@@ -400,13 +385,17 @@ def invoke_cli():
|
|||||||
# Pipe previous command output (if there was a previous command)
|
# Pipe previous command output (if there was a previous command)
|
||||||
edges: list[Edge] = list()
|
edges: list[Edge] = list()
|
||||||
if len(history) > 0 or current_id != start_id:
|
if len(history) > 0 or current_id != start_id:
|
||||||
from_id = history[0] if current_id == start_id else str(current_id - 1)
|
from_id = (
|
||||||
|
history[0] if current_id == start_id else str(current_id - 1)
|
||||||
|
)
|
||||||
from_node = (
|
from_node = (
|
||||||
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
|
||||||
if current_id != start_id
|
if current_id != start_id
|
||||||
else context.session.graph.get_node(from_id)
|
else context.session.graph.get_node(from_id)
|
||||||
)
|
)
|
||||||
matching_edges = generate_matching_edges(from_node, command.command, context)
|
matching_edges = generate_matching_edges(
|
||||||
|
from_node, command.command, context
|
||||||
|
)
|
||||||
edges.extend(matching_edges)
|
edges.extend(matching_edges)
|
||||||
|
|
||||||
# Parse provided links
|
# Parse provided links
|
||||||
@@ -417,18 +406,16 @@ def invoke_cli():
|
|||||||
node_id = str(current_id + int(node_id))
|
node_id = str(current_id + int(node_id))
|
||||||
|
|
||||||
link_node = context.session.graph.get_node(node_id)
|
link_node = context.session.graph.get_node(node_id)
|
||||||
matching_edges = generate_matching_edges(link_node, command.command, context)
|
matching_edges = generate_matching_edges(
|
||||||
|
link_node, command.command, context
|
||||||
|
)
|
||||||
matching_destinations = [e.destination for e in matching_edges]
|
matching_destinations = [e.destination for e in matching_edges]
|
||||||
edges = [e for e in edges if e.destination not in matching_destinations]
|
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||||
edges.extend(matching_edges)
|
edges.extend(matching_edges)
|
||||||
|
|
||||||
if "link" in args and args["link"]:
|
if "link" in args and args["link"]:
|
||||||
for link in args["link"]:
|
for link in args["link"]:
|
||||||
edges = [
|
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
|
||||||
e
|
|
||||||
for e in edges
|
|
||||||
if e.destination.node_id != command.command.id or e.destination.field != link[2]
|
|
||||||
]
|
|
||||||
|
|
||||||
node_id = link[0]
|
node_id = link[0]
|
||||||
if re_negid.match(node_id):
|
if re_negid.match(node_id):
|
||||||
@@ -441,7 +428,7 @@ def invoke_cli():
|
|||||||
edges.append(
|
edges.append(
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
|
||||||
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field),
|
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,5 +4,9 @@ __all__ = []
|
|||||||
|
|
||||||
dirname = os.path.dirname(os.path.abspath(__file__))
|
dirname = os.path.dirname(os.path.abspath(__file__))
|
||||||
for f in os.listdir(dirname):
|
for f in os.listdir(dirname):
|
||||||
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
|
if (
|
||||||
|
f != "__init__.py"
|
||||||
|
and os.path.isfile("%s/%s" % (dirname, f))
|
||||||
|
and f[-3:] == ".py"
|
||||||
|
):
|
||||||
__all__.append(f[:-3])
|
__all__.append(f[:-3])
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
|
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
|
||||||
|
get_type_hints)
|
||||||
|
|
||||||
from pydantic import BaseConfig, BaseModel, Field
|
from pydantic import BaseConfig, BaseModel, Field
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ from pydantic import Field, validator
|
|||||||
from invokeai.app.models.image import ImageField
|
from invokeai.app.models.image import ImageField
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext, UIConfig)
|
||||||
|
|
||||||
|
|
||||||
class IntCollectionOutput(BaseInvocationOutput):
|
class IntCollectionOutput(BaseInvocationOutput):
|
||||||
@@ -26,7 +27,8 @@ class FloatCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["float_collection"] = "float_collection"
|
type: Literal["float_collection"] = "float_collection"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[float] = Field(default=[], description="The float collection")
|
collection: list[float] = Field(
|
||||||
|
default=[], description="The float collection")
|
||||||
|
|
||||||
|
|
||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
@@ -35,7 +37,8 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
|||||||
type: Literal["image_collection"] = "image_collection"
|
type: Literal["image_collection"] = "image_collection"
|
||||||
|
|
||||||
# Outputs
|
# Outputs
|
||||||
collection: list[ImageField] = Field(default=[], description="The output images")
|
collection: list[ImageField] = Field(
|
||||||
|
default=[], description="The output images")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["type", "collection"]}
|
schema_extra = {"required": ["type", "collection"]}
|
||||||
@@ -53,7 +56,10 @@ class RangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
|
"ui": {
|
||||||
|
"title": "Range",
|
||||||
|
"tags": ["range", "integer", "collection"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@validator("stop")
|
@validator("stop")
|
||||||
@@ -63,7 +69,9 @@ class RangeInvocation(BaseInvocation):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
return IntCollectionOutput(
|
||||||
|
collection=list(range(self.start, self.stop, self.step))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RangeOfSizeInvocation(BaseInvocation):
|
class RangeOfSizeInvocation(BaseInvocation):
|
||||||
@@ -78,11 +86,18 @@ class RangeOfSizeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
|
"ui": {
|
||||||
|
"title": "Sized Range",
|
||||||
|
"tags": ["range", "integer", "size", "collection"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
|
return IntCollectionOutput(
|
||||||
|
collection=list(
|
||||||
|
range(
|
||||||
|
self.start, self.start + self.size,
|
||||||
|
self.step)))
|
||||||
|
|
||||||
|
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
@@ -92,7 +107,9 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
low: int = Field(default=0, description="The inclusive low value")
|
low: int = Field(default=0, description="The inclusive low value")
|
||||||
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
high: int = Field(
|
||||||
|
default=np.iinfo(np.int32).max, description="The exclusive high value"
|
||||||
|
)
|
||||||
size: int = Field(default=1, description="The number of values to generate")
|
size: int = Field(default=1, description="The number of values to generate")
|
||||||
seed: int = Field(
|
seed: int = Field(
|
||||||
ge=0,
|
ge=0,
|
||||||
@@ -103,12 +120,19 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
|
"ui": {
|
||||||
|
"title": "Random Range",
|
||||||
|
"tags": ["range", "integer", "random", "collection"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
rng = np.random.default_rng(self.seed)
|
rng = np.random.default_rng(self.seed)
|
||||||
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
return IntCollectionOutput(
|
||||||
|
collection=list(
|
||||||
|
rng.integers(
|
||||||
|
low=self.low, high=self.high,
|
||||||
|
size=self.size)))
|
||||||
|
|
||||||
|
|
||||||
class ImageCollectionInvocation(BaseInvocation):
|
class ImageCollectionInvocation(BaseInvocation):
|
||||||
|
|||||||
@@ -1,73 +1,66 @@
|
|||||||
from typing import Literal, Optional, Union, List, Annotated
|
from typing import Literal, Optional, Union, List, Annotated
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
|
||||||
from .model import ClipField
|
|
||||||
|
|
||||||
from ...backend.util.devices import torch_dtype
|
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import (Blend, Conjunction,
|
||||||
|
CrossAttentionControlSubstitute,
|
||||||
|
FlattenedPrompt, Fragment)
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from ...backend.model_management import ModelType
|
from ...backend.model_management import ModelType
|
||||||
from ...backend.model_management.models import ModelNotFoundException
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
class ConditioningField(BaseModel):
|
||||||
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
|
conditioning_name: Optional[str] = Field(
|
||||||
|
default=None, description="The name of conditioning data")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["conditioning_name"]}
|
schema_extra = {"required": ["conditioning_name"]}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
# type: Literal["basic_conditioning"] = "basic_conditioning"
|
#type: Literal["basic_conditioning"] = "basic_conditioning"
|
||||||
embeds: torch.Tensor
|
embeds: torch.Tensor
|
||||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||||
# weight: float
|
# weight: float
|
||||||
# mode: ConditioningAlgo
|
# mode: ConditioningAlgo
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
|
ConditioningInfoType = Annotated[
|
||||||
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
Field(discriminator="type")
|
||||||
|
]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningFieldData:
|
class ConditioningFieldData:
|
||||||
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||||
# unconditioned: Optional[torch.Tensor]
|
#unconditioned: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
#class ConditioningAlgo(str, Enum):
|
||||||
# class ConditioningAlgo(str, Enum):
|
|
||||||
# Compose = "compose"
|
# Compose = "compose"
|
||||||
# ComposeEx = "compose_ex"
|
# ComposeEx = "compose_ex"
|
||||||
# PerpNeg = "perp_neg"
|
# PerpNeg = "perp_neg"
|
||||||
|
|
||||||
|
|
||||||
class CompelOutput(BaseInvocationOutput):
|
class CompelOutput(BaseInvocationOutput):
|
||||||
"""Compel parser output"""
|
"""Compel parser output"""
|
||||||
|
|
||||||
# fmt: off
|
#fmt: off
|
||||||
type: Literal["compel_output"] = "compel_output"
|
type: Literal["compel_output"] = "compel_output"
|
||||||
|
|
||||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
@@ -81,28 +74,33 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
"ui": {
|
||||||
|
"title": "Prompt (Compel)",
|
||||||
|
"tags": ["prompt", "compel"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.dict(),
|
**self.clip.tokenizer.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**self.clip.text_encoder.dict(),
|
**self.clip.text_encoder.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
@@ -118,18 +116,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
#import traceback
|
||||||
# print(traceback.format_exc())
|
#print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||||
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||||
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
|
||||||
|
text_encoder_info as text_encoder:
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
|
||||||
text_encoder_info.context.model, _lora_loader()
|
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
|
||||||
tokenizer,
|
|
||||||
ti_manager,
|
|
||||||
), ModelPatcher.apply_clip_skip(
|
|
||||||
text_encoder_info.context.model, self.clip.skipped_layers
|
|
||||||
), text_encoder_info as text_encoder:
|
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@@ -144,12 +139,14 @@ class CompelInvocation(BaseInvocation):
|
|||||||
if context.services.configuration.log_tokenization:
|
if context.services.configuration.log_tokenization:
|
||||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(
|
||||||
|
prompt)
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
tokens_count_including_eos_bos=get_max_token_count(
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
tokenizer, conjunction),
|
||||||
)
|
cross_attention_control_args=options.get(
|
||||||
|
"cross_attention_control", None),)
|
||||||
|
|
||||||
c = c.detach().to("cpu")
|
c = c.detach().to("cpu")
|
||||||
|
|
||||||
@@ -171,26 +168,24 @@ class CompelInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(),
|
**clip_field.text_encoder.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
@@ -206,18 +201,15 @@ class SDXLPromptInvocationBase:
|
|||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
#import traceback
|
||||||
# print(traceback.format_exc())
|
#print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||||
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||||
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||||
|
text_encoder_info as text_encoder:
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
|
||||||
text_encoder_info.context.model, _lora_loader()
|
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
|
||||||
tokenizer,
|
|
||||||
ti_manager,
|
|
||||||
), ModelPatcher.apply_clip_skip(
|
|
||||||
text_encoder_info.context.model, clip_field.skipped_layers
|
|
||||||
), text_encoder_info as text_encoder:
|
|
||||||
text_inputs = tokenizer(
|
text_inputs = tokenizer(
|
||||||
prompt,
|
prompt,
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
@@ -249,22 +241,21 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**clip_field.tokenizer.dict(),
|
**clip_field.tokenizer.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
**clip_field.text_encoder.dict(),
|
**clip_field.text_encoder.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
|
lora_info = context.services.model_manager.get_model(
|
||||||
|
**lora.dict(exclude={"weight"}), context=context)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
@@ -280,25 +271,22 @@ class SDXLPromptInvocationBase:
|
|||||||
)
|
)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
# print(e)
|
# print(e)
|
||||||
# import traceback
|
#import traceback
|
||||||
# print(traceback.format_exc())
|
#print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
|
||||||
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
|
||||||
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
|
||||||
|
text_encoder_info as text_encoder:
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
|
||||||
text_encoder_info.context.model, _lora_loader()
|
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
|
||||||
tokenizer,
|
|
||||||
ti_manager,
|
|
||||||
), ModelPatcher.apply_clip_skip(
|
|
||||||
text_encoder_info.context.model, clip_field.skipped_layers
|
|
||||||
), text_encoder_info as text_encoder:
|
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=True, # TODO:
|
truncate_long_prompts=True, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=True,
|
requires_pooled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -332,7 +320,6 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
return c, c_pooled, ec
|
return c, c_pooled, ec
|
||||||
|
|
||||||
|
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
@@ -352,7 +339,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
"ui": {
|
||||||
|
"title": "SDXL Prompt (Compel)",
|
||||||
|
"tags": ["prompt", "compel"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -367,7 +360,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
target_size = (self.target_height, self.target_width)
|
target_size = (self.target_height, self.target_width)
|
||||||
|
|
||||||
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
add_time_ids = torch.tensor([
|
||||||
|
original_size + crop_coords + target_size
|
||||||
|
])
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
@@ -389,13 +384,12 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
|
||||||
|
|
||||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||||
original_width: int = Field(1024, description="")
|
original_width: int = Field(1024, description="")
|
||||||
original_height: int = Field(1024, description="")
|
original_height: int = Field(1024, description="")
|
||||||
crop_top: int = Field(0, description="")
|
crop_top: int = Field(0, description="")
|
||||||
@@ -409,7 +403,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
"ui": {
|
"ui": {
|
||||||
"title": "SDXL Refiner Prompt (Compel)",
|
"title": "SDXL Refiner Prompt (Compel)",
|
||||||
"tags": ["prompt", "compel"],
|
"tags": ["prompt", "compel"],
|
||||||
"type_hints": {"model": "model"},
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,7 +416,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
|
||||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
add_time_ids = torch.tensor([
|
||||||
|
original_size + crop_coords + (self.aesthetic_score,)
|
||||||
|
])
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
@@ -428,7 +426,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
embeds=c2,
|
embeds=c2,
|
||||||
pooled_embeds=c2_pooled,
|
pooled_embeds=c2_pooled,
|
||||||
add_time_ids=add_time_ids,
|
add_time_ids=add_time_ids,
|
||||||
extra_conditioning=ec2, # or None
|
extra_conditioning=ec2, # or None
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -442,7 +440,6 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Pass unmodified prompt to conditioning without compel processing."""
|
"""Pass unmodified prompt to conditioning without compel processing."""
|
||||||
|
|
||||||
@@ -462,7 +459,13 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
|
"ui": {
|
||||||
|
"title": "SDXL Prompt (Raw)",
|
||||||
|
"tags": ["prompt", "compel"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -477,7 +480,9 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
target_size = (self.target_height, self.target_width)
|
target_size = (self.target_height, self.target_width)
|
||||||
|
|
||||||
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
|
add_time_ids = torch.tensor([
|
||||||
|
original_size + crop_coords + target_size
|
||||||
|
])
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
@@ -499,13 +504,12 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
|
||||||
|
|
||||||
style: str = Field(default="", description="Style prompt") # TODO: ?
|
style: str = Field(default="", description="Style prompt") # TODO: ?
|
||||||
original_width: int = Field(1024, description="")
|
original_width: int = Field(1024, description="")
|
||||||
original_height: int = Field(1024, description="")
|
original_height: int = Field(1024, description="")
|
||||||
crop_top: int = Field(0, description="")
|
crop_top: int = Field(0, description="")
|
||||||
@@ -519,7 +523,9 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
"ui": {
|
"ui": {
|
||||||
"title": "SDXL Refiner Prompt (Raw)",
|
"title": "SDXL Refiner Prompt (Raw)",
|
||||||
"tags": ["prompt", "compel"],
|
"tags": ["prompt", "compel"],
|
||||||
"type_hints": {"model": "model"},
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -530,7 +536,9 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
original_size = (self.original_height, self.original_width)
|
original_size = (self.original_height, self.original_width)
|
||||||
crop_coords = (self.crop_top, self.crop_left)
|
crop_coords = (self.crop_top, self.crop_left)
|
||||||
|
|
||||||
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
|
add_time_ids = torch.tensor([
|
||||||
|
original_size + crop_coords + (self.aesthetic_score,)
|
||||||
|
])
|
||||||
|
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[
|
conditionings=[
|
||||||
@@ -538,7 +546,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
embeds=c2,
|
embeds=c2,
|
||||||
pooled_embeds=c2_pooled,
|
pooled_embeds=c2_pooled,
|
||||||
add_time_ids=add_time_ids,
|
add_time_ids=add_time_ids,
|
||||||
extra_conditioning=ec2, # or None
|
extra_conditioning=ec2, # or None
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -555,14 +563,11 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
class ClipSkipInvocationOutput(BaseInvocationOutput):
|
||||||
"""Clip skip node output"""
|
"""Clip skip node output"""
|
||||||
|
|
||||||
type: Literal["clip_skip_output"] = "clip_skip_output"
|
type: Literal["clip_skip_output"] = "clip_skip_output"
|
||||||
clip: ClipField = Field(None, description="Clip with skipped layers")
|
clip: ClipField = Field(None, description="Clip with skipped layers")
|
||||||
|
|
||||||
|
|
||||||
class ClipSkipInvocation(BaseInvocation):
|
class ClipSkipInvocation(BaseInvocation):
|
||||||
"""Skip layers in clip text_encoder model."""
|
"""Skip layers in clip text_encoder model."""
|
||||||
|
|
||||||
type: Literal["clip_skip"] = "clip_skip"
|
type: Literal["clip_skip"] = "clip_skip"
|
||||||
|
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
clip: ClipField = Field(None, description="Clip to use")
|
||||||
@@ -570,7 +575,10 @@ class ClipSkipInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
|
"ui": {
|
||||||
|
"title": "CLIP Skip",
|
||||||
|
"tags": ["clip", "skip"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||||
@@ -581,26 +589,46 @@ class ClipSkipInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
def get_max_token_count(
|
def get_max_token_count(
|
||||||
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
|
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
|
||||||
) -> int:
|
truncate_if_too_long=False) -> int:
|
||||||
if type(prompt) is Blend:
|
if type(prompt) is Blend:
|
||||||
blend: Blend = prompt
|
blend: Blend = prompt
|
||||||
return max([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in blend.prompts])
|
return max(
|
||||||
|
[
|
||||||
|
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||||
|
for p in blend.prompts
|
||||||
|
]
|
||||||
|
)
|
||||||
elif type(prompt) is Conjunction:
|
elif type(prompt) is Conjunction:
|
||||||
conjunction: Conjunction = prompt
|
conjunction: Conjunction = prompt
|
||||||
return sum([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in conjunction.prompts])
|
return sum(
|
||||||
|
[
|
||||||
|
get_max_token_count(tokenizer, p, truncate_if_too_long)
|
||||||
|
for p in conjunction.prompts
|
||||||
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
|
return len(
|
||||||
|
get_tokens_for_prompt_object(
|
||||||
|
tokenizer, prompt, truncate_if_too_long))
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
|
def get_tokens_for_prompt_object(
|
||||||
|
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
|
||||||
|
) -> List[str]:
|
||||||
if type(parsed_prompt) is Blend:
|
if type(parsed_prompt) is Blend:
|
||||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
raise ValueError(
|
||||||
|
"Blend is not supported here - you need to get tokens for each of its .children"
|
||||||
|
)
|
||||||
|
|
||||||
text_fragments = [
|
text_fragments = [
|
||||||
x.text
|
x.text
|
||||||
if type(x) is Fragment
|
if type(x) is Fragment
|
||||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
else (
|
||||||
|
" ".join([f.text for f in x.original])
|
||||||
|
if type(x) is CrossAttentionControlSubstitute
|
||||||
|
else str(x)
|
||||||
|
)
|
||||||
for x in parsed_prompt.children
|
for x in parsed_prompt.children
|
||||||
]
|
]
|
||||||
text = " ".join(text_fragments)
|
text = " ".join(text_fragments)
|
||||||
@@ -611,17 +639,25 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
|
def log_tokenization_for_conjunction(
|
||||||
|
c: Conjunction, tokenizer, display_label_prefix=None
|
||||||
|
):
|
||||||
display_label_prefix = display_label_prefix or ""
|
display_label_prefix = display_label_prefix or ""
|
||||||
for i, p in enumerate(c.prompts):
|
for i, p in enumerate(c.prompts):
|
||||||
if len(c.prompts) > 1:
|
if len(c.prompts) > 1:
|
||||||
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
|
||||||
else:
|
else:
|
||||||
this_display_label_prefix = display_label_prefix
|
this_display_label_prefix = display_label_prefix
|
||||||
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
|
log_tokenization_for_prompt_object(
|
||||||
|
p,
|
||||||
|
tokenizer,
|
||||||
|
display_label_prefix=this_display_label_prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
|
def log_tokenization_for_prompt_object(
|
||||||
|
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||||
|
):
|
||||||
display_label_prefix = display_label_prefix or ""
|
display_label_prefix = display_label_prefix or ""
|
||||||
if type(p) is Blend:
|
if type(p) is Blend:
|
||||||
blend: Blend = p
|
blend: Blend = p
|
||||||
@@ -658,10 +694,13 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text = " ".join([x.text for x in flattened_prompt.children])
|
text = " ".join([x.text for x in flattened_prompt.children])
|
||||||
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
log_tokenization_for_text(
|
||||||
|
text, tokenizer, display_label=display_label_prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
|
def log_tokenization_for_text(
|
||||||
|
text, tokenizer, display_label=None, truncate_if_too_long=False):
|
||||||
"""shows how the prompt is tokenized
|
"""shows how the prompt is tokenized
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
# but for readability it has been replaced with ' '
|
# but for readability it has been replaced with ' '
|
||||||
|
|||||||
@@ -6,29 +6,20 @@ from typing import Dict, List, Literal, Optional, Union
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from controlnet_aux import (
|
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
|
||||||
CannyDetector,
|
LeresDetector, LineartAnimeDetector,
|
||||||
ContentShuffleDetector,
|
LineartDetector, MediapipeFaceDetector,
|
||||||
HEDdetector,
|
MidasDetector, MLSDdetector, NormalBaeDetector,
|
||||||
LeresDetector,
|
OpenposeDetector, PidiNetDetector, SamDetector,
|
||||||
LineartAnimeDetector,
|
ZoeDetector)
|
||||||
LineartDetector,
|
|
||||||
MediapipeFaceDetector,
|
|
||||||
MidasDetector,
|
|
||||||
MLSDdetector,
|
|
||||||
NormalBaeDetector,
|
|
||||||
OpenposeDetector,
|
|
||||||
PidiNetDetector,
|
|
||||||
SamDetector,
|
|
||||||
ZoeDetector,
|
|
||||||
)
|
|
||||||
from controlnet_aux.util import HWC3, ade_palette
|
from controlnet_aux.util import HWC3, ade_palette
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType
|
from ...backend.model_management import BaseModelType, ModelType
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
from ..models.image import ImageOutput, PILInvocationConfig
|
from ..models.image import ImageOutput, PILInvocationConfig
|
||||||
|
|
||||||
CONTROLNET_DEFAULT_MODELS = [
|
CONTROLNET_DEFAULT_MODELS = [
|
||||||
@@ -43,6 +34,7 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
"lllyasviel/sd-controlnet-scribble",
|
"lllyasviel/sd-controlnet-scribble",
|
||||||
"lllyasviel/sd-controlnet-normal",
|
"lllyasviel/sd-controlnet-normal",
|
||||||
"lllyasviel/sd-controlnet-mlsd",
|
"lllyasviel/sd-controlnet-mlsd",
|
||||||
|
|
||||||
#############################################
|
#############################################
|
||||||
# lllyasviel sd v1.5, ControlNet v1.1 models
|
# lllyasviel sd v1.5, ControlNet v1.1 models
|
||||||
#############################################
|
#############################################
|
||||||
@@ -64,6 +56,7 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
"lllyasviel/control_v11e_sd15_shuffle",
|
"lllyasviel/control_v11e_sd15_shuffle",
|
||||||
"lllyasviel/control_v11e_sd15_ip2p",
|
"lllyasviel/control_v11e_sd15_ip2p",
|
||||||
"lllyasviel/control_v11f1e_sd15_tile",
|
"lllyasviel/control_v11f1e_sd15_tile",
|
||||||
|
|
||||||
#################################################
|
#################################################
|
||||||
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
|
||||||
##################################################
|
##################################################
|
||||||
@@ -78,6 +71,7 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
"thibaud/controlnet-sd21-lineart-diffusers",
|
"thibaud/controlnet-sd21-lineart-diffusers",
|
||||||
"thibaud/controlnet-sd21-normalbae-diffusers",
|
"thibaud/controlnet-sd21-normalbae-diffusers",
|
||||||
"thibaud/controlnet-sd21-ade20k-diffusers",
|
"thibaud/controlnet-sd21-ade20k-diffusers",
|
||||||
|
|
||||||
##############################################
|
##############################################
|
||||||
# ControlNetMediaPipeface, ControlNet v1.1
|
# ControlNetMediaPipeface, ControlNet v1.1
|
||||||
##############################################
|
##############################################
|
||||||
@@ -89,17 +83,10 @@ CONTROLNET_DEFAULT_MODELS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
|
||||||
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
|
CONTROLNET_MODE_VALUES = Literal[tuple(
|
||||||
CONTROLNET_RESIZE_VALUES = Literal[
|
["balanced", "more_prompt", "more_control", "unbalanced"])]
|
||||||
tuple(
|
CONTROLNET_RESIZE_VALUES = Literal[tuple(
|
||||||
[
|
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
|
||||||
"just_resize",
|
|
||||||
"crop_resize",
|
|
||||||
"fill_resize",
|
|
||||||
"just_resize_simple",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModelField(BaseModel):
|
class ControlNetModelField(BaseModel):
|
||||||
@@ -111,17 +98,21 @@ class ControlNetModelField(BaseModel):
|
|||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(default=None, description="The control image")
|
image: ImageField = Field(default=None, description="The control image")
|
||||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
control_model: Optional[ControlNetModelField] = Field(
|
||||||
|
default=None, description="The ControlNet model to use")
|
||||||
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(
|
||||||
|
default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
default=0, ge=0, le=1,
|
||||||
)
|
description="When the ControlNet is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(
|
end_step_percent: float = Field(
|
||||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
default=1, ge=0, le=1,
|
||||||
)
|
description="When the ControlNet is last applied (% of total steps)")
|
||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
control_mode: CONTROLNET_MODE_VALUES = Field(
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
default="balanced", description="The control mode to use")
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
|
||||||
|
default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
@validator("control_weight")
|
@validator("control_weight")
|
||||||
def validate_control_weight(cls, v):
|
def validate_control_weight(cls, v):
|
||||||
@@ -129,10 +120,11 @@ class ControlField(BaseModel):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < -1 or i > 2:
|
if i < -1 or i > 2:
|
||||||
raise ValueError("Control weights must be within -1 to 2 range")
|
raise ValueError(
|
||||||
|
'Control weights must be within -1 to 2 range')
|
||||||
else:
|
else:
|
||||||
if v < -1 or v > 2:
|
if v < -1 or v > 2:
|
||||||
raise ValueError("Control weights must be within -1 to 2 range")
|
raise ValueError('Control weights must be within -1 to 2 range')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -144,13 +136,12 @@ class ControlField(BaseModel):
|
|||||||
"control_model": "controlnet_model",
|
"control_model": "controlnet_model",
|
||||||
# "control_weight": "number",
|
# "control_weight": "number",
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ControlOutput(BaseInvocationOutput):
|
class ControlOutput(BaseInvocationOutput):
|
||||||
"""node output for ControlNet info"""
|
"""node output for ControlNet info"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["control_output"] = "control_output"
|
type: Literal["control_output"] = "control_output"
|
||||||
control: ControlField = Field(default=None, description="The control info")
|
control: ControlField = Field(default=None, description="The control info")
|
||||||
@@ -159,7 +150,6 @@ class ControlOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
class ControlNetInvocation(BaseInvocation):
|
class ControlNetInvocation(BaseInvocation):
|
||||||
"""Collects ControlNet info to pass to other nodes"""
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["controlnet"] = "controlnet"
|
type: Literal["controlnet"] = "controlnet"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -186,7 +176,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number",
|
"cfg_scale": "number",
|
||||||
"control_weight": "float",
|
"control_weight": "float",
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +205,10 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Image Processor",
|
||||||
|
"tags": ["image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@@ -240,7 +233,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image_category=ImageCategory.CONTROL,
|
image_category=ImageCategory.CONTROL,
|
||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate
|
||||||
)
|
)
|
||||||
|
|
||||||
"""Builds an ImageOutput and its ImageField"""
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
@@ -255,9 +248,9 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class CannyImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["canny_image_processor"] = "canny_image_processor"
|
type: Literal["canny_image_processor"] = "canny_image_processor"
|
||||||
# Input
|
# Input
|
||||||
@@ -267,18 +260,22 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Canny Processor",
|
||||||
|
"tags": ["controlnet", "canny", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
canny_processor = CannyDetector()
|
canny_processor = CannyDetector()
|
||||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
processed_image = canny_processor(
|
||||||
|
image, self.low_threshold, self.high_threshold)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class HedImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["hed_image_processor"] = "hed_image_processor"
|
type: Literal["hed_image_processor"] = "hed_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -291,25 +288,27 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Softedge(HED) Processor",
|
||||||
|
"tags": ["controlnet", "softedge", "hed", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = hed_processor(
|
processed_image = hed_processor(image,
|
||||||
image,
|
detect_resolution=self.detect_resolution,
|
||||||
detect_resolution=self.detect_resolution,
|
image_resolution=self.image_resolution,
|
||||||
image_resolution=self.image_resolution,
|
# safe not supported in controlnet_aux v0.0.3
|
||||||
# safe not supported in controlnet_aux v0.0.3
|
# safe=self.safe,
|
||||||
# safe=self.safe,
|
scribble=self.scribble,
|
||||||
scribble=self.scribble,
|
)
|
||||||
)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class LineartImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
type: Literal["lineart_image_processor"] = "lineart_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -320,20 +319,24 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Lineart Processor",
|
||||||
|
"tags": ["controlnet", "lineart", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
|
lineart_processor = LineartDetector.from_pretrained(
|
||||||
|
"lllyasviel/Annotators")
|
||||||
processed_image = lineart_processor(
|
processed_image = lineart_processor(
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
|
image, detect_resolution=self.detect_resolution,
|
||||||
)
|
image_resolution=self.image_resolution, coarse=self.coarse)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class LineartAnimeImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -345,23 +348,23 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Lineart Anime Processor",
|
"title": "Lineart Anime Processor",
|
||||||
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
|
"tags": ["controlnet", "lineart", "anime", "image", "processor"]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
|
processor = LineartAnimeDetector.from_pretrained(
|
||||||
processed_image = processor(
|
"lllyasviel/Annotators")
|
||||||
image,
|
processed_image = processor(image,
|
||||||
detect_resolution=self.detect_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
image_resolution=self.image_resolution,
|
image_resolution=self.image_resolution,
|
||||||
)
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class OpenposeImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies Openpose processing to image"""
|
"""Applies Openpose processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
type: Literal["openpose_image_processor"] = "openpose_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -372,23 +375,25 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Openpose Processor",
|
||||||
|
"tags": ["controlnet", "openpose", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
openpose_processor = OpenposeDetector.from_pretrained(
|
||||||
|
"lllyasviel/Annotators")
|
||||||
processed_image = openpose_processor(
|
processed_image = openpose_processor(
|
||||||
image,
|
image, detect_resolution=self.detect_resolution,
|
||||||
detect_resolution=self.detect_resolution,
|
|
||||||
image_resolution=self.image_resolution,
|
image_resolution=self.image_resolution,
|
||||||
hand_and_face=self.hand_and_face,
|
hand_and_face=self.hand_and_face,)
|
||||||
)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class MidasDepthImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -400,24 +405,26 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Midas (Depth) Processor",
|
||||||
|
"tags": ["controlnet", "midas", "depth", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = midas_processor(
|
processed_image = midas_processor(image,
|
||||||
image,
|
a=np.pi * self.a_mult,
|
||||||
a=np.pi * self.a_mult,
|
bg_th=self.bg_th,
|
||||||
bg_th=self.bg_th,
|
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
# depth_and_normal=self.depth_and_normal,
|
||||||
# depth_and_normal=self.depth_and_normal,
|
)
|
||||||
)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class NormalbaeImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -427,20 +434,24 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Normal BAE Processor",
|
||||||
|
"tags": ["controlnet", "normal", "bae", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
normalbae_processor = NormalBaeDetector.from_pretrained(
|
||||||
|
"lllyasviel/Annotators")
|
||||||
processed_image = normalbae_processor(
|
processed_image = normalbae_processor(
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
image, detect_resolution=self.detect_resolution,
|
||||||
)
|
image_resolution=self.image_resolution)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class MlsdImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -452,24 +463,24 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "MLSD Processor",
|
||||||
|
"tags": ["controlnet", "mlsd", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = mlsd_processor(
|
processed_image = mlsd_processor(
|
||||||
image,
|
image, detect_resolution=self.detect_resolution,
|
||||||
detect_resolution=self.detect_resolution,
|
image_resolution=self.image_resolution, thr_v=self.thr_v,
|
||||||
image_resolution=self.image_resolution,
|
thr_d=self.thr_d)
|
||||||
thr_v=self.thr_v,
|
|
||||||
thr_d=self.thr_d,
|
|
||||||
)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class PidiImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
type: Literal["pidi_image_processor"] = "pidi_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -481,24 +492,25 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "PIDI Processor",
|
||||||
|
"tags": ["controlnet", "pidi", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
pidi_processor = PidiNetDetector.from_pretrained(
|
||||||
|
"lllyasviel/Annotators")
|
||||||
processed_image = pidi_processor(
|
processed_image = pidi_processor(
|
||||||
image,
|
image, detect_resolution=self.detect_resolution,
|
||||||
detect_resolution=self.detect_resolution,
|
image_resolution=self.image_resolution, safe=self.safe,
|
||||||
image_resolution=self.image_resolution,
|
scribble=self.scribble)
|
||||||
safe=self.safe,
|
|
||||||
scribble=self.scribble,
|
|
||||||
)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class ContentShuffleImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -513,45 +525,48 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Content Shuffle Processor",
|
"title": "Content Shuffle Processor",
|
||||||
"tags": ["controlnet", "contentshuffle", "image", "processor"],
|
"tags": ["controlnet", "contentshuffle", "image", "processor"]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
processed_image = content_shuffle_processor(
|
processed_image = content_shuffle_processor(image,
|
||||||
image,
|
detect_resolution=self.detect_resolution,
|
||||||
detect_resolution=self.detect_resolution,
|
image_resolution=self.image_resolution,
|
||||||
image_resolution=self.image_resolution,
|
h=self.h,
|
||||||
h=self.h,
|
w=self.w,
|
||||||
w=self.w,
|
f=self.f
|
||||||
f=self.f,
|
)
|
||||||
)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
|
||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class ZoeDepthImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Zoe (Depth) Processor",
|
||||||
|
"tags": ["controlnet", "zoe", "depth", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
zoe_depth_processor = ZoeDetector.from_pretrained(
|
||||||
|
"lllyasviel/Annotators")
|
||||||
processed_image = zoe_depth_processor(image)
|
processed_image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class MediapipeFaceProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -561,22 +576,26 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Mediapipe Processor",
|
||||||
|
"tags": ["controlnet", "mediapipe", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||||
# so convert to RGB if needed
|
# so convert to RGB if needed
|
||||||
if image.mode == "RGBA":
|
if image.mode == 'RGBA':
|
||||||
image = image.convert("RGB")
|
image = image.convert('RGB')
|
||||||
mediapipe_face_processor = MediapipeFaceDetector()
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
processed_image = mediapipe_face_processor(
|
||||||
|
image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class LeresImageProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["leres_image_processor"] = "leres_image_processor"
|
type: Literal["leres_image_processor"] = "leres_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -589,23 +608,24 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
|
"ui": {
|
||||||
|
"title": "Leres (Depth) Processor",
|
||||||
|
"tags": ["controlnet", "leres", "depth", "image", "processor"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = leres_processor(
|
processed_image = leres_processor(
|
||||||
image,
|
image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost,
|
||||||
thr_a=self.thr_a,
|
|
||||||
thr_b=self.thr_b,
|
|
||||||
boost=self.boost,
|
|
||||||
detect_resolution=self.detect_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
image_resolution=self.image_resolution,
|
image_resolution=self.image_resolution)
|
||||||
)
|
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class TileResamplerProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["tile_image_processor"] = "tile_image_processor"
|
type: Literal["tile_image_processor"] = "tile_image_processor"
|
||||||
# Inputs
|
# Inputs
|
||||||
@@ -617,17 +637,16 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Tile Resample Processor",
|
"title": "Tile Resample Processor",
|
||||||
"tags": ["controlnet", "tile", "resample", "image", "processor"],
|
"tags": ["controlnet", "tile", "resample", "image", "processor"]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
|
||||||
def tile_resample(
|
def tile_resample(self,
|
||||||
self,
|
np_img: np.ndarray,
|
||||||
np_img: np.ndarray,
|
res=512, # never used?
|
||||||
res=512, # never used?
|
down_sampling_rate=1.0,
|
||||||
down_sampling_rate=1.0,
|
):
|
||||||
):
|
|
||||||
np_img = HWC3(np_img)
|
np_img = HWC3(np_img)
|
||||||
if down_sampling_rate < 1.1:
|
if down_sampling_rate < 1.1:
|
||||||
return np_img
|
return np_img
|
||||||
@@ -639,41 +658,36 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
|
|||||||
|
|
||||||
def run_processor(self, img):
|
def run_processor(self, img):
|
||||||
np_img = np.array(img, dtype=np.uint8)
|
np_img = np.array(img, dtype=np.uint8)
|
||||||
processed_np_image = self.tile_resample(
|
processed_np_image = self.tile_resample(np_img,
|
||||||
np_img,
|
# res=self.tile_size,
|
||||||
# res=self.tile_size,
|
down_sampling_rate=self.down_sampling_rate
|
||||||
down_sampling_rate=self.down_sampling_rate,
|
)
|
||||||
)
|
|
||||||
processed_image = Image.fromarray(processed_np_image)
|
processed_image = Image.fromarray(processed_np_image)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
|
class SegmentAnythingProcessorInvocation(
|
||||||
|
ImageProcessorInvocation, PILInvocationConfig):
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
type: Literal["segment_anything_processor"] = "segment_anything_processor"
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [
|
||||||
"ui": {
|
"controlnet", "segment", "anything", "sam", "image", "processor"]}, }
|
||||||
"title": "Segment Anything Processor",
|
|
||||||
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
"ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
)
|
|
||||||
np_img = np.array(image, dtype=np.uint8)
|
np_img = np.array(image, dtype=np.uint8)
|
||||||
processed_image = segment_anything_processor(np_img)
|
processed_image = segment_anything_processor(np_img)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
class SamDetectorReproducibleColors(SamDetector):
|
class SamDetectorReproducibleColors(SamDetector):
|
||||||
|
|
||||||
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
|
||||||
# base class show_anns() method randomizes colors,
|
# base class show_anns() method randomizes colors,
|
||||||
# which seems to also lead to non-reproducible image generation
|
# which seems to also lead to non-reproducible image generation
|
||||||
@@ -681,15 +695,19 @@ class SamDetectorReproducibleColors(SamDetector):
|
|||||||
def show_anns(self, anns: List[Dict]):
|
def show_anns(self, anns: List[Dict]):
|
||||||
if len(anns) == 0:
|
if len(anns) == 0:
|
||||||
return
|
return
|
||||||
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
|
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||||
h, w = anns[0]["segmentation"].shape
|
h, w = anns[0]['segmentation'].shape
|
||||||
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
final_img = Image.fromarray(
|
||||||
|
np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
|
||||||
palette = ade_palette()
|
palette = ade_palette()
|
||||||
for i, ann in enumerate(sorted_anns):
|
for i, ann in enumerate(sorted_anns):
|
||||||
m = ann["segmentation"]
|
m = ann['segmentation']
|
||||||
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
|
||||||
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
# doing modulo just in case number of annotated regions exceeds number of colors in palette
|
||||||
ann_color = palette[i % len(palette)]
|
ann_color = palette[i % len(palette)]
|
||||||
img[:, :] = ann_color
|
img[:, :] = ann_color
|
||||||
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
|
final_img.paste(
|
||||||
|
Image.fromarray(img, mode="RGB"),
|
||||||
|
(0, 0),
|
||||||
|
Image.fromarray(np.uint8(m * 255)))
|
||||||
return np.array(final_img, dtype=np.uint8)
|
return np.array(final_img, dtype=np.uint8)
|
||||||
|
|||||||
@@ -37,7 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
|
"ui": {
|
||||||
|
"title": "OpenCV Inpaint",
|
||||||
|
"tags": ["opencv", "inpaint"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from typing import Literal, Optional, get_args
|
|||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
|
||||||
|
ResourceOrigin)
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
|
|
||||||
@@ -24,12 +25,13 @@ from contextlib import contextmanager, ExitStack, ContextDecorator
|
|||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
DEFAULT_INFILL_METHOD = (
|
||||||
|
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
from .latent import get_scheduler
|
from .latent import get_scheduler
|
||||||
|
|
||||||
|
|
||||||
class OldModelContext(ContextDecorator):
|
class OldModelContext(ContextDecorator):
|
||||||
model: StableDiffusionGeneratorPipeline
|
model: StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
@@ -42,7 +44,6 @@ class OldModelContext(ContextDecorator):
|
|||||||
def __exit__(self, *exc):
|
def __exit__(self, *exc):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class OldModelInfo:
|
class OldModelInfo:
|
||||||
name: str
|
name: str
|
||||||
hash: str
|
hash: str
|
||||||
@@ -63,34 +64,20 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
|
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
seed: int = Field(
|
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
|
||||||
ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed
|
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
||||||
)
|
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
|
||||||
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
|
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
|
||||||
width: int = Field(
|
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
default=512,
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
multiple_of=8,
|
|
||||||
gt=0,
|
|
||||||
description="The width of the resulting image",
|
|
||||||
)
|
|
||||||
height: int = Field(
|
|
||||||
default=512,
|
|
||||||
multiple_of=8,
|
|
||||||
gt=0,
|
|
||||||
description="The height of the resulting image",
|
|
||||||
)
|
|
||||||
cfg_scale: float = Field(
|
|
||||||
default=7.5,
|
|
||||||
ge=1,
|
|
||||||
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
|
|
||||||
)
|
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
|
|
||||||
unet: UNetField = Field(default=None, description="UNet model")
|
unet: UNetField = Field(default=None, description="UNet model")
|
||||||
vae: VaeField = Field(default=None, description="Vae model")
|
vae: VaeField = Field(default=None, description="Vae model")
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(description="The input image")
|
image: Optional[ImageField] = Field(description="The input image")
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
|
strength: float = Field(
|
||||||
|
default=0.75, gt=0, le=1, description="The strength of the original image"
|
||||||
|
)
|
||||||
fit: bool = Field(
|
fit: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
description="Whether or not the result should be fit to the aspect ratio of the input image",
|
||||||
@@ -99,10 +86,18 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
mask: Optional[ImageField] = Field(description="The mask")
|
mask: Optional[ImageField] = Field(description="The mask")
|
||||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
seam_blur: int = Field(
|
||||||
seam_strength: float = Field(default=0.75, gt=0, le=1, description="The seam inpaint strength")
|
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
||||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
)
|
||||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
seam_strength: float = Field(
|
||||||
|
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||||
|
)
|
||||||
|
seam_steps: int = Field(
|
||||||
|
default=30, ge=1, description="The number of steps to use for seam inpaint"
|
||||||
|
)
|
||||||
|
tile_size: int = Field(
|
||||||
|
default=32, ge=1, description="The tile infill method size (px)"
|
||||||
|
)
|
||||||
infill_method: INFILL_METHODS = Field(
|
infill_method: INFILL_METHODS = Field(
|
||||||
default=DEFAULT_INFILL_METHOD,
|
default=DEFAULT_INFILL_METHOD,
|
||||||
description="The method used to infill empty regions (px)",
|
description="The method used to infill empty regions (px)",
|
||||||
@@ -133,7 +128,10 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
|
"ui": {
|
||||||
|
"tags": ["stable-diffusion", "image"],
|
||||||
|
"title": "Inpaint"
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def dispatch_progress(
|
def dispatch_progress(
|
||||||
@@ -164,23 +162,18 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}),
|
**lora.dict(exclude={"weight"}), context=context,)
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
|
||||||
**self.unet.unet.dict(),
|
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
|
||||||
context=context,
|
|
||||||
)
|
with vae_info as vae,\
|
||||||
vae_info = context.services.model_manager.get_model(
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
**self.vae.vae.dict(),
|
unet_info as unet:
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
with vae_info as vae, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
|
||||||
device = context.services.model_manager.mgr.cache.execution_device
|
device = context.services.model_manager.mgr.cache.execution_device
|
||||||
dtype = context.services.model_manager.mgr.cache.precision
|
dtype = context.services.model_manager.mgr.cache.precision
|
||||||
|
|
||||||
@@ -204,11 +197,21 @@ class InpaintInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = None if self.image is None else context.services.images.get_pil_image(self.image.image_name)
|
image = (
|
||||||
mask = None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name)
|
None
|
||||||
|
if self.image is None
|
||||||
|
else context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
)
|
||||||
|
mask = (
|
||||||
|
None
|
||||||
|
if self.mask is None
|
||||||
|
else context.services.images.get_pil_image(self.mask.image_name)
|
||||||
|
)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
|
|||||||
@@ -9,12 +9,8 @@ from pathlib import Path
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from ..models.image import (
|
from ..models.image import (
|
||||||
ImageCategory,
|
ImageCategory, ImageField, ResourceOrigin,
|
||||||
ImageField,
|
PILInvocationConfig, ImageOutput, MaskOutput,
|
||||||
ResourceOrigin,
|
|
||||||
PILInvocationConfig,
|
|
||||||
ImageOutput,
|
|
||||||
MaskOutput,
|
|
||||||
)
|
)
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@@ -24,7 +20,6 @@ from .baseinvocation import (
|
|||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
|
|
||||||
|
|
||||||
class LoadImageInvocation(BaseInvocation):
|
class LoadImageInvocation(BaseInvocation):
|
||||||
"""Load an image and provide it as output."""
|
"""Load an image and provide it as output."""
|
||||||
|
|
||||||
@@ -39,7 +34,10 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Load Image", "tags": ["image", "load"]},
|
"ui": {
|
||||||
|
"title": "Load Image",
|
||||||
|
"tags": ["image", "load"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -58,11 +56,16 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
type: Literal["show_image"] = "show_image"
|
type: Literal["show_image"] = "show_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to show")
|
image: Optional[ImageField] = Field(
|
||||||
|
default=None, description="The image to show"
|
||||||
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Show Image", "tags": ["image", "show"]},
|
"ui": {
|
||||||
|
"title": "Show Image",
|
||||||
|
"tags": ["image", "show"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -95,13 +98,18 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
|
"ui": {
|
||||||
|
"title": "Crop Image",
|
||||||
|
"tags": ["image", "crop"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
|
image_crop = Image.new(
|
||||||
|
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
||||||
|
)
|
||||||
image_crop.paste(image, (-self.x, -self.y))
|
image_crop.paste(image, (-self.x, -self.y))
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
@@ -136,14 +144,21 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
|
"ui": {
|
||||||
|
"title": "Paste Image",
|
||||||
|
"tags": ["image", "paste"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
mask = (
|
mask = (
|
||||||
None if self.mask is None else ImageOps.invert(context.services.images.get_pil_image(self.mask.image_name))
|
None
|
||||||
|
if self.mask is None
|
||||||
|
else ImageOps.invert(
|
||||||
|
context.services.images.get_pil_image(self.mask.image_name)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||||
|
|
||||||
@@ -152,7 +167,9 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
max_x = max(base_image.width, image.width + self.x)
|
max_x = max(base_image.width, image.width + self.x)
|
||||||
max_y = max(base_image.height, image.height + self.y)
|
max_y = max(base_image.height, image.height + self.y)
|
||||||
|
|
||||||
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
|
new_image = Image.new(
|
||||||
|
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
||||||
|
)
|
||||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||||
|
|
||||||
@@ -185,7 +202,10 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
|
"ui": {
|
||||||
|
"title": "Mask From Alpha",
|
||||||
|
"tags": ["image", "mask", "alpha"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
@@ -224,7 +244,10 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
|
"ui": {
|
||||||
|
"title": "Multiply Images",
|
||||||
|
"tags": ["image", "multiply"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -265,7 +288,10 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
|
"ui": {
|
||||||
|
"title": "Image Channel",
|
||||||
|
"tags": ["image", "channel"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -305,7 +331,10 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
|
"ui": {
|
||||||
|
"title": "Convert Image",
|
||||||
|
"tags": ["image", "convert"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -328,7 +357,6 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
@@ -343,14 +371,19 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
|
"ui": {
|
||||||
|
"title": "Blur Image",
|
||||||
|
"tags": ["image", "blur"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
blur = (
|
blur = (
|
||||||
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
|
ImageFilter.GaussianBlur(self.radius)
|
||||||
|
if self.blur_type == "gaussian"
|
||||||
|
else ImageFilter.BoxBlur(self.radius)
|
||||||
)
|
)
|
||||||
blur_image = image.filter(blur)
|
blur_image = image.filter(blur)
|
||||||
|
|
||||||
@@ -405,7 +438,10 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
|
"ui": {
|
||||||
|
"title": "Resize Image",
|
||||||
|
"tags": ["image", "resize"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -448,7 +484,10 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
|
"ui": {
|
||||||
|
"title": "Scale Image",
|
||||||
|
"tags": ["image", "scale"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -493,7 +532,10 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
|
"ui": {
|
||||||
|
"title": "Image Linear Interpolation",
|
||||||
|
"tags": ["image", "linear", "interpolation", "lerp"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -519,7 +561,6 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
@@ -536,7 +577,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"title": "Image Inverse Linear Interpolation",
|
"title": "Image Inverse Linear Interpolation",
|
||||||
"tags": ["image", "linear", "interpolation", "inverse"],
|
"tags": ["image", "linear", "interpolation", "inverse"]
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -544,7 +585,12 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
||||||
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
|
image_arr = (
|
||||||
|
numpy.minimum(
|
||||||
|
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
||||||
|
)
|
||||||
|
* 255
|
||||||
|
)
|
||||||
|
|
||||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||||
|
|
||||||
@@ -563,7 +609,6 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
|
|
||||||
@@ -577,7 +622,10 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
|
"ui": {
|
||||||
|
"title": "Blur NSFW Images",
|
||||||
|
"tags": ["image", "nsfw", "checker"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -589,7 +637,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
||||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
caution = self._get_caution_img()
|
caution = self._get_caution_img()
|
||||||
blurry_image.paste(caution, (0, 0), caution)
|
blurry_image.paste(caution,(0,0),caution)
|
||||||
image = blurry_image
|
image = blurry_image
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
@@ -608,15 +656,13 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_caution_img(self) -> Image:
|
def _get_caution_img(self)->Image:
|
||||||
import invokeai.app.assets.images as image_assets
|
import invokeai.app.assets.images as image_assets
|
||||||
|
caution = Image.open(Path(image_assets.__path__[0]) / 'caution.png')
|
||||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
return caution.resize((caution.width // 2, caution.height //2))
|
||||||
return caution.resize((caution.width // 2, caution.height // 2))
|
|
||||||
|
|
||||||
|
|
||||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Add an invisible watermark to an image"""
|
""" Add an invisible watermark to an image """
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["img_watermark"] = "img_watermark"
|
type: Literal["img_watermark"] = "img_watermark"
|
||||||
@@ -629,7 +675,10 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
|
"ui": {
|
||||||
|
"title": "Add Invisible Watermark",
|
||||||
|
"tags": ["image", "watermark", "invisible"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -650,3 +699,6 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
width=image_dto.width,
|
width=image_dto.width,
|
||||||
height=image_dto.height,
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,9 @@ def infill_methods() -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
DEFAULT_INFILL_METHOD = (
|
||||||
|
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||||
@@ -42,7 +44,9 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
|
|||||||
return im
|
return im
|
||||||
|
|
||||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
im_patched_np = PatchMatch.inpaint(
|
||||||
|
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||||
|
)
|
||||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||||
return im_patched
|
return im_patched
|
||||||
|
|
||||||
@@ -64,7 +68,9 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
def tile_fill_missing(
|
||||||
|
im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
||||||
|
) -> Image.Image:
|
||||||
# Only fill if there's an alpha layer
|
# Only fill if there's an alpha layer
|
||||||
if im.mode != "RGBA":
|
if im.mode != "RGBA":
|
||||||
return im
|
return im
|
||||||
@@ -97,7 +103,9 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
# Find all invalid tiles and replace with a random valid tile
|
# Find all invalid tiles and replace with a random valid tile
|
||||||
replace_count = (tiles_mask == False).sum()
|
replace_count = (tiles_mask == False).sum()
|
||||||
rng = np.random.default_rng(seed=seed)
|
rng = np.random.default_rng(seed=seed)
|
||||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
||||||
|
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
||||||
|
]
|
||||||
|
|
||||||
# Convert back to an image
|
# Convert back to an image
|
||||||
tiles_all = tiles_all.reshape(tshape)
|
tiles_all = tiles_all.reshape(tshape)
|
||||||
@@ -118,7 +126,9 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
type: Literal["infill_rgba"] = "infill_rgba"
|
type: Literal["infill_rgba"] = "infill_rgba"
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
image: Optional[ImageField] = Field(
|
||||||
|
default=None, description="The image to infill"
|
||||||
|
)
|
||||||
color: ColorField = Field(
|
color: ColorField = Field(
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
@@ -126,7 +136,10 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
|
"ui": {
|
||||||
|
"title": "Color Infill",
|
||||||
|
"tags": ["image", "inpaint", "color", "infill"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -158,7 +171,9 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["infill_tile"] = "infill_tile"
|
type: Literal["infill_tile"] = "infill_tile"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
image: Optional[ImageField] = Field(
|
||||||
|
default=None, description="The image to infill"
|
||||||
|
)
|
||||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||||
seed: int = Field(
|
seed: int = Field(
|
||||||
ge=0,
|
ge=0,
|
||||||
@@ -169,13 +184,18 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
|
"ui": {
|
||||||
|
"title": "Tile Infill",
|
||||||
|
"tags": ["image", "inpaint", "tile", "infill"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
infilled = tile_fill_missing(
|
||||||
|
image.copy(), seed=self.seed, tile_size=self.tile_size
|
||||||
|
)
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
@@ -199,11 +219,16 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
image: Optional[ImageField] = Field(
|
||||||
|
default=None, description="The image to infill"
|
||||||
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
|
"ui": {
|
||||||
|
"title": "Patch Match Infill",
|
||||||
|
"tags": ["image", "inpaint", "patchmatch", "infill"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
|||||||
@@ -12,22 +12,20 @@ from pydantic import BaseModel, Field, validator
|
|||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models.base import ModelType
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
||||||
ControlNetData,
|
image_resized_to_grid_as_tensor)
|
||||||
StableDiffusionGeneratorPipeline,
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
||||||
image_resized_to_grid_as_tensor,
|
PostprocessingSettings
|
||||||
)
|
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.model_management import ModelPatcher
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
@@ -48,7 +46,8 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
"""A latents field used for passing latents between invocations"""
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
latents_name: Optional[str] = Field(
|
||||||
|
default=None, description="The name of the latents")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["latents_name"]}
|
schema_extra = {"required": ["latents_name"]}
|
||||||
@@ -56,15 +55,14 @@ class LatentsField(BaseModel):
|
|||||||
|
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["latents_output"] = "latents_output"
|
type: Literal["latents_output"] = "latents_output"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: LatentsField = Field(default=None, description="The output latents")
|
latents: LatentsField = Field(default=None, description="The output latents")
|
||||||
width: int = Field(description="The width of the latents in pixels")
|
width: int = Field(description="The width of the latents in pixels")
|
||||||
height: int = Field(description="The height of the latents in pixels")
|
height: int = Field(description="The height of the latents in pixels")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||||
@@ -75,7 +73,9 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
|
tuple(list(SCHEDULER_MAP.keys()))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
@@ -83,10 +83,11 @@ def get_scheduler(
|
|||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
||||||
|
scheduler_name, SCHEDULER_MAP['ddim']
|
||||||
|
)
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
**scheduler_info.dict(),
|
**scheduler_info.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
@@ -101,7 +102,7 @@ def get_scheduler(
|
|||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
if not hasattr(scheduler, "uses_inpainting_model"):
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
scheduler.uses_inpainting_model = lambda: False
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
@@ -122,8 +123,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@@ -132,10 +133,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < 1:
|
if i < 1:
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
else:
|
else:
|
||||||
if v < 1:
|
if v < 1:
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@@ -148,8 +149,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number",
|
"cfg_scale": "number"
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,14 +190,16 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
threshold=0.0, # threshold,
|
threshold=0.0, # threshold,
|
||||||
warmup=0.2, # warmup,
|
warmup=0.2, # warmup,
|
||||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
v_symmetry_time_pct=None # v_symmetry_time_pct,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||||
scheduler,
|
scheduler,
|
||||||
|
|
||||||
# for ddim scheduler
|
# for ddim scheduler
|
||||||
eta=0.0, # ddim_eta
|
eta=0.0, # ddim_eta
|
||||||
|
|
||||||
# for ancestral and sde schedulers
|
# for ancestral and sde schedulers
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||||
)
|
)
|
||||||
@@ -244,6 +247,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[ControlNetData]:
|
) -> List[ControlNetData]:
|
||||||
|
|
||||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||||
control_height_resize = latents_shape[2] * 8
|
control_height_resize = latents_shape[2] * 8
|
||||||
control_width_resize = latents_shape[3] * 8
|
control_width_resize = latents_shape[3] * 8
|
||||||
@@ -257,7 +261,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
control_list = control_input
|
control_list = control_input
|
||||||
else:
|
else:
|
||||||
control_list = None
|
control_list = None
|
||||||
if control_list is None:
|
if (control_list is None):
|
||||||
control_data = None
|
control_data = None
|
||||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||||
else:
|
else:
|
||||||
@@ -277,7 +281,9 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
control_models.append(control_model)
|
control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
input_image = context.services.images.get_pil_image(
|
||||||
|
control_image_field.image_name
|
||||||
|
)
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@@ -312,71 +318,69 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings():
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}),
|
**lora.dict(exclude={"weight"}), context=context,
|
||||||
context=context,
|
)
|
||||||
)
|
yield (lora_info.context.model, lora.weight)
|
||||||
yield (lora_info.context.model, lora.weight)
|
del lora_info
|
||||||
del lora_info
|
return
|
||||||
return
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(),
|
**self.unet.unet.dict(), context=context,
|
||||||
|
)
|
||||||
|
with ExitStack() as exit_stack,\
|
||||||
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
|
unet_info as unet:
|
||||||
|
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
|
||||||
unet_info.context.model, _lora_loader()
|
|
||||||
), unet_info as unet:
|
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
context=context,
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
scheduler_info=self.unet.scheduler,
|
|
||||||
scheduler_name=self.scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
control_data = self.prep_control_data(
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
model=pipeline, context=context, control_input=self.control,
|
||||||
|
latents_shape=noise.shape,
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
# TODO: Verify the noise is the right size
|
||||||
model=pipeline,
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
context=context,
|
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
||||||
control_input=self.control,
|
noise=noise,
|
||||||
latents_shape=noise.shape,
|
num_inference_steps=self.steps,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
conditioning_data=conditioning_data,
|
||||||
do_classifier_free_guidance=True,
|
control_data=control_data, # list[ControlNetData]
|
||||||
exit_stack=exit_stack,
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
result_latents = result_latents.to("cpu")
|
||||||
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
|
torch.cuda.empty_cache()
|
||||||
noise=noise,
|
|
||||||
num_inference_steps=self.steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
control_data=control_data, # list[ControlNetData]
|
|
||||||
callback=step_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
result_latents = result_latents.to("cpu")
|
context.services.latents.save(name, result_latents)
|
||||||
torch.cuda.empty_cache()
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.save(name, result_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
@@ -385,8 +389,11 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
type: Literal["l2l"] = "l2l"
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
latents: Optional[LatentsField] = Field(
|
||||||
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
description="The latents to use as a base image")
|
||||||
|
strength: float = Field(
|
||||||
|
default=0.7, ge=0, le=1,
|
||||||
|
description="The strength of the latents to use")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@@ -398,89 +405,87 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
"cfg_scale": "number",
|
"cfg_scale": "number",
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}),
|
**lora.dict(exclude={"weight"}), context=context,
|
||||||
context=context,
|
)
|
||||||
)
|
yield (lora_info.context.model, lora.weight)
|
||||||
yield (lora_info.context.model, lora.weight)
|
del lora_info
|
||||||
del lora_info
|
return
|
||||||
return
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(),
|
**self.unet.unet.dict(), context=context,
|
||||||
|
)
|
||||||
|
with ExitStack() as exit_stack,\
|
||||||
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||||
|
unet_info as unet:
|
||||||
|
|
||||||
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
|
||||||
unet_info.context.model, _lora_loader()
|
|
||||||
), unet_info as unet:
|
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
context=context,
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
scheduler_info=self.unet.scheduler,
|
|
||||||
scheduler_name=self.scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
control_data = self.prep_control_data(
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
model=pipeline, context=context, control_input=self.control,
|
||||||
|
latents_shape=noise.shape,
|
||||||
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
|
do_classifier_free_guidance=True,
|
||||||
|
exit_stack=exit_stack,
|
||||||
|
)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
# TODO: Verify the noise is the right size
|
||||||
model=pipeline,
|
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
||||||
context=context,
|
latent, device=unet.device, dtype=latent.dtype
|
||||||
control_input=self.control,
|
)
|
||||||
latents_shape=noise.shape,
|
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
exit_stack=exit_stack,
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||||
initial_latents = (
|
self.steps,
|
||||||
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
self.strength,
|
||||||
)
|
device=unet.device,
|
||||||
|
)
|
||||||
|
|
||||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
||||||
self.steps,
|
latents=initial_latents,
|
||||||
self.strength,
|
timesteps=timesteps,
|
||||||
device=unet.device,
|
noise=noise,
|
||||||
)
|
num_inference_steps=self.steps,
|
||||||
|
conditioning_data=conditioning_data,
|
||||||
|
control_data=control_data, # list[ControlNetData]
|
||||||
|
callback=step_callback
|
||||||
|
)
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
latents=initial_latents,
|
result_latents = result_latents.to("cpu")
|
||||||
timesteps=timesteps,
|
torch.cuda.empty_cache()
|
||||||
noise=noise,
|
|
||||||
num_inference_steps=self.steps,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
control_data=control_data, # list[ControlNetData]
|
|
||||||
callback=step_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
result_latents = result_latents.to("cpu")
|
context.services.latents.save(name, result_latents)
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.save(name, result_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
|
|
||||||
@@ -491,13 +496,14 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
type: Literal["l2i"] = "l2i"
|
type: Literal["l2i"] = "l2i"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
latents: Optional[LatentsField] = Field(
|
||||||
|
description="The latents to generate an image from")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
|
tiled: bool = Field(
|
||||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
default=False,
|
||||||
metadata: Optional[CoreMetadata] = Field(
|
description="Decode latents by overlapping tiles(less memory consumption)")
|
||||||
default=None, description="Optional core metadata to be written to the image"
|
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
||||||
)
|
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@@ -513,8 +519,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
@@ -581,7 +586,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
|
||||||
|
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
|
|
||||||
|
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
@@ -590,30 +596,36 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["lresize"] = "lresize"
|
type: Literal["lresize"] = "lresize"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
latents: Optional[LatentsField] = Field(
|
||||||
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
description="The latents to resize")
|
||||||
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
width: Union[int, None] = Field(default=512,
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
|
height: Union[int, None] = Field(default=512,
|
||||||
|
ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||||
|
default="bilinear", description="The interpolation mode")
|
||||||
antialias: bool = Field(
|
antialias: bool = Field(
|
||||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
default=False,
|
||||||
)
|
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
"ui": {
|
||||||
|
"title": "Resize Latents",
|
||||||
|
"tags": ["latents", "resize"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device=choose_torch_device()
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device),
|
latents.to(device), size=(self.height // 8, self.width // 8),
|
||||||
size=(self.height // 8, self.width // 8),
|
mode=self.mode, antialias=self.antialias
|
||||||
mode=self.mode,
|
if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
@@ -632,30 +644,35 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["lscale"] = "lscale"
|
type: Literal["lscale"] = "lscale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
latents: Optional[LatentsField] = Field(
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
description="The latents to scale")
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
scale_factor: float = Field(
|
||||||
|
gt=0, description="The factor by which to scale the latents")
|
||||||
|
mode: LATENTS_INTERPOLATION_MODE = Field(
|
||||||
|
default="bilinear", description="The interpolation mode")
|
||||||
antialias: bool = Field(
|
antialias: bool = Field(
|
||||||
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
default=False,
|
||||||
)
|
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
"ui": {
|
||||||
|
"title": "Scale Latents",
|
||||||
|
"tags": ["latents", "scale"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device=choose_torch_device()
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device),
|
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
|
||||||
scale_factor=self.scale_factor,
|
antialias=self.antialias
|
||||||
mode=self.mode,
|
if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
@@ -676,13 +693,19 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(description="The image to encode")
|
image: Optional[ImageField] = Field(description="The image to encode")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
tiled: bool = Field(
|
||||||
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
default=False,
|
||||||
|
description="Encode latents by overlaping tiles(less memory consumption)")
|
||||||
|
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
"ui": {
|
||||||
|
"title": "Image To Latents",
|
||||||
|
"tags": ["latents", "image"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -692,10 +715,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# )
|
# )
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(),
|
**self.vae.vae.dict(), context=context,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
@@ -722,12 +744,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
vae.post_quant_conv.to(orig_dtype)
|
vae.post_quant_conv.to(orig_dtype)
|
||||||
vae.decoder.conv_in.to(orig_dtype)
|
vae.decoder.conv_in.to(orig_dtype)
|
||||||
vae.decoder.mid_block.to(orig_dtype)
|
vae.decoder.mid_block.to(orig_dtype)
|
||||||
# else:
|
#else:
|
||||||
# latents = latents.float()
|
# latents = latents.float()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
# latents = latents.half()
|
#latents = latents.half()
|
||||||
|
|
||||||
if self.tiled:
|
if self.tiled:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
@@ -738,7 +760,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||||
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
latents = image_tensor_dist.sample().to(
|
||||||
|
dtype=vae.dtype
|
||||||
|
) # FIXME: uses torch.randn. make reproducible!
|
||||||
|
|
||||||
latents = vae.config.scaling_factor * latents
|
latents = vae.config.scaling_factor * latents
|
||||||
latents = latents.to(dtype=orig_dtype)
|
latents = latents.to(dtype=orig_dtype)
|
||||||
|
|||||||
@@ -54,7 +54,10 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Add", "tags": ["math", "add"]},
|
"ui": {
|
||||||
|
"title": "Add",
|
||||||
|
"tags": ["math", "add"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
@@ -72,7 +75,10 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
|
"ui": {
|
||||||
|
"title": "Subtract",
|
||||||
|
"tags": ["math", "subtract"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
@@ -90,7 +96,10 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
|
"ui": {
|
||||||
|
"title": "Multiply",
|
||||||
|
"tags": ["math", "multiply"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
@@ -108,7 +117,10 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Divide", "tags": ["math", "divide"]},
|
"ui": {
|
||||||
|
"title": "Divide",
|
||||||
|
"tags": ["math", "divide"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
@@ -128,7 +140,10 @@ class RandomIntInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
|
"ui": {
|
||||||
|
"title": "Random Integer",
|
||||||
|
"tags": ["math", "random", "integer"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModel):
|
class LoRAMetadataField(BaseModel):
|
||||||
"""LoRA metadata for an image generated in InvokeAI."""
|
"""LoRA metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
@@ -38,7 +37,9 @@ class CoreMetadata(BaseModel):
|
|||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = Field(
|
||||||
|
description="The ControlNets used for inference"
|
||||||
|
)
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
vae: Union[VAEModelField, None] = Field(
|
vae: Union[VAEModelField, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -50,24 +51,38 @@ class CoreMetadata(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
)
|
)
|
||||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
init_image: Union[str, None] = Field(
|
||||||
|
default=None, description="The name of the initial image"
|
||||||
|
)
|
||||||
|
|
||||||
# SDXL
|
# SDXL
|
||||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
positive_style_prompt: Union[str, None] = Field(
|
||||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
default=None, description="The positive style prompt parameter"
|
||||||
|
)
|
||||||
|
negative_style_prompt: Union[str, None] = Field(
|
||||||
|
default=None, description="The negative style prompt parameter"
|
||||||
|
)
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
refiner_model: Union[MainModelField, None] = Field(
|
||||||
|
default=None, description="The SDXL Refiner model used"
|
||||||
|
)
|
||||||
refiner_cfg_scale: Union[float, None] = Field(
|
refiner_cfg_scale: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
refiner_steps: Union[int, None] = Field(
|
||||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
default=None, description="The number of steps used for the refiner"
|
||||||
|
)
|
||||||
|
refiner_scheduler: Union[str, None] = Field(
|
||||||
|
default=None, description="The scheduler used for the refiner"
|
||||||
|
)
|
||||||
refiner_aesthetic_store: Union[float, None] = Field(
|
refiner_aesthetic_store: Union[float, None] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
refiner_start: Union[float, None] = Field(
|
||||||
|
default=None, description="The start value used for refiner denoising"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageMetadata(BaseModel):
|
class ImageMetadata(BaseModel):
|
||||||
@@ -77,7 +92,9 @@ class ImageMetadata(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
description="The image's core metadata, if it was created in the Linear or Canvas UI",
|
||||||
)
|
)
|
||||||
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
|
graph: Optional[dict] = Field(
|
||||||
|
default=None, description="The graph that created the image"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
class MetadataAccumulatorOutput(BaseInvocationOutput):
|
||||||
@@ -109,34 +126,50 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = Field(
|
||||||
|
description="The ControlNets used for inference"
|
||||||
|
)
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
strength: Union[float, None] = Field(
|
strength: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The strength used for latents-to-latents",
|
description="The strength used for latents-to-latents",
|
||||||
)
|
)
|
||||||
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
|
init_image: Union[str, None] = Field(
|
||||||
|
default=None, description="The name of the initial image"
|
||||||
|
)
|
||||||
vae: Union[VAEModelField, None] = Field(
|
vae: Union[VAEModelField, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
|
|
||||||
# SDXL
|
# SDXL
|
||||||
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
|
positive_style_prompt: Union[str, None] = Field(
|
||||||
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
|
default=None, description="The positive style prompt parameter"
|
||||||
|
)
|
||||||
|
negative_style_prompt: Union[str, None] = Field(
|
||||||
|
default=None, description="The negative style prompt parameter"
|
||||||
|
)
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
|
refiner_model: Union[MainModelField, None] = Field(
|
||||||
|
default=None, description="The SDXL Refiner model used"
|
||||||
|
)
|
||||||
refiner_cfg_scale: Union[float, None] = Field(
|
refiner_cfg_scale: Union[float, None] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The classifier-free guidance scale parameter used for the refiner",
|
description="The classifier-free guidance scale parameter used for the refiner",
|
||||||
)
|
)
|
||||||
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
|
refiner_steps: Union[int, None] = Field(
|
||||||
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
|
default=None, description="The number of steps used for the refiner"
|
||||||
|
)
|
||||||
|
refiner_scheduler: Union[str, None] = Field(
|
||||||
|
default=None, description="The scheduler used for the refiner"
|
||||||
|
)
|
||||||
refiner_aesthetic_store: Union[float, None] = Field(
|
refiner_aesthetic_store: Union[float, None] = Field(
|
||||||
default=None, description="The aesthetic score used for the refiner"
|
default=None, description="The aesthetic score used for the refiner"
|
||||||
)
|
)
|
||||||
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
|
refiner_start: Union[float, None] = Field(
|
||||||
|
default=None, description="The start value used for refiner denoising"
|
||||||
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
|
|||||||
@@ -4,14 +4,17 @@ from typing import List, Literal, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(BaseModel):
|
||||||
model_name: str = Field(description="Info to load submodel")
|
model_name: str = Field(description="Info to load submodel")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
model_type: ModelType = Field(description="Info to load submodel")
|
model_type: ModelType = Field(description="Info to load submodel")
|
||||||
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
submodel: Optional[SubModelType] = Field(
|
||||||
|
default=None, description="Info to load submodel"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoraInfo(ModelInfo):
|
class LoraInfo(ModelInfo):
|
||||||
@@ -30,7 +33,6 @@ class ClipField(BaseModel):
|
|||||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
class VaeField(BaseModel):
|
class VaeField(BaseModel):
|
||||||
# TODO: better naming?
|
# TODO: better naming?
|
||||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
vae: ModelInfo = Field(description="Info to load vae submodel")
|
||||||
@@ -47,13 +49,11 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class MainModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
"""Main model field"""
|
"""Main model field"""
|
||||||
|
|
||||||
model_name: str = Field(description="Name of the model")
|
model_name: str = Field(description="Name of the model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
model_type: ModelType = Field(description="Model Type")
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelField(BaseModel):
|
class LoRAModelField(BaseModel):
|
||||||
@@ -62,7 +62,6 @@ class LoRAModelField(BaseModel):
|
|||||||
model_name: str = Field(description="Name of the LoRA model")
|
model_name: str = Field(description="Name of the LoRA model")
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
@@ -198,7 +197,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["lora_loader"] = "lora_loader"
|
type: Literal["lora_loader"] = "lora_loader"
|
||||||
|
|
||||||
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
|
lora: Union[LoRAModelField, None] = Field(
|
||||||
|
default=None, description="Lora model name"
|
||||||
|
)
|
||||||
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
weight: float = Field(default=0.75, description="With what weight to apply lora")
|
||||||
|
|
||||||
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
|
||||||
@@ -227,10 +228,14 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
):
|
):
|
||||||
raise Exception(f"Unkown lora name: {lora_name}!")
|
raise Exception(f"Unkown lora name: {lora_name}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
|
if self.unet is not None and any(
|
||||||
|
lora.model_name == lora_name for lora in self.unet.loras
|
||||||
|
):
|
||||||
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
raise Exception(f'Lora "{lora_name}" already applied to unet')
|
||||||
|
|
||||||
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
|
if self.clip is not None and any(
|
||||||
|
lora.model_name == lora_name for lora in self.clip.loras
|
||||||
|
):
|
||||||
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
raise Exception(f'Lora "{lora_name}" already applied to clip')
|
||||||
|
|
||||||
output = LoraLoaderOutput()
|
output = LoraLoaderOutput()
|
||||||
|
|||||||
@@ -1,578 +0,0 @@
|
|||||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
|
||||||
|
|
||||||
from contextlib import ExitStack
|
|
||||||
from typing import List, Literal, Optional, Union
|
|
||||||
|
|
||||||
import re
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
|
||||||
from ...backend.model_management import ONNXModelPatcher
|
|
||||||
from ...backend.util import choose_torch_device
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
|
||||||
from .compel import ConditioningField
|
|
||||||
from .controlnet_image_processors import ControlField
|
|
||||||
from .image import ImageOutput
|
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .model import ClipField
|
|
||||||
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
|
||||||
from .compel import CompelOutput
|
|
||||||
|
|
||||||
|
|
||||||
ORT_TO_NP_TYPE = {
|
|
||||||
"tensor(bool)": np.bool_,
|
|
||||||
"tensor(int8)": np.int8,
|
|
||||||
"tensor(uint8)": np.uint8,
|
|
||||||
"tensor(int16)": np.int16,
|
|
||||||
"tensor(uint16)": np.uint16,
|
|
||||||
"tensor(int32)": np.int32,
|
|
||||||
"tensor(uint32)": np.uint32,
|
|
||||||
"tensor(int64)": np.int64,
|
|
||||||
"tensor(uint64)": np.uint64,
|
|
||||||
"tensor(float16)": np.float16,
|
|
||||||
"tensor(float)": np.float32,
|
|
||||||
"tensor(double)": np.float64,
|
|
||||||
}
|
|
||||||
|
|
||||||
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXPromptInvocation(BaseInvocation):
|
|
||||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
|
||||||
|
|
||||||
prompt: str = Field(default="", description="Prompt")
|
|
||||||
clip: ClipField = Field(None, description="Clip to use")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CompelOutput:
|
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
|
||||||
**self.clip.tokenizer.dict(),
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**self.clip.text_encoder.dict(),
|
|
||||||
)
|
|
||||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
|
||||||
loras = [
|
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
|
||||||
for lora in self.clip.loras
|
|
||||||
]
|
|
||||||
|
|
||||||
ti_list = []
|
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
|
||||||
name = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
ti_list.append(
|
|
||||||
# stack.enter_context(
|
|
||||||
# context.services.model_manager.get_model(
|
|
||||||
# model_name=name,
|
|
||||||
# base_model=self.clip.text_encoder.base_model,
|
|
||||||
# model_type=ModelType.TextualInversion,
|
|
||||||
# )
|
|
||||||
# )
|
|
||||||
context.services.model_manager.get_model(
|
|
||||||
model_name=name,
|
|
||||||
base_model=self.clip.text_encoder.base_model,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
).context.model
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# print(e)
|
|
||||||
# import traceback
|
|
||||||
# print(traceback.format_exc())
|
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
|
||||||
if loras or ti_list:
|
|
||||||
text_encoder.release_session()
|
|
||||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
|
|
||||||
orig_tokenizer, text_encoder, ti_list
|
|
||||||
) as (tokenizer, ti_manager):
|
|
||||||
text_encoder.create_session()
|
|
||||||
|
|
||||||
# copy from
|
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
|
||||||
text_inputs = tokenizer(
|
|
||||||
self.prompt,
|
|
||||||
padding="max_length",
|
|
||||||
max_length=tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="np",
|
|
||||||
)
|
|
||||||
text_input_ids = text_inputs.input_ids
|
|
||||||
"""
|
|
||||||
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
|
||||||
|
|
||||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
|
||||||
removed_text = self.tokenizer.batch_decode(
|
|
||||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
||||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
|
||||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
|
||||||
|
|
||||||
return CompelOutput(
|
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
|
||||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
|
||||||
"""Generates latents from conditionings."""
|
|
||||||
|
|
||||||
type: Literal["t2l_onnx"] = "t2l_onnx"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
# fmt: off
|
|
||||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
|
||||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
|
||||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
|
||||||
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
@validator("cfg_scale")
|
|
||||||
def ge_one(cls, v):
|
|
||||||
"""validate that all cfg_scale values are >= 1"""
|
|
||||||
if isinstance(v, list):
|
|
||||||
for i in v:
|
|
||||||
if i < 1:
|
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
|
||||||
else:
|
|
||||||
if v < 1:
|
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
|
||||||
return v
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents"],
|
|
||||||
"type_hints": {
|
|
||||||
"model": "model",
|
|
||||||
"control": "control",
|
|
||||||
# "cfg_scale": "float",
|
|
||||||
"cfg_scale": "number",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# based on
|
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
||||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
|
||||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
if isinstance(c, torch.Tensor):
|
|
||||||
c = c.cpu().numpy()
|
|
||||||
if isinstance(uc, torch.Tensor):
|
|
||||||
uc = uc.cpu().numpy()
|
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
prompt_embeds = np.concatenate([uc, c])
|
|
||||||
|
|
||||||
latents = context.services.latents.get(self.noise.latents_name)
|
|
||||||
if isinstance(latents, torch.Tensor):
|
|
||||||
latents = latents.cpu().numpy()
|
|
||||||
|
|
||||||
# TODO: better execution device handling
|
|
||||||
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
|
||||||
|
|
||||||
# get the initial random noise unless the user supplied it
|
|
||||||
do_classifier_free_guidance = True
|
|
||||||
# latents_dtype = prompt_embeds.dtype
|
|
||||||
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
|
||||||
# if latents.shape != latents_shape:
|
|
||||||
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
|
||||||
context=context,
|
|
||||||
scheduler_info=self.unet.scheduler,
|
|
||||||
scheduler_name=self.scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
def torch2numpy(latent: torch.Tensor):
|
|
||||||
return latent.cpu().numpy()
|
|
||||||
|
|
||||||
def numpy2torch(latent, device):
|
|
||||||
return torch.from_numpy(latent).to(device)
|
|
||||||
|
|
||||||
def dispatch_progress(
|
|
||||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.dict(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler.set_timesteps(self.steps)
|
|
||||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
|
||||||
|
|
||||||
extra_step_kwargs = dict()
|
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
|
||||||
extra_step_kwargs.update(
|
|
||||||
eta=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
|
||||||
|
|
||||||
with unet_info as unet, ExitStack() as stack:
|
|
||||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
|
||||||
loras = [
|
|
||||||
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
|
||||||
for lora in self.unet.loras
|
|
||||||
]
|
|
||||||
|
|
||||||
if loras:
|
|
||||||
unet.release_session()
|
|
||||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
|
||||||
# TODO:
|
|
||||||
_, _, h, w = latents.shape
|
|
||||||
unet.create_session(h, w)
|
|
||||||
|
|
||||||
timestep_dtype = next(
|
|
||||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
|
||||||
)
|
|
||||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
|
||||||
for i in tqdm(range(len(scheduler.timesteps))):
|
|
||||||
t = scheduler.timesteps[i]
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
|
||||||
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
|
||||||
latent_model_input = latent_model_input.cpu().numpy()
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
|
||||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
|
||||||
noise_pred = noise_pred[0]
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
scheduler_output = scheduler.step(
|
|
||||||
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
|
||||||
)
|
|
||||||
latents = torch2numpy(scheduler_output.prev_sample)
|
|
||||||
|
|
||||||
state = PipelineIntermediateState(
|
|
||||||
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
|
||||||
)
|
|
||||||
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
|
||||||
|
|
||||||
# call the callback, if provided
|
|
||||||
# if callback is not None and i % callback_steps == 0:
|
|
||||||
# callback(i, t, latents)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
context.services.latents.save(name, latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
|
||||||
"""Generates an image from latents."""
|
|
||||||
|
|
||||||
type: Literal["l2i_onnx"] = "l2i_onnx"
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
|
||||||
metadata: Optional[CoreMetadata] = Field(
|
|
||||||
default=None, description="Optional core metadata to be written to the image"
|
|
||||||
)
|
|
||||||
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["latents", "image"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
|
||||||
|
|
||||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
|
||||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
|
||||||
**self.vae.vae.dict(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
with vae_info as vae:
|
|
||||||
vae.create_session()
|
|
||||||
|
|
||||||
# copied from
|
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
|
||||||
latents = 1 / 0.18215 * latents
|
|
||||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
|
||||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
|
||||||
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
|
||||||
|
|
||||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
|
||||||
image = image.transpose((0, 2, 3, 1))
|
|
||||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
|
||||||
image=image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
|
||||||
"""Model loader output"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
|
||||||
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
|
||||||
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loading submodels of selected model."""
|
|
||||||
|
|
||||||
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
|
||||||
model_name = "stable-diffusion-v1-5"
|
|
||||||
base_model = BaseModelType.StableDiffusion1
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=BaseModelType.StableDiffusion1,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unkown model name: {model_name}!")
|
|
||||||
|
|
||||||
return ONNXModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
vae_decoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.VaeDecoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
vae_encoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.ONNX,
|
|
||||||
submodel=SubModelType.VaeEncoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxModelField(BaseModel):
|
|
||||||
"""Onnx model field"""
|
|
||||||
|
|
||||||
model_name: str = Field(description="Name of the model")
|
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
|
||||||
model_type: ModelType = Field(description="Model Type")
|
|
||||||
|
|
||||||
|
|
||||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loads a main model, outputting its submodels."""
|
|
||||||
|
|
||||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
|
||||||
|
|
||||||
model: OnnxModelField = Field(description="The model to load")
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "Onnx Model Loader",
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
|
||||||
base_model = self.model.base_model
|
|
||||||
model_name = self.model.model_name
|
|
||||||
model_type = ModelType.ONNX
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.Tokenizer,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.TextEncoder,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.UNet,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
return ONNXModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
vae_decoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.VaeDecoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
vae_encoder=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.VaeEncoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@@ -12,37 +12,16 @@ import matplotlib.pyplot as plt
|
|||||||
|
|
||||||
from easing_functions import (
|
from easing_functions import (
|
||||||
LinearInOut,
|
LinearInOut,
|
||||||
QuadEaseInOut,
|
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
|
||||||
QuadEaseIn,
|
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
|
||||||
QuadEaseOut,
|
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
|
||||||
CubicEaseInOut,
|
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
|
||||||
CubicEaseIn,
|
SineEaseInOut, SineEaseIn, SineEaseOut,
|
||||||
CubicEaseOut,
|
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
|
||||||
QuarticEaseInOut,
|
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
|
||||||
QuarticEaseIn,
|
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
|
||||||
QuarticEaseOut,
|
BackEaseIn, BackEaseInOut, BackEaseOut,
|
||||||
QuinticEaseInOut,
|
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
|
||||||
QuinticEaseIn,
|
|
||||||
QuinticEaseOut,
|
|
||||||
SineEaseInOut,
|
|
||||||
SineEaseIn,
|
|
||||||
SineEaseOut,
|
|
||||||
CircularEaseIn,
|
|
||||||
CircularEaseInOut,
|
|
||||||
CircularEaseOut,
|
|
||||||
ExponentialEaseInOut,
|
|
||||||
ExponentialEaseIn,
|
|
||||||
ExponentialEaseOut,
|
|
||||||
ElasticEaseIn,
|
|
||||||
ElasticEaseInOut,
|
|
||||||
ElasticEaseOut,
|
|
||||||
BackEaseIn,
|
|
||||||
BackEaseInOut,
|
|
||||||
BackEaseOut,
|
|
||||||
BounceEaseIn,
|
|
||||||
BounceEaseInOut,
|
|
||||||
BounceEaseOut,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@@ -66,12 +45,17 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
|
"ui": {
|
||||||
|
"title": "Linear Range (Float)",
|
||||||
|
"tags": ["math", "float", "linear", "range"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||||
return FloatCollectionOutput(collection=param_list)
|
return FloatCollectionOutput(
|
||||||
|
collection=param_list
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
EASING_FUNCTIONS_MAP = {
|
EASING_FUNCTIONS_MAP = {
|
||||||
@@ -108,7 +92,9 @@ EASING_FUNCTIONS_MAP = {
|
|||||||
"BounceInOut": BounceEaseInOut,
|
"BounceInOut": BounceEaseInOut,
|
||||||
}
|
}
|
||||||
|
|
||||||
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
|
EASING_FUNCTION_KEYS: Any = Literal[
|
||||||
|
tuple(list(EASING_FUNCTIONS_MAP.keys()))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# actually I think for now could just use CollectionOutput (which is list[Any]
|
# actually I think for now could just use CollectionOutput (which is list[Any]
|
||||||
@@ -137,9 +123,13 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
|
"ui": {
|
||||||
|
"title": "Param Easing By Step",
|
||||||
|
"tags": ["param", "step", "easing"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||||
log_diagnostics = False
|
log_diagnostics = False
|
||||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||||
@@ -180,13 +170,12 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
# and create reverse copy of list[1:end-1]
|
# and create reverse copy of list[1:end-1]
|
||||||
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
|
||||||
|
|
||||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
|
||||||
if log_diagnostics:
|
if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
|
||||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
easing_function = easing_class(start=self.start_value,
|
||||||
easing_function = easing_class(
|
end=self.end_value,
|
||||||
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
|
duration=base_easing_duration - 1)
|
||||||
)
|
|
||||||
base_easing_vals = list()
|
base_easing_vals = list()
|
||||||
for step_index in range(base_easing_duration):
|
for step_index in range(base_easing_duration):
|
||||||
easing_val = easing_function.ease(step_index)
|
easing_val = easing_function.ease(step_index)
|
||||||
@@ -225,7 +214,9 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
#
|
#
|
||||||
|
|
||||||
else: # no mirroring (default)
|
else: # no mirroring (default)
|
||||||
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
|
easing_function = easing_class(start=self.start_value,
|
||||||
|
end=self.end_value,
|
||||||
|
duration=num_easing_steps - 1)
|
||||||
for step_index in range(num_easing_steps):
|
for step_index in range(num_easing_steps):
|
||||||
step_val = easing_function.ease(step_index)
|
step_val = easing_function.ease(step_index)
|
||||||
easing_list.append(step_val)
|
easing_list.append(step_val)
|
||||||
@@ -249,11 +240,13 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
ax = plt.gca()
|
ax = plt.gca()
|
||||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
plt.savefig(buf, format="png")
|
plt.savefig(buf, format='png')
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
im = PIL.Image.open(buf)
|
im = PIL.Image.open(buf)
|
||||||
im.show()
|
im.show()
|
||||||
buf.close()
|
buf.close()
|
||||||
|
|
||||||
# output array of size steps, each entry list[i] is param value for step i
|
# output array of size steps, each entry list[i] is param value for step i
|
||||||
return FloatCollectionOutput(collection=param_list)
|
return FloatCollectionOutput(
|
||||||
|
collection=param_list
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,80 +4,67 @@ from typing import Literal
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.invocations.prompt import PromptOutput
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
|
||||||
from .math import FloatOutput, IntOutput
|
from .math import FloatOutput, IntOutput
|
||||||
|
|
||||||
# Pass-through parameter nodes - used by subgraphs
|
# Pass-through parameter nodes - used by subgraphs
|
||||||
|
|
||||||
|
|
||||||
class ParamIntInvocation(BaseInvocation):
|
class ParamIntInvocation(BaseInvocation):
|
||||||
"""An integer parameter"""
|
"""An integer parameter"""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["param_int"] = "param_int"
|
type: Literal["param_int"] = "param_int"
|
||||||
a: int = Field(default=0, description="The integer value")
|
a: int = Field(default=0, description="The integer value")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
|
"ui": {
|
||||||
}
|
"tags": ["param", "integer"],
|
||||||
|
"title": "Integer Parameter"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntOutput:
|
def invoke(self, context: InvocationContext) -> IntOutput:
|
||||||
return IntOutput(a=self.a)
|
return IntOutput(a=self.a)
|
||||||
|
|
||||||
|
|
||||||
class ParamFloatInvocation(BaseInvocation):
|
class ParamFloatInvocation(BaseInvocation):
|
||||||
"""A float parameter"""
|
"""A float parameter"""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["param_float"] = "param_float"
|
type: Literal["param_float"] = "param_float"
|
||||||
param: float = Field(default=0.0, description="The float value")
|
param: float = Field(default=0.0, description="The float value")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
|
"ui": {
|
||||||
}
|
"tags": ["param", "float"],
|
||||||
|
"title": "Float Parameter"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||||
return FloatOutput(param=self.param)
|
return FloatOutput(param=self.param)
|
||||||
|
|
||||||
|
|
||||||
class StringOutput(BaseInvocationOutput):
|
class StringOutput(BaseInvocationOutput):
|
||||||
"""A string output"""
|
"""A string output"""
|
||||||
|
|
||||||
type: Literal["string_output"] = "string_output"
|
type: Literal["string_output"] = "string_output"
|
||||||
text: str = Field(default=None, description="The output string")
|
text: str = Field(default=None, description="The output string")
|
||||||
|
|
||||||
|
|
||||||
class ParamStringInvocation(BaseInvocation):
|
class ParamStringInvocation(BaseInvocation):
|
||||||
"""A string parameter"""
|
"""A string parameter"""
|
||||||
|
type: Literal['param_string'] = 'param_string'
|
||||||
type: Literal["param_string"] = "param_string"
|
text: str = Field(default='', description='The string value')
|
||||||
text: str = Field(default="", description="The string value")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
|
"ui": {
|
||||||
}
|
"tags": ["param", "string"],
|
||||||
|
"title": "String Parameter"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||||
return StringOutput(text=self.text)
|
return StringOutput(text=self.text)
|
||||||
|
|
||||||
|
|
||||||
class ParamPromptInvocation(BaseInvocation):
|
|
||||||
"""A prompt input parameter"""
|
|
||||||
|
|
||||||
type: Literal["param_prompt"] = "param_prompt"
|
|
||||||
prompt: str = Field(default="", description="The prompt value")
|
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptOutput:
|
|
||||||
return PromptOutput(prompt=self.prompt)
|
|
||||||
|
|||||||
@@ -7,21 +7,19 @@ from pydantic import Field, validator
|
|||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
|
||||||
|
|
||||||
|
|
||||||
class PromptOutput(BaseInvocationOutput):
|
class PromptOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a prompt"""
|
"""Base class for invocations that output a prompt"""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
type: Literal["prompt"] = "prompt"
|
type: Literal["prompt"] = "prompt"
|
||||||
|
|
||||||
prompt: str = Field(default=None, description="The output prompt")
|
prompt: str = Field(default=None, description="The output prompt")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": [
|
'required': [
|
||||||
"type",
|
'type',
|
||||||
"prompt",
|
'prompt',
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,11 +44,16 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
type: Literal["dynamic_prompt"] = "dynamic_prompt"
|
||||||
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
prompt: str = Field(description="The prompt to parse with dynamicprompts")
|
||||||
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
max_prompts: int = Field(default=1, description="The number of prompts to generate")
|
||||||
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
|
combinatorial: bool = Field(
|
||||||
|
default=False, description="Whether to use the combinatorial generator"
|
||||||
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
|
"ui": {
|
||||||
|
"title": "Dynamic Prompt",
|
||||||
|
"tags": ["prompt", "dynamic"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||||
@@ -65,8 +68,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
class PromptsFromFileInvocation(BaseInvocation):
|
class PromptsFromFileInvocation(BaseInvocation):
|
||||||
"""Loads prompts from a text file"""
|
'''Loads prompts from a text file'''
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
type: Literal['prompt_from_file'] = 'prompt_from_file'
|
||||||
|
|
||||||
@@ -76,11 +78,14 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
post_prompt: Optional[str] = Field(description="String to append to each prompt")
|
||||||
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
|
||||||
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
|
"ui": {
|
||||||
|
"title": "Prompts From File",
|
||||||
|
"tags": ["prompt", "file"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@validator("file_path")
|
@validator("file_path")
|
||||||
@@ -98,13 +103,11 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
with open(file_path) as f:
|
with open(file_path) as f:
|
||||||
for i, line in enumerate(f):
|
for i, line in enumerate(f):
|
||||||
if i >= start_line and i < end_line:
|
if i >= start_line and i < end_line:
|
||||||
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
|
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
|
||||||
if i >= end_line:
|
if i >= end_line:
|
||||||
break
|
break
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
|
||||||
prompts = self.promptsFromFile(
|
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
|
||||||
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
|
|
||||||
)
|
|
||||||
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ from pydantic import Field, validator
|
|||||||
|
|
||||||
from ...backend.model_management import ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
|
|
||||||
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
||||||
|
|
||||||
|
|
||||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL base model loader output"""
|
"""SDXL base model loader output"""
|
||||||
|
|
||||||
@@ -26,18 +26,15 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL refiner model loader output"""
|
"""SDXL refiner model loader output"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
@@ -128,10 +125,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||||
|
|
||||||
model: MainModelField = Field(description="The model to load")
|
model: MainModelField = Field(description="The model to load")
|
||||||
@@ -202,7 +197,6 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
@@ -219,9 +213,9 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
|
||||||
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@@ -230,10 +224,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < 1:
|
if i < 1:
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
else:
|
else:
|
||||||
if v < 1:
|
if v < 1:
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@@ -243,10 +237,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
"title": "SDXL Text To Latents",
|
"title": "SDXL Text To Latents",
|
||||||
"tags": ["latents"],
|
"tags": ["latents"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number",
|
"cfg_scale": "number"
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,7 +265,9 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
latents = context.services.latents.get(self.noise.latents_name)
|
latents = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
@@ -292,15 +288,18 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_inference_steps = self.steps
|
num_inference_steps = self.steps
|
||||||
|
scheduler.set_timesteps(num_inference_steps)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
latents = latents * scheduler.init_noise_sigma
|
||||||
|
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(
|
||||||
|
**self.unet.unet.dict(), context=context
|
||||||
|
)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with unet_info as unet:
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
|
||||||
timesteps = scheduler.timesteps
|
|
||||||
|
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype) * scheduler.init_noise_sigma
|
|
||||||
|
|
||||||
extra_step_kwargs = dict()
|
extra_step_kwargs = dict()
|
||||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
@@ -351,10 +350,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
# del noise_pred_uncond
|
#del noise_pred_uncond
|
||||||
# del noise_pred_text
|
#del noise_pred_text
|
||||||
|
|
||||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
@@ -365,7 +364,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
# if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
else:
|
else:
|
||||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
@@ -379,13 +378,13 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
with tqdm(total=num_inference_steps) as progress_bar:
|
with tqdm(total=num_inference_steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||||
|
|
||||||
# import gc
|
#import gc
|
||||||
# gc.collect()
|
#gc.collect()
|
||||||
# torch.cuda.empty_cache()
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
|
|
||||||
@@ -412,41 +411,42 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
# perform guidance
|
# perform guidance
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
# del noise_pred_text
|
#del noise_pred_text
|
||||||
# del noise_pred_uncond
|
#del noise_pred_uncond
|
||||||
# import gc
|
#import gc
|
||||||
# gc.collect()
|
#gc.collect()
|
||||||
# torch.cuda.empty_cache()
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
# del noise_pred
|
#del noise_pred
|
||||||
# import gc
|
#import gc
|
||||||
# gc.collect()
|
#gc.collect()
|
||||||
# torch.cuda.empty_cache()
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
# if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#################
|
#################
|
||||||
|
|
||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
|
||||||
|
|
||||||
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
@@ -466,9 +466,9 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
|
||||||
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
|
||||||
|
|
||||||
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@@ -477,10 +477,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < 1:
|
if i < 1:
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
else:
|
else:
|
||||||
if v < 1:
|
if v < 1:
|
||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@@ -490,10 +490,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
"title": "SDXL Latents to Latents",
|
"title": "SDXL Latents to Latents",
|
||||||
"tags": ["latents"],
|
"tags": ["latents"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model",
|
"model": "model",
|
||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number",
|
"cfg_scale": "number"
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -518,7 +518,9 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(
|
||||||
|
context.graph_execution_state_id
|
||||||
|
)
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
@@ -538,27 +540,26 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
# apply denoising_start
|
||||||
**self.unet.unet.dict(),
|
num_inference_steps = self.steps
|
||||||
context=context,
|
scheduler.set_timesteps(num_inference_steps)
|
||||||
)
|
|
||||||
|
|
||||||
|
t_start = int(round(self.denoising_start * num_inference_steps))
|
||||||
|
timesteps = scheduler.timesteps[t_start * scheduler.order:]
|
||||||
|
num_inference_steps = num_inference_steps - t_start
|
||||||
|
|
||||||
|
# apply noise(if provided)
|
||||||
|
if self.noise is not None and timesteps.shape[0] > 0:
|
||||||
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
||||||
|
del noise
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(
|
||||||
|
**self.unet.unet.dict(), context=context,
|
||||||
|
)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with unet_info as unet:
|
with unet_info as unet:
|
||||||
# apply denoising_start
|
|
||||||
num_inference_steps = self.steps
|
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
|
||||||
|
|
||||||
t_start = int(round(self.denoising_start * num_inference_steps))
|
|
||||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
|
||||||
num_inference_steps = num_inference_steps - t_start
|
|
||||||
|
|
||||||
# apply noise(if provided)
|
|
||||||
if self.noise is not None and timesteps.shape[0] > 0:
|
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
|
||||||
latents = scheduler.add_noise(latents, noise, timesteps[:1])
|
|
||||||
del noise
|
|
||||||
|
|
||||||
# apply scheduler extra args
|
# apply scheduler extra args
|
||||||
extra_step_kwargs = dict()
|
extra_step_kwargs = dict()
|
||||||
@@ -610,10 +611,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
# del noise_pred_uncond
|
#del noise_pred_uncond
|
||||||
# del noise_pred_text
|
#del noise_pred_text
|
||||||
|
|
||||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
@@ -624,7 +625,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
# if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
else:
|
else:
|
||||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
|
||||||
@@ -638,13 +639,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
with tqdm(total=num_inference_steps) as progress_bar:
|
with tqdm(total=num_inference_steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
||||||
latent_model_input = scheduler.scale_model_input(latents, t)
|
latent_model_input = scheduler.scale_model_input(latents, t)
|
||||||
|
|
||||||
# import gc
|
#import gc
|
||||||
# gc.collect()
|
#gc.collect()
|
||||||
# torch.cuda.empty_cache()
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
|
|
||||||
@@ -671,36 +672,38 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
|
|||||||
# perform guidance
|
# perform guidance
|
||||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
# del noise_pred_text
|
#del noise_pred_text
|
||||||
# del noise_pred_uncond
|
#del noise_pred_uncond
|
||||||
# import gc
|
#import gc
|
||||||
# gc.collect()
|
#gc.collect()
|
||||||
# torch.cuda.empty_cache()
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
# if do_classifier_free_guidance and guidance_rescale > 0.0:
|
#if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||||
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||||
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||||
|
|
||||||
# del noise_pred
|
#del noise_pred
|
||||||
# import gc
|
#import gc
|
||||||
# gc.collect()
|
#gc.collect()
|
||||||
# torch.cuda.empty_cache()
|
#torch.cuda.empty_cache()
|
||||||
|
|
||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
||||||
progress_bar.update()
|
progress_bar.update()
|
||||||
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
|
||||||
# if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#################
|
#################
|
||||||
|
|
||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
|||||||
@@ -29,11 +29,16 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["esrgan"] = "esrgan"
|
type: Literal["esrgan"] = "esrgan"
|
||||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||||
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
model_name: ESRGAN_MODELS = Field(
|
||||||
|
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||||
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
|
"ui": {
|
||||||
|
"title": "Upscale (RealESRGAN)",
|
||||||
|
"tags": ["image", "upscale", "realesrgan"]
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@@ -103,7 +108,9 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||||
|
|
||||||
# back to PIL
|
# back to PIL
|
||||||
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
pil_image = Image.fromarray(
|
||||||
|
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
|
||||||
|
).convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
class CanceledException(Exception):
|
class CanceledException(Exception):
|
||||||
"""Execution canceled by user."""
|
"""Execution canceled by user."""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from ..invocations.baseinvocation import (
|
|||||||
InvocationConfig,
|
InvocationConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
"""An image field used for passing image objects between invocations"""
|
"""An image field used for passing image objects between invocations"""
|
||||||
|
|
||||||
@@ -35,7 +34,6 @@ class ProgressImage(BaseModel):
|
|||||||
height: int = Field(description="The effective height of the image in pixels")
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
|
||||||
|
|
||||||
class PILInvocationConfig(BaseModel):
|
class PILInvocationConfig(BaseModel):
|
||||||
"""Helper class to provide all PIL invocations with additional config"""
|
"""Helper class to provide all PIL invocations with additional config"""
|
||||||
|
|
||||||
@@ -46,7 +44,6 @@ class PILInvocationConfig(BaseModel):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
|
|
||||||
@@ -79,7 +76,6 @@ class MaskOutput(BaseInvocationOutput):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||||
"""The origin of a resource (eg image).
|
"""The origin of a resource (eg image).
|
||||||
|
|
||||||
@@ -136,3 +132,5 @@ class InvalidImageCategoryException(ValueError):
|
|||||||
|
|
||||||
def __init__(self, message="Invalid image category."):
|
def __init__(self, message="Invalid image category."):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -207,7 +207,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
return OffsetPaginatedResults(
|
||||||
|
items=images, offset=offset, limit=limit, total=count
|
||||||
|
)
|
||||||
|
|
||||||
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
|
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -102,7 +102,9 @@ class BoardImagesService(BoardImagesServiceABC):
|
|||||||
self,
|
self,
|
||||||
board_id: str,
|
board_id: str,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
return self._services.board_image_records.get_all_board_image_names_for_board(
|
||||||
|
board_id
|
||||||
|
)
|
||||||
|
|
||||||
def get_board_for_image(
|
def get_board_for_image(
|
||||||
self,
|
self,
|
||||||
@@ -112,7 +114,9 @@ class BoardImagesService(BoardImagesServiceABC):
|
|||||||
return board_id
|
return board_id
|
||||||
|
|
||||||
|
|
||||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
def board_record_to_dto(
|
||||||
|
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
|
||||||
|
) -> BoardDTO:
|
||||||
"""Converts a board record to a board DTO."""
|
"""Converts a board record to a board DTO."""
|
||||||
return BoardDTO(
|
return BoardDTO(
|
||||||
**board_record.dict(exclude={"cover_image_name"}),
|
**board_record.dict(exclude={"cover_image_name"}),
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ from pydantic import BaseModel, Field, Extra
|
|||||||
|
|
||||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||||
board_name: Optional[str] = Field(description="The board's new name.")
|
board_name: Optional[str] = Field(description="The board's new name.")
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
cover_image_name: Optional[str] = Field(
|
||||||
|
description="The name of the board's new cover image."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordNotFoundException(Exception):
|
class BoardRecordNotFoundException(Exception):
|
||||||
@@ -290,7 +292,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
|
|
||||||
count = cast(int, self._cursor.fetchone()[0])
|
count = cast(int, self._cursor.fetchone()[0])
|
||||||
|
|
||||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
return OffsetPaginatedResults[BoardRecord](
|
||||||
|
items=boards, offset=offset, limit=limit, total=count
|
||||||
|
)
|
||||||
|
|
||||||
except sqlite3.Error as e:
|
except sqlite3.Error as e:
|
||||||
self._conn.rollback()
|
self._conn.rollback()
|
||||||
|
|||||||
@@ -108,12 +108,16 @@ class BoardService(BoardServiceABC):
|
|||||||
|
|
||||||
def get_dto(self, board_id: str) -> BoardDTO:
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
board_record = self._services.board_records.get(board_id)
|
board_record = self._services.board_records.get(board_id)
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
board_record.board_id
|
||||||
|
)
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
board_id
|
||||||
|
)
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
@@ -122,44 +126,60 @@ class BoardService(BoardServiceABC):
|
|||||||
changes: BoardChanges,
|
changes: BoardChanges,
|
||||||
) -> BoardDTO:
|
) -> BoardDTO:
|
||||||
board_record = self._services.board_records.update(board_id, changes)
|
board_record = self._services.board_records.update(board_id, changes)
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
board_record.board_id
|
||||||
|
)
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
board_id
|
||||||
|
)
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
def delete(self, board_id: str) -> None:
|
def delete(self, board_id: str) -> None:
|
||||||
self._services.board_records.delete(board_id)
|
self._services.board_records.delete(board_id)
|
||||||
|
|
||||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
def get_many(
|
||||||
|
self, offset: int = 0, limit: int = 10
|
||||||
|
) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
board_records = self._services.board_records.get_many(offset, limit)
|
board_records = self._services.board_records.get_many(offset, limit)
|
||||||
board_dtos = []
|
board_dtos = []
|
||||||
for r in board_records.items:
|
for r in board_records.items:
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
return OffsetPaginatedResults[BoardDTO](
|
||||||
|
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
|
||||||
|
)
|
||||||
|
|
||||||
def get_all(self) -> list[BoardDTO]:
|
def get_all(self) -> list[BoardDTO]:
|
||||||
board_records = self._services.board_records.get_all()
|
board_records = self._services.board_records.get_all()
|
||||||
board_dtos = []
|
board_dtos = []
|
||||||
for r in board_records:
|
for r in board_records:
|
||||||
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
cover_image = self._services.image_records.get_most_recent_image_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
if cover_image:
|
if cover_image:
|
||||||
cover_image_name = cover_image.image_name
|
cover_image_name = cover_image.image_name
|
||||||
else:
|
else:
|
||||||
cover_image_name = None
|
cover_image_name = None
|
||||||
|
|
||||||
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
image_count = self._services.board_image_records.get_image_count_for_board(
|
||||||
|
r.board_id
|
||||||
|
)
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
return board_dtos
|
return board_dtos
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||||
|
|
||||||
"""Invokeai configuration system.
|
'''Invokeai configuration system.
|
||||||
|
|
||||||
Arguments and fields are taken from the pydantic definition of the
|
Arguments and fields are taken from the pydantic definition of the
|
||||||
model. Defaults can be set by creating a yaml configuration file that
|
model. Defaults can be set by creating a yaml configuration file that
|
||||||
@@ -158,7 +158,7 @@ two configs are kept in separate sections of the config file:
|
|||||||
outdir: outputs
|
outdir: outputs
|
||||||
...
|
...
|
||||||
|
|
||||||
"""
|
'''
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import argparse
|
import argparse
|
||||||
import pydoc
|
import pydoc
|
||||||
@@ -170,67 +170,64 @@ from pathlib import Path
|
|||||||
from pydantic import BaseSettings, Field, parse_obj_as
|
from pydantic import BaseSettings, Field, parse_obj_as
|
||||||
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path('invokeai.yaml')
|
||||||
DB_FILE = Path("invokeai.db")
|
MODEL_CORE = Path('models/core')
|
||||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
DB_FILE = Path('invokeai.db')
|
||||||
|
LEGACY_INIT_FILE = Path('invokeai.init')
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
"""
|
'''
|
||||||
Runtime configuration settings in which default values are
|
Runtime configuration settings in which default values are
|
||||||
read from an omegaconf .yaml file.
|
read from an omegaconf .yaml file.
|
||||||
"""
|
'''
|
||||||
|
initconf : ClassVar[DictConfig] = None
|
||||||
|
argparse_groups : ClassVar[Dict] = {}
|
||||||
|
|
||||||
initconf: ClassVar[DictConfig] = None
|
def parse_args(self, argv: list=sys.argv[1:]):
|
||||||
argparse_groups: ClassVar[Dict] = {}
|
|
||||||
|
|
||||||
def parse_args(self, argv: list = sys.argv[1:]):
|
|
||||||
parser = self.get_parser()
|
parser = self.get_parser()
|
||||||
opt = parser.parse_args(argv)
|
opt = parser.parse_args(argv)
|
||||||
for name in self.__fields__:
|
for name in self.__fields__:
|
||||||
if name not in self._excluded():
|
if name not in self._excluded():
|
||||||
setattr(self, name, getattr(opt, name))
|
setattr(self, name, getattr(opt,name))
|
||||||
|
|
||||||
def to_yaml(self) -> str:
|
def to_yaml(self)->str:
|
||||||
"""
|
"""
|
||||||
Return a YAML string representing our settings. This can be used
|
Return a YAML string representing our settings. This can be used
|
||||||
as the contents of `invokeai.yaml` to restore settings later.
|
as the contents of `invokeai.yaml` to restore settings later.
|
||||||
"""
|
"""
|
||||||
cls = self.__class__
|
cls = self.__class__
|
||||||
type = get_args(get_type_hints(cls)["type"])[0]
|
type = get_args(get_type_hints(cls)['type'])[0]
|
||||||
field_dict = dict({type: dict()})
|
field_dict = dict({type:dict()})
|
||||||
for name, field in self.__fields__.items():
|
for name,field in self.__fields__.items():
|
||||||
if name in cls._excluded_from_yaml():
|
if name in cls._excluded_from_yaml():
|
||||||
continue
|
continue
|
||||||
category = field.field_info.extra.get("category") or "Uncategorized"
|
category = field.field_info.extra.get("category") or "Uncategorized"
|
||||||
value = getattr(self, name)
|
value = getattr(self,name)
|
||||||
if category not in field_dict[type]:
|
if category not in field_dict[type]:
|
||||||
field_dict[type][category] = dict()
|
field_dict[type][category] = dict()
|
||||||
# keep paths as strings to make it easier to read
|
# keep paths as strings to make it easier to read
|
||||||
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
|
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
|
||||||
conf = OmegaConf.create(field_dict)
|
conf = OmegaConf.create(field_dict)
|
||||||
return OmegaConf.to_yaml(conf)
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_parser_arguments(cls, parser):
|
def add_parser_arguments(cls, parser):
|
||||||
if "type" in get_type_hints(cls):
|
if 'type' in get_type_hints(cls):
|
||||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
|
||||||
else:
|
else:
|
||||||
settings_stanza = "Uncategorized"
|
settings_stanza = "Uncategorized"
|
||||||
|
|
||||||
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
|
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
|
||||||
|
|
||||||
initconf = (
|
initconf = cls.initconf.get(settings_stanza) \
|
||||||
cls.initconf.get(settings_stanza)
|
if cls.initconf and settings_stanza in cls.initconf \
|
||||||
if cls.initconf and settings_stanza in cls.initconf
|
else OmegaConf.create()
|
||||||
else OmegaConf.create()
|
|
||||||
)
|
|
||||||
|
|
||||||
# create an upcase version of the environment in
|
# create an upcase version of the environment in
|
||||||
# order to achieve case-insensitive environment
|
# order to achieve case-insensitive environment
|
||||||
# variables (the way Windows does)
|
# variables (the way Windows does)
|
||||||
upcase_environ = dict()
|
upcase_environ = dict()
|
||||||
for key, value in os.environ.items():
|
for key,value in os.environ.items():
|
||||||
upcase_environ[key.upper()] = value
|
upcase_environ[key.upper()] = value
|
||||||
|
|
||||||
fields = cls.__fields__
|
fields = cls.__fields__
|
||||||
@@ -240,8 +237,8 @@ class InvokeAISettings(BaseSettings):
|
|||||||
if name not in cls._excluded():
|
if name not in cls._excluded():
|
||||||
current_default = field.default
|
current_default = field.default
|
||||||
|
|
||||||
category = field.field_info.extra.get("category", "Uncategorized")
|
category = field.field_info.extra.get("category","Uncategorized")
|
||||||
env_name = env_prefix + "_" + name
|
env_name = env_prefix + '_' + name
|
||||||
if category in initconf and name in initconf.get(category):
|
if category in initconf and name in initconf.get(category):
|
||||||
field.default = initconf.get(category).get(name)
|
field.default = initconf.get(category).get(name)
|
||||||
if env_name.upper() in upcase_environ:
|
if env_name.upper() in upcase_environ:
|
||||||
@@ -251,15 +248,15 @@ class InvokeAISettings(BaseSettings):
|
|||||||
field.default = current_default
|
field.default = current_default
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def cmd_name(self, command_field: str = "type") -> str:
|
def cmd_name(self, command_field: str='type')->str:
|
||||||
hints = get_type_hints(self)
|
hints = get_type_hints(self)
|
||||||
if command_field in hints:
|
if command_field in hints:
|
||||||
return get_args(hints[command_field])[0]
|
return get_args(hints[command_field])[0]
|
||||||
else:
|
else:
|
||||||
return "Uncategorized"
|
return 'Uncategorized'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_parser(cls) -> ArgumentParser:
|
def get_parser(cls)->ArgumentParser:
|
||||||
parser = PagingArgumentParser(
|
parser = PagingArgumentParser(
|
||||||
prog=cls.cmd_name(),
|
prog=cls.cmd_name(),
|
||||||
description=cls.__doc__,
|
description=cls.__doc__,
|
||||||
@@ -272,42 +269,24 @@ class InvokeAISettings(BaseSettings):
|
|||||||
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded(self) -> List[str]:
|
def _excluded(self)->List[str]:
|
||||||
# internal fields that shouldn't be exposed as command line options
|
# internal fields that shouldn't be exposed as command line options
|
||||||
return ["type", "initconf", "cached_root"]
|
return ['type','initconf']
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _excluded_from_yaml(self) -> List[str]:
|
def _excluded_from_yaml(self)->List[str]:
|
||||||
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
|
||||||
return [
|
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root', 'nsfw_checker']
|
||||||
"type",
|
|
||||||
"initconf",
|
|
||||||
"gpu_mem_reserved",
|
|
||||||
"max_loaded_models",
|
|
||||||
"version",
|
|
||||||
"from_file",
|
|
||||||
"model",
|
|
||||||
"restore",
|
|
||||||
"root",
|
|
||||||
"nsfw_checker",
|
|
||||||
"cached_root",
|
|
||||||
]
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file_encoding = "utf-8"
|
env_file_encoding = 'utf-8'
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
|
||||||
field_type = get_type_hints(cls).get(name)
|
field_type = get_type_hints(cls).get(name)
|
||||||
default = (
|
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
|
||||||
default_override
|
|
||||||
if default_override is not None
|
|
||||||
else field.default
|
|
||||||
if field.default_factory is None
|
|
||||||
else field.default_factory()
|
|
||||||
)
|
|
||||||
if category := field.field_info.extra.get("category"):
|
if category := field.field_info.extra.get("category"):
|
||||||
if category not in cls.argparse_groups:
|
if category not in cls.argparse_groups:
|
||||||
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
cls.argparse_groups[category] = command_parser.add_argument_group(category)
|
||||||
@@ -336,10 +315,10 @@ class InvokeAISettings(BaseSettings):
|
|||||||
argparse_group.add_argument(
|
argparse_group.add_argument(
|
||||||
f"--{name}",
|
f"--{name}",
|
||||||
dest=name,
|
dest=name,
|
||||||
nargs="*",
|
nargs='*',
|
||||||
type=field.type_,
|
type=field.type_,
|
||||||
default=default,
|
default=default,
|
||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -348,35 +327,31 @@ class InvokeAISettings(BaseSettings):
|
|||||||
dest=name,
|
dest=name,
|
||||||
type=field.type_,
|
type=field.type_,
|
||||||
default=default,
|
default=default,
|
||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
|
def _find_root()->Path:
|
||||||
|
|
||||||
def _find_root() -> Path:
|
|
||||||
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
|
||||||
if os.environ.get("INVOKEAI_ROOT"):
|
if os.environ.get("INVOKEAI_ROOT"):
|
||||||
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
|
||||||
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
|
elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
|
||||||
root = (venv.parent).resolve()
|
root = (venv.parent).resolve()
|
||||||
else:
|
else:
|
||||||
root = Path("~/invokeai").expanduser().resolve()
|
root = Path("~/invokeai").expanduser().resolve()
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(InvokeAISettings):
|
class InvokeAIAppConfig(InvokeAISettings):
|
||||||
"""
|
'''
|
||||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||||
the command-line client (recommended for experts only), or
|
the command-line client (recommended for experts only), or
|
||||||
"invokeai-web" to launch the web server. Global options
|
"invokeai-web" to launch the web server. Global options
|
||||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||||
setting environment variables INVOKEAI_<setting>.
|
setting environment variables INVOKEAI_<setting>.
|
||||||
"""
|
'''
|
||||||
|
|
||||||
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
singleton_config: ClassVar[InvokeAIAppConfig] = None
|
||||||
singleton_init: ClassVar[Dict] = None
|
singleton_init: ClassVar[Dict] = None
|
||||||
|
|
||||||
# fmt: off
|
#fmt: off
|
||||||
type: Literal["InvokeAI"] = "InvokeAI"
|
type: Literal["InvokeAI"] = "InvokeAI"
|
||||||
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
|
||||||
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
|
||||||
@@ -424,17 +399,16 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
|
||||||
|
|
||||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||||
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
|
#fmt: on
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
|
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
|
||||||
"""
|
'''
|
||||||
Update settings with contents of init file, environment, and
|
Update settings with contents of init file, environment, and
|
||||||
command-line settings.
|
command-line settings.
|
||||||
:param conf: alternate Omegaconf dictionary object
|
:param conf: alternate Omegaconf dictionary object
|
||||||
:param argv: aternate sys.argv list
|
:param argv: aternate sys.argv list
|
||||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||||
"""
|
'''
|
||||||
# Set the runtime root directory. We parse command-line switches here
|
# Set the runtime root directory. We parse command-line switches here
|
||||||
# in order to pick up the --root_dir option.
|
# in order to pick up the --root_dir option.
|
||||||
super().parse_args(argv)
|
super().parse_args(argv)
|
||||||
@@ -451,144 +425,135 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
if self.singleton_init and not clobber:
|
if self.singleton_init and not clobber:
|
||||||
hints = get_type_hints(self.__class__)
|
hints = get_type_hints(self.__class__)
|
||||||
for k in self.singleton_init:
|
for k in self.singleton_init:
|
||||||
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
|
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
def get_config(cls,**kwargs)->InvokeAIAppConfig:
|
||||||
"""
|
'''
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
"""
|
'''
|
||||||
if (
|
if cls.singleton_config is None \
|
||||||
cls.singleton_config is None
|
or type(cls.singleton_config)!=cls \
|
||||||
or type(cls.singleton_config) != cls
|
or (kwargs and cls.singleton_init != kwargs):
|
||||||
or (kwargs and cls.singleton_init != kwargs)
|
|
||||||
):
|
|
||||||
cls.singleton_config = cls(**kwargs)
|
cls.singleton_config = cls(**kwargs)
|
||||||
cls.singleton_init = kwargs
|
cls.singleton_init = kwargs
|
||||||
return cls.singleton_config
|
return cls.singleton_config
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_path(self) -> Path:
|
def root_path(self)->Path:
|
||||||
"""
|
'''
|
||||||
Path to the runtime root directory
|
Path to the runtime root directory
|
||||||
"""
|
'''
|
||||||
# we cache value of root to protect against it being '.' and the cwd changing
|
if self.root:
|
||||||
if self.cached_root:
|
return Path(self.root).expanduser().absolute()
|
||||||
root = self.cached_root
|
|
||||||
elif self.root:
|
|
||||||
root = Path(self.root).expanduser().absolute()
|
|
||||||
else:
|
else:
|
||||||
root = self.find_root()
|
return self.find_root()
|
||||||
self.cached_root = root
|
|
||||||
return self.cached_root
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_dir(self) -> Path:
|
def root_dir(self)->Path:
|
||||||
"""
|
'''
|
||||||
Alias for above.
|
Alias for above.
|
||||||
"""
|
'''
|
||||||
return self.root_path
|
return self.root_path
|
||||||
|
|
||||||
def _resolve(self, partial_path: Path) -> Path:
|
def _resolve(self,partial_path:Path)->Path:
|
||||||
return (self.root_path / partial_path).resolve()
|
return (self.root_path / partial_path).resolve()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def init_file_path(self) -> Path:
|
def init_file_path(self)->Path:
|
||||||
"""
|
'''
|
||||||
Path to invokeai.yaml
|
Path to invokeai.yaml
|
||||||
"""
|
'''
|
||||||
return self._resolve(INIT_FILE)
|
return self._resolve(INIT_FILE)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_path(self) -> Path:
|
def output_path(self)->Path:
|
||||||
"""
|
'''
|
||||||
Path to defaults outputs directory.
|
Path to defaults outputs directory.
|
||||||
"""
|
'''
|
||||||
return self._resolve(self.outdir)
|
return self._resolve(self.outdir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_path(self) -> Path:
|
def db_path(self)->Path:
|
||||||
"""
|
'''
|
||||||
Path to the invokeai.db file.
|
Path to the invokeai.db file.
|
||||||
"""
|
'''
|
||||||
return self._resolve(self.db_dir) / DB_FILE
|
return self._resolve(self.db_dir) / DB_FILE
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_conf_path(self) -> Path:
|
def model_conf_path(self)->Path:
|
||||||
"""
|
'''
|
||||||
Path to models configuration file.
|
Path to models configuration file.
|
||||||
"""
|
'''
|
||||||
return self._resolve(self.conf_path)
|
return self._resolve(self.conf_path)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def legacy_conf_path(self) -> Path:
|
def legacy_conf_path(self)->Path:
|
||||||
"""
|
'''
|
||||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||||
"""
|
'''
|
||||||
return self._resolve(self.legacy_conf_dir)
|
return self._resolve(self.legacy_conf_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def models_path(self) -> Path:
|
def models_path(self)->Path:
|
||||||
"""
|
'''
|
||||||
Path to the models directory
|
Path to the models directory
|
||||||
"""
|
'''
|
||||||
return self._resolve(self.models_dir)
|
return self._resolve(self.models_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def autoconvert_path(self) -> Path:
|
def autoconvert_path(self)->Path:
|
||||||
"""
|
'''
|
||||||
Path to the directory containing models to be imported automatically at startup.
|
Path to the directory containing models to be imported automatically at startup.
|
||||||
"""
|
'''
|
||||||
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
|
||||||
|
|
||||||
# the following methods support legacy calls leftover from the Globals era
|
# the following methods support legacy calls leftover from the Globals era
|
||||||
@property
|
@property
|
||||||
def full_precision(self) -> bool:
|
def full_precision(self)->bool:
|
||||||
"""Return true if precision set to float32"""
|
"""Return true if precision set to float32"""
|
||||||
return self.precision == "float32"
|
return self.precision=='float32'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disable_xformers(self) -> bool:
|
def disable_xformers(self)->bool:
|
||||||
"""Return true if xformers_enabled is false"""
|
"""Return true if xformers_enabled is false"""
|
||||||
return not self.xformers_enabled
|
return not self.xformers_enabled
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def try_patchmatch(self) -> bool:
|
def try_patchmatch(self)->bool:
|
||||||
"""Return true if patchmatch true"""
|
"""Return true if patchmatch true"""
|
||||||
return self.patchmatch
|
return self.patchmatch
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nsfw_checker(self) -> bool:
|
def nsfw_checker(self)->bool:
|
||||||
"""NSFW node is always active and disabled from Web UIe"""
|
""" NSFW node is always active and disabled from Web UIe"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def invisible_watermark(self) -> bool:
|
def invisible_watermark(self)->bool:
|
||||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
""" invisible watermark node is always active and disabled from Web UIe"""
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_root() -> Path:
|
def find_root()->Path:
|
||||||
"""
|
'''
|
||||||
Choose the runtime root directory when not specified on command line or
|
Choose the runtime root directory when not specified on command line or
|
||||||
init file.
|
init file.
|
||||||
"""
|
'''
|
||||||
return _find_root()
|
return _find_root()
|
||||||
|
|
||||||
|
|
||||||
class PagingArgumentParser(argparse.ArgumentParser):
|
class PagingArgumentParser(argparse.ArgumentParser):
|
||||||
"""
|
'''
|
||||||
A custom ArgumentParser that uses pydoc to page its output.
|
A custom ArgumentParser that uses pydoc to page its output.
|
||||||
It also supports reading defaults from an init file.
|
It also supports reading defaults from an init file.
|
||||||
"""
|
'''
|
||||||
|
|
||||||
def print_help(self, file=None):
|
def print_help(self, file=None):
|
||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
|
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
|
||||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
'''
|
||||||
"""
|
|
||||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||||
"""
|
'''
|
||||||
return InvokeAIAppConfig.get_config(**kwargs)
|
return InvokeAIAppConfig.get_config(**kwargs)
|
||||||
|
|||||||
@@ -7,70 +7,47 @@ from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Gr
|
|||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
|
|
||||||
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
|
||||||
|
|
||||||
|
|
||||||
def create_text_to_image() -> LibraryGraph:
|
def create_text_to_image() -> LibraryGraph:
|
||||||
return LibraryGraph(
|
return LibraryGraph(
|
||||||
id=default_text_to_image_graph_id,
|
id=default_text_to_image_graph_id,
|
||||||
name="t2i",
|
name='t2i',
|
||||||
description="Converts text to an image",
|
description='Converts text to an image',
|
||||||
graph=Graph(
|
graph=Graph(
|
||||||
nodes={
|
nodes={
|
||||||
"width": ParamIntInvocation(id="width", a=512),
|
'width': ParamIntInvocation(id='width', a=512),
|
||||||
"height": ParamIntInvocation(id="height", a=512),
|
'height': ParamIntInvocation(id='height', a=512),
|
||||||
"seed": ParamIntInvocation(id="seed", a=-1),
|
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||||
"3": NoiseInvocation(id="3"),
|
'3': NoiseInvocation(id='3'),
|
||||||
"4": CompelInvocation(id="4"),
|
'4': CompelInvocation(id='4'),
|
||||||
"5": CompelInvocation(id="5"),
|
'5': CompelInvocation(id='5'),
|
||||||
"6": TextToLatentsInvocation(id="6"),
|
'6': TextToLatentsInvocation(id='6'),
|
||||||
"7": LatentsToImageInvocation(id="7"),
|
'7': LatentsToImageInvocation(id='7'),
|
||||||
"8": ImageNSFWBlurInvocation(id="8"),
|
'8': ImageNSFWBlurInvocation(id='8'),
|
||||||
},
|
},
|
||||||
edges=[
|
edges=[
|
||||||
Edge(
|
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||||
source=EdgeConnection(node_id="width", field="a"),
|
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||||
destination=EdgeConnection(node_id="3", field="width"),
|
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||||
),
|
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||||
Edge(
|
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||||
source=EdgeConnection(node_id="height", field="a"),
|
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||||
destination=EdgeConnection(node_id="3", field="height"),
|
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||||
),
|
Edge(source=EdgeConnection(node_id='7', field='image'), destination=EdgeConnection(node_id='8', field='image')),
|
||||||
Edge(
|
]
|
||||||
source=EdgeConnection(node_id="seed", field="a"),
|
|
||||||
destination=EdgeConnection(node_id="3", field="seed"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="3", field="noise"),
|
|
||||||
destination=EdgeConnection(node_id="6", field="noise"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="6", field="latents"),
|
|
||||||
destination=EdgeConnection(node_id="7", field="latents"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
|
||||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
|
||||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
|
||||||
),
|
|
||||||
Edge(
|
|
||||||
source=EdgeConnection(node_id="7", field="image"),
|
|
||||||
destination=EdgeConnection(node_id="8", field="image"),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
exposed_inputs=[
|
exposed_inputs=[
|
||||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||||
ExposedNodeInput(node_path="width", field="a", alias="width"),
|
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||||
ExposedNodeInput(node_path="height", field="a", alias="height"),
|
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||||
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
|
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||||
],
|
],
|
||||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
exposed_outputs=[
|
||||||
)
|
ExposedNodeOutput(node_path='8', field='image', alias='image')
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||||
|
|||||||
@@ -44,7 +44,9 @@ class EventServiceBase:
|
|||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
node=node,
|
node=node,
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
progress_image=progress_image.dict()
|
||||||
|
if progress_image is not None
|
||||||
|
else None,
|
||||||
step=step,
|
step=step,
|
||||||
total_steps=total_steps,
|
total_steps=total_steps,
|
||||||
),
|
),
|
||||||
@@ -88,7 +90,9 @@ class EventServiceBase:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
|
def emit_invocation_started(
|
||||||
|
self, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||||
|
) -> None:
|
||||||
"""Emitted when an invocation has started"""
|
"""Emitted when an invocation has started"""
|
||||||
self.__emit_session_event(
|
self.__emit_session_event(
|
||||||
event_name="invocation_started",
|
event_name="invocation_started",
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ from ..invocations.baseinvocation import (
|
|||||||
# in 3.10 this would be "from types import NoneType"
|
# in 3.10 this would be "from types import NoneType"
|
||||||
NoneType = type(None)
|
NoneType = type(None)
|
||||||
|
|
||||||
|
|
||||||
class EdgeConnection(BaseModel):
|
class EdgeConnection(BaseModel):
|
||||||
node_id: str = Field(description="The id of the node for this edge connection")
|
node_id: str = Field(description="The id of the node for this edge connection")
|
||||||
field: str = Field(description="The field for this connection")
|
field: str = Field(description="The field for this connection")
|
||||||
@@ -62,7 +61,6 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
|
|||||||
node_input_field = node_inputs.get(field) or None
|
node_input_field = node_inputs.get(field) or None
|
||||||
return node_input_field
|
return node_input_field
|
||||||
|
|
||||||
|
|
||||||
def is_union_subtype(t1, t2):
|
def is_union_subtype(t1, t2):
|
||||||
t1_args = get_args(t1)
|
t1_args = get_args(t1)
|
||||||
t2_args = get_args(t2)
|
t2_args = get_args(t2)
|
||||||
@@ -73,7 +71,6 @@ def is_union_subtype(t1, t2):
|
|||||||
# t1 is a Union, check that all of its types are in t2_args
|
# t1 is a Union, check that all of its types are in t2_args
|
||||||
return all(arg in t2_args for arg in t1_args)
|
return all(arg in t2_args for arg in t1_args)
|
||||||
|
|
||||||
|
|
||||||
def is_list_or_contains_list(t):
|
def is_list_or_contains_list(t):
|
||||||
t_args = get_args(t)
|
t_args = get_args(t)
|
||||||
|
|
||||||
@@ -157,17 +154,15 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": [
|
'required': [
|
||||||
"type",
|
'type',
|
||||||
"image",
|
'image',
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class GraphInvocation(BaseInvocation):
|
class GraphInvocation(BaseInvocation):
|
||||||
"""Execute a graph"""
|
"""Execute a graph"""
|
||||||
|
|
||||||
type: Literal["graph"] = "graph"
|
type: Literal["graph"] = "graph"
|
||||||
|
|
||||||
# TODO: figure out how to create a default here
|
# TODO: figure out how to create a default here
|
||||||
@@ -187,21 +182,23 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": [
|
'required': [
|
||||||
"type",
|
'type',
|
||||||
"item",
|
'item',
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fill this out and move to invocations
|
# TODO: Fill this out and move to invocations
|
||||||
class IterateInvocation(BaseInvocation):
|
class IterateInvocation(BaseInvocation):
|
||||||
"""Iterates over a list of items"""
|
"""Iterates over a list of items"""
|
||||||
|
|
||||||
type: Literal["iterate"] = "iterate"
|
type: Literal["iterate"] = "iterate"
|
||||||
|
|
||||||
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
|
collection: list[Any] = Field(
|
||||||
index: int = Field(description="The index, will be provided on executed iterators", default=0)
|
description="The list of items to iterate over", default_factory=list
|
||||||
|
)
|
||||||
|
index: int = Field(
|
||||||
|
description="The index, will be provided on executed iterators", default=0
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||||
"""Produces the outputs as values"""
|
"""Produces the outputs as values"""
|
||||||
@@ -215,13 +212,12 @@ class CollectInvocationOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": [
|
'required': [
|
||||||
"type",
|
'type',
|
||||||
"collection",
|
'collection',
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class CollectInvocation(BaseInvocation):
|
class CollectInvocation(BaseInvocation):
|
||||||
"""Collects values into a collection"""
|
"""Collects values into a collection"""
|
||||||
|
|
||||||
@@ -273,7 +269,9 @@ class Graph(BaseModel):
|
|||||||
if node_path in self.nodes:
|
if node_path in self.nodes:
|
||||||
return (self, node_path)
|
return (self, node_path)
|
||||||
|
|
||||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
node_id = (
|
||||||
|
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||||
|
)
|
||||||
if node_id not in self.nodes:
|
if node_id not in self.nodes:
|
||||||
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
||||||
|
|
||||||
@@ -335,7 +333,9 @@ class Graph(BaseModel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate all edges reference nodes in the graph
|
# Validate all edges reference nodes in the graph
|
||||||
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
|
node_ids = set(
|
||||||
|
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
|
||||||
|
)
|
||||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -361,14 +361,22 @@ class Graph(BaseModel):
|
|||||||
# Validate all iterators
|
# Validate all iterators
|
||||||
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
|
||||||
if not all(
|
if not all(
|
||||||
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
|
(
|
||||||
|
self._is_iterator_connection_valid(n.id)
|
||||||
|
for n in self.nodes.values()
|
||||||
|
if isinstance(n, IterateInvocation)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate all collectors
|
# Validate all collectors
|
||||||
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
|
||||||
if not all(
|
if not all(
|
||||||
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
|
(
|
||||||
|
self._is_collector_connection_valid(n.id)
|
||||||
|
for n in self.nodes.values()
|
||||||
|
if isinstance(n, CollectInvocation)
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -387,51 +395,48 @@ class Graph(BaseModel):
|
|||||||
# Validate that an edge to this node+field doesn't already exist
|
# Validate that an edge to this node+field doesn't already exist
|
||||||
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||||
raise InvalidEdgeError(
|
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
|
||||||
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate that no cycles would be created
|
# Validate that no cycles would be created
|
||||||
g = self.nx_graph_flat()
|
g = self.nx_graph_flat()
|
||||||
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||||
if not nx.is_directed_acyclic_graph(g):
|
if not nx.is_directed_acyclic_graph(g):
|
||||||
raise InvalidEdgeError(
|
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
|
||||||
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate that the field types are compatible
|
# Validate that the field types are compatible
|
||||||
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
|
if not are_connections_compatible(
|
||||||
raise InvalidEdgeError(
|
from_node, edge.source.field, to_node, edge.destination.field
|
||||||
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
):
|
||||||
)
|
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||||
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||||
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
|
if not self._is_iterator_connection_valid(
|
||||||
raise InvalidEdgeError(
|
edge.destination.node_id, new_input=edge.source
|
||||||
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
):
|
||||||
)
|
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||||
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
|
if not self._is_iterator_connection_valid(
|
||||||
raise InvalidEdgeError(
|
edge.source.node_id, new_output=edge.destination
|
||||||
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
):
|
||||||
)
|
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||||
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
|
if not self._is_collector_connection_valid(
|
||||||
raise InvalidEdgeError(
|
edge.destination.node_id, new_input=edge.source
|
||||||
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
):
|
||||||
)
|
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||||
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||||
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
|
if not self._is_collector_connection_valid(
|
||||||
raise InvalidEdgeError(
|
edge.source.node_id, new_output=edge.destination
|
||||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
):
|
||||||
)
|
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
|
||||||
|
|
||||||
|
|
||||||
def has_node(self, node_path: str) -> bool:
|
def has_node(self, node_path: str) -> bool:
|
||||||
"""Determines whether or not a node exists in the graph."""
|
"""Determines whether or not a node exists in the graph."""
|
||||||
@@ -460,13 +465,17 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Ensure the node type matches the new node
|
# Ensure the node type matches the new node
|
||||||
if type(node) != type(new_node):
|
if type(node) != type(new_node):
|
||||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
raise TypeError(
|
||||||
|
f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure the new id is either the same or is not in the graph
|
# Ensure the new id is either the same or is not in the graph
|
||||||
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
||||||
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
||||||
if new_node.id != node.id and self.has_node(new_path):
|
if new_node.id != node.id and self.has_node(new_path):
|
||||||
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
|
raise NodeAlreadyInGraphError(
|
||||||
|
"Node with id {new_node.id} already exists in graph"
|
||||||
|
)
|
||||||
|
|
||||||
# Set the new node in the graph
|
# Set the new node in the graph
|
||||||
graph.nodes[new_node.id] = new_node
|
graph.nodes[new_node.id] = new_node
|
||||||
@@ -488,7 +497,9 @@ class Graph(BaseModel):
|
|||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
Edge(
|
Edge(
|
||||||
source=edge.source,
|
source=edge.source,
|
||||||
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
|
destination=EdgeConnection(
|
||||||
|
node_id=new_graph_node_path, field=edge.destination.field
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -501,12 +512,16 @@ class Graph(BaseModel):
|
|||||||
)
|
)
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
Edge(
|
Edge(
|
||||||
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
|
source=EdgeConnection(
|
||||||
destination=edge.destination,
|
node_id=new_graph_node_path, field=edge.source.field
|
||||||
|
),
|
||||||
|
destination=edge.destination
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
|
def _get_input_edges(
|
||||||
|
self, node_path: str, field: Optional[str] = None
|
||||||
|
) -> list[Edge]:
|
||||||
"""Gets all input edges for a node"""
|
"""Gets all input edges for a node"""
|
||||||
edges = self._get_input_edges_and_graphs(node_path)
|
edges = self._get_input_edges_and_graphs(node_path)
|
||||||
|
|
||||||
@@ -523,7 +538,7 @@ class Graph(BaseModel):
|
|||||||
destination=EdgeConnection(
|
destination=EdgeConnection(
|
||||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||||
field=e.destination.field,
|
field=e.destination.field,
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
for _, prefix, e in filtered_edges
|
for _, prefix, e in filtered_edges
|
||||||
]
|
]
|
||||||
@@ -535,20 +550,32 @@ class Graph(BaseModel):
|
|||||||
edges = list()
|
edges = list()
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
edges.extend(
|
||||||
|
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
|
||||||
|
)
|
||||||
|
|
||||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
node_id = (
|
||||||
|
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||||
|
)
|
||||||
node = self.nodes[node_id]
|
node = self.nodes[node_id]
|
||||||
|
|
||||||
if isinstance(node, GraphInvocation):
|
if isinstance(node, GraphInvocation):
|
||||||
graph = node.graph
|
graph = node.graph
|
||||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
graph_path = (
|
||||||
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
node.id
|
||||||
|
if prefix is None or prefix == ""
|
||||||
|
else self._get_node_path(node.id, prefix=prefix)
|
||||||
|
)
|
||||||
|
graph_edges = graph._get_input_edges_and_graphs(
|
||||||
|
node_path[(len(node_id) + 1) :], prefix=graph_path
|
||||||
|
)
|
||||||
edges.extend(graph_edges)
|
edges.extend(graph_edges)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
|
|
||||||
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
|
def _get_output_edges(
|
||||||
|
self, node_path: str, field: str
|
||||||
|
) -> list[Edge]:
|
||||||
"""Gets all output edges for a node"""
|
"""Gets all output edges for a node"""
|
||||||
edges = self._get_output_edges_and_graphs(node_path)
|
edges = self._get_output_edges_and_graphs(node_path)
|
||||||
|
|
||||||
@@ -565,7 +592,7 @@ class Graph(BaseModel):
|
|||||||
destination=EdgeConnection(
|
destination=EdgeConnection(
|
||||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||||
field=e.destination.field,
|
field=e.destination.field,
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
for _, prefix, e in filtered_edges
|
for _, prefix, e in filtered_edges
|
||||||
]
|
]
|
||||||
@@ -577,15 +604,25 @@ class Graph(BaseModel):
|
|||||||
edges = list()
|
edges = list()
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
edges.extend(
|
||||||
|
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
|
||||||
|
)
|
||||||
|
|
||||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
node_id = (
|
||||||
|
node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||||
|
)
|
||||||
node = self.nodes[node_id]
|
node = self.nodes[node_id]
|
||||||
|
|
||||||
if isinstance(node, GraphInvocation):
|
if isinstance(node, GraphInvocation):
|
||||||
graph = node.graph
|
graph = node.graph
|
||||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
graph_path = (
|
||||||
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
node.id
|
||||||
|
if prefix is None or prefix == ""
|
||||||
|
else self._get_node_path(node.id, prefix=prefix)
|
||||||
|
)
|
||||||
|
graph_edges = graph._get_output_edges_and_graphs(
|
||||||
|
node_path[(len(node_id) + 1) :], prefix=graph_path
|
||||||
|
)
|
||||||
edges.extend(graph_edges)
|
edges.extend(graph_edges)
|
||||||
|
|
||||||
return edges
|
return edges
|
||||||
@@ -609,8 +646,12 @@ class Graph(BaseModel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
|
input_field = get_output_field(
|
||||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
self.get_node(inputs[0].node_id), inputs[0].field
|
||||||
|
)
|
||||||
|
output_fields = list(
|
||||||
|
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||||
|
)
|
||||||
|
|
||||||
# Input type must be a list
|
# Input type must be a list
|
||||||
if get_origin(input_field) != list:
|
if get_origin(input_field) != list:
|
||||||
@@ -618,7 +659,12 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Validate that all outputs match the input type
|
# Validate that all outputs match the input type
|
||||||
input_field_item_type = get_args(input_field)[0]
|
input_field_item_type = get_args(input_field)[0]
|
||||||
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
|
if not all(
|
||||||
|
(
|
||||||
|
are_connection_types_compatible(input_field_item_type, f)
|
||||||
|
for f in output_fields
|
||||||
|
)
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -638,21 +684,35 @@ class Graph(BaseModel):
|
|||||||
outputs.append(new_output)
|
outputs.append(new_output)
|
||||||
|
|
||||||
# Get input and output fields (the fields linked to the iterator's input/output)
|
# Get input and output fields (the fields linked to the iterator's input/output)
|
||||||
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
|
input_fields = list(
|
||||||
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
|
[get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
|
||||||
|
)
|
||||||
|
output_fields = list(
|
||||||
|
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||||
|
)
|
||||||
|
|
||||||
# Validate that all inputs are derived from or match a single type
|
# Validate that all inputs are derived from or match a single type
|
||||||
input_field_types = set(
|
input_field_types = set(
|
||||||
[
|
[
|
||||||
t
|
t
|
||||||
for input_field in input_fields
|
for input_field in input_fields
|
||||||
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
|
for t in (
|
||||||
|
[input_field]
|
||||||
|
if get_origin(input_field) == None
|
||||||
|
else get_args(input_field)
|
||||||
|
)
|
||||||
if t != NoneType
|
if t != NoneType
|
||||||
]
|
]
|
||||||
) # Get unique types
|
) # Get unique types
|
||||||
type_tree = nx.DiGraph()
|
type_tree = nx.DiGraph()
|
||||||
type_tree.add_nodes_from(input_field_types)
|
type_tree.add_nodes_from(input_field_types)
|
||||||
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
|
type_tree.add_edges_from(
|
||||||
|
[
|
||||||
|
e
|
||||||
|
for e in itertools.permutations(input_field_types, 2)
|
||||||
|
if issubclass(e[1], e[0])
|
||||||
|
]
|
||||||
|
)
|
||||||
type_degrees = type_tree.in_degree(type_tree.nodes)
|
type_degrees = type_tree.in_degree(type_tree.nodes)
|
||||||
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
|
||||||
return False # There is more than one root type
|
return False # There is more than one root type
|
||||||
@@ -669,7 +729,9 @@ class Graph(BaseModel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Verify that all outputs match the input type (are a base class or the same class)
|
# Verify that all outputs match the input type (are a base class or the same class)
|
||||||
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
|
if not all(
|
||||||
|
(issubclass(input_root_type, get_args(f)[0]) for f in output_fields)
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -689,7 +751,9 @@ class Graph(BaseModel):
|
|||||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
def nx_graph_flat(
|
||||||
|
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||||
|
) -> nx.DiGraph:
|
||||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||||
g = nx_graph or nx.DiGraph()
|
g = nx_graph or nx.DiGraph()
|
||||||
|
|
||||||
@@ -698,18 +762,26 @@ class Graph(BaseModel):
|
|||||||
[
|
[
|
||||||
self._get_node_path(n.id, prefix)
|
self._get_node_path(n.id, prefix)
|
||||||
for n in self.nodes.values()
|
for n in self.nodes.values()
|
||||||
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
|
if not isinstance(n, GraphInvocation)
|
||||||
|
and not isinstance(n, IterateInvocation)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expand graph nodes
|
# Expand graph nodes
|
||||||
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
for sgn in (
|
||||||
|
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
|
||||||
|
):
|
||||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||||
|
|
||||||
# TODO: figure out if iteration nodes need to be expanded
|
# TODO: figure out if iteration nodes need to be expanded
|
||||||
|
|
||||||
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
g.add_edges_from(
|
||||||
|
[
|
||||||
|
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
||||||
|
for e in unique_edges
|
||||||
|
]
|
||||||
|
)
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
@@ -728,19 +800,23 @@ class GraphExecutionState(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Nodes that have been executed
|
# Nodes that have been executed
|
||||||
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
|
executed: set[str] = Field(
|
||||||
|
description="The set of node ids that have been executed", default_factory=set
|
||||||
|
)
|
||||||
executed_history: list[str] = Field(
|
executed_history: list[str] = Field(
|
||||||
description="The list of node ids that have been executed, in order of execution",
|
description="The list of node ids that have been executed, in order of execution",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The results of executed nodes
|
# The results of executed nodes
|
||||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
results: dict[
|
||||||
description="The results of node executions", default_factory=dict
|
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
|
||||||
)
|
] = Field(description="The results of node executions", default_factory=dict)
|
||||||
|
|
||||||
# Errors raised when executing nodes
|
# Errors raised when executing nodes
|
||||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
errors: dict[str, str] = Field(
|
||||||
|
description="Errors raised when executing nodes", default_factory=dict
|
||||||
|
)
|
||||||
|
|
||||||
# Map of prepared/executed nodes to their original nodes
|
# Map of prepared/executed nodes to their original nodes
|
||||||
prepared_source_mapping: dict[str, str] = Field(
|
prepared_source_mapping: dict[str, str] = Field(
|
||||||
@@ -756,16 +832,16 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"required": [
|
'required': [
|
||||||
"id",
|
'id',
|
||||||
"graph",
|
'graph',
|
||||||
"execution_graph",
|
'execution_graph',
|
||||||
"executed",
|
'executed',
|
||||||
"executed_history",
|
'executed_history',
|
||||||
"results",
|
'results',
|
||||||
"errors",
|
'errors',
|
||||||
"prepared_source_mapping",
|
'prepared_source_mapping',
|
||||||
"source_prepared_mapping",
|
'source_prepared_mapping',
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -823,7 +899,9 @@ class GraphExecutionState(BaseModel):
|
|||||||
"""Returns true if the graph has any errors"""
|
"""Returns true if the graph has any errors"""
|
||||||
return len(self.errors) > 0
|
return len(self.errors) > 0
|
||||||
|
|
||||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
def _create_execution_node(
|
||||||
|
self, node_path: str, iteration_node_map: list[tuple[str, str]]
|
||||||
|
) -> list[str]:
|
||||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||||
|
|
||||||
node = self.graph.get_node(node_path)
|
node = self.graph.get_node(node_path)
|
||||||
@@ -833,12 +911,20 @@ class GraphExecutionState(BaseModel):
|
|||||||
# If this is an iterator node, we must create a copy for each iteration
|
# If this is an iterator node, we must create a copy for each iteration
|
||||||
if isinstance(node, IterateInvocation):
|
if isinstance(node, IterateInvocation):
|
||||||
# Get input collection edge (should error if there are no inputs)
|
# Get input collection edge (should error if there are no inputs)
|
||||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
|
input_collection_edge = next(
|
||||||
input_collection_prepared_node_id = next(
|
iter(self.graph._get_input_edges(node_path, "collection"))
|
||||||
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
)
|
||||||
|
input_collection_prepared_node_id = next(
|
||||||
|
n[1]
|
||||||
|
for n in iteration_node_map
|
||||||
|
if n[0] == input_collection_edge.source.node_id
|
||||||
|
)
|
||||||
|
input_collection_prepared_node_output = self.results[
|
||||||
|
input_collection_prepared_node_id
|
||||||
|
]
|
||||||
|
input_collection = getattr(
|
||||||
|
input_collection_prepared_node_output, input_collection_edge.source.field
|
||||||
)
|
)
|
||||||
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
|
|
||||||
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
|
|
||||||
self_iteration_count = len(input_collection)
|
self_iteration_count = len(input_collection)
|
||||||
|
|
||||||
new_nodes = list()
|
new_nodes = list()
|
||||||
@@ -853,7 +939,9 @@ class GraphExecutionState(BaseModel):
|
|||||||
# For collect nodes, this may contain multiple inputs to the same field
|
# For collect nodes, this may contain multiple inputs to the same field
|
||||||
new_edges = list()
|
new_edges = list()
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
|
for input_node_id in (
|
||||||
|
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
|
||||||
|
):
|
||||||
new_edge = Edge(
|
new_edge = Edge(
|
||||||
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
||||||
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
||||||
@@ -894,7 +982,11 @@ class GraphExecutionState(BaseModel):
|
|||||||
def _iterator_graph(self) -> nx.DiGraph:
|
def _iterator_graph(self) -> nx.DiGraph:
|
||||||
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
|
||||||
g = self.graph.nx_graph_flat()
|
g = self.graph.nx_graph_flat()
|
||||||
collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
|
collectors = (
|
||||||
|
n
|
||||||
|
for n in self.graph.nodes
|
||||||
|
if isinstance(self.graph.get_node(n), CollectInvocation)
|
||||||
|
)
|
||||||
for c in collectors:
|
for c in collectors:
|
||||||
g.remove_edges_from(list(g.in_edges(c)))
|
g.remove_edges_from(list(g.in_edges(c)))
|
||||||
return g
|
return g
|
||||||
@@ -902,7 +994,11 @@ class GraphExecutionState(BaseModel):
|
|||||||
def _get_node_iterators(self, node_id: str) -> list[str]:
|
def _get_node_iterators(self, node_id: str) -> list[str]:
|
||||||
"""Gets iterators for a node"""
|
"""Gets iterators for a node"""
|
||||||
g = self._iterator_graph()
|
g = self._iterator_graph()
|
||||||
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
|
iterators = [
|
||||||
|
n
|
||||||
|
for n in nx.ancestors(g, node_id)
|
||||||
|
if isinstance(self.graph.get_node(n), IterateInvocation)
|
||||||
|
]
|
||||||
return iterators
|
return iterators
|
||||||
|
|
||||||
def _prepare(self) -> Optional[str]:
|
def _prepare(self) -> Optional[str]:
|
||||||
@@ -949,18 +1045,29 @@ class GraphExecutionState(BaseModel):
|
|||||||
if isinstance(next_node, CollectInvocation):
|
if isinstance(next_node, CollectInvocation):
|
||||||
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
# Collapse all iterator input mappings and create a single execution node for the collect invocation
|
||||||
all_iteration_mappings = list(
|
all_iteration_mappings = list(
|
||||||
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
|
itertools.chain(
|
||||||
|
*(
|
||||||
|
((s, p) for p in self.source_prepared_mapping[s])
|
||||||
|
for s in next_node_parents
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
|
||||||
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
|
create_results = self._create_execution_node(
|
||||||
|
next_node_id, all_iteration_mappings
|
||||||
|
)
|
||||||
if create_results is not None:
|
if create_results is not None:
|
||||||
new_node_ids.extend(create_results)
|
new_node_ids.extend(create_results)
|
||||||
else: # Iterators or normal nodes
|
else: # Iterators or normal nodes
|
||||||
# Get all iterator combinations for this node
|
# Get all iterator combinations for this node
|
||||||
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
|
||||||
iterator_nodes = self._get_node_iterators(next_node_id)
|
iterator_nodes = self._get_node_iterators(next_node_id)
|
||||||
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
|
iterator_nodes_prepared = [
|
||||||
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
|
list(self.source_prepared_mapping[n]) for n in iterator_nodes
|
||||||
|
]
|
||||||
|
iterator_node_prepared_combinations = list(
|
||||||
|
itertools.product(*iterator_nodes_prepared)
|
||||||
|
)
|
||||||
|
|
||||||
# Select the correct prepared parents for each iteration
|
# Select the correct prepared parents for each iteration
|
||||||
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
|
||||||
@@ -989,16 +1096,31 @@ class GraphExecutionState(BaseModel):
|
|||||||
return next(iter(prepared_nodes))
|
return next(iter(prepared_nodes))
|
||||||
|
|
||||||
# Check if the requested node is an iterator
|
# Check if the requested node is an iterator
|
||||||
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
|
prepared_iterator = next(
|
||||||
|
(n for n in prepared_nodes if n in prepared_iterator_nodes), None
|
||||||
|
)
|
||||||
if prepared_iterator is not None:
|
if prepared_iterator is not None:
|
||||||
return prepared_iterator
|
return prepared_iterator
|
||||||
|
|
||||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||||
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
iterator_source_node_mapping = [
|
||||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes
|
||||||
|
]
|
||||||
|
parent_iterators = [
|
||||||
|
itn
|
||||||
|
for itn in iterator_source_node_mapping
|
||||||
|
if nx.has_path(graph, itn[1], source_node_path)
|
||||||
|
]
|
||||||
|
|
||||||
return next(
|
return next(
|
||||||
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
|
(
|
||||||
|
n
|
||||||
|
for n in prepared_nodes
|
||||||
|
if all(
|
||||||
|
nx.has_path(execution_graph, pit[0], n)
|
||||||
|
for pit in parent_iterators
|
||||||
|
)
|
||||||
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1013,8 +1135,8 @@ class GraphExecutionState(BaseModel):
|
|||||||
(
|
(
|
||||||
n
|
n
|
||||||
for n in sorted_nodes
|
for n in sorted_nodes
|
||||||
if n not in self.executed # the node must not already be executed...
|
if n not in self.executed # the node must not already be executed...
|
||||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -1099,18 +1221,15 @@ class ExposedNodeOutput(BaseModel):
|
|||||||
field: str = Field(description="The field name of the output")
|
field: str = Field(description="The field name of the output")
|
||||||
alias: str = Field(description="The alias of the output")
|
alias: str = Field(description="The alias of the output")
|
||||||
|
|
||||||
|
|
||||||
class LibraryGraph(BaseModel):
|
class LibraryGraph(BaseModel):
|
||||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
|
||||||
graph: Graph = Field(description="The graph")
|
graph: Graph = Field(description="The graph")
|
||||||
name: str = Field(description="The name of the graph")
|
name: str = Field(description="The name of the graph")
|
||||||
description: str = Field(description="The description of the graph")
|
description: str = Field(description="The description of the graph")
|
||||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||||
exposed_outputs: list[ExposedNodeOutput] = Field(
|
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
|
||||||
description="The outputs exposed by this graph", default_factory=list
|
|
||||||
)
|
|
||||||
|
|
||||||
@validator("exposed_inputs", "exposed_outputs")
|
@validator('exposed_inputs', 'exposed_outputs')
|
||||||
def validate_exposed_aliases(cls, v):
|
def validate_exposed_aliases(cls, v):
|
||||||
if len(v) != len(set(i.alias for i in v)):
|
if len(v) != len(set(i.alias for i in v)):
|
||||||
raise ValueError("Duplicate exposed alias")
|
raise ValueError("Duplicate exposed alias")
|
||||||
@@ -1118,27 +1237,23 @@ class LibraryGraph(BaseModel):
|
|||||||
|
|
||||||
@root_validator
|
@root_validator
|
||||||
def validate_exposed_nodes(cls, values):
|
def validate_exposed_nodes(cls, values):
|
||||||
graph = values["graph"]
|
graph = values['graph']
|
||||||
|
|
||||||
# Validate exposed inputs
|
# Validate exposed inputs
|
||||||
for exposed_input in values["exposed_inputs"]:
|
for exposed_input in values['exposed_inputs']:
|
||||||
if not graph.has_node(exposed_input.node_path):
|
if not graph.has_node(exposed_input.node_path):
|
||||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||||
node = graph.get_node(exposed_input.node_path)
|
node = graph.get_node(exposed_input.node_path)
|
||||||
if get_input_field(node, exposed_input.field) is None:
|
if get_input_field(node, exposed_input.field) is None:
|
||||||
raise ValueError(
|
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
|
||||||
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate exposed outputs
|
# Validate exposed outputs
|
||||||
for exposed_output in values["exposed_outputs"]:
|
for exposed_output in values['exposed_outputs']:
|
||||||
if not graph.has_node(exposed_output.node_path):
|
if not graph.has_node(exposed_output.node_path):
|
||||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||||
node = graph.get_node(exposed_output.node_path)
|
node = graph.get_node(exposed_output.node_path)
|
||||||
if get_output_field(node, exposed_output.field) is None:
|
if get_output_field(node, exposed_output.field) is None:
|
||||||
raise ValueError(
|
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
|
||||||
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
self.__cache_ids = Queue()
|
self.__cache_ids = Queue()
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__output_folder: Path = (
|
||||||
|
output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
|
)
|
||||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||||
|
|
||||||
# Validate required output folders at launch
|
# Validate required output folders at launch
|
||||||
@@ -181,7 +183,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
def __set_cache(self, image_name: Path, image: PILImageType):
|
def __set_cache(self, image_name: Path, image: PILImageType):
|
||||||
if not image_name in self.__cache:
|
if not image_name in self.__cache:
|
||||||
self.__cache[image_name] = image
|
self.__cache[image_name] = image
|
||||||
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
|
self.__cache_ids.put(
|
||||||
|
image_name
|
||||||
|
) # TODO: this should refresh position for LRU cache
|
||||||
if len(self.__cache) > self.__max_cache_size:
|
if len(self.__cache) > self.__max_cache_size:
|
||||||
cache_id = self.__cache_ids.get()
|
cache_id = self.__cache_ids.get()
|
||||||
if cache_id in self.__cache:
|
if cache_id in self.__cache:
|
||||||
|
|||||||
@@ -426,7 +426,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
return OffsetPaginatedResults(
|
||||||
|
items=images, offset=offset, limit=limit, total=count
|
||||||
|
)
|
||||||
|
|
||||||
def delete(self, image_name: str) -> None:
|
def delete(self, image_name: str) -> None:
|
||||||
try:
|
try:
|
||||||
@@ -464,6 +466,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
|
|
||||||
def delete_intermediates(self) -> list[str]:
|
def delete_intermediates(self) -> list[str]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
@@ -502,7 +505,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
|||||||
is_intermediate: bool = False,
|
is_intermediate: bool = False,
|
||||||
) -> datetime:
|
) -> datetime:
|
||||||
try:
|
try:
|
||||||
metadata_json = None if metadata is None else json.dumps(metadata)
|
metadata_json = (
|
||||||
|
None if metadata is None else json.dumps(metadata)
|
||||||
|
)
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
|
|||||||
@@ -217,8 +217,12 @@ class ImageService(ImageServiceABC):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
if board_id is not None:
|
if board_id is not None:
|
||||||
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
self._services.board_image_records.add_image_to_board(
|
||||||
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
|
board_id=board_id, image_name=image_name
|
||||||
|
)
|
||||||
|
self._services.image_files.save(
|
||||||
|
image_name=image_name, image=image, metadata=metadata, graph=graph
|
||||||
|
)
|
||||||
image_dto = self.get_dto(image_name)
|
image_dto = self.get_dto(image_name)
|
||||||
|
|
||||||
return image_dto
|
return image_dto
|
||||||
@@ -293,7 +297,9 @@ class ImageService(ImageServiceABC):
|
|||||||
if not image_record.session_id:
|
if not image_record.session_id:
|
||||||
return ImageMetadata()
|
return ImageMetadata()
|
||||||
|
|
||||||
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
session_raw = self._services.graph_execution_manager.get_raw(
|
||||||
|
image_record.session_id
|
||||||
|
)
|
||||||
graph = None
|
graph = None
|
||||||
|
|
||||||
if session_raw:
|
if session_raw:
|
||||||
@@ -358,7 +364,9 @@ class ImageService(ImageServiceABC):
|
|||||||
r,
|
r,
|
||||||
self._services.urls.get_image_url(r.image_name),
|
self._services.urls.get_image_url(r.image_name),
|
||||||
self._services.urls.get_image_url(r.image_name, True),
|
self._services.urls.get_image_url(r.image_name, True),
|
||||||
self._services.board_image_records.get_board_for_image(r.image_name),
|
self._services.board_image_records.get_board_for_image(
|
||||||
|
r.image_name
|
||||||
|
),
|
||||||
),
|
),
|
||||||
results.items,
|
results.items,
|
||||||
)
|
)
|
||||||
@@ -390,7 +398,11 @@ class ImageService(ImageServiceABC):
|
|||||||
|
|
||||||
def delete_images_on_board(self, board_id: str):
|
def delete_images_on_board(self, board_id: str):
|
||||||
try:
|
try:
|
||||||
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
image_names = (
|
||||||
|
self._services.board_image_records.get_all_board_image_names_for_board(
|
||||||
|
board_id
|
||||||
|
)
|
||||||
|
)
|
||||||
for image_name in image_names:
|
for image_name in image_names:
|
||||||
self._services.image_files.delete(image_name)
|
self._services.image_files.delete(image_name)
|
||||||
self._services.image_records.delete_many(image_names)
|
self._services.image_records.delete_many(image_names)
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from queue import Queue
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueItem(BaseModel):
|
class InvocationQueueItem(BaseModel):
|
||||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||||
@@ -46,11 +45,9 @@ class MemoryInvocationQueue(InvocationQueueABC):
|
|||||||
def get(self) -> InvocationQueueItem:
|
def get(self) -> InvocationQueueItem:
|
||||||
item = self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
|
||||||
while (
|
while isinstance(item, InvocationQueueItem) \
|
||||||
isinstance(item, InvocationQueueItem)
|
and item.graph_execution_state_id in self.__cancellations \
|
||||||
and item.graph_execution_state_id in self.__cancellations
|
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
|
||||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
|
|
||||||
):
|
|
||||||
item = self.__queue.get()
|
item = self.__queue.get()
|
||||||
|
|
||||||
# Clear old items
|
# Clear old items
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from .graph import Graph, GraphExecutionState
|
|||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invocation_services import InvocationServices
|
from .invocation_services import InvocationServices
|
||||||
|
|
||||||
|
|
||||||
class Invoker:
|
class Invoker:
|
||||||
"""The invoker, used to execute invocations"""
|
"""The invoker, used to execute invocations"""
|
||||||
|
|
||||||
@@ -17,7 +16,9 @@ class Invoker:
|
|||||||
self.services = services
|
self.services = services
|
||||||
self._start()
|
self._start()
|
||||||
|
|
||||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
def invoke(
|
||||||
|
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||||
|
) -> Optional[str]:
|
||||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||||
|
|
||||||
|
|||||||
@@ -9,15 +9,13 @@ T = TypeVar("T", bound=BaseModel)
|
|||||||
|
|
||||||
class PaginatedResults(GenericModel, Generic[T]):
|
class PaginatedResults(GenericModel, Generic[T]):
|
||||||
"""Paginated results"""
|
"""Paginated results"""
|
||||||
|
#fmt: off
|
||||||
# fmt: off
|
|
||||||
items: list[T] = Field(description="Items")
|
items: list[T] = Field(description="Items")
|
||||||
page: int = Field(description="Current Page")
|
page: int = Field(description="Current Page")
|
||||||
pages: int = Field(description="Total number of pages")
|
pages: int = Field(description="Total number of pages")
|
||||||
per_page: int = Field(description="Number of items per page")
|
per_page: int = Field(description="Number of items per page")
|
||||||
total: int = Field(description="Total number of items in result")
|
total: int = Field(description="Total number of items in result")
|
||||||
# fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageABC(ABC, Generic[T]):
|
class ItemStorageABC(ABC, Generic[T]):
|
||||||
_on_changed_callbacks: list[Callable[[T], None]]
|
_on_changed_callbacks: list[Callable[[T], None]]
|
||||||
@@ -50,7 +48,9 @@ class ItemStorageABC(ABC, Generic[T]):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
def search(
|
||||||
|
self, query: str, page: int = 0, per_page: int = 10
|
||||||
|
) -> PaginatedResults[T]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
def on_changed(self, on_changed: Callable[[T], None]) -> None:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from typing import Dict, Union, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class LatentsStorageBase(ABC):
|
class LatentsStorageBase(ABC):
|
||||||
"""Responsible for storing and retrieving latents."""
|
"""Responsible for storing and retrieving latents."""
|
||||||
|
|
||||||
@@ -89,5 +88,7 @@ class DiskLatentsStorage(LatentsStorageBase):
|
|||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
latent_path.unlink()
|
latent_path.unlink()
|
||||||
|
|
||||||
|
|
||||||
def get_path(self, name: str) -> Path:
|
def get_path(self, name: str) -> Path:
|
||||||
return self.__output_folder / name
|
return self.__output_folder / name
|
||||||
|
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False,
|
clobber: bool = False
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
@@ -169,20 +169,21 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def rename_model(
|
def rename_model(self,
|
||||||
self,
|
model_name: str,
|
||||||
model_name: str,
|
base_model: BaseModelType,
|
||||||
base_model: BaseModelType,
|
model_type: ModelType,
|
||||||
model_type: ModelType,
|
new_name: str,
|
||||||
new_name: str,
|
):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Rename the indicated model.
|
Rename the indicated model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_checkpoint_configs(self) -> List[Path]:
|
def list_checkpoint_configs(
|
||||||
|
self
|
||||||
|
)->List[Path]:
|
||||||
"""
|
"""
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
"""
|
"""
|
||||||
@@ -193,7 +194,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@@ -210,12 +211,11 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def heuristic_import(
|
def heuristic_import(self,
|
||||||
self,
|
items_to_import: set[str],
|
||||||
items_to_import: set[str],
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
)->dict[str, AddModelResult]:
|
||||||
) -> dict[str, AddModelResult]:
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
@@ -230,23 +230,19 @@ class ModelManagerServiceBase(ABC):
|
|||||||
The result is a set of successfully installed models. Each element
|
The result is a set of successfully installed models. Each element
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
that model.
|
that model.
|
||||||
"""
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def merge_models(
|
def merge_models(
|
||||||
self,
|
self,
|
||||||
model_names: List[str] = Field(
|
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||||
),
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
alpha: Optional[float] = 0.5,
|
||||||
default=None, description="Base model shared by all models to be merged"
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
),
|
force: Optional[bool] = False,
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
merge_dest_directory: Optional[Path] = None
|
||||||
alpha: Optional[float] = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: Optional[bool] = False,
|
|
||||||
merge_dest_directory: Optional[Path] = None,
|
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
@@ -260,7 +256,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def search_for_models(self, directory: Path) -> List[Path]:
|
def search_for_models(self, directory: Path)->List[Path]:
|
||||||
"""
|
"""
|
||||||
Return list of all models found in the designated directory.
|
Return list of all models found in the designated directory.
|
||||||
"""
|
"""
|
||||||
@@ -284,11 +280,9 @@ class ModelManagerServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# simple implementation
|
# simple implementation
|
||||||
class ModelManagerService(ModelManagerServiceBase):
|
class ModelManagerService(ModelManagerServiceBase):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
"""Responsible for managing models on disk and in memory"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: InvokeAIAppConfig,
|
config: InvokeAIAppConfig,
|
||||||
@@ -305,16 +299,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
else:
|
else:
|
||||||
config_file = config.root_dir / "configs/models.yaml"
|
config_file = config.root_dir / "configs/models.yaml"
|
||||||
|
|
||||||
logger.debug(f"Config file={config_file}")
|
logger.debug(f'Config file={config_file}')
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
|
device_name = torch.cuda.get_device_name() if device==torch.device('cuda') else ''
|
||||||
logger.info(f"GPU device = {device} {device_name}")
|
logger.info(f'GPU device = {device} {device_name}')
|
||||||
|
|
||||||
precision = config.precision
|
precision = config.precision
|
||||||
if precision == "auto":
|
if precision == "auto":
|
||||||
precision = choose_precision(device)
|
precision = choose_precision(device)
|
||||||
dtype = torch.float32 if precision == "float32" else torch.float16
|
dtype = torch.float32 if precision == 'float32' else torch.float16
|
||||||
|
|
||||||
# this is transitional backward compatibility
|
# this is transitional backward compatibility
|
||||||
# support for the deprecated `max_loaded_models`
|
# support for the deprecated `max_loaded_models`
|
||||||
@@ -322,7 +316,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
# cache size is set to 2.5 GB times
|
# cache size is set to 2.5 GB times
|
||||||
# the number of max_loaded_models. Otherwise
|
# the number of max_loaded_models. Otherwise
|
||||||
# use new `max_cache_size` config setting
|
# use new `max_cache_size` config setting
|
||||||
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
|
max_cache_size = config.max_cache_size \
|
||||||
|
if hasattr(config,'max_cache_size') \
|
||||||
|
else config.max_loaded_models * 2.5
|
||||||
|
|
||||||
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
|
||||||
|
|
||||||
@@ -336,7 +332,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
sequential_offload=sequential_offload,
|
sequential_offload=sequential_offload,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
logger.info("Model manager service initialized")
|
logger.info('Model manager service initialized')
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
@@ -375,7 +371,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info,
|
model_info=model_info
|
||||||
)
|
)
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
@@ -409,7 +405,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
return self.mgr.model_names()
|
return self.mgr.model_names()
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
|
self,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
model_type: Optional[ModelType] = None
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Return a list of models.
|
Return a list of models.
|
||||||
@@ -420,7 +418,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
Return information about the model using the same format as list_models()
|
Return information about the model using the same format as list_models()
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
|
return self.mgr.list_model(model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type)
|
||||||
|
|
||||||
def add_model(
|
def add_model(
|
||||||
self,
|
self,
|
||||||
@@ -429,7 +429,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
model_attributes: dict,
|
model_attributes: dict,
|
||||||
clobber: bool = False,
|
clobber: bool = False,
|
||||||
) -> None:
|
)->None:
|
||||||
"""
|
"""
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@@ -437,7 +437,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f"add/update model {model_name}")
|
self.logger.debug(f'add/update model {model_name}')
|
||||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||||
|
|
||||||
def update_model(
|
def update_model(
|
||||||
@@ -454,7 +454,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f"update model {model_name}")
|
self.logger.debug(f'update model {model_name}')
|
||||||
if not self.model_exists(model_name, base_model, model_type):
|
if not self.model_exists(model_name, base_model, model_type):
|
||||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
raise ModelNotFoundException(f"Unknown model {model_name}")
|
||||||
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
|
||||||
@@ -470,7 +470,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well.
|
as well.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f"delete model {model_name}")
|
self.logger.debug(f'delete model {model_name}')
|
||||||
self.mgr.del_model(model_name, base_model, model_type)
|
self.mgr.del_model(model_name, base_model, model_type)
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
@@ -478,10 +478,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
convert_dest_directory: Optional[Path] = Field(
|
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||||
default=None, description="Optional directory location for merged model"
|
|
||||||
),
|
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
@@ -496,10 +494,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
directory already in place.
|
directory already in place.
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f"convert model {model_name}")
|
self.logger.debug(f'convert model {model_name}')
|
||||||
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
|
||||||
|
|
||||||
def commit(self, conf_file: Optional[Path] = None):
|
def commit(self, conf_file: Optional[Path]=None):
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
If no conf_file is provided, then replaces the
|
If no conf_file is provided, then replaces the
|
||||||
@@ -526,7 +524,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
model_info=model_info,
|
model_info=model_info
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_started(
|
context.services.events.emit_model_load_started(
|
||||||
@@ -537,16 +535,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel=submodel,
|
submodel=submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logger(self):
|
def logger(self):
|
||||||
return self.mgr.logger
|
return self.mgr.logger
|
||||||
|
|
||||||
def heuristic_import(
|
def heuristic_import(self,
|
||||||
self,
|
items_to_import: set[str],
|
||||||
items_to_import: set[str],
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
)->dict[str, AddModelResult]:
|
||||||
) -> dict[str, AddModelResult]:
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
@@ -561,24 +559,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
The result is a set of successfully installed models. Each element
|
The result is a set of successfully installed models. Each element
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
that model.
|
that model.
|
||||||
"""
|
'''
|
||||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||||
|
|
||||||
def merge_models(
|
def merge_models(
|
||||||
self,
|
self,
|
||||||
model_names: List[str] = Field(
|
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
|
||||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
|
||||||
),
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
alpha: Optional[float] = 0.5,
|
||||||
default=None, description="Base model shared by all models to be merged"
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
),
|
force: Optional[bool] = False,
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
|
||||||
alpha: Optional[float] = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: Optional[bool] = False,
|
|
||||||
merge_dest_directory: Optional[Path] = Field(
|
|
||||||
default=None, description="Optional directory location for merged model"
|
|
||||||
),
|
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
"""
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
@@ -592,19 +584,19 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
merger = ModelMerger(self.mgr)
|
merger = ModelMerger(self.mgr)
|
||||||
try:
|
try:
|
||||||
result = merger.merge_diffusion_models_and_save(
|
result = merger.merge_diffusion_models_and_save(
|
||||||
model_names=model_names,
|
model_names = model_names,
|
||||||
base_model=base_model,
|
base_model = base_model,
|
||||||
merged_model_name=merged_model_name,
|
merged_model_name = merged_model_name,
|
||||||
alpha=alpha,
|
alpha = alpha,
|
||||||
interp=interp,
|
interp = interp,
|
||||||
force=force,
|
force = force,
|
||||||
merge_dest_directory=merge_dest_directory,
|
merge_dest_directory=merge_dest_directory,
|
||||||
)
|
)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
raise ValueError(e)
|
raise ValueError(e)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def search_for_models(self, directory: Path) -> List[Path]:
|
def search_for_models(self, directory: Path)->List[Path]:
|
||||||
"""
|
"""
|
||||||
Return list of all models found in the designated directory.
|
Return list of all models found in the designated directory.
|
||||||
"""
|
"""
|
||||||
@@ -619,23 +611,22 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
return self.mgr.sync_to_config()
|
return self.mgr.sync_to_config()
|
||||||
|
|
||||||
def list_checkpoint_configs(self) -> List[Path]:
|
def list_checkpoint_configs(self)->List[Path]:
|
||||||
"""
|
"""
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
"""
|
"""
|
||||||
config = self.mgr.app_config
|
config = self.mgr.app_config
|
||||||
conf_path = config.legacy_conf_path
|
conf_path = config.legacy_conf_path
|
||||||
root_path = config.root_path
|
root_path = config.root_path
|
||||||
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
|
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
|
||||||
|
|
||||||
def rename_model(
|
def rename_model(self,
|
||||||
self,
|
model_name: str,
|
||||||
model_name: str,
|
base_model: BaseModelType,
|
||||||
base_model: BaseModelType,
|
model_type: ModelType,
|
||||||
model_type: ModelType,
|
new_name: str = None,
|
||||||
new_name: str = None,
|
new_base: BaseModelType = None,
|
||||||
new_base: BaseModelType = None,
|
):
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Rename the indicated model. Can provide a new name and/or a new base.
|
Rename the indicated model. Can provide a new name and/or a new base.
|
||||||
:param model_name: Current name of the model
|
:param model_name: Current name of the model
|
||||||
@@ -644,10 +635,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
:param new_name: New name for the model
|
:param new_name: New name for the model
|
||||||
:param new_base: New base for the model
|
:param new_base: New base for the model
|
||||||
"""
|
"""
|
||||||
self.mgr.rename_model(
|
self.mgr.rename_model(base_model = base_model,
|
||||||
base_model=base_model,
|
model_type = model_type,
|
||||||
model_type=model_type,
|
model_name = model_name,
|
||||||
model_name=model_name,
|
new_name = new_name,
|
||||||
new_name=new_name,
|
new_base = new_base,
|
||||||
new_base=new_base,
|
)
|
||||||
)
|
|
||||||
|
|||||||
@@ -11,20 +11,30 @@ class BoardRecord(BaseModel):
|
|||||||
"""The unique ID of the board."""
|
"""The unique ID of the board."""
|
||||||
board_name: str = Field(description="The name of the board.")
|
board_name: str = Field(description="The name of the board.")
|
||||||
"""The name of the board."""
|
"""The name of the board."""
|
||||||
created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
|
created_at: Union[datetime, str] = Field(
|
||||||
|
description="The created timestamp of the board."
|
||||||
|
)
|
||||||
"""The created timestamp of the image."""
|
"""The created timestamp of the image."""
|
||||||
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
|
updated_at: Union[datetime, str] = Field(
|
||||||
|
description="The updated timestamp of the board."
|
||||||
|
)
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
|
deleted_at: Union[datetime, str, None] = Field(
|
||||||
|
description="The deleted timestamp of the board."
|
||||||
|
)
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
|
cover_image_name: Optional[str] = Field(
|
||||||
|
description="The name of the cover image of the board."
|
||||||
|
)
|
||||||
"""The name of the cover image of the board."""
|
"""The name of the cover image of the board."""
|
||||||
|
|
||||||
|
|
||||||
class BoardDTO(BoardRecord):
|
class BoardDTO(BoardRecord):
|
||||||
"""Deserialized board record with cover image URL and image count."""
|
"""Deserialized board record with cover image URL and image count."""
|
||||||
|
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
cover_image_name: Optional[str] = Field(
|
||||||
|
description="The name of the board's cover image."
|
||||||
|
)
|
||||||
"""The URL of the thumbnail of the most recent image in the board."""
|
"""The URL of the thumbnail of the most recent image in the board."""
|
||||||
image_count: int = Field(description="The number of images in the board.")
|
image_count: int = Field(description="The number of images in the board.")
|
||||||
"""The number of images in the board."""
|
"""The number of images in the board."""
|
||||||
|
|||||||
@@ -20,11 +20,17 @@ class ImageRecord(BaseModel):
|
|||||||
"""The actual width of the image in px. This may be different from the width in metadata."""
|
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||||
height: int = Field(description="The height of the image in px.")
|
height: int = Field(description="The height of the image in px.")
|
||||||
"""The actual height of the image in px. This may be different from the height in metadata."""
|
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||||
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the image.")
|
created_at: Union[datetime.datetime, str] = Field(
|
||||||
|
description="The created timestamp of the image."
|
||||||
|
)
|
||||||
"""The created timestamp of the image."""
|
"""The created timestamp of the image."""
|
||||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
|
updated_at: Union[datetime.datetime, str] = Field(
|
||||||
|
description="The updated timestamp of the image."
|
||||||
|
)
|
||||||
"""The updated timestamp of the image."""
|
"""The updated timestamp of the image."""
|
||||||
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
|
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||||
|
description="The deleted timestamp of the image."
|
||||||
|
)
|
||||||
"""The deleted timestamp of the image."""
|
"""The deleted timestamp of the image."""
|
||||||
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
|
||||||
"""Whether this is an intermediate image."""
|
"""Whether this is an intermediate image."""
|
||||||
@@ -49,14 +55,18 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
|||||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||||
"""
|
"""
|
||||||
|
|
||||||
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
|
image_category: Optional[ImageCategory] = Field(
|
||||||
|
description="The image's new category."
|
||||||
|
)
|
||||||
"""The image's new category."""
|
"""The image's new category."""
|
||||||
session_id: Optional[StrictStr] = Field(
|
session_id: Optional[StrictStr] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The image's new session ID.",
|
description="The image's new session ID.",
|
||||||
)
|
)
|
||||||
"""The image's new session ID."""
|
"""The image's new session ID."""
|
||||||
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
|
is_intermediate: Optional[StrictBool] = Field(
|
||||||
|
default=None, description="The image's new `is_intermediate` flag."
|
||||||
|
)
|
||||||
"""The image's new `is_intermediate` flag."""
|
"""The image's new `is_intermediate` flag."""
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +84,9 @@ class ImageUrlsDTO(BaseModel):
|
|||||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
"""Deserialized image record, enriched for the frontend."""
|
"""Deserialized image record, enriched for the frontend."""
|
||||||
|
|
||||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
board_id: Optional[str] = Field(
|
||||||
|
description="The id of the board the image belongs to, if one exists."
|
||||||
|
)
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -98,8 +110,12 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
|||||||
|
|
||||||
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
|
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
|
||||||
image_name = image_dict.get("image_name", "unknown")
|
image_name = image_dict.get("image_name", "unknown")
|
||||||
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
|
image_origin = ResourceOrigin(
|
||||||
image_category = ImageCategory(image_dict.get("image_category", ImageCategory.GENERAL.value))
|
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||||
|
)
|
||||||
|
image_category = ImageCategory(
|
||||||
|
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||||
|
)
|
||||||
width = image_dict.get("width", 0)
|
width = image_dict.get("width", 0)
|
||||||
height = image_dict.get("height", 0)
|
height = image_dict.get("height", 0)
|
||||||
session_id = image_dict.get("session_id", None)
|
session_id = image_dict.get("session_id", None)
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from .invoker import InvocationProcessorABC, Invoker
|
|||||||
from ..models.exceptions import CanceledException
|
from ..models.exceptions import CanceledException
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
__invoker_thread: Thread
|
__invoker_thread: Thread
|
||||||
__stop_event: Event
|
__stop_event: Event
|
||||||
@@ -26,7 +24,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
target=self.__process,
|
target=self.__process,
|
||||||
kwargs=dict(stop_event=self.__stop_event),
|
kwargs=dict(stop_event=self.__stop_event),
|
||||||
)
|
)
|
||||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
self.__invoker_thread.daemon = (
|
||||||
|
True # TODO: make async and do not use threads
|
||||||
|
)
|
||||||
self.__invoker_thread.start()
|
self.__invoker_thread.start()
|
||||||
|
|
||||||
def stop(self, *args, **kwargs) -> None:
|
def stop(self, *args, **kwargs) -> None:
|
||||||
@@ -47,8 +47,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
graph_execution_state = (
|
||||||
queue_item.graph_execution_state_id
|
self.__invoker.services.graph_execution_manager.get(
|
||||||
|
queue_item.graph_execution_state_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||||
@@ -60,7 +62,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
invocation = graph_execution_state.execution_graph.get_node(
|
||||||
|
queue_item.invocation_id
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||||
@@ -78,7 +82,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
self.__invoker.services.events.emit_invocation_started(
|
self.__invoker.services.events.emit_invocation_started(
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
graph_execution_state_id=graph_execution_state.id,
|
||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invoke
|
# Invoke
|
||||||
@@ -91,14 +95,18 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
if self.__invoker.services.queue.is_canceled(
|
||||||
|
graph_execution_state.id
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
graph_execution_state.complete(invocation.id, outputs)
|
graph_execution_state.complete(invocation.id, outputs)
|
||||||
|
|
||||||
# Save the state changes
|
# Save the state changes
|
||||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
self.__invoker.services.graph_execution_manager.set(
|
||||||
|
graph_execution_state
|
||||||
|
)
|
||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self.__invoker.services.events.emit_invocation_complete(
|
self.__invoker.services.events.emit_invocation_complete(
|
||||||
@@ -122,7 +130,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
graph_execution_state.set_node_error(invocation.id, error)
|
graph_execution_state.set_node_error(invocation.id, error)
|
||||||
|
|
||||||
# Save the state changes
|
# Save the state changes
|
||||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
self.__invoker.services.graph_execution_manager.set(
|
||||||
|
graph_execution_state
|
||||||
|
)
|
||||||
|
|
||||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||||
# Send error event
|
# Send error event
|
||||||
@@ -137,7 +147,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
if self.__invoker.services.queue.is_canceled(
|
||||||
|
graph_execution_state.id
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Queue any further commands if invoking all
|
# Queue any further commands if invoking all
|
||||||
@@ -152,10 +164,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
node=invocation.dict(),
|
node=invocation.dict(),
|
||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=traceback.format_exc(),
|
error=traceback.format_exc()
|
||||||
)
|
)
|
||||||
elif is_complete:
|
elif is_complete:
|
||||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
self.__invoker.services.events.emit_graph_execution_complete(
|
||||||
|
graph_execution_state.id
|
||||||
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||||
|
|||||||
@@ -66,7 +66,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def get(self, id: str) -> Optional[T]:
|
def get(self, id: str) -> Optional[T]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
self._cursor.execute(
|
||||||
|
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
|
)
|
||||||
result = self._cursor.fetchone()
|
result = self._cursor.fetchone()
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
@@ -79,7 +81,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def get_raw(self, id: str) -> Optional[str]:
|
def get_raw(self, id: str) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
self._cursor.execute(
|
||||||
|
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
|
)
|
||||||
result = self._cursor.fetchone()
|
result = self._cursor.fetchone()
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
@@ -92,7 +96,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
def delete(self, id: str):
|
def delete(self, id: str):
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
|
self._cursor.execute(
|
||||||
|
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
|
||||||
|
)
|
||||||
self._conn.commit()
|
self._conn.commit()
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
@@ -116,9 +122,13 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
return PaginatedResults[T](
|
||||||
|
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||||
|
)
|
||||||
|
|
||||||
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
|
def search(
|
||||||
|
self, query: str, page: int = 0, per_page: int = 10
|
||||||
|
) -> PaginatedResults[T]:
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
@@ -139,4 +149,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
|
|
||||||
pageCount = int(count / per_page) + 1
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
|
return PaginatedResults[T](
|
||||||
|
items=items, page=page, pages=pageCount, per_page=per_page, total=count
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,8 +17,16 @@ from controlnet_aux.util import HWC3, resize_image
|
|||||||
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
|
||||||
|
|
||||||
lvmin_kernels_raw = [
|
lvmin_kernels_raw = [
|
||||||
np.array([[-1, -1, -1], [0, 1, 0], [1, 1, 1]], dtype=np.int32),
|
np.array([
|
||||||
np.array([[0, -1, -1], [1, 1, -1], [0, 1, 0]], dtype=np.int32),
|
[-1, -1, -1],
|
||||||
|
[0, 1, 0],
|
||||||
|
[1, 1, 1]
|
||||||
|
], dtype=np.int32),
|
||||||
|
np.array([
|
||||||
|
[0, -1, -1],
|
||||||
|
[1, 1, -1],
|
||||||
|
[0, 1, 0]
|
||||||
|
], dtype=np.int32)
|
||||||
]
|
]
|
||||||
|
|
||||||
lvmin_kernels = []
|
lvmin_kernels = []
|
||||||
@@ -28,8 +36,16 @@ lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
|
|||||||
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
|
||||||
|
|
||||||
lvmin_prunings_raw = [
|
lvmin_prunings_raw = [
|
||||||
np.array([[-1, -1, -1], [-1, 1, -1], [0, 0, -1]], dtype=np.int32),
|
np.array([
|
||||||
np.array([[-1, -1, -1], [-1, 1, -1], [-1, 0, 0]], dtype=np.int32),
|
[-1, -1, -1],
|
||||||
|
[-1, 1, -1],
|
||||||
|
[0, 0, -1]
|
||||||
|
], dtype=np.int32),
|
||||||
|
np.array([
|
||||||
|
[-1, -1, -1],
|
||||||
|
[-1, 1, -1],
|
||||||
|
[-1, 0, 0]
|
||||||
|
], dtype=np.int32)
|
||||||
]
|
]
|
||||||
|
|
||||||
lvmin_prunings = []
|
lvmin_prunings = []
|
||||||
@@ -83,10 +99,10 @@ def nake_nms(x):
|
|||||||
################################################################################
|
################################################################################
|
||||||
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
|
||||||
def pixel_perfect_resolution(
|
def pixel_perfect_resolution(
|
||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
target_H: int,
|
target_H: int,
|
||||||
target_W: int,
|
target_W: int,
|
||||||
resize_mode: str,
|
resize_mode: str,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
|
||||||
@@ -119,7 +135,7 @@ def pixel_perfect_resolution(
|
|||||||
|
|
||||||
if resize_mode == "fill_resize":
|
if resize_mode == "fill_resize":
|
||||||
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
estimation = min(k0, k1) * float(min(raw_H, raw_W))
|
||||||
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
|
||||||
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
estimation = max(k0, k1) * float(min(raw_H, raw_W))
|
||||||
|
|
||||||
# print(f"Pixel Perfect Computation:")
|
# print(f"Pixel Perfect Computation:")
|
||||||
@@ -138,7 +154,13 @@ def pixel_perfect_resolution(
|
|||||||
# modified for InvokeAI
|
# modified for InvokeAI
|
||||||
###########################################################################
|
###########################################################################
|
||||||
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
# def detectmap_proc(detected_map, module, resize_mode, h, w):
|
||||||
def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")):
|
def np_img_resize(
|
||||||
|
np_img: np.ndarray,
|
||||||
|
resize_mode: str,
|
||||||
|
h: int,
|
||||||
|
w: int,
|
||||||
|
device: torch.device = torch.device('cpu')
|
||||||
|
):
|
||||||
# if 'inpaint' in module:
|
# if 'inpaint' in module:
|
||||||
# np_img = np_img.astype(np.float32)
|
# np_img = np_img.astype(np.float32)
|
||||||
# else:
|
# else:
|
||||||
@@ -162,14 +184,15 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
|||||||
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
# below is very boring but do not change these. If you change these Apple or Mac may fail.
|
||||||
y = torch.from_numpy(y)
|
y = torch.from_numpy(y)
|
||||||
y = y.float() / 255.0
|
y = y.float() / 255.0
|
||||||
y = rearrange(y, "h w c -> 1 c h w")
|
y = rearrange(y, 'h w c -> 1 c h w')
|
||||||
y = y.clone()
|
y = y.clone()
|
||||||
# y = y.to(devices.get_device_for("controlnet"))
|
# y = y.to(devices.get_device_for("controlnet"))
|
||||||
y = y.to(device)
|
y = y.to(device)
|
||||||
y = y.clone()
|
y = y.clone()
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def high_quality_resize(x: np.ndarray, size):
|
def high_quality_resize(x: np.ndarray,
|
||||||
|
size):
|
||||||
# Written by lvmin
|
# Written by lvmin
|
||||||
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
|
||||||
inpaint_mask = None
|
inpaint_mask = None
|
||||||
@@ -221,7 +244,7 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
# if resize_mode == external_code.ResizeMode.RESIZE:
|
# if resize_mode == external_code.ResizeMode.RESIZE:
|
||||||
if resize_mode == "just_resize": # RESIZE
|
if resize_mode == "just_resize": # RESIZE
|
||||||
np_img = high_quality_resize(np_img, (w, h))
|
np_img = high_quality_resize(np_img, (w, h))
|
||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
@@ -247,21 +270,20 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
|||||||
new_h, new_w, _ = np_img.shape
|
new_h, new_w, _ = np_img.shape
|
||||||
pad_h = max(0, (h - new_h) // 2)
|
pad_h = max(0, (h - new_h) // 2)
|
||||||
pad_w = max(0, (w - new_w) // 2)
|
pad_w = max(0, (w - new_w) // 2)
|
||||||
high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img
|
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
|
||||||
np_img = high_quality_background
|
np_img = high_quality_background
|
||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
else: # resize_mode == "crop_resize" (INNER_FIT)
|
else: # resize_mode == "crop_resize" (INNER_FIT)
|
||||||
k = max(k0, k1)
|
k = max(k0, k1)
|
||||||
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
|
||||||
new_h, new_w, _ = np_img.shape
|
new_h, new_w, _ = np_img.shape
|
||||||
pad_h = max(0, (new_h - h) // 2)
|
pad_h = max(0, (new_h - h) // 2)
|
||||||
pad_w = max(0, (new_w - w) // 2)
|
pad_w = max(0, (new_w - w) // 2)
|
||||||
np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w]
|
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
|
||||||
np_img = safe_numpy(np_img)
|
np_img = safe_numpy(np_img)
|
||||||
return get_pytorch_control(np_img), np_img
|
return get_pytorch_control(np_img), np_img
|
||||||
|
|
||||||
|
|
||||||
def prepare_control_image(
|
def prepare_control_image(
|
||||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||||
@@ -279,17 +301,15 @@ def prepare_control_image(
|
|||||||
resize_mode="just_resize_simple",
|
resize_mode="just_resize_simple",
|
||||||
):
|
):
|
||||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||||
if (
|
if (resize_mode == "just_resize_simple" or
|
||||||
resize_mode == "just_resize_simple"
|
resize_mode == "crop_resize_simple" or
|
||||||
or resize_mode == "crop_resize_simple"
|
resize_mode == "fill_resize_simple"):
|
||||||
or resize_mode == "fill_resize_simple"
|
|
||||||
):
|
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
if resize_mode == "just_resize_simple":
|
if (resize_mode == "just_resize_simple"):
|
||||||
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
elif resize_mode == "crop_resize_simple": # not yet implemented
|
elif (resize_mode == "crop_resize_simple"): # not yet implemented
|
||||||
pass
|
pass
|
||||||
elif resize_mode == "fill_resize_simple": # not yet implemented
|
elif (resize_mode == "fill_resize_simple"): # not yet implemented
|
||||||
pass
|
pass
|
||||||
nimage = np.array(image)
|
nimage = np.array(image)
|
||||||
nimage = nimage[None, :]
|
nimage = nimage[None, :]
|
||||||
@@ -300,7 +320,7 @@ def prepare_control_image(
|
|||||||
timage = torch.from_numpy(nimage)
|
timage = torch.from_numpy(nimage)
|
||||||
|
|
||||||
# use fancy lvmin controlnet resizing
|
# use fancy lvmin controlnet resizing
|
||||||
elif resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize":
|
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
|
||||||
nimage = np.array(image)
|
nimage = np.array(image)
|
||||||
timage, nimage = np_img_resize(
|
timage, nimage = np_img_resize(
|
||||||
np_img=nimage,
|
np_img=nimage,
|
||||||
@@ -316,7 +336,7 @@ def prepare_control_image(
|
|||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
timage = timage.to(device=device, dtype=dtype)
|
timage = timage.to(device=device, dtype=dtype)
|
||||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
|
||||||
if do_classifier_free_guidance and not cfg_injection:
|
if do_classifier_free_guidance and not cfg_injection:
|
||||||
timage = torch.cat([timage] * 2)
|
timage = torch.cat([timage] * 2)
|
||||||
return timage
|
return timage
|
||||||
|
|||||||
@@ -9,16 +9,19 @@ from ...backend.stable_diffusion import PipelineIntermediateState
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
|
||||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||||
|
|
||||||
if smooth_matrix is not None:
|
if smooth_matrix is not None:
|
||||||
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
|
||||||
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
|
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
|
||||||
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
|
||||||
|
|
||||||
latents_ubyte = (
|
latents_ubyte = (
|
||||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
((latent_image + 1) / 2)
|
||||||
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
.byte()
|
||||||
).cpu()
|
).cpu()
|
||||||
|
|
||||||
return Image.fromarray(latents_ubyte.numpy())
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
@@ -89,7 +92,6 @@ def stable_diffusion_step_callback(
|
|||||||
total_steps=node["steps"],
|
total_steps=node["steps"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def stable_diffusion_xl_step_callback(
|
def stable_diffusion_xl_step_callback(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
node: dict,
|
node: dict,
|
||||||
@@ -104,9 +106,9 @@ def stable_diffusion_xl_step_callback(
|
|||||||
sdxl_latent_rgb_factors = torch.tensor(
|
sdxl_latent_rgb_factors = torch.tensor(
|
||||||
[
|
[
|
||||||
# R G B
|
# R G B
|
||||||
[0.3816, 0.4930, 0.5320],
|
[ 0.3816, 0.4930, 0.5320],
|
||||||
[-0.3753, 0.1631, 0.1739],
|
[-0.3753, 0.1631, 0.1739],
|
||||||
[0.1770, 0.3588, -0.2048],
|
[ 0.1770, 0.3588, -0.2048],
|
||||||
[-0.4350, -0.2644, -0.4289],
|
[-0.4350, -0.2644, -0.4289],
|
||||||
],
|
],
|
||||||
dtype=sample.dtype,
|
dtype=sample.dtype,
|
||||||
@@ -115,9 +117,9 @@ def stable_diffusion_xl_step_callback(
|
|||||||
|
|
||||||
sdxl_smooth_matrix = torch.tensor(
|
sdxl_smooth_matrix = torch.tensor(
|
||||||
[
|
[
|
||||||
# [ 0.0478, 0.1285, 0.0478],
|
#[ 0.0478, 0.1285, 0.0478],
|
||||||
# [ 0.1285, 0.2948, 0.1285],
|
#[ 0.1285, 0.2948, 0.1285],
|
||||||
# [ 0.0478, 0.1285, 0.0478],
|
#[ 0.0478, 0.1285, 0.0478],
|
||||||
[0.0358, 0.0964, 0.0358],
|
[0.0358, 0.0964, 0.0358],
|
||||||
[0.0964, 0.4711, 0.0964],
|
[0.0964, 0.4711, 0.0964],
|
||||||
[0.0358, 0.0964, 0.0358],
|
[0.0358, 0.0964, 0.0358],
|
||||||
|
|||||||
@@ -1,6 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .generator import InvokeAIGeneratorBasicParams, InvokeAIGenerator, InvokeAIGeneratorOutput, Img2Img, Inpaint
|
from .generator import (
|
||||||
from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo
|
InvokeAIGeneratorBasicParams,
|
||||||
|
InvokeAIGenerator,
|
||||||
|
InvokeAIGeneratorOutput,
|
||||||
|
Img2Img,
|
||||||
|
Inpaint
|
||||||
|
)
|
||||||
|
from .model_management import (
|
||||||
|
ModelManager, ModelCache, BaseModelType,
|
||||||
|
ModelType, SubModelType, ModelInfo
|
||||||
|
)
|
||||||
from .model_management.models import SilenceWarnings
|
from .model_management.models import SilenceWarnings
|
||||||
|
|||||||
@@ -33,66 +33,61 @@ from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
|||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorBasicParams:
|
class InvokeAIGeneratorBasicParams:
|
||||||
seed: Optional[int] = None
|
seed: Optional[int]=None
|
||||||
width: int = 512
|
width: int=512
|
||||||
height: int = 512
|
height: int=512
|
||||||
cfg_scale: float = 7.5
|
cfg_scale: float=7.5
|
||||||
steps: int = 20
|
steps: int=20
|
||||||
ddim_eta: float = 0.0
|
ddim_eta: float=0.0
|
||||||
scheduler: str = "ddim"
|
scheduler: str='ddim'
|
||||||
precision: str = "float16"
|
precision: str='float16'
|
||||||
perlin: float = 0.0
|
perlin: float=0.0
|
||||||
threshold: float = 0.0
|
threshold: float=0.0
|
||||||
seamless: bool = False
|
seamless: bool=False
|
||||||
seamless_axes: List[str] = field(default_factory=lambda: ["x", "y"])
|
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||||
h_symmetry_time_pct: Optional[float] = None
|
h_symmetry_time_pct: Optional[float]=None
|
||||||
v_symmetry_time_pct: Optional[float] = None
|
v_symmetry_time_pct: Optional[float]=None
|
||||||
variation_amount: float = 0.0
|
variation_amount: float = 0.0
|
||||||
with_variations: list = field(default_factory=list)
|
with_variations: list=field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorOutput:
|
class InvokeAIGeneratorOutput:
|
||||||
"""
|
'''
|
||||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||||
operation, including the image, its seed, the model name used to generate the image
|
operation, including the image, its seed, the model name used to generate the image
|
||||||
and the model hash, as well as all the generate() parameters that went into
|
and the model hash, as well as all the generate() parameters that went into
|
||||||
generating the image (in .params, also available as attributes)
|
generating the image (in .params, also available as attributes)
|
||||||
"""
|
'''
|
||||||
|
|
||||||
image: Image.Image
|
image: Image.Image
|
||||||
seed: int
|
seed: int
|
||||||
model_hash: str
|
model_hash: str
|
||||||
attention_maps_images: List[Image.Image]
|
attention_maps_images: List[Image.Image]
|
||||||
params: Namespace
|
params: Namespace
|
||||||
|
|
||||||
|
|
||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
# old code that calls Generate will continue to work.
|
# old code that calls Generate will continue to work.
|
||||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
model_info: dict,
|
||||||
model_info: dict,
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
params: InvokeAIGeneratorBasicParams = InvokeAIGeneratorBasicParams(),
|
**kwargs,
|
||||||
**kwargs,
|
):
|
||||||
):
|
self.model_info=model_info
|
||||||
self.model_info = model_info
|
self.params=params
|
||||||
self.params = params
|
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
conditioning: tuple,
|
conditioning: tuple,
|
||||||
scheduler,
|
scheduler,
|
||||||
callback: Optional[Callable] = None,
|
callback: Optional[Callable]=None,
|
||||||
step_callback: Optional[Callable] = None,
|
step_callback: Optional[Callable]=None,
|
||||||
iterations: int = 1,
|
iterations: int=1,
|
||||||
**keyword_args,
|
**keyword_args,
|
||||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
"""
|
'''
|
||||||
Return an iterator across the indicated number of generations.
|
Return an iterator across the indicated number of generations.
|
||||||
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
Each time the iterator is called it will return an InvokeAIGeneratorOutput
|
||||||
object. Use like this:
|
object. Use like this:
|
||||||
@@ -112,7 +107,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
for o in outputs:
|
for o in outputs:
|
||||||
print(o.image, o.seed)
|
print(o.image, o.seed)
|
||||||
|
|
||||||
"""
|
'''
|
||||||
generator_args = dataclasses.asdict(self.params)
|
generator_args = dataclasses.asdict(self.params)
|
||||||
generator_args.update(keyword_args)
|
generator_args.update(keyword_args)
|
||||||
|
|
||||||
@@ -123,21 +118,22 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
gen_class = self._generator_class()
|
gen_class = self._generator_class()
|
||||||
generator = gen_class(model, self.params.precision, **self.kwargs)
|
generator = gen_class(model, self.params.precision, **self.kwargs)
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
generator.set_variation(
|
generator.set_variation(generator_args.get('seed'),
|
||||||
generator_args.get("seed"),
|
generator_args.get('variation_amount'),
|
||||||
generator_args.get("variation_amount"),
|
generator_args.get('with_variations')
|
||||||
generator_args.get("with_variations"),
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
for component in [model.unet, model.vae]:
|
for component in [model.unet, model.vae]:
|
||||||
configure_model_padding(
|
configure_model_padding(component,
|
||||||
component, generator_args.get("seamless", False), generator_args.get("seamless_axes")
|
generator_args.get('seamless',False),
|
||||||
)
|
generator_args.get('seamless_axes')
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
configure_model_padding(
|
configure_model_padding(model,
|
||||||
model, generator_args.get("seamless", False), generator_args.get("seamless_axes")
|
generator_args.get('seamless',False),
|
||||||
)
|
generator_args.get('seamless_axes')
|
||||||
|
)
|
||||||
|
|
||||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||||
for i in iteration_count:
|
for i in iteration_count:
|
||||||
@@ -151,66 +147,66 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
seed=results[0][1],
|
seed=results[0][1],
|
||||||
attention_maps_images=results[0][2],
|
attention_maps_images=results[0][2],
|
||||||
model_hash=model_hash,
|
model_hash = model_hash,
|
||||||
params=Namespace(model_name=model_name, **generator_args),
|
params=Namespace(model_name=model_name,**generator_args),
|
||||||
)
|
)
|
||||||
if callback:
|
if callback:
|
||||||
callback(output)
|
callback(output)
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def schedulers(self) -> List[str]:
|
def schedulers(self)->List[str]:
|
||||||
"""
|
'''
|
||||||
Return list of all the schedulers that we currently handle.
|
Return list of all the schedulers that we currently handle.
|
||||||
"""
|
'''
|
||||||
return list(SCHEDULER_MAP.keys())
|
return list(SCHEDULER_MAP.keys())
|
||||||
|
|
||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||||
return generator_class(model, self.params.precision)
|
return generator_class(model, self.params.precision)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_class(cls) -> Type[Generator]:
|
def _generator_class(cls)->Type[Generator]:
|
||||||
"""
|
'''
|
||||||
In derived classes return the name of the generator to apply.
|
In derived classes return the name of the generator to apply.
|
||||||
If you don't override will return the name of the derived
|
If you don't override will return the name of the derived
|
||||||
class, which nicely parallels the generator class names.
|
class, which nicely parallels the generator class names.
|
||||||
"""
|
'''
|
||||||
return Generator
|
return Generator
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
def generate(
|
def generate(self,
|
||||||
self, init_image: Union[Image.Image, torch.FloatTensor], strength: float = 0.75, **keyword_args
|
init_image: Union[Image.Image, torch.FloatTensor],
|
||||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
strength: float=0.75,
|
||||||
return super().generate(init_image=init_image, strength=strength, **keyword_args)
|
**keyword_args
|
||||||
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
|
return super().generate(init_image=init_image,
|
||||||
|
strength=strength,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_class(cls):
|
def _generator_class(cls):
|
||||||
from .img2img import Img2Img
|
from .img2img import Img2Img
|
||||||
|
|
||||||
return Img2Img
|
return Img2Img
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def generate(
|
def generate(self,
|
||||||
self,
|
mask_image: Union[Image.Image, torch.FloatTensor],
|
||||||
mask_image: Union[Image.Image, torch.FloatTensor],
|
# Seam settings - when 0, doesn't fill seam
|
||||||
# Seam settings - when 0, doesn't fill seam
|
seam_size: int = 96,
|
||||||
seam_size: int = 96,
|
seam_blur: int = 16,
|
||||||
seam_blur: int = 16,
|
seam_strength: float = 0.7,
|
||||||
seam_strength: float = 0.7,
|
seam_steps: int = 30,
|
||||||
seam_steps: int = 30,
|
tile_size: int = 32,
|
||||||
tile_size: int = 32,
|
inpaint_replace=False,
|
||||||
inpaint_replace=False,
|
infill_method=None,
|
||||||
infill_method=None,
|
inpaint_width=None,
|
||||||
inpaint_width=None,
|
inpaint_height=None,
|
||||||
inpaint_height=None,
|
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
**keyword_args
|
||||||
**keyword_args,
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
) -> Iterator[InvokeAIGeneratorOutput]:
|
|
||||||
return super().generate(
|
return super().generate(
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
seam_size=seam_size,
|
seam_size=seam_size,
|
||||||
@@ -223,16 +219,13 @@ class Inpaint(Img2Img):
|
|||||||
inpaint_width=inpaint_width,
|
inpaint_width=inpaint_width,
|
||||||
inpaint_height=inpaint_height,
|
inpaint_height=inpaint_height,
|
||||||
inpaint_fill=inpaint_fill,
|
inpaint_fill=inpaint_fill,
|
||||||
**keyword_args,
|
**keyword_args
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_class(cls):
|
def _generator_class(cls):
|
||||||
from .inpaint import Inpaint
|
from .inpaint import Inpaint
|
||||||
|
|
||||||
return Inpaint
|
return Inpaint
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
latent_channels: int
|
latent_channels: int
|
||||||
@@ -258,7 +251,9 @@ class Generator:
|
|||||||
Returns a function returning an image derived from the prompt and the initial image
|
Returns a function returning an image derived from the prompt and the initial image
|
||||||
Return value depends on the seed at the time you call it
|
Return value depends on the seed at the time you call it
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
|
raise NotImplementedError(
|
||||||
|
"image_iterator() must be implemented in a descendent class"
|
||||||
|
)
|
||||||
|
|
||||||
def set_variation(self, seed, variation_amount, with_variations):
|
def set_variation(self, seed, variation_amount, with_variations):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
@@ -285,7 +280,9 @@ class Generator:
|
|||||||
scope = nullcontext
|
scope = nullcontext
|
||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
attention_maps_images = []
|
attention_maps_images = []
|
||||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
attention_maps_callback = lambda saver: attention_maps_images.append(
|
||||||
|
saver.get_stacked_maps_image()
|
||||||
|
)
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
init_image=init_image,
|
init_image=init_image,
|
||||||
@@ -330,7 +327,11 @@ class Generator:
|
|||||||
results.append([image, seed, attention_maps_images])
|
results.append([image, seed, attention_maps_images])
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
attention_maps_image = None if len(attention_maps_images) == 0 else attention_maps_images[-1]
|
attention_maps_image = (
|
||||||
|
None
|
||||||
|
if len(attention_maps_images) == 0
|
||||||
|
else attention_maps_images[-1]
|
||||||
|
)
|
||||||
image_callback(
|
image_callback(
|
||||||
image,
|
image,
|
||||||
seed,
|
seed,
|
||||||
@@ -341,7 +342,9 @@ class Generator:
|
|||||||
seed = self.new_seed()
|
seed = self.new_seed()
|
||||||
|
|
||||||
# Free up memory from the last generation.
|
# Free up memory from the last generation.
|
||||||
clear_cuda_cache = kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
clear_cuda_cache = (
|
||||||
|
kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None
|
||||||
|
)
|
||||||
if clear_cuda_cache is not None:
|
if clear_cuda_cache is not None:
|
||||||
clear_cuda_cache()
|
clear_cuda_cache()
|
||||||
|
|
||||||
@@ -368,8 +371,14 @@ class Generator:
|
|||||||
|
|
||||||
# Get the original alpha channel of the mask if there is one.
|
# Get the original alpha channel of the mask if there is one.
|
||||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||||
pil_init_mask = init_mask.getchannel("A") if init_mask.mode == "RGBA" else init_mask.convert("L")
|
pil_init_mask = (
|
||||||
pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist
|
init_mask.getchannel("A")
|
||||||
|
if init_mask.mode == "RGBA"
|
||||||
|
else init_mask.convert("L")
|
||||||
|
)
|
||||||
|
pil_init_image = init_image.convert(
|
||||||
|
"RGBA"
|
||||||
|
) # Add an alpha channel if one doesn't exist
|
||||||
|
|
||||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||||
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8)
|
||||||
@@ -395,7 +404,10 @@ class Generator:
|
|||||||
np_matched_result[:, :, :] = (
|
np_matched_result[:, :, :] = (
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
(np_matched_result[:, :, :].astype(np.float32) - gen_means[None, None, :])
|
(
|
||||||
|
np_matched_result[:, :, :].astype(np.float32)
|
||||||
|
- gen_means[None, None, :]
|
||||||
|
)
|
||||||
/ gen_std[None, None, :]
|
/ gen_std[None, None, :]
|
||||||
)
|
)
|
||||||
* init_std[None, None, :]
|
* init_std[None, None, :]
|
||||||
@@ -421,7 +433,9 @@ class Generator:
|
|||||||
else:
|
else:
|
||||||
blurred_init_mask = pil_init_mask
|
blurred_init_mask = pil_init_mask
|
||||||
|
|
||||||
multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1])
|
multiplied_blurred_init_mask = ImageChops.multiply(
|
||||||
|
blurred_init_mask, self.pil_image.split()[-1]
|
||||||
|
)
|
||||||
|
|
||||||
# Paste original on color-corrected generation (using blurred mask)
|
# Paste original on color-corrected generation (using blurred mask)
|
||||||
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
|
||||||
@@ -447,7 +461,10 @@ class Generator:
|
|||||||
|
|
||||||
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||||
latents_ubyte = (
|
latents_ubyte = (
|
||||||
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
|
((latent_image + 1) / 2)
|
||||||
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
|
.mul(0xFF) # to 0..255
|
||||||
|
.byte()
|
||||||
).cpu()
|
).cpu()
|
||||||
|
|
||||||
return Image.fromarray(latents_ubyte.numpy())
|
return Image.fromarray(latents_ubyte.numpy())
|
||||||
@@ -477,7 +494,9 @@ class Generator:
|
|||||||
temp_height = int((height + 7) / 8) * 8
|
temp_height = int((height + 7) / 8) * 8
|
||||||
noise = torch.stack(
|
noise = torch.stack(
|
||||||
[
|
[
|
||||||
rand_perlin_2d((temp_height, temp_width), (8, 8), device=self.model.device).to(fixdevice)
|
rand_perlin_2d(
|
||||||
|
(temp_height, temp_width), (8, 8), device=self.model.device
|
||||||
|
).to(fixdevice)
|
||||||
for _ in range(input_channels)
|
for _ in range(input_channels)
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
@@ -554,6 +573,8 @@ class Generator:
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
perlin_noise = self.get_perlin_noise(
|
||||||
|
width // self.downsampling_factor, height // self.downsampling_factor
|
||||||
|
)
|
||||||
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
x = (1 - self.perlin) * x + self.perlin * perlin_noise
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -77,7 +77,10 @@ class Img2Img(Generator):
|
|||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
if (
|
||||||
|
pipeline_output.attention_map_saver is not None
|
||||||
|
and attention_maps_callback is not None
|
||||||
|
):
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
return pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
|
||||||
@@ -88,5 +91,7 @@ class Img2Img(Generator):
|
|||||||
x = torch.randn_like(like, device=device)
|
x = torch.randn_like(like, device=device)
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
shape = like.shape
|
shape = like.shape
|
||||||
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2])
|
x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
|
||||||
|
shape[3], shape[2]
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -68,11 +68,15 @@ class Inpaint(Img2Img):
|
|||||||
return im
|
return im
|
||||||
|
|
||||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
im_patched_np = PatchMatch.inpaint(
|
||||||
|
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
||||||
|
)
|
||||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||||
return im_patched
|
return im_patched
|
||||||
|
|
||||||
def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
def tile_fill_missing(
|
||||||
|
self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
|
||||||
|
) -> Image.Image:
|
||||||
# Only fill if there's an alpha layer
|
# Only fill if there's an alpha layer
|
||||||
if im.mode != "RGBA":
|
if im.mode != "RGBA":
|
||||||
return im
|
return im
|
||||||
@@ -123,11 +127,15 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
return si
|
return si
|
||||||
|
|
||||||
def mask_edge(self, mask: Image.Image, edge_size: int, edge_blur: int) -> Image.Image:
|
def mask_edge(
|
||||||
|
self, mask: Image.Image, edge_size: int, edge_blur: int
|
||||||
|
) -> Image.Image:
|
||||||
npimg = np.asarray(mask, dtype=np.uint8)
|
npimg = np.asarray(mask, dtype=np.uint8)
|
||||||
|
|
||||||
# Detect any partially transparent regions
|
# Detect any partially transparent regions
|
||||||
npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)))
|
npgradient = np.uint8(
|
||||||
|
255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))
|
||||||
|
)
|
||||||
|
|
||||||
# Detect hard edges
|
# Detect hard edges
|
||||||
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
npedge = cv2.Canny(npimg, threshold1=100, threshold2=200)
|
||||||
@@ -136,7 +144,9 @@ class Inpaint(Img2Img):
|
|||||||
npmask = npgradient + npedge
|
npmask = npgradient + npedge
|
||||||
|
|
||||||
# Expand
|
# Expand
|
||||||
npmask = cv2.dilate(npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2))
|
npmask = cv2.dilate(
|
||||||
|
npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2)
|
||||||
|
)
|
||||||
|
|
||||||
new_mask = Image.fromarray(npmask)
|
new_mask = Image.fromarray(npmask)
|
||||||
|
|
||||||
@@ -232,19 +242,25 @@ class Inpaint(Img2Img):
|
|||||||
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
|
if infill_method == "patchmatch" and PatchMatch.patchmatch_available():
|
||||||
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
init_filled = self.infill_patchmatch(self.pil_image.copy())
|
||||||
elif infill_method == "tile":
|
elif infill_method == "tile":
|
||||||
init_filled = self.tile_fill_missing(self.pil_image.copy(), seed=self.seed, tile_size=tile_size)
|
init_filled = self.tile_fill_missing(
|
||||||
|
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
|
||||||
|
)
|
||||||
elif infill_method == "solid":
|
elif infill_method == "solid":
|
||||||
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
|
||||||
init_filled = Image.alpha_composite(solid_bg, init_image)
|
init_filled = Image.alpha_composite(solid_bg, init_image)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Non-supported infill type {infill_method}", infill_method)
|
raise ValueError(
|
||||||
|
f"Non-supported infill type {infill_method}", infill_method
|
||||||
|
)
|
||||||
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
|
init_filled.paste(init_image, (0, 0), init_image.split()[-1])
|
||||||
|
|
||||||
# Resize if requested for inpainting
|
# Resize if requested for inpainting
|
||||||
if inpaint_width and inpaint_height:
|
if inpaint_width and inpaint_height:
|
||||||
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
init_filled = init_filled.resize((inpaint_width, inpaint_height))
|
||||||
|
|
||||||
debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging)
|
debug_image(
|
||||||
|
init_filled, "init_filled", debug_status=self.enable_image_debugging
|
||||||
|
)
|
||||||
|
|
||||||
# Create init tensor
|
# Create init tensor
|
||||||
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
|
||||||
@@ -273,7 +289,9 @@ class Inpaint(Img2Img):
|
|||||||
"mask_image AFTER multiply with pil_image",
|
"mask_image AFTER multiply with pil_image",
|
||||||
debug_status=self.enable_image_debugging,
|
debug_status=self.enable_image_debugging,
|
||||||
)
|
)
|
||||||
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
mask: torch.FloatTensor = image_resized_to_grid_as_tensor(
|
||||||
|
mask_image, normalize=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
mask: torch.FloatTensor = mask_image
|
mask: torch.FloatTensor = mask_image
|
||||||
|
|
||||||
@@ -284,9 +302,9 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
# todo: support cross-attention control
|
# todo: support cross-attention control
|
||||||
uc, c, _ = conditioning
|
uc, c, _ = conditioning
|
||||||
conditioning_data = ConditioningData(uc, c, cfg_scale).add_scheduler_args_if_applicable(
|
conditioning_data = ConditioningData(
|
||||||
pipeline.scheduler, eta=ddim_eta
|
uc, c, cfg_scale
|
||||||
)
|
).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)
|
||||||
|
|
||||||
def make_image(x_T: torch.Tensor, seed: int):
|
def make_image(x_T: torch.Tensor, seed: int):
|
||||||
pipeline_output = pipeline.inpaint_from_embeddings(
|
pipeline_output = pipeline.inpaint_from_embeddings(
|
||||||
@@ -300,10 +318,15 @@ class Inpaint(Img2Img):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None:
|
if (
|
||||||
|
pipeline_output.attention_map_saver is not None
|
||||||
|
and attention_maps_callback is not None
|
||||||
|
):
|
||||||
attention_maps_callback(pipeline_output.attention_map_saver)
|
attention_maps_callback(pipeline_output.attention_map_saver)
|
||||||
|
|
||||||
result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0])
|
result = self.postprocess_size_and_mask(
|
||||||
|
pipeline.numpy_to_pil(pipeline_output.images)[0]
|
||||||
|
)
|
||||||
|
|
||||||
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
# Seam paint if this is our first pass (seam_size set to 0 during seam painting)
|
||||||
if seam_size > 0:
|
if seam_size > 0:
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from .txt2mask import Txt2Mask
|
|||||||
from .util import InitImageResizer, make_grid
|
from .util import InitImageResizer, make_grid
|
||||||
|
|
||||||
|
|
||||||
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False):
|
def debug_image(
|
||||||
|
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
|
||||||
|
):
|
||||||
if not debug_status:
|
if not debug_status:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,8 @@ from PIL import Image
|
|||||||
from imwatermark import WatermarkEncoder
|
from imwatermark import WatermarkEncoder
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
|
||||||
class InvisibleWatermark:
|
class InvisibleWatermark:
|
||||||
"""
|
"""
|
||||||
Wrapper around InvisibleWatermark module.
|
Wrapper around InvisibleWatermark module.
|
||||||
@@ -23,12 +21,14 @@ class InvisibleWatermark:
|
|||||||
return config.invisible_watermark
|
return config.invisible_watermark
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_watermark(self, image: Image, watermark_text: str) -> Image:
|
def add_watermark(self, image: Image, watermark_text:str) -> Image:
|
||||||
if not self.invisible_watermark_available():
|
if not self.invisible_watermark_available():
|
||||||
return image
|
return image
|
||||||
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
||||||
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||||
encoder = WatermarkEncoder()
|
encoder = WatermarkEncoder()
|
||||||
encoder.set_watermark("bytes", watermark_text.encode("utf-8"))
|
encoder.set_watermark('bytes', watermark_text.encode('utf-8'))
|
||||||
bgr_encoded = encoder.encode(bgr, "dwtDct")
|
bgr_encoded = encoder.encode(bgr, 'dwtDct')
|
||||||
return Image.fromarray(cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
return Image.fromarray(
|
||||||
|
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
|
||||||
|
).convert("RGBA")
|
||||||
|
|||||||
@@ -7,10 +7,8 @@ be suppressed or deferred
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
"""
|
"""
|
||||||
Thin class wrapper around the patchmatch function.
|
Thin class wrapper around the patchmatch function.
|
||||||
|
|||||||
@@ -34,7 +34,9 @@ class PngWriter:
|
|||||||
|
|
||||||
# saves image named _image_ to outdir/name, writing metadata from prompt
|
# saves image named _image_ to outdir/name, writing metadata from prompt
|
||||||
# returns full path of output
|
# returns full path of output
|
||||||
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
|
def save_image_and_prompt_to_png(
|
||||||
|
self, image, dream_prompt, name, metadata=None, compress_level=6
|
||||||
|
):
|
||||||
path = os.path.join(self.outdir, name)
|
path = os.path.join(self.outdir, name)
|
||||||
info = PngImagePlugin.PngInfo()
|
info = PngImagePlugin.PngInfo()
|
||||||
info.add_text("Dream", dream_prompt)
|
info.add_text("Dream", dream_prompt)
|
||||||
@@ -112,6 +114,8 @@ class PromptFormatter:
|
|||||||
if opt.variation_amount > 0:
|
if opt.variation_amount > 0:
|
||||||
switches.append(f"-v{opt.variation_amount}")
|
switches.append(f"-v{opt.variation_amount}")
|
||||||
if opt.with_variations:
|
if opt.with_variations:
|
||||||
formatted_variations = ",".join(f"{seed}:{weight}" for seed, weight in opt.with_variations)
|
formatted_variations = ",".join(
|
||||||
|
f"{seed}:{weight}" for seed, weight in opt.with_variations
|
||||||
|
)
|
||||||
switches.append(f"-V{formatted_variations}")
|
switches.append(f"-V{formatted_variations}")
|
||||||
return " ".join(switches)
|
return " ".join(switches)
|
||||||
|
|||||||
@@ -9,17 +9,14 @@ from invokeai.backend import SilenceWarnings
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
CHECKER_PATH = 'core/convert/stable-diffusion-safety-checker'
|
||||||
|
|
||||||
|
|
||||||
class SafetyChecker:
|
class SafetyChecker:
|
||||||
"""
|
"""
|
||||||
Wrapper around SafetyChecker model.
|
Wrapper around SafetyChecker model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
safety_checker = None
|
safety_checker = None
|
||||||
feature_extractor = None
|
feature_extractor = None
|
||||||
tried_load: bool = False
|
tried_load: bool = False
|
||||||
@@ -33,14 +30,16 @@ class SafetyChecker:
|
|||||||
try:
|
try:
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
config.models_path / CHECKER_PATH
|
||||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
)
|
||||||
logger.info("NSFW checker initialized")
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
config.models_path / CHECKER_PATH)
|
||||||
|
logger.info('NSFW checker initialized')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
logger.warning(f'Could not load NSFW checker: {str(e)}')
|
||||||
else:
|
else:
|
||||||
logger.info("NSFW checker loading disabled")
|
logger.info('NSFW checker loading disabled')
|
||||||
self.tried_load = True
|
self.tried_load = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -5,8 +5,12 @@ def _conv_forward_asymmetric(self, input, weight, bias):
|
|||||||
"""
|
"""
|
||||||
Patch for Conv2d._conv_forward that supports asymmetric padding
|
Patch for Conv2d._conv_forward that supports asymmetric padding
|
||||||
"""
|
"""
|
||||||
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
|
working = nn.functional.pad(
|
||||||
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
|
input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]
|
||||||
|
)
|
||||||
|
working = nn.functional.pad(
|
||||||
|
working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]
|
||||||
|
)
|
||||||
return nn.functional.conv2d(
|
return nn.functional.conv2d(
|
||||||
working,
|
working,
|
||||||
weight,
|
weight,
|
||||||
@@ -28,14 +32,18 @@ def configure_model_padding(model, seamless, seamless_axes):
|
|||||||
if seamless:
|
if seamless:
|
||||||
m.asymmetric_padding_mode = {}
|
m.asymmetric_padding_mode = {}
|
||||||
m.asymmetric_padding = {}
|
m.asymmetric_padding = {}
|
||||||
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
|
m.asymmetric_padding_mode["x"] = (
|
||||||
|
"circular" if ("x" in seamless_axes) else "constant"
|
||||||
|
)
|
||||||
m.asymmetric_padding["x"] = (
|
m.asymmetric_padding["x"] = (
|
||||||
m._reversed_padding_repeated_twice[0],
|
m._reversed_padding_repeated_twice[0],
|
||||||
m._reversed_padding_repeated_twice[1],
|
m._reversed_padding_repeated_twice[1],
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
|
m.asymmetric_padding_mode["y"] = (
|
||||||
|
"circular" if ("y" in seamless_axes) else "constant"
|
||||||
|
)
|
||||||
m.asymmetric_padding["y"] = (
|
m.asymmetric_padding["y"] = (
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
|||||||
@@ -39,18 +39,23 @@ CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
|
|||||||
CLIPSEG_SIZE = 352
|
CLIPSEG_SIZE = 352
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
def __init__(self, image: Image, heatmap: torch.Tensor):
|
def __init__(self, image: Image, heatmap: torch.Tensor):
|
||||||
self.heatmap = heatmap
|
self.heatmap = heatmap
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
def to_grayscale(self, invert: bool = False) -> Image:
|
def to_grayscale(self, invert: bool = False) -> Image:
|
||||||
return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)))
|
return self._rescale(
|
||||||
|
Image.fromarray(
|
||||||
|
np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def to_mask(self, threshold: float = 0.5) -> Image:
|
def to_mask(self, threshold: float = 0.5) -> Image:
|
||||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||||
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L"))
|
return self._rescale(
|
||||||
|
Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")
|
||||||
|
)
|
||||||
|
|
||||||
def to_transparent(self, invert: bool = False) -> Image:
|
def to_transparent(self, invert: bool = False) -> Image:
|
||||||
transparent_image = self.image.copy()
|
transparent_image = self.image.copy()
|
||||||
@@ -62,7 +67,11 @@ class SegmentedGrayscale(object):
|
|||||||
|
|
||||||
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||||
def _rescale(self, heatmap: Image) -> Image:
|
def _rescale(self, heatmap: Image) -> Image:
|
||||||
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
size = (
|
||||||
|
self.image.width
|
||||||
|
if (self.image.width > self.image.height)
|
||||||
|
else self.image.height
|
||||||
|
)
|
||||||
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
|
||||||
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
return resized_image.crop((0, 0, self.image.width, self.image.height))
|
||||||
|
|
||||||
@@ -78,8 +87,12 @@ class Txt2Mask(object):
|
|||||||
|
|
||||||
# BUG: we are not doing anything with the device option at this time
|
# BUG: we are not doing anything with the device option at this time
|
||||||
self.device = device
|
self.device = device
|
||||||
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir)
|
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||||
|
)
|
||||||
|
self.model = CLIPSegForImageSegmentation.from_pretrained(
|
||||||
|
CLIPSEG_MODEL, cache_dir=config.cache_dir
|
||||||
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
def segment(self, image, prompt: str) -> SegmentedGrayscale:
|
||||||
@@ -94,7 +107,9 @@ class Txt2Mask(object):
|
|||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
img = self._scale_and_crop(image)
|
img = self._scale_and_crop(image)
|
||||||
|
|
||||||
inputs = self.processor(text=[prompt], images=[img], padding=True, return_tensors="pt")
|
inputs = self.processor(
|
||||||
|
text=[prompt], images=[img], padding=True, return_tensors="pt"
|
||||||
|
)
|
||||||
outputs = self.model(**inputs)
|
outputs = self.model(**inputs)
|
||||||
heatmap = torch.sigmoid(outputs.logits)
|
heatmap = torch.sigmoid(outputs.logits)
|
||||||
return SegmentedGrayscale(image, heatmap)
|
return SegmentedGrayscale(image, heatmap)
|
||||||
|
|||||||
@@ -6,31 +6,26 @@ from invokeai.app.services.config import (
|
|||||||
InvokeAIAppConfig,
|
InvokeAIAppConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_invokeai_root(config: InvokeAIAppConfig):
|
def check_invokeai_root(config: InvokeAIAppConfig):
|
||||||
try:
|
try:
|
||||||
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
assert config.model_conf_path.exists()
|
||||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
assert config.db_path.exists()
|
||||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
assert config.models_path.exists()
|
||||||
for model in [
|
for model in [
|
||||||
"CLIP-ViT-bigG-14-laion2B-39B-b160k",
|
'CLIP-ViT-bigG-14-laion2B-39B-b160k',
|
||||||
"bert-base-uncased",
|
'bert-base-uncased',
|
||||||
"clip-vit-large-patch14",
|
'clip-vit-large-patch14',
|
||||||
"sd-vae-ft-mse",
|
'sd-vae-ft-mse',
|
||||||
"stable-diffusion-2-clip",
|
'stable-diffusion-2-clip',
|
||||||
"stable-diffusion-safety-checker",
|
'stable-diffusion-safety-checker']:
|
||||||
]:
|
assert (config.models_path / f'core/convert/{model}').exists()
|
||||||
path = config.models_path / f"core/convert/{model}"
|
except:
|
||||||
assert path.exists(), f"{path} is missing"
|
|
||||||
except Exception as e:
|
|
||||||
print()
|
print()
|
||||||
print(f"An exception has occurred: {str(e)}")
|
print('== STARTUP ABORTED ==')
|
||||||
print("== STARTUP ABORTED ==")
|
print('** One or more necessary files is missing from your InvokeAI root directory **')
|
||||||
print("** One or more necessary files is missing from your InvokeAI root directory **")
|
print('** Please rerun the configuration script to fix this problem. **')
|
||||||
print("** Please rerun the configuration script to fix this problem. **")
|
print('** From the launcher, selection option [7]. **')
|
||||||
print("** From the launcher, selection option [7]. **")
|
print('** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **')
|
||||||
print(
|
input('Press any key to continue...')
|
||||||
'** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **'
|
|
||||||
)
|
|
||||||
input("Press any key to continue...")
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,9 @@ from invokeai.backend.install.model_install_backend import (
|
|||||||
InstallSelections,
|
InstallSelections,
|
||||||
ModelInstall,
|
ModelInstall,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_probe import ModelType, BaseModelType
|
from invokeai.backend.model_management.model_probe import (
|
||||||
|
ModelType, BaseModelType
|
||||||
|
)
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@@ -75,7 +77,7 @@ Model_dir = "models"
|
|||||||
Default_config_file = config.model_conf_path
|
Default_config_file = config.model_conf_path
|
||||||
SD_Configs = config.legacy_conf_path
|
SD_Configs = config.legacy_conf_path
|
||||||
|
|
||||||
PRECISION_CHOICES = ["auto", "float16", "float32"]
|
PRECISION_CHOICES = ['auto','float16','float32']
|
||||||
|
|
||||||
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||||
# This is the InvokeAI initialization file, which contains command-line default values.
|
# This is the InvokeAI initialization file, which contains command-line default values.
|
||||||
@@ -83,8 +85,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
|||||||
# or renaming it and then running invokeai-configure again.
|
# or renaming it and then running invokeai-configure again.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger = InvokeAILogger.getLogger()
|
logger=InvokeAILogger.getLogger()
|
||||||
|
|
||||||
|
|
||||||
# --------------------------------------------
|
# --------------------------------------------
|
||||||
def postscript(errors: None):
|
def postscript(errors: None):
|
||||||
@@ -107,9 +108,7 @@ Add the '--help' argument to see all of the command-line switches available for
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
message = (
|
message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
||||||
"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
|
|
||||||
)
|
|
||||||
for err in errors:
|
for err in errors:
|
||||||
message += f"\t - {err}\n"
|
message += f"\t - {err}\n"
|
||||||
message += "Please check the logs above and correct any issues."
|
message += "Please check the logs above and correct any issues."
|
||||||
@@ -170,7 +169,9 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
logger.info(f"Installing {label} model file {model_url}...")
|
logger.info(f"Installing {label} model file {model_url}...")
|
||||||
if not os.path.exists(model_dest):
|
if not os.path.exists(model_dest):
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
request.urlretrieve(model_url, model_dest, ProgressBar(os.path.basename(model_dest)))
|
request.urlretrieve(
|
||||||
|
model_url, model_dest, ProgressBar(os.path.basename(model_dest))
|
||||||
|
)
|
||||||
logger.info("...downloaded successfully")
|
logger.info("...downloaded successfully")
|
||||||
else:
|
else:
|
||||||
logger.info("...exists")
|
logger.info("...exists")
|
||||||
@@ -181,33 +182,33 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
|
|
||||||
|
|
||||||
def download_conversion_models():
|
def download_conversion_models():
|
||||||
target_dir = config.models_path / "core/convert"
|
target_dir = config.root_path / 'models/core/convert'
|
||||||
kwargs = dict() # for future use
|
kwargs = dict() # for future use
|
||||||
try:
|
try:
|
||||||
logger.info("Downloading core tokenizers and text encoders")
|
logger.info('Downloading core tokenizers and text encoders')
|
||||||
|
|
||||||
# bert
|
# bert
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
||||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||||
bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True)
|
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
||||||
|
|
||||||
# sd-1
|
# sd-1
|
||||||
repo_id = "openai/clip-vit-large-patch14"
|
repo_id = 'openai/clip-vit-large-patch14'
|
||||||
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14")
|
hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||||
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14")
|
hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14')
|
||||||
|
|
||||||
# sd-2
|
# sd-2
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
||||||
|
|
||||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||||
|
|
||||||
# sd-xl - tokenizer_2
|
# sd-xl - tokenizer_2
|
||||||
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||||
_, model_name = repo_id.split("/")
|
_, model_name = repo_id.split('/')
|
||||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
||||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||||
|
|
||||||
@@ -215,59 +216,56 @@ def download_conversion_models():
|
|||||||
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
pipeline.save_pretrained(target_dir / model_name, safe_serialization=True)
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
logger.info("Downloading stable diffusion VAE")
|
logger.info('Downloading stable diffusion VAE')
|
||||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs)
|
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||||
vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True)
|
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
||||||
|
|
||||||
# safety checking
|
# safety checking
|
||||||
logger.info("Downloading safety checker")
|
logger.info('Downloading safety checker')
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs)
|
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||||
|
|
||||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs)
|
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
||||||
pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True)
|
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing ESRGAN Upscaling models...")
|
logger.info("Installing ESRGAN Upscaling models...")
|
||||||
URLs = [
|
URLs = [
|
||||||
dict(
|
dict(
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
description="RealESRGAN_x4plus.pth",
|
description = "RealESRGAN_x4plus.pth",
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
description="RealESRGAN_x4plus_anime_6B.pth",
|
description = "RealESRGAN_x4plus_anime_6B.pth",
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
description="ESRGAN_SRx4_DF2KOST_official.pth",
|
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
),
|
),
|
||||||
dict(
|
dict(
|
||||||
url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
description="RealESRGAN_x2plus.pth",
|
description = "RealESRGAN_x2plus.pth",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
for model in URLs:
|
for model in URLs:
|
||||||
download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"])
|
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_support_models():
|
def download_support_models():
|
||||||
download_realesrgan()
|
download_realesrgan()
|
||||||
download_conversion_models()
|
download_conversion_models()
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def get_root(root: str = None) -> str:
|
def get_root(root: str = None) -> str:
|
||||||
if root:
|
if root:
|
||||||
@@ -277,7 +275,6 @@ def get_root(root: str = None) -> str:
|
|||||||
else:
|
else:
|
||||||
return str(config.root_path)
|
return str(config.root_path)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||||
# for responsive resizing - disabled
|
# for responsive resizing - disabled
|
||||||
@@ -286,14 +283,14 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
def create(self):
|
def create(self):
|
||||||
program_opts = self.parentApp.program_opts
|
program_opts = self.parentApp.program_opts
|
||||||
old_opts = self.parentApp.invokeai_opts
|
old_opts = self.parentApp.invokeai_opts
|
||||||
first_time = not (config.root_path / "invokeai.yaml").exists()
|
first_time = not (config.root_path / 'invokeai.yaml').exists()
|
||||||
access_token = HfFolder.get_token()
|
access_token = HfFolder.get_token()
|
||||||
window_width, window_height = get_terminal_size()
|
window_width, window_height = get_terminal_size()
|
||||||
label = """Configure startup settings. You can come back and change these later.
|
label = """Configure startup settings. You can come back and change these later.
|
||||||
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields.
|
||||||
Use cursor arrows to make a checkbox selection, and space to toggle.
|
Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||||
"""
|
"""
|
||||||
for i in textwrap.wrap(label, width=window_width - 6):
|
for i in textwrap.wrap(label,width=window_width-6):
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value=i,
|
value=i,
|
||||||
@@ -303,7 +300,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
|
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens."""
|
||||||
for line in textwrap.wrap(label, width=window_width - 6):
|
for line in textwrap.wrap(label,width=window_width-6):
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value=line,
|
value=line,
|
||||||
@@ -346,7 +343,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
relx=50,
|
relx=50,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely -= 1
|
self.nextrely -=1
|
||||||
self.always_use_cpu = self.add_widget_intelligent(
|
self.always_use_cpu = self.add_widget_intelligent(
|
||||||
npyscreen.Checkbox,
|
npyscreen.Checkbox,
|
||||||
name="Force CPU to be used on GPU systems",
|
name="Force CPU to be used on GPU systems",
|
||||||
@@ -354,8 +351,10 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
relx=80,
|
relx=80,
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto")
|
precision = old_opts.precision or (
|
||||||
self.nextrely += 1
|
"float32" if program_opts.full_precision else "auto"
|
||||||
|
)
|
||||||
|
self.nextrely +=1
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.TitleFixedText,
|
npyscreen.TitleFixedText,
|
||||||
name="Floating Point Precision",
|
name="Floating Point Precision",
|
||||||
@@ -364,10 +363,10 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
color="CONTROL",
|
color="CONTROL",
|
||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely -= 1
|
self.nextrely -=1
|
||||||
self.precision = self.add_widget_intelligent(
|
self.precision = self.add_widget_intelligent(
|
||||||
SingleSelectColumns,
|
SingleSelectColumns,
|
||||||
columns=3,
|
columns = 3,
|
||||||
name="Precision",
|
name="Precision",
|
||||||
values=PRECISION_CHOICES,
|
values=PRECISION_CHOICES,
|
||||||
value=PRECISION_CHOICES.index(precision),
|
value=PRECISION_CHOICES.index(precision),
|
||||||
@@ -399,25 +398,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
|||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.autoimport_dirs = {}
|
self.autoimport_dirs = {}
|
||||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
self.autoimport_dirs['autoimport_dir'] = self.add_widget_intelligent(
|
||||||
FileBox,
|
FileBox,
|
||||||
name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
name=f'Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models',
|
||||||
value=str(config.root_path / config.autoimport_dir),
|
value=str(config.root_path / config.autoimport_dir),
|
||||||
select_dir=True,
|
select_dir=True,
|
||||||
must_exist=False,
|
must_exist=False,
|
||||||
use_two_lines=False,
|
use_two_lines=False,
|
||||||
labelColor="GOOD",
|
labelColor="GOOD",
|
||||||
begin_entry_at=32,
|
begin_entry_at=32,
|
||||||
max_height=3,
|
max_height = 3,
|
||||||
scroll_exit=True,
|
scroll_exit=True
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ
|
||||||
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
|
AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT
|
||||||
https://huggingface.co/spaces/CompVis/stable-diffusion-license and
|
https://huggingface.co/spaces/CompVis/stable-diffusion-license and
|
||||||
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
|
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
|
||||||
"""
|
"""
|
||||||
for i in textwrap.wrap(label, width=window_width - 6):
|
for i in textwrap.wrap(label,width=window_width-6):
|
||||||
self.add_widget_intelligent(
|
self.add_widget_intelligent(
|
||||||
npyscreen.FixedText,
|
npyscreen.FixedText,
|
||||||
value=i,
|
value=i,
|
||||||
@@ -432,7 +431,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
scroll_exit=True,
|
scroll_exit=True,
|
||||||
)
|
)
|
||||||
self.nextrely += 1
|
self.nextrely += 1
|
||||||
label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT"
|
label = (
|
||||||
|
"DONE"
|
||||||
|
if program_opts.skip_sd_weights or program_opts.default_only
|
||||||
|
else "NEXT"
|
||||||
|
)
|
||||||
self.ok_button = self.add_widget_intelligent(
|
self.ok_button = self.add_widget_intelligent(
|
||||||
CenteredButtonPress,
|
CenteredButtonPress,
|
||||||
name=label,
|
name=label,
|
||||||
@@ -455,7 +458,9 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
def validate_field_values(self, opt: Namespace) -> bool:
|
def validate_field_values(self, opt: Namespace) -> bool:
|
||||||
bad_fields = []
|
bad_fields = []
|
||||||
if not opt.license_acceptance:
|
if not opt.license_acceptance:
|
||||||
bad_fields.append("Please accept the license terms before proceeding to model downloads")
|
bad_fields.append(
|
||||||
|
"Please accept the license terms before proceeding to model downloads"
|
||||||
|
)
|
||||||
if not Path(opt.outdir).parent.exists():
|
if not Path(opt.outdir).parent.exists():
|
||||||
bad_fields.append(
|
bad_fields.append(
|
||||||
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory."
|
||||||
@@ -473,11 +478,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
|||||||
new_opts = Namespace()
|
new_opts = Namespace()
|
||||||
|
|
||||||
for attr in [
|
for attr in [
|
||||||
"outdir",
|
"outdir",
|
||||||
"free_gpu_mem",
|
"free_gpu_mem",
|
||||||
"max_cache_size",
|
"max_cache_size",
|
||||||
"xformers_enabled",
|
"xformers_enabled",
|
||||||
"always_use_cpu",
|
"always_use_cpu",
|
||||||
]:
|
]:
|
||||||
setattr(new_opts, attr, getattr(self, attr).value)
|
setattr(new_opts, attr, getattr(self, attr).value)
|
||||||
|
|
||||||
@@ -529,17 +534,16 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
|||||||
editApp.run()
|
editApp.run()
|
||||||
return editApp.new_opts()
|
return editApp.new_opts()
|
||||||
|
|
||||||
|
|
||||||
def default_startup_options(init_file: Path) -> Namespace:
|
def default_startup_options(init_file: Path) -> Namespace:
|
||||||
opts = InvokeAIAppConfig.get_config()
|
opts = InvokeAIAppConfig.get_config()
|
||||||
return opts
|
return opts
|
||||||
|
|
||||||
|
|
||||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installer = ModelInstall(config)
|
installer = ModelInstall(config)
|
||||||
except omegaconf.errors.ConfigKeyError:
|
except omegaconf.errors.ConfigKeyError:
|
||||||
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
logger.warning('Your models.yaml file is corrupt or out of date. Reinitializing')
|
||||||
initialize_rootdir(config.root_path, True)
|
initialize_rootdir(config.root_path, True)
|
||||||
installer = ModelInstall(config)
|
installer = ModelInstall(config)
|
||||||
|
|
||||||
@@ -552,46 +556,55 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
|||||||
else list(),
|
else list(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||||
logger.info("Initializing InvokeAI runtime directory")
|
logger.info("Initializing InvokeAI runtime directory")
|
||||||
for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"):
|
for name in (
|
||||||
|
"models",
|
||||||
|
"databases",
|
||||||
|
"text-inversion-output",
|
||||||
|
"text-inversion-training-data",
|
||||||
|
"configs"
|
||||||
|
):
|
||||||
os.makedirs(os.path.join(root, name), exist_ok=True)
|
os.makedirs(os.path.join(root, name), exist_ok=True)
|
||||||
for model_type in ModelType:
|
for model_type in ModelType:
|
||||||
Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True)
|
Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
configs_src = Path(configs.__path__[0])
|
configs_src = Path(configs.__path__[0])
|
||||||
configs_dest = root / "configs"
|
configs_dest = root / "configs"
|
||||||
if not os.path.samefile(configs_src, configs_dest):
|
if not os.path.samefile(configs_src, configs_dest):
|
||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||||
|
|
||||||
dest = root / "models"
|
dest = root / 'models'
|
||||||
for model_base in BaseModelType:
|
for model_base in BaseModelType:
|
||||||
for model_type in ModelType:
|
for model_type in ModelType:
|
||||||
path = dest / model_base.value / model_type.value
|
path = dest / model_base.value / model_type.value
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
path = dest / "core"
|
path = dest / 'core'
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
maybe_create_models_yaml(root)
|
maybe_create_models_yaml(root)
|
||||||
|
|
||||||
|
|
||||||
def maybe_create_models_yaml(root: Path):
|
def maybe_create_models_yaml(root: Path):
|
||||||
models_yaml = root / "configs" / "models.yaml"
|
models_yaml = root / 'configs' / 'models.yaml'
|
||||||
if models_yaml.exists():
|
if models_yaml.exists():
|
||||||
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
|
if OmegaConf.load(models_yaml).get('__metadata__'): # up to date
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
|
logger.info('Creating new models.yaml, original saved as models.yaml.orig')
|
||||||
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
|
models_yaml.rename(models_yaml.parent / 'models.yaml.orig')
|
||||||
|
|
||||||
with open(models_yaml, "w") as yaml_file:
|
|
||||||
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
|
||||||
|
|
||||||
|
with open(models_yaml,'w') as yaml_file:
|
||||||
|
yaml_file.write(yaml.dump({'__metadata__':
|
||||||
|
{'version':'3.0.0'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
def run_console_ui(
|
||||||
|
program_opts: Namespace, initfile: Path = None
|
||||||
|
) -> (Namespace, Namespace):
|
||||||
# parse_args() will read from init file if present
|
# parse_args() will read from init file if present
|
||||||
invokeai_opts = default_startup_options(initfile)
|
invokeai_opts = default_startup_options(initfile)
|
||||||
invokeai_opts.root = program_opts.root
|
invokeai_opts.root = program_opts.root
|
||||||
@@ -603,7 +616,6 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
|
|||||||
# the install-models application spawns a subprocess to install
|
# the install-models application spawns a subprocess to install
|
||||||
# models, and will crash unless this is set before running.
|
# models, and will crash unless this is set before running.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||||
@@ -623,41 +635,38 @@ def write_opts(opts: Namespace, init_file: Path):
|
|||||||
new_config = InvokeAIAppConfig.get_config()
|
new_config = InvokeAIAppConfig.get_config()
|
||||||
new_config.root = config.root
|
new_config.root = config.root
|
||||||
|
|
||||||
for key, value in opts.__dict__.items():
|
for key,value in opts.__dict__.items():
|
||||||
if hasattr(new_config, key):
|
if hasattr(new_config,key):
|
||||||
setattr(new_config, key, value)
|
setattr(new_config,key,value)
|
||||||
|
|
||||||
with open(init_file, "w", encoding="utf-8") as file:
|
with open(init_file,'w', encoding='utf-8') as file:
|
||||||
file.write(new_config.to_yaml())
|
file.write(new_config.to_yaml())
|
||||||
|
|
||||||
if hasattr(opts, "hf_token") and opts.hf_token:
|
if hasattr(opts,'hf_token') and opts.hf_token:
|
||||||
HfLogin(opts.hf_token)
|
HfLogin(opts.hf_token)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def default_output_dir() -> Path:
|
def default_output_dir() -> Path:
|
||||||
return config.root_path / "outputs"
|
return config.root_path / "outputs"
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||||
opt = default_startup_options(initfile)
|
opt = default_startup_options(initfile)
|
||||||
write_opts(opt, initfile)
|
write_opts(opt, initfile)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
# Here we bring in
|
# Here we bring in
|
||||||
# the legacy Args object in order to parse
|
# the legacy Args object in order to parse
|
||||||
# the old init file and write out the new
|
# the old init file and write out the new
|
||||||
# yaml format.
|
# yaml format.
|
||||||
def migrate_init_file(legacy_format: Path):
|
def migrate_init_file(legacy_format:Path):
|
||||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
old = legacy_parser.parse_args([f'@{str(legacy_format)}'])
|
||||||
new = InvokeAIAppConfig.get_config()
|
new = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
fields = list(get_type_hints(InvokeAIAppConfig).keys())
|
||||||
for attr in fields:
|
for attr in fields:
|
||||||
if hasattr(old, attr):
|
if hasattr(old,attr):
|
||||||
setattr(new, attr, getattr(old, attr))
|
setattr(new,attr,getattr(old,attr))
|
||||||
|
|
||||||
# a few places where the field names have changed and we have to
|
# a few places where the field names have changed and we have to
|
||||||
# manually add in the new names/values
|
# manually add in the new names/values
|
||||||
@@ -665,39 +674,36 @@ def migrate_init_file(legacy_format: Path):
|
|||||||
new.conf_path = old.conf
|
new.conf_path = old.conf
|
||||||
new.root = legacy_format.parent.resolve()
|
new.root = legacy_format.parent.resolve()
|
||||||
|
|
||||||
invokeai_yaml = legacy_format.parent / "invokeai.yaml"
|
invokeai_yaml = legacy_format.parent / 'invokeai.yaml'
|
||||||
with open(invokeai_yaml, "w", encoding="utf-8") as outfile:
|
with open(invokeai_yaml,"w", encoding="utf-8") as outfile:
|
||||||
outfile.write(new.to_yaml())
|
outfile.write(new.to_yaml())
|
||||||
|
|
||||||
legacy_format.replace(legacy_format.parent / "invokeai.init.orig")
|
legacy_format.replace(legacy_format.parent / 'invokeai.init.orig')
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def migrate_models(root: Path):
|
def migrate_models(root: Path):
|
||||||
from invokeai.backend.install.migrate_to_3 import do_migrate
|
from invokeai.backend.install.migrate_to_3 import do_migrate
|
||||||
|
|
||||||
do_migrate(root, root)
|
do_migrate(root, root)
|
||||||
|
|
||||||
|
def migrate_if_needed(opt: Namespace, root: Path)->bool:
|
||||||
def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
|
||||||
# We check for to see if the runtime directory is correctly initialized.
|
# We check for to see if the runtime directory is correctly initialized.
|
||||||
old_init_file = root / "invokeai.init"
|
old_init_file = root / 'invokeai.init'
|
||||||
new_init_file = root / "invokeai.yaml"
|
new_init_file = root / 'invokeai.yaml'
|
||||||
old_hub = root / "models/hub"
|
old_hub = root / 'models/hub'
|
||||||
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists()
|
||||||
|
|
||||||
if migration_needed:
|
if migration_needed:
|
||||||
if opt.yes_to_all or yes_or_no(
|
if opt.yes_to_all or \
|
||||||
f"{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?"
|
yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'):
|
||||||
):
|
|
||||||
logger.info("** Migrating invokeai.init to invokeai.yaml")
|
logger.info('** Migrating invokeai.init to invokeai.yaml')
|
||||||
migrate_init_file(old_init_file)
|
migrate_init_file(old_init_file)
|
||||||
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
|
config.parse_args(argv=[],conf=OmegaConf.load(new_init_file))
|
||||||
|
|
||||||
if old_hub.exists():
|
if old_hub.exists():
|
||||||
migrate_models(config.root_path)
|
migrate_models(config.root_path)
|
||||||
else:
|
else:
|
||||||
print("Cannot continue without conversion. Aborting.")
|
print('Cannot continue without conversion. Aborting.')
|
||||||
|
|
||||||
return migration_needed
|
return migration_needed
|
||||||
|
|
||||||
@@ -758,9 +764,9 @@ def main():
|
|||||||
|
|
||||||
invoke_args = []
|
invoke_args = []
|
||||||
if opt.root:
|
if opt.root:
|
||||||
invoke_args.extend(["--root", opt.root])
|
invoke_args.extend(['--root',opt.root])
|
||||||
if opt.full_precision:
|
if opt.full_precision:
|
||||||
invoke_args.extend(["--precision", "float32"])
|
invoke_args.extend(['--precision','float32'])
|
||||||
config.parse_args(invoke_args)
|
config.parse_args(invoke_args)
|
||||||
logger = InvokeAILogger().getLogger(config=config)
|
logger = InvokeAILogger().getLogger(config=config)
|
||||||
|
|
||||||
@@ -776,16 +782,20 @@ def main():
|
|||||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||||
|
|
||||||
models_to_download = default_user_selections(opt)
|
models_to_download = default_user_selections(opt)
|
||||||
new_init_file = config.root_path / "invokeai.yaml"
|
new_init_file = config.root_path / 'invokeai.yaml'
|
||||||
if opt.yes_to_all:
|
if opt.yes_to_all:
|
||||||
write_default_options(opt, new_init_file)
|
write_default_options(opt, new_init_file)
|
||||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
init_options = Namespace(
|
||||||
|
precision="float32" if opt.full_precision else "float16"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||||
if init_options:
|
if init_options:
|
||||||
write_opts(init_options, new_init_file)
|
write_opts(init_options, new_init_file)
|
||||||
else:
|
else:
|
||||||
logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n')
|
logger.info(
|
||||||
|
'\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n'
|
||||||
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if opt.skip_support_models:
|
if opt.skip_support_models:
|
||||||
@@ -801,7 +811,7 @@ def main():
|
|||||||
|
|
||||||
postscript(errors=errors)
|
postscript(errors=errors)
|
||||||
if not opt.yes_to_all:
|
if not opt.yes_to_all:
|
||||||
input("Press any key to continue...")
|
input('Press any key to continue...')
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nGoodbye! Come back soon.")
|
print("\nGoodbye! Come back soon.")
|
||||||
|
|
||||||
|
|||||||
@@ -47,18 +47,17 @@ PRECISION_CHOICES = [
|
|||||||
"float16",
|
"float16",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FileArgumentParser(ArgumentParser):
|
class FileArgumentParser(ArgumentParser):
|
||||||
"""
|
"""
|
||||||
Supports reading defaults from an init file.
|
Supports reading defaults from an init file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def convert_arg_line_to_args(self, arg_line):
|
def convert_arg_line_to_args(self, arg_line):
|
||||||
return shlex.split(arg_line, comments=True)
|
return shlex.split(arg_line, comments=True)
|
||||||
|
|
||||||
|
|
||||||
legacy_parser = FileArgumentParser(
|
legacy_parser = FileArgumentParser(
|
||||||
description="""
|
description=
|
||||||
|
"""
|
||||||
Generate images using Stable Diffusion.
|
Generate images using Stable Diffusion.
|
||||||
Use --web to launch the web interface.
|
Use --web to launch the web interface.
|
||||||
Use --from_file to load prompts from a file path or standard input ("-").
|
Use --from_file to load prompts from a file path or standard input ("-").
|
||||||
@@ -66,279 +65,304 @@ Generate images using Stable Diffusion.
|
|||||||
Other command-line arguments are defaults that can usually be overridden
|
Other command-line arguments are defaults that can usually be overridden
|
||||||
prompt the command prompt.
|
prompt the command prompt.
|
||||||
""",
|
""",
|
||||||
fromfile_prefix_chars="@",
|
fromfile_prefix_chars='@',
|
||||||
)
|
)
|
||||||
general_group = legacy_parser.add_argument_group("General")
|
general_group = legacy_parser.add_argument_group('General')
|
||||||
model_group = legacy_parser.add_argument_group("Model selection")
|
model_group = legacy_parser.add_argument_group('Model selection')
|
||||||
file_group = legacy_parser.add_argument_group("Input/output")
|
file_group = legacy_parser.add_argument_group('Input/output')
|
||||||
web_server_group = legacy_parser.add_argument_group("Web server")
|
web_server_group = legacy_parser.add_argument_group('Web server')
|
||||||
render_group = legacy_parser.add_argument_group("Rendering")
|
render_group = legacy_parser.add_argument_group('Rendering')
|
||||||
postprocessing_group = legacy_parser.add_argument_group("Postprocessing")
|
postprocessing_group = legacy_parser.add_argument_group('Postprocessing')
|
||||||
deprecated_group = legacy_parser.add_argument_group("Deprecated options")
|
deprecated_group = legacy_parser.add_argument_group('Deprecated options')
|
||||||
|
|
||||||
deprecated_group.add_argument("--laion400m")
|
deprecated_group.add_argument('--laion400m')
|
||||||
deprecated_group.add_argument("--weights") # deprecated
|
deprecated_group.add_argument('--weights') # deprecated
|
||||||
general_group.add_argument("--version", "-V", action="store_true", help="Print InvokeAI version number")
|
general_group.add_argument(
|
||||||
|
'--version','-V',
|
||||||
|
action='store_true',
|
||||||
|
help='Print InvokeAI version number'
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--root_dir",
|
'--root_dir',
|
||||||
default=None,
|
default=None,
|
||||||
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--config",
|
'--config',
|
||||||
"-c",
|
'-c',
|
||||||
"-config",
|
'-config',
|
||||||
dest="conf",
|
dest='conf',
|
||||||
default="./configs/models.yaml",
|
default='./configs/models.yaml',
|
||||||
help="Path to configuration file for alternate models.",
|
help='Path to configuration file for alternate models.',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--model",
|
'--model',
|
||||||
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--weight_dirs",
|
'--weight_dirs',
|
||||||
nargs="+",
|
nargs='+',
|
||||||
type=str,
|
type=str,
|
||||||
help="List of one or more directories that will be auto-scanned for new model weights to import",
|
help='List of one or more directories that will be auto-scanned for new model weights to import',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--png_compression",
|
'--png_compression','-z',
|
||||||
"-z",
|
|
||||||
type=int,
|
type=int,
|
||||||
default=6,
|
default=6,
|
||||||
choices=range(0, 9),
|
choices=range(0,9),
|
||||||
dest="png_compression",
|
dest='png_compression',
|
||||||
help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.",
|
help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.'
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"-F",
|
'-F',
|
||||||
"--full_precision",
|
'--full_precision',
|
||||||
dest="full_precision",
|
dest='full_precision',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help="Deprecated way to set --precision=float32",
|
help='Deprecated way to set --precision=float32',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--max_loaded_models",
|
'--max_loaded_models',
|
||||||
dest="max_loaded_models",
|
dest='max_loaded_models',
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
|
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--free_gpu_mem",
|
'--free_gpu_mem',
|
||||||
dest="free_gpu_mem",
|
dest='free_gpu_mem',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help="Force free gpu memory before final decoding",
|
help='Force free gpu memory before final decoding',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--sequential_guidance",
|
'--sequential_guidance',
|
||||||
dest="sequential_guidance",
|
dest='sequential_guidance',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help="Calculate guidance in serial instead of in parallel, lowering memory requirement " "at the expense of speed",
|
help="Calculate guidance in serial instead of in parallel, lowering memory requirement "
|
||||||
|
"at the expense of speed",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--xformers",
|
'--xformers',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help="Enable/disable xformers support (default enabled if installed)",
|
help='Enable/disable xformers support (default enabled if installed)',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--always_use_cpu", dest="always_use_cpu", action="store_true", help="Force use of CPU even if GPU is available"
|
"--always_use_cpu",
|
||||||
|
dest="always_use_cpu",
|
||||||
|
action="store_true",
|
||||||
|
help="Force use of CPU even if GPU is available"
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--precision",
|
'--precision',
|
||||||
dest="precision",
|
dest='precision',
|
||||||
type=str,
|
type=str,
|
||||||
choices=PRECISION_CHOICES,
|
choices=PRECISION_CHOICES,
|
||||||
metavar="PRECISION",
|
metavar='PRECISION',
|
||||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||||
default="auto",
|
default='auto',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--ckpt_convert",
|
'--ckpt_convert',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
dest="ckpt_convert",
|
dest='ckpt_convert',
|
||||||
default=True,
|
default=True,
|
||||||
help="Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.",
|
help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.'
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--internet",
|
'--internet',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
dest="internet_available",
|
dest='internet_available',
|
||||||
default=True,
|
default=True,
|
||||||
help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).",
|
help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--nsfw_checker",
|
'--nsfw_checker',
|
||||||
"--safety_checker",
|
'--safety_checker',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
dest="safety_checker",
|
dest='safety_checker',
|
||||||
default=False,
|
default=False,
|
||||||
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
|
help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--autoimport",
|
'--autoimport',
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--autoconvert",
|
'--autoconvert',
|
||||||
default=None,
|
default=None,
|
||||||
type=str,
|
type=str,
|
||||||
help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models",
|
help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models',
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--patchmatch",
|
'--patchmatch',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.",
|
help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.',
|
||||||
)
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
"--from_file",
|
'--from_file',
|
||||||
dest="infile",
|
dest='infile',
|
||||||
type=str,
|
type=str,
|
||||||
help="If specified, load prompts from this file",
|
help='If specified, load prompts from this file',
|
||||||
)
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
"--outdir",
|
'--outdir',
|
||||||
"-o",
|
'-o',
|
||||||
type=str,
|
type=str,
|
||||||
help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs",
|
help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs',
|
||||||
default="outputs",
|
default='outputs',
|
||||||
)
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
"--prompt_as_dir",
|
'--prompt_as_dir',
|
||||||
"-p",
|
'-p',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help="Place images in subdirectories named after the prompt.",
|
help='Place images in subdirectories named after the prompt.',
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"--fnformat",
|
'--fnformat',
|
||||||
default="{prefix}.{seed}.png",
|
default='{prefix}.{seed}.png',
|
||||||
type=str,
|
type=str,
|
||||||
help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png",
|
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
||||||
)
|
)
|
||||||
render_group.add_argument("-s", "--steps", type=int, default=50, help="Number of steps")
|
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"-W",
|
'-s',
|
||||||
"--width",
|
'--steps',
|
||||||
type=int,
|
type=int,
|
||||||
help="Image width, multiple of 64",
|
default=50,
|
||||||
|
help='Number of steps'
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"-H",
|
'-W',
|
||||||
"--height",
|
'--width',
|
||||||
type=int,
|
type=int,
|
||||||
help="Image height, multiple of 64",
|
help='Image width, multiple of 64',
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"-C",
|
'-H',
|
||||||
"--cfg_scale",
|
'--height',
|
||||||
|
type=int,
|
||||||
|
help='Image height, multiple of 64',
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'-C',
|
||||||
|
'--cfg_scale',
|
||||||
default=7.5,
|
default=7.5,
|
||||||
type=float,
|
type=float,
|
||||||
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"--sampler",
|
'--sampler',
|
||||||
"-A",
|
'-A',
|
||||||
"-m",
|
'-m',
|
||||||
dest="sampler_name",
|
dest='sampler_name',
|
||||||
type=str,
|
type=str,
|
||||||
choices=SAMPLER_CHOICES,
|
choices=SAMPLER_CHOICES,
|
||||||
metavar="SAMPLER_NAME",
|
metavar='SAMPLER_NAME',
|
||||||
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||||
default="k_lms",
|
default='k_lms',
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"--log_tokenization", "-t", action="store_true", help="shows how the prompt is split into tokens"
|
'--log_tokenization',
|
||||||
|
'-t',
|
||||||
|
action='store_true',
|
||||||
|
help='shows how the prompt is split into tokens'
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"-f",
|
'-f',
|
||||||
"--strength",
|
'--strength',
|
||||||
type=float,
|
type=float,
|
||||||
help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely",
|
help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"-T",
|
'-T',
|
||||||
"-fit",
|
'-fit',
|
||||||
"--fit",
|
'--fit',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)",
|
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
|
||||||
)
|
)
|
||||||
|
|
||||||
render_group.add_argument("--grid", "-g", action=argparse.BooleanOptionalAction, help="generate a grid")
|
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"--embedding_directory",
|
'--grid',
|
||||||
"--embedding_path",
|
'-g',
|
||||||
dest="embedding_path",
|
action=argparse.BooleanOptionalAction,
|
||||||
default="embeddings",
|
help='generate a grid'
|
||||||
type=str,
|
|
||||||
help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)",
|
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"--lora_directory",
|
'--embedding_directory',
|
||||||
dest="lora_path",
|
'--embedding_path',
|
||||||
default="loras",
|
dest='embedding_path',
|
||||||
|
default='embeddings',
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)",
|
help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)'
|
||||||
)
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"--embeddings",
|
'--lora_directory',
|
||||||
|
dest='lora_path',
|
||||||
|
default='loras',
|
||||||
|
type=str,
|
||||||
|
help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--embeddings',
|
||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help="Enable embedding directory (default). Use --no-embeddings to disable.",
|
help='Enable embedding directory (default). Use --no-embeddings to disable.',
|
||||||
)
|
)
|
||||||
render_group.add_argument("--enable_image_debugging", action="store_true", help="Generates debugging image to display")
|
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
"--karras_max",
|
'--enable_image_debugging',
|
||||||
|
action='store_true',
|
||||||
|
help='Generates debugging image to display'
|
||||||
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--karras_max',
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].",
|
help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]."
|
||||||
)
|
)
|
||||||
# Restoration related args
|
# Restoration related args
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
"--no_restore",
|
'--no_restore',
|
||||||
dest="restore",
|
dest='restore',
|
||||||
action="store_false",
|
action='store_false',
|
||||||
help="Disable face restoration with GFPGAN or codeformer",
|
help='Disable face restoration with GFPGAN or codeformer',
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
"--no_upscale",
|
'--no_upscale',
|
||||||
dest="esrgan",
|
dest='esrgan',
|
||||||
action="store_false",
|
action='store_false',
|
||||||
help="Disable upscaling with ESRGAN",
|
help='Disable upscaling with ESRGAN',
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
"--esrgan_bg_tile",
|
'--esrgan_bg_tile',
|
||||||
type=int,
|
type=int,
|
||||||
default=400,
|
default=400,
|
||||||
help="Tile size for background sampler, 0 for no tile during testing. Default: 400.",
|
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
"--esrgan_denoise_str",
|
'--esrgan_denoise_str',
|
||||||
type=float,
|
type=float,
|
||||||
default=0.75,
|
default=0.75,
|
||||||
help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75",
|
help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75',
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
"--gfpgan_model_path",
|
'--gfpgan_model_path',
|
||||||
type=str,
|
type=str,
|
||||||
default="./models/gfpgan/GFPGANv1.4.pth",
|
default='./models/gfpgan/GFPGANv1.4.pth',
|
||||||
help="Indicates the path to the GFPGAN model",
|
help='Indicates the path to the GFPGAN model',
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
"--web",
|
'--web',
|
||||||
dest="web",
|
dest='web',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help="Start in web server mode.",
|
help='Start in web server mode.',
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
"--web_develop",
|
'--web_develop',
|
||||||
dest="web_develop",
|
dest='web_develop',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help="Start in web server development mode.",
|
help='Start in web server development mode.',
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
"--web_verbose",
|
"--web_verbose",
|
||||||
@@ -352,27 +376,32 @@ web_server_group.add_argument(
|
|||||||
help="Additional allowed origins, comma-separated",
|
help="Additional allowed origins, comma-separated",
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
"--host",
|
'--host',
|
||||||
type=str,
|
type=str,
|
||||||
default="127.0.0.1",
|
default='127.0.0.1',
|
||||||
help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.",
|
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
|
||||||
)
|
)
|
||||||
web_server_group.add_argument("--port", type=int, default="9090", help="Web server: Port to listen on")
|
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
"--certfile",
|
'--port',
|
||||||
|
type=int,
|
||||||
|
default='9090',
|
||||||
|
help='Web server: Port to listen on'
|
||||||
|
)
|
||||||
|
web_server_group.add_argument(
|
||||||
|
'--certfile',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Web server: Path to certificate file to use for SSL. Use together with --keyfile",
|
help='Web server: Path to certificate file to use for SSL. Use together with --keyfile'
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
"--keyfile",
|
'--keyfile',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="Web server: Path to private key file to use for SSL. Use together with --certfile",
|
help='Web server: Path to private key file to use for SSL. Use together with --certfile'
|
||||||
)
|
)
|
||||||
web_server_group.add_argument(
|
web_server_group.add_argument(
|
||||||
"--gui",
|
'--gui',
|
||||||
dest="gui",
|
dest='gui',
|
||||||
action="store_true",
|
action='store_true',
|
||||||
help="Start InvokeAI GUI",
|
help='Start InvokeAI GUI',
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""
|
'''
|
||||||
Migrate the models directory and models.yaml file from an existing
|
Migrate the models directory and models.yaml file from an existing
|
||||||
InvokeAI 2.3 installation to 3.0.0.
|
InvokeAI 2.3 installation to 3.0.0.
|
||||||
"""
|
'''
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
@@ -29,13 +29,14 @@ from transformers import (
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_management import ModelManager
|
from invokeai.backend.model_management import ModelManager
|
||||||
from invokeai.backend.model_management.model_probe import ModelProbe, ModelType, BaseModelType, ModelProbeInfo
|
from invokeai.backend.model_management.model_probe import (
|
||||||
|
ModelProbe, ModelType, BaseModelType, ModelProbeInfo
|
||||||
|
)
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
diffusers.logging.set_verbosity_error()
|
diffusers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
# holder for paths that we will migrate
|
# holder for paths that we will migrate
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelPaths:
|
class ModelPaths:
|
||||||
@@ -44,15 +45,13 @@ class ModelPaths:
|
|||||||
loras: Path
|
loras: Path
|
||||||
controlnets: Path
|
controlnets: Path
|
||||||
|
|
||||||
|
|
||||||
class MigrateTo3(object):
|
class MigrateTo3(object):
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
from_root: Path,
|
||||||
from_root: Path,
|
to_models: Path,
|
||||||
to_models: Path,
|
model_manager: ModelManager,
|
||||||
model_manager: ModelManager,
|
src_paths: ModelPaths,
|
||||||
src_paths: ModelPaths,
|
):
|
||||||
):
|
|
||||||
self.root_directory = from_root
|
self.root_directory = from_root
|
||||||
self.dest_models = to_models
|
self.dest_models = to_models
|
||||||
self.mgr = model_manager
|
self.mgr = model_manager
|
||||||
@@ -60,66 +59,67 @@ class MigrateTo3(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize_yaml(cls, yaml_file: Path):
|
def initialize_yaml(cls, yaml_file: Path):
|
||||||
with open(yaml_file, "w") as file:
|
with open(yaml_file, 'w') as file:
|
||||||
file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
file.write(
|
||||||
|
yaml.dump(
|
||||||
|
{
|
||||||
|
'__metadata__': {'version':'3.0.0'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def create_directory_structure(self):
|
def create_directory_structure(self):
|
||||||
"""
|
'''
|
||||||
Create the basic directory structure for the models folder.
|
Create the basic directory structure for the models folder.
|
||||||
"""
|
'''
|
||||||
for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:
|
for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]:
|
||||||
for model_type in [
|
for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora,
|
||||||
ModelType.Main,
|
ModelType.ControlNet,ModelType.TextualInversion]:
|
||||||
ModelType.Vae,
|
|
||||||
ModelType.Lora,
|
|
||||||
ModelType.ControlNet,
|
|
||||||
ModelType.TextualInversion,
|
|
||||||
]:
|
|
||||||
path = self.dest_models / model_base.value / model_type.value
|
path = self.dest_models / model_base.value / model_type.value
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
path = self.dest_models / "core"
|
path = self.dest_models / 'core'
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def copy_file(src: Path, dest: Path):
|
def copy_file(src:Path,dest:Path):
|
||||||
"""
|
'''
|
||||||
copy a single file with logging
|
copy a single file with logging
|
||||||
"""
|
'''
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
logger.info(f"Skipping existing {str(dest)}")
|
logger.info(f'Skipping existing {str(dest)}')
|
||||||
return
|
return
|
||||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||||
try:
|
try:
|
||||||
shutil.copy(src, dest)
|
shutil.copy(src, dest)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"COPY FAILED: {str(e)}")
|
logger.error(f'COPY FAILED: {str(e)}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def copy_dir(src: Path, dest: Path):
|
def copy_dir(src:Path,dest:Path):
|
||||||
"""
|
'''
|
||||||
Recursively copy a directory with logging
|
Recursively copy a directory with logging
|
||||||
"""
|
'''
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
logger.info(f"Skipping existing {str(dest)}")
|
logger.info(f'Skipping existing {str(dest)}')
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Copying {str(src)} to {str(dest)}")
|
logger.info(f'Copying {str(src)} to {str(dest)}')
|
||||||
try:
|
try:
|
||||||
shutil.copytree(src, dest)
|
shutil.copytree(src, dest)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"COPY FAILED: {str(e)}")
|
logger.error(f'COPY FAILED: {str(e)}')
|
||||||
|
|
||||||
def migrate_models(self, src_dir: Path):
|
def migrate_models(self, src_dir: Path):
|
||||||
"""
|
'''
|
||||||
Recursively walk through src directory, probe anything
|
Recursively walk through src directory, probe anything
|
||||||
that looks like a model, and copy the model into the
|
that looks like a model, and copy the model into the
|
||||||
appropriate location within the destination models directory.
|
appropriate location within the destination models directory.
|
||||||
"""
|
'''
|
||||||
directories_scanned = set()
|
directories_scanned = set()
|
||||||
for root, dirs, files in os.walk(src_dir):
|
for root, dirs, files in os.walk(src_dir):
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
try:
|
try:
|
||||||
model = Path(root, d)
|
model = Path(root,d)
|
||||||
info = ModelProbe().heuristic_probe(model)
|
info = ModelProbe().heuristic_probe(model)
|
||||||
if not info:
|
if not info:
|
||||||
continue
|
continue
|
||||||
@@ -136,9 +136,9 @@ class MigrateTo3(object):
|
|||||||
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
# don't copy raw learned_embeds.bin or pytorch_lora_weights.bin
|
||||||
# let them be copied as part of a tree copy operation
|
# let them be copied as part of a tree copy operation
|
||||||
try:
|
try:
|
||||||
if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}:
|
if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}:
|
||||||
continue
|
continue
|
||||||
model = Path(root, f)
|
model = Path(root,f)
|
||||||
if model.parent in directories_scanned:
|
if model.parent in directories_scanned:
|
||||||
continue
|
continue
|
||||||
info = ModelProbe().heuristic_probe(model)
|
info = ModelProbe().heuristic_probe(model)
|
||||||
@@ -154,146 +154,148 @@ class MigrateTo3(object):
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
def migrate_support_models(self):
|
def migrate_support_models(self):
|
||||||
"""
|
'''
|
||||||
Copy the clipseg, upscaler, and restoration models to their new
|
Copy the clipseg, upscaler, and restoration models to their new
|
||||||
locations.
|
locations.
|
||||||
"""
|
'''
|
||||||
dest_directory = self.dest_models
|
dest_directory = self.dest_models
|
||||||
if (self.root_directory / "models/clipseg").exists():
|
if (self.root_directory / 'models/clipseg').exists():
|
||||||
self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg")
|
self.copy_dir(self.root_directory / 'models/clipseg', dest_directory / 'core/misc/clipseg')
|
||||||
if (self.root_directory / "models/realesrgan").exists():
|
if (self.root_directory / 'models/realesrgan').exists():
|
||||||
self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan")
|
self.copy_dir(self.root_directory / 'models/realesrgan', dest_directory / 'core/upscaling/realesrgan')
|
||||||
for d in ["codeformer", "gfpgan"]:
|
for d in ['codeformer','gfpgan']:
|
||||||
path = self.root_directory / "models" / d
|
path = self.root_directory / 'models' / d
|
||||||
if path.exists():
|
if path.exists():
|
||||||
self.copy_dir(path, dest_directory / f"core/face_restoration/{d}")
|
self.copy_dir(path,dest_directory / f'core/face_restoration/{d}')
|
||||||
|
|
||||||
def migrate_tuning_models(self):
|
def migrate_tuning_models(self):
|
||||||
"""
|
'''
|
||||||
Migrate the embeddings, loras and controlnets directories to their new homes.
|
Migrate the embeddings, loras and controlnets directories to their new homes.
|
||||||
"""
|
'''
|
||||||
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]:
|
||||||
if not src:
|
if not src:
|
||||||
continue
|
continue
|
||||||
if src.is_dir():
|
if src.is_dir():
|
||||||
logger.info(f"Scanning {src}")
|
logger.info(f'Scanning {src}')
|
||||||
self.migrate_models(src)
|
self.migrate_models(src)
|
||||||
else:
|
else:
|
||||||
logger.info(f"{src} directory not found; skipping")
|
logger.info(f'{src} directory not found; skipping')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def migrate_conversion_models(self):
|
def migrate_conversion_models(self):
|
||||||
"""
|
'''
|
||||||
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
Migrate all the models that are needed by the ckpt_to_diffusers conversion
|
||||||
script.
|
script.
|
||||||
"""
|
'''
|
||||||
|
|
||||||
dest_directory = self.dest_models
|
dest_directory = self.dest_models
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
cache_dir=self.root_directory / "models/hub",
|
cache_dir = self.root_directory / 'models/hub',
|
||||||
# local_files_only = True
|
#local_files_only = True
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
logger.info("Migrating core tokenizers and text encoders")
|
logger.info('Migrating core tokenizers and text encoders')
|
||||||
target_dir = dest_directory / "core" / "convert"
|
target_dir = dest_directory / 'core' / 'convert'
|
||||||
|
|
||||||
self._migrate_pretrained(
|
self._migrate_pretrained(BertTokenizerFast,
|
||||||
BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs
|
repo_id='bert-base-uncased',
|
||||||
)
|
dest = target_dir / 'bert-base-uncased',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# sd-1
|
# sd-1
|
||||||
repo_id = "openai/clip-vit-large-patch14"
|
repo_id = 'openai/clip-vit-large-patch14'
|
||||||
self._migrate_pretrained(
|
self._migrate_pretrained(CLIPTokenizer,
|
||||||
CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs
|
repo_id= repo_id,
|
||||||
)
|
dest= target_dir / 'clip-vit-large-patch14',
|
||||||
self._migrate_pretrained(
|
**kwargs)
|
||||||
CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs
|
self._migrate_pretrained(CLIPTextModel,
|
||||||
)
|
repo_id = repo_id,
|
||||||
|
dest = target_dir / 'clip-vit-large-patch14',
|
||||||
|
force = True,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# sd-2
|
# sd-2
|
||||||
repo_id = "stabilityai/stable-diffusion-2"
|
repo_id = "stabilityai/stable-diffusion-2"
|
||||||
self._migrate_pretrained(
|
self._migrate_pretrained(CLIPTokenizer,
|
||||||
CLIPTokenizer,
|
repo_id = repo_id,
|
||||||
repo_id=repo_id,
|
dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer',
|
||||||
dest=target_dir / "stable-diffusion-2-clip" / "tokenizer",
|
**{'subfolder':'tokenizer',**kwargs}
|
||||||
**{"subfolder": "tokenizer", **kwargs},
|
)
|
||||||
)
|
self._migrate_pretrained(CLIPTextModel,
|
||||||
self._migrate_pretrained(
|
repo_id = repo_id,
|
||||||
CLIPTextModel,
|
dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder',
|
||||||
repo_id=repo_id,
|
**{'subfolder':'text_encoder',**kwargs}
|
||||||
dest=target_dir / "stable-diffusion-2-clip" / "text_encoder",
|
)
|
||||||
**{"subfolder": "text_encoder", **kwargs},
|
|
||||||
)
|
|
||||||
|
|
||||||
# VAE
|
# VAE
|
||||||
logger.info("Migrating stable diffusion VAE")
|
logger.info('Migrating stable diffusion VAE')
|
||||||
self._migrate_pretrained(
|
self._migrate_pretrained(AutoencoderKL,
|
||||||
AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs
|
repo_id = 'stabilityai/sd-vae-ft-mse',
|
||||||
)
|
dest = target_dir / 'sd-vae-ft-mse',
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# safety checking
|
# safety checking
|
||||||
logger.info("Migrating safety checker")
|
logger.info('Migrating safety checker')
|
||||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
self._migrate_pretrained(
|
self._migrate_pretrained(AutoFeatureExtractor,
|
||||||
AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs
|
repo_id = repo_id,
|
||||||
)
|
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||||
self._migrate_pretrained(
|
**kwargs)
|
||||||
StableDiffusionSafetyChecker,
|
self._migrate_pretrained(StableDiffusionSafetyChecker,
|
||||||
repo_id=repo_id,
|
repo_id = repo_id,
|
||||||
dest=target_dir / "stable-diffusion-safety-checker",
|
dest = target_dir / 'stable-diffusion-safety-checker',
|
||||||
**kwargs,
|
**kwargs)
|
||||||
)
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
|
|
||||||
def _model_probe_to_path(self, info: ModelProbeInfo) -> Path:
|
def _model_probe_to_path(self, info: ModelProbeInfo)->Path:
|
||||||
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
return Path(self.dest_models, info.base_type.value, info.model_type.value)
|
||||||
|
|
||||||
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs):
|
def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs):
|
||||||
if dest.exists() and not force:
|
if dest.exists() and not force:
|
||||||
logger.info(f"Skipping existing {dest}")
|
logger.info(f'Skipping existing {dest}')
|
||||||
return
|
return
|
||||||
model = model_class.from_pretrained(repo_id, **kwargs)
|
model = model_class.from_pretrained(repo_id, **kwargs)
|
||||||
self._save_pretrained(model, dest, overwrite=force)
|
self._save_pretrained(model, dest, overwrite=force)
|
||||||
|
|
||||||
def _save_pretrained(self, model, dest: Path, overwrite: bool = False):
|
def _save_pretrained(self, model, dest: Path, overwrite: bool=False):
|
||||||
model_name = dest.name
|
model_name = dest.name
|
||||||
if overwrite:
|
if overwrite:
|
||||||
model.save_pretrained(dest, safe_serialization=True)
|
model.save_pretrained(dest, safe_serialization=True)
|
||||||
else:
|
else:
|
||||||
download_path = dest.with_name(f"{model_name}.downloading")
|
download_path = dest.with_name(f'{model_name}.downloading')
|
||||||
model.save_pretrained(download_path, safe_serialization=True)
|
model.save_pretrained(download_path, safe_serialization=True)
|
||||||
download_path.replace(dest)
|
download_path.replace(dest)
|
||||||
|
|
||||||
def _download_vae(self, repo_id: str, subfolder: str = None) -> Path:
|
def _download_vae(self, repo_id: str, subfolder:str=None)->Path:
|
||||||
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder)
|
vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder)
|
||||||
info = ModelProbe().heuristic_probe(vae)
|
info = ModelProbe().heuristic_probe(vae)
|
||||||
_, model_name = repo_id.split("/")
|
_, model_name = repo_id.split('/')
|
||||||
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
dest = self._model_probe_to_path(info) / self.unique_name(model_name, info)
|
||||||
vae.save_pretrained(dest, safe_serialization=True)
|
vae.save_pretrained(dest, safe_serialization=True)
|
||||||
return dest
|
return dest
|
||||||
|
|
||||||
def _vae_path(self, vae: Union[str, dict]) -> Path:
|
def _vae_path(self, vae: Union[str,dict])->Path:
|
||||||
"""
|
'''
|
||||||
Convert 2.3 VAE stanza to a straight path.
|
Convert 2.3 VAE stanza to a straight path.
|
||||||
"""
|
'''
|
||||||
vae_path = None
|
vae_path = None
|
||||||
|
|
||||||
# First get a path
|
# First get a path
|
||||||
if isinstance(vae, str):
|
if isinstance(vae,str):
|
||||||
vae_path = vae
|
vae_path = vae
|
||||||
|
|
||||||
elif isinstance(vae, DictConfig):
|
elif isinstance(vae,DictConfig):
|
||||||
if p := vae.get("path"):
|
if p := vae.get('path'):
|
||||||
vae_path = p
|
vae_path = p
|
||||||
elif repo_id := vae.get("repo_id"):
|
elif repo_id := vae.get('repo_id'):
|
||||||
if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded
|
if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded
|
||||||
vae_path = "models/core/convert/sd-vae-ft-mse"
|
vae_path = 'models/core/convert/sd-vae-ft-mse'
|
||||||
return vae_path
|
return vae_path
|
||||||
else:
|
else:
|
||||||
vae_path = self._download_vae(repo_id, vae.get("subfolder"))
|
vae_path = self._download_vae(repo_id, vae.get('subfolder'))
|
||||||
|
|
||||||
assert vae_path is not None, "Couldn't find VAE for this model"
|
assert vae_path is not None, "Couldn't find VAE for this model"
|
||||||
|
|
||||||
@@ -305,60 +307,63 @@ class MigrateTo3(object):
|
|||||||
dest = self._model_probe_to_path(info) / vae_path.name
|
dest = self._model_probe_to_path(info) / vae_path.name
|
||||||
if not dest.exists():
|
if not dest.exists():
|
||||||
if vae_path.is_dir():
|
if vae_path.is_dir():
|
||||||
self.copy_dir(vae_path, dest)
|
self.copy_dir(vae_path,dest)
|
||||||
else:
|
else:
|
||||||
self.copy_file(vae_path, dest)
|
self.copy_file(vae_path,dest)
|
||||||
vae_path = dest
|
vae_path = dest
|
||||||
|
|
||||||
if vae_path.is_relative_to(self.dest_models):
|
if vae_path.is_relative_to(self.dest_models):
|
||||||
rel_path = vae_path.relative_to(self.dest_models)
|
rel_path = vae_path.relative_to(self.dest_models)
|
||||||
return Path("models", rel_path)
|
return Path('models',rel_path)
|
||||||
else:
|
else:
|
||||||
return vae_path
|
return vae_path
|
||||||
|
|
||||||
def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config):
|
def migrate_repo_id(self, repo_id: str, model_name: str=None, **extra_config):
|
||||||
"""
|
'''
|
||||||
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
Migrate a locally-cached diffusers pipeline identified with a repo_id
|
||||||
"""
|
'''
|
||||||
dest_dir = self.dest_models
|
dest_dir = self.dest_models
|
||||||
|
|
||||||
cache = self.root_directory / "models/hub"
|
cache = self.root_directory / 'models/hub'
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
cache_dir=cache,
|
cache_dir = cache,
|
||||||
safety_checker=None,
|
safety_checker = None,
|
||||||
# local_files_only = True,
|
# local_files_only = True,
|
||||||
)
|
)
|
||||||
|
|
||||||
owner, repo_name = repo_id.split("/")
|
owner,repo_name = repo_id.split('/')
|
||||||
model_name = model_name or repo_name
|
model_name = model_name or repo_name
|
||||||
model = cache / "--".join(["models", owner, repo_name])
|
model = cache / '--'.join(['models',owner,repo_name])
|
||||||
|
|
||||||
if len(list(model.glob("snapshots/**/model_index.json"))) == 0:
|
if len(list(model.glob('snapshots/**/model_index.json')))==0:
|
||||||
return
|
return
|
||||||
revisions = [x.name for x in model.glob("refs/*")]
|
revisions = [x.name for x in model.glob('refs/*')]
|
||||||
|
|
||||||
# if an fp16 is available we use that
|
# if an fp16 is available we use that
|
||||||
revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0]
|
revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0]
|
||||||
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs)
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
|
repo_id,
|
||||||
|
revision=revision,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(pipeline)
|
info = ModelProbe().heuristic_probe(pipeline)
|
||||||
if not info:
|
if not info:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
||||||
return
|
return
|
||||||
|
|
||||||
dest = self._model_probe_to_path(info) / model_name
|
dest = self._model_probe_to_path(info) / model_name
|
||||||
self._save_pretrained(pipeline, dest)
|
self._save_pretrained(pipeline, dest)
|
||||||
|
|
||||||
rel_path = Path("models", dest.relative_to(dest_dir))
|
rel_path = Path('models',dest.relative_to(dest_dir))
|
||||||
self._add_model(model_name, info, rel_path, **extra_config)
|
self._add_model(model_name, info, rel_path, **extra_config)
|
||||||
|
|
||||||
def migrate_path(self, location: Path, model_name: str = None, **extra_config):
|
def migrate_path(self, location: Path, model_name: str=None, **extra_config):
|
||||||
"""
|
'''
|
||||||
Migrate a model referred to using 'weights' or 'path'
|
Migrate a model referred to using 'weights' or 'path'
|
||||||
"""
|
'''
|
||||||
|
|
||||||
# handle relative paths
|
# handle relative paths
|
||||||
dest_dir = self.dest_models
|
dest_dir = self.dest_models
|
||||||
@@ -370,72 +375,77 @@ class MigrateTo3(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.")
|
logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.')
|
||||||
return
|
return
|
||||||
|
|
||||||
# uh oh, weights is in the old models directory - move it into the new one
|
# uh oh, weights is in the old models directory - move it into the new one
|
||||||
if Path(location).is_relative_to(self.src_paths.models):
|
if Path(location).is_relative_to(self.src_paths.models):
|
||||||
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name)
|
||||||
if location.is_dir():
|
if location.is_dir():
|
||||||
self.copy_dir(location, dest)
|
self.copy_dir(location,dest)
|
||||||
else:
|
else:
|
||||||
self.copy_file(location, dest)
|
self.copy_file(location,dest)
|
||||||
location = Path("models", info.base_type.value, info.model_type.value, location.name)
|
location = Path('models', info.base_type.value, info.model_type.value, location.name)
|
||||||
|
|
||||||
self._add_model(model_name, info, location, **extra_config)
|
self._add_model(model_name, info, location, **extra_config)
|
||||||
|
|
||||||
def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config):
|
def _add_model(self,
|
||||||
|
model_name: str,
|
||||||
|
info: ModelProbeInfo,
|
||||||
|
location: Path,
|
||||||
|
**extra_config):
|
||||||
if info.model_type != ModelType.Main:
|
if info.model_type != ModelType.Main:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.mgr.add_model(
|
self.mgr.add_model(
|
||||||
model_name=model_name,
|
model_name = model_name,
|
||||||
base_model=info.base_type,
|
base_model = info.base_type,
|
||||||
model_type=info.model_type,
|
model_type = info.model_type,
|
||||||
clobber=True,
|
clobber = True,
|
||||||
model_attributes={
|
model_attributes = {
|
||||||
"path": str(location),
|
'path': str(location),
|
||||||
"description": f"A {info.base_type.value} {info.model_type.value} model",
|
'description': f'A {info.base_type.value} {info.model_type.value} model',
|
||||||
"model_format": info.format,
|
'model_format': info.format,
|
||||||
"variant": info.variant_type.value,
|
'variant': info.variant_type.value,
|
||||||
**extra_config,
|
**extra_config,
|
||||||
},
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def migrate_defined_models(self):
|
def migrate_defined_models(self):
|
||||||
"""
|
'''
|
||||||
Migrate models defined in models.yaml
|
Migrate models defined in models.yaml
|
||||||
"""
|
'''
|
||||||
# find any models referred to in old models.yaml
|
# find any models referred to in old models.yaml
|
||||||
conf = OmegaConf.load(self.root_directory / "configs/models.yaml")
|
conf = OmegaConf.load(self.root_directory / 'configs/models.yaml')
|
||||||
|
|
||||||
for model_name, stanza in conf.items():
|
for model_name, stanza in conf.items():
|
||||||
|
|
||||||
try:
|
try:
|
||||||
passthru_args = {}
|
passthru_args = {}
|
||||||
|
|
||||||
if vae := stanza.get("vae"):
|
if vae := stanza.get('vae'):
|
||||||
try:
|
try:
|
||||||
passthru_args["vae"] = str(self._vae_path(vae))
|
passthru_args['vae'] = str(self._vae_path(vae))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"')
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|
||||||
if config := stanza.get("config"):
|
if config := stanza.get('config'):
|
||||||
passthru_args["config"] = config
|
passthru_args['config'] = config
|
||||||
|
|
||||||
if description := stanza.get("description"):
|
if description:= stanza.get('description'):
|
||||||
passthru_args["description"] = description
|
passthru_args['description'] = description
|
||||||
|
|
||||||
if repo_id := stanza.get("repo_id"):
|
if repo_id := stanza.get('repo_id'):
|
||||||
logger.info(f"Migrating diffusers model {model_name}")
|
logger.info(f'Migrating diffusers model {model_name}')
|
||||||
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
self.migrate_repo_id(repo_id, model_name, **passthru_args)
|
||||||
|
|
||||||
elif location := stanza.get("weights"):
|
elif location := stanza.get('weights'):
|
||||||
logger.info(f"Migrating checkpoint model {model_name}")
|
logger.info(f'Migrating checkpoint model {model_name}')
|
||||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||||
|
|
||||||
elif location := stanza.get("path"):
|
elif location := stanza.get('path'):
|
||||||
logger.info(f"Migrating diffusers model {model_name}")
|
logger.info(f'Migrating diffusers model {model_name}')
|
||||||
self.migrate_path(Path(location), model_name, **passthru_args)
|
self.migrate_path(Path(location), model_name, **passthru_args)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -451,71 +461,67 @@ class MigrateTo3(object):
|
|||||||
self.migrate_tuning_models()
|
self.migrate_tuning_models()
|
||||||
self.migrate_defined_models()
|
self.migrate_defined_models()
|
||||||
|
|
||||||
|
def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths:
|
||||||
def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths:
|
'''
|
||||||
"""
|
|
||||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||||
"""
|
'''
|
||||||
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
|
parser = argparse.ArgumentParser(fromfile_prefix_chars='@')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--embedding_directory",
|
'--embedding_directory',
|
||||||
"--embedding_path",
|
'--embedding_path',
|
||||||
type=Path,
|
type=Path,
|
||||||
dest="embedding_path",
|
dest='embedding_path',
|
||||||
default=Path("embeddings"),
|
default=Path('embeddings'),
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora_directory",
|
'--lora_directory',
|
||||||
dest="lora_path",
|
dest='lora_path',
|
||||||
type=Path,
|
type=Path,
|
||||||
default=Path("loras"),
|
default=Path('loras'),
|
||||||
)
|
)
|
||||||
opt, _ = parser.parse_known_args([f"@{str(initfile)}"])
|
opt,_ = parser.parse_known_args([f'@{str(initfile)}'])
|
||||||
return ModelPaths(
|
return ModelPaths(
|
||||||
models=root / "models",
|
models = root / 'models',
|
||||||
embeddings=root / str(opt.embedding_path).strip('"'),
|
embeddings = root / str(opt.embedding_path).strip('"'),
|
||||||
loras=root / str(opt.lora_path).strip('"'),
|
loras = root / str(opt.lora_path).strip('"'),
|
||||||
controlnets=root / "controlnets",
|
controlnets = root / 'controlnets',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths:
|
||||||
def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths:
|
'''
|
||||||
"""
|
|
||||||
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
Returns tuple of (embedding_path, lora_path, controlnet_path)
|
||||||
"""
|
'''
|
||||||
# Don't use the config object because it is unforgiving of version updates
|
# Don't use the config object because it is unforgiving of version updates
|
||||||
# Just use omegaconf directly
|
# Just use omegaconf directly
|
||||||
opt = OmegaConf.load(initfile)
|
opt = OmegaConf.load(initfile)
|
||||||
paths = opt.InvokeAI.Paths
|
paths = opt.InvokeAI.Paths
|
||||||
models = paths.get("models_dir", "models")
|
models = paths.get('models_dir','models')
|
||||||
embeddings = paths.get("embedding_dir", "embeddings")
|
embeddings = paths.get('embedding_dir','embeddings')
|
||||||
loras = paths.get("lora_dir", "loras")
|
loras = paths.get('lora_dir','loras')
|
||||||
controlnets = paths.get("controlnet_dir", "controlnets")
|
controlnets = paths.get('controlnet_dir','controlnets')
|
||||||
return ModelPaths(
|
return ModelPaths(
|
||||||
models=root / models,
|
models = root / models,
|
||||||
embeddings=root / embeddings,
|
embeddings = root / embeddings,
|
||||||
loras=root / loras,
|
loras = root /loras,
|
||||||
controlnets=root / controlnets,
|
controlnets = root / controlnets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
def get_legacy_embeddings(root: Path) -> ModelPaths:
|
||||||
path = root / "invokeai.init"
|
path = root / 'invokeai.init'
|
||||||
if path.exists():
|
if path.exists():
|
||||||
return _parse_legacy_initfile(root, path)
|
return _parse_legacy_initfile(root, path)
|
||||||
path = root / "invokeai.yaml"
|
path = root / 'invokeai.yaml'
|
||||||
if path.exists():
|
if path.exists():
|
||||||
return _parse_legacy_yamlfile(root, path)
|
return _parse_legacy_yamlfile(root, path)
|
||||||
|
|
||||||
|
|
||||||
def do_migrate(src_directory: Path, dest_directory: Path):
|
def do_migrate(src_directory: Path, dest_directory: Path):
|
||||||
"""
|
"""
|
||||||
Migrate models from src to dest InvokeAI root directories
|
Migrate models from src to dest InvokeAI root directories
|
||||||
"""
|
"""
|
||||||
config_file = dest_directory / "configs" / "models.yaml.3"
|
config_file = dest_directory / 'configs' / 'models.yaml.3'
|
||||||
dest_models = dest_directory / "models.3"
|
dest_models = dest_directory / 'models.3'
|
||||||
|
|
||||||
version_3 = (dest_directory / "models" / "core").exists()
|
version_3 = (dest_directory / 'models' / 'core').exists()
|
||||||
|
|
||||||
# Here we create the destination models.yaml file.
|
# Here we create the destination models.yaml file.
|
||||||
# If we are writing into a version 3 directory and the
|
# If we are writing into a version 3 directory and the
|
||||||
@@ -524,80 +530,80 @@ def do_migrate(src_directory: Path, dest_directory: Path):
|
|||||||
# create a new empty one.
|
# create a new empty one.
|
||||||
if version_3: # write into the dest directory
|
if version_3: # write into the dest directory
|
||||||
try:
|
try:
|
||||||
shutil.copy(dest_directory / "configs" / "models.yaml", config_file)
|
shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file)
|
||||||
except:
|
except:
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
MigrateTo3.initialize_yaml(config_file)
|
||||||
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory
|
||||||
(dest_directory / "models").replace(dest_models)
|
(dest_directory / 'models').replace(dest_models)
|
||||||
else:
|
else:
|
||||||
MigrateTo3.initialize_yaml(config_file)
|
MigrateTo3.initialize_yaml(config_file)
|
||||||
mgr = ModelManager(config_file)
|
mgr = ModelManager(config_file)
|
||||||
|
|
||||||
paths = get_legacy_embeddings(src_directory)
|
paths = get_legacy_embeddings(src_directory)
|
||||||
migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths)
|
migrator = MigrateTo3(
|
||||||
|
from_root = src_directory,
|
||||||
|
to_models = dest_models,
|
||||||
|
model_manager = mgr,
|
||||||
|
src_paths = paths
|
||||||
|
)
|
||||||
migrator.migrate()
|
migrator.migrate()
|
||||||
print("Migration successful.")
|
print("Migration successful.")
|
||||||
|
|
||||||
if not version_3:
|
if not version_3:
|
||||||
(dest_directory / "models").replace(src_directory / "models.orig")
|
(dest_directory / 'models').replace(src_directory / 'models.orig')
|
||||||
print(f"Original models directory moved to {dest_directory}/models.orig")
|
print(f'Original models directory moved to {dest_directory}/models.orig')
|
||||||
|
|
||||||
(dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig")
|
(dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig')
|
||||||
print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig")
|
print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig')
|
||||||
|
|
||||||
config_file.replace(config_file.with_suffix(""))
|
|
||||||
dest_models.replace(dest_models.with_suffix(""))
|
|
||||||
|
|
||||||
|
config_file.replace(config_file.with_suffix(''))
|
||||||
|
dest_models.replace(dest_models.with_suffix(''))
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(prog="invokeai-migrate3",
|
||||||
prog="invokeai-migrate3",
|
description="""
|
||||||
description="""
|
|
||||||
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format
|
||||||
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
'--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a
|
||||||
|
|
||||||
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively.
|
||||||
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure
|
||||||
script, which will perform a full upgrade in place.""",
|
script, which will perform a full upgrade in place."""
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument('--from-directory',
|
||||||
"--from-directory",
|
dest='src_root',
|
||||||
dest="src_root",
|
type=Path,
|
||||||
type=Path,
|
required=True,
|
||||||
required=True,
|
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")'
|
||||||
help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")',
|
)
|
||||||
)
|
parser.add_argument('--to-directory',
|
||||||
parser.add_argument(
|
dest='dest_root',
|
||||||
"--to-directory",
|
type=Path,
|
||||||
dest="dest_root",
|
required=True,
|
||||||
type=Path,
|
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")'
|
||||||
required=True,
|
)
|
||||||
help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")',
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
src_root = args.src_root
|
src_root = args.src_root
|
||||||
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
assert src_root.is_dir(), f"{src_root} is not a valid directory"
|
||||||
assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory"
|
||||||
assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory"
|
assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory"
|
||||||
assert (src_root / "invokeai.init").exists() or (
|
assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file."
|
||||||
src_root / "invokeai.yaml"
|
|
||||||
).exists(), f"{src_root} does not contain an InvokeAI init file."
|
|
||||||
|
|
||||||
dest_root = args.dest_root
|
dest_root = args.dest_root
|
||||||
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
assert dest_root.is_dir(), f"{dest_root} is not a valid directory"
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
config.parse_args(["--root", str(dest_root)])
|
config.parse_args(['--root',str(dest_root)])
|
||||||
|
|
||||||
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
# TODO: revisit - don't rely on invokeai.yaml to exist yet!
|
||||||
dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists()
|
dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists()
|
||||||
if not dest_is_setup:
|
if not dest_is_setup:
|
||||||
import invokeai.frontend.install.invokeai_configure
|
import invokeai.frontend.install.invokeai_configure
|
||||||
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
from invokeai.backend.install.invokeai_configure import initialize_rootdir
|
||||||
|
|
||||||
initialize_rootdir(dest_root, True)
|
initialize_rootdir(dest_root, True)
|
||||||
|
|
||||||
do_migrate(src_root, dest_root)
|
do_migrate(src_root,dest_root)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,15 +4,14 @@ Utility (backend) functions used by model_install.py
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass,field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Optional, List, Dict, Callable, Union, Set
|
from typing import List, Dict, Callable, Union, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
import onnx
|
|
||||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -29,7 +28,7 @@ warnings.filterwarnings("ignore")
|
|||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
logger = InvokeAILogger.getLogger(name="InvokeAI")
|
logger = InvokeAILogger.getLogger(name='InvokeAI')
|
||||||
|
|
||||||
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
||||||
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
||||||
@@ -46,63 +45,59 @@ Config_preamble = """
|
|||||||
|
|
||||||
LEGACY_CONFIGS = {
|
LEGACY_CONFIGS = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
ModelVariantType.Normal: "v1-inference.yaml",
|
ModelVariantType.Normal: 'v1-inference.yaml',
|
||||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml',
|
||||||
},
|
},
|
||||||
|
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelVariantType.Normal: {
|
ModelVariantType.Normal: {
|
||||||
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
SchedulerPredictionType.Epsilon: 'v2-inference.yaml',
|
||||||
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml',
|
||||||
},
|
},
|
||||||
ModelVariantType.Inpaint: {
|
ModelVariantType.Inpaint: {
|
||||||
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml',
|
||||||
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml',
|
||||||
},
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
BaseModelType.StableDiffusionXL: {
|
BaseModelType.StableDiffusionXL: {
|
||||||
ModelVariantType.Normal: "sd_xl_base.yaml",
|
ModelVariantType.Normal: 'sd_xl_base.yaml',
|
||||||
},
|
},
|
||||||
|
|
||||||
BaseModelType.StableDiffusionXLRefiner: {
|
BaseModelType.StableDiffusionXLRefiner: {
|
||||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
ModelVariantType.Normal: 'sd_xl_refiner.yaml',
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInstallList:
|
class ModelInstallList:
|
||||||
"""Class for listing models to be installed/removed"""
|
'''Class for listing models to be installed/removed'''
|
||||||
|
|
||||||
install_models: List[str] = field(default_factory=list)
|
install_models: List[str] = field(default_factory=list)
|
||||||
remove_models: List[str] = field(default_factory=list)
|
remove_models: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InstallSelections():
|
||||||
|
install_models: List[str]= field(default_factory=list)
|
||||||
|
remove_models: List[str]=field(default_factory=list)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InstallSelections:
|
class ModelLoadInfo():
|
||||||
install_models: List[str] = field(default_factory=list)
|
|
||||||
remove_models: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelLoadInfo:
|
|
||||||
name: str
|
name: str
|
||||||
model_type: ModelType
|
model_type: ModelType
|
||||||
base_type: BaseModelType
|
base_type: BaseModelType
|
||||||
path: Optional[Path] = None
|
path: Path = None
|
||||||
repo_id: Optional[str] = None
|
repo_id: str = None
|
||||||
description: str = ""
|
description: str = ''
|
||||||
installed: bool = False
|
installed: bool = False
|
||||||
recommended: bool = False
|
recommended: bool = False
|
||||||
default: bool = False
|
default: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ModelInstall(object):
|
class ModelInstall(object):
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
config:InvokeAIAppConfig,
|
||||||
config: InvokeAIAppConfig,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
model_manager: ModelManager = None,
|
||||||
model_manager: ModelManager = None,
|
access_token:str = None):
|
||||||
access_token: str = None,
|
|
||||||
):
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||||
self.datasets = OmegaConf.load(Dataset_path)
|
self.datasets = OmegaConf.load(Dataset_path)
|
||||||
@@ -110,79 +105,66 @@ class ModelInstall(object):
|
|||||||
self.access_token = access_token or HfFolder.get_token()
|
self.access_token = access_token or HfFolder.get_token()
|
||||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||||
|
|
||||||
def all_models(self) -> Dict[str, ModelLoadInfo]:
|
def all_models(self)->Dict[str,ModelLoadInfo]:
|
||||||
"""
|
'''
|
||||||
Return dict of model_key=>ModelLoadInfo objects.
|
Return dict of model_key=>ModelLoadInfo objects.
|
||||||
This method consolidates and simplifies the entries in both
|
This method consolidates and simplifies the entries in both
|
||||||
models.yaml and INITIAL_MODELS.yaml so that they can
|
models.yaml and INITIAL_MODELS.yaml so that they can
|
||||||
be treated uniformly. It also sorts the models alphabetically
|
be treated uniformly. It also sorts the models alphabetically
|
||||||
by their name, to improve the display somewhat.
|
by their name, to improve the display somewhat.
|
||||||
"""
|
'''
|
||||||
model_dict = dict()
|
model_dict = dict()
|
||||||
|
|
||||||
# first populate with the entries in INITIAL_MODELS.yaml
|
# first populate with the entries in INITIAL_MODELS.yaml
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
name, base, model_type = ModelManager.parse_key(key)
|
name,base,model_type = ModelManager.parse_key(key)
|
||||||
value["name"] = name
|
value['name'] = name
|
||||||
value["base_type"] = base
|
value['base_type'] = base
|
||||||
value["model_type"] = model_type
|
value['model_type'] = model_type
|
||||||
model_dict[key] = ModelLoadInfo(**value)
|
model_dict[key] = ModelLoadInfo(**value)
|
||||||
|
|
||||||
# supplement with entries in models.yaml
|
# supplement with entries in models.yaml
|
||||||
installed_models = [x for x in self.mgr.list_models()]
|
installed_models = self.mgr.list_models()
|
||||||
# suppresses autoloaded models
|
|
||||||
# installed_models = [x for x in self.mgr.list_models() if not self._is_autoloaded(x)]
|
|
||||||
|
|
||||||
for md in installed_models:
|
for md in installed_models:
|
||||||
base = md["base_model"]
|
base = md['base_model']
|
||||||
model_type = md["model_type"]
|
model_type = md['model_type']
|
||||||
name = md["model_name"]
|
name = md['model_name']
|
||||||
key = ModelManager.create_key(name, base, model_type)
|
key = ModelManager.create_key(name, base, model_type)
|
||||||
if key in model_dict:
|
if key in model_dict:
|
||||||
model_dict[key].installed = True
|
model_dict[key].installed = True
|
||||||
else:
|
else:
|
||||||
model_dict[key] = ModelLoadInfo(
|
model_dict[key] = ModelLoadInfo(
|
||||||
name=name,
|
name = name,
|
||||||
base_type=base,
|
base_type = base,
|
||||||
model_type=model_type,
|
model_type = model_type,
|
||||||
path=value.get("path"),
|
path = value.get('path'),
|
||||||
installed=True,
|
installed = True,
|
||||||
)
|
)
|
||||||
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
|
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
|
||||||
|
|
||||||
def _is_autoloaded(self, model_info: dict) -> bool:
|
|
||||||
path = model_info.get("path")
|
|
||||||
if not path:
|
|
||||||
return False
|
|
||||||
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
|
|
||||||
if autodir_path := getattr(self.config, autodir):
|
|
||||||
autodir_path = self.config.root_path / autodir_path
|
|
||||||
if Path(path).is_relative_to(autodir_path):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def list_models(self, model_type):
|
def list_models(self, model_type):
|
||||||
installed = self.mgr.list_models(model_type=model_type)
|
installed = self.mgr.list_models(model_type=model_type)
|
||||||
print(f"Installed models of type `{model_type}`:")
|
print(f'Installed models of type `{model_type}`:')
|
||||||
for i in installed:
|
for i in installed:
|
||||||
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
||||||
|
|
||||||
# logic here a little reversed to maintain backward compatibility
|
# logic here a little reversed to maintain backward compatibility
|
||||||
def starter_models(self, all_models: bool = False) -> Set[str]:
|
def starter_models(self, all_models: bool=False)->Set[str]:
|
||||||
models = set()
|
models = set()
|
||||||
for key, value in self.datasets.items():
|
for key, value in self.datasets.items():
|
||||||
name, base, model_type = ModelManager.parse_key(key)
|
name,base,model_type = ModelManager.parse_key(key)
|
||||||
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
||||||
models.add(key)
|
models.add(key)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def recommended_models(self) -> Set[str]:
|
def recommended_models(self)->Set[str]:
|
||||||
starters = self.starter_models(all_models=True)
|
starters = self.starter_models(all_models=True)
|
||||||
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
return set([x for x in starters if self.datasets[x].get('recommended',False)])
|
||||||
|
|
||||||
def default_model(self) -> str:
|
def default_model(self)->str:
|
||||||
starters = self.starter_models()
|
starters = self.starter_models()
|
||||||
defaults = [x for x in starters if self.datasets[x].get("default", False)]
|
defaults = [x for x in starters if self.datasets[x].get('default',False)]
|
||||||
return defaults[0]
|
return defaults[0]
|
||||||
|
|
||||||
def install(self, selections: InstallSelections):
|
def install(self, selections: InstallSelections):
|
||||||
@@ -194,17 +176,17 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
# remove requested models
|
# remove requested models
|
||||||
for key in selections.remove_models:
|
for key in selections.remove_models:
|
||||||
name, base, mtype = self.mgr.parse_key(key)
|
name,base,mtype = self.mgr.parse_key(key)
|
||||||
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
|
logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]')
|
||||||
try:
|
try:
|
||||||
self.mgr.del_model(name, base, mtype)
|
self.mgr.del_model(name,base,mtype)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
logger.warning(e)
|
logger.warning(e)
|
||||||
job += 1
|
job += 1
|
||||||
|
|
||||||
# add requested models
|
# add requested models
|
||||||
for path in selections.install_models:
|
for path in selections.install_models:
|
||||||
logger.info(f"Installing {path} [{job}/{jobs}]")
|
logger.info(f'Installing {path} [{job}/{jobs}]')
|
||||||
try:
|
try:
|
||||||
self.heuristic_import(path)
|
self.heuristic_import(path)
|
||||||
except (ValueError, KeyError) as e:
|
except (ValueError, KeyError) as e:
|
||||||
@@ -214,16 +196,15 @@ class ModelInstall(object):
|
|||||||
dlogging.set_verbosity(verbosity)
|
dlogging.set_verbosity(verbosity)
|
||||||
self.mgr.commit()
|
self.mgr.commit()
|
||||||
|
|
||||||
def heuristic_import(
|
def heuristic_import(self,
|
||||||
self,
|
model_path_id_or_url: Union[str,Path],
|
||||||
model_path_id_or_url: Union[str, Path],
|
models_installed: Set[Path]=None,
|
||||||
models_installed: Set[Path] = None,
|
)->Dict[str, AddModelResult]:
|
||||||
) -> Dict[str, AddModelResult]:
|
'''
|
||||||
"""
|
|
||||||
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
||||||
:param models_installed: Set of installed models, used for recursive invocation
|
:param models_installed: Set of installed models, used for recursive invocation
|
||||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||||
"""
|
'''
|
||||||
|
|
||||||
if not models_installed:
|
if not models_installed:
|
||||||
models_installed = dict()
|
models_installed = dict()
|
||||||
@@ -233,15 +214,13 @@ class ModelInstall(object):
|
|||||||
path = Path(model_path_id_or_url)
|
path = Path(model_path_id_or_url)
|
||||||
# checkpoint file, or similar
|
# checkpoint file, or similar
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
models_installed.update({str(path): self._install_path(path)})
|
models_installed.update({str(path):self._install_path(path)})
|
||||||
|
|
||||||
# folders style or similar
|
# folders style or similar
|
||||||
elif path.is_dir() and any(
|
elif path.is_dir() and any([(path/x).exists() for x in \
|
||||||
[
|
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
|
||||||
(path / x).exists()
|
]
|
||||||
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
|
):
|
||||||
]
|
|
||||||
):
|
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
||||||
|
|
||||||
# recursive scan
|
# recursive scan
|
||||||
@@ -250,7 +229,7 @@ class ModelInstall(object):
|
|||||||
self.heuristic_import(child, models_installed=models_installed)
|
self.heuristic_import(child, models_installed=models_installed)
|
||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
elif len(str(model_path_id_or_url).split("/")) == 2:
|
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||||
|
|
||||||
# a URL
|
# a URL
|
||||||
@@ -258,42 +237,40 @@ class ModelInstall(object):
|
|||||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
|
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||||
|
|
||||||
return models_installed
|
return models_installed
|
||||||
|
|
||||||
# install a model from a local path. The optional info parameter is there to prevent
|
# install a model from a local path. The optional info parameter is there to prevent
|
||||||
# the model from being probed twice in the event that it has already been probed.
|
# the model from being probed twice in the event that it has already been probed.
|
||||||
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult:
|
||||||
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
if not info:
|
if not info:
|
||||||
logger.warning(f"Unable to parse format of {path}")
|
logger.warning(f'Unable to parse format of {path}')
|
||||||
return None
|
return None
|
||||||
model_name = path.stem if path.is_file() else path.name
|
model_name = path.stem if path.is_file() else path.name
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||||
attributes = self._make_attributes(path, info)
|
attributes = self._make_attributes(path,info)
|
||||||
return self.mgr.add_model(
|
return self.mgr.add_model(model_name = model_name,
|
||||||
model_name=model_name,
|
base_model = info.base_type,
|
||||||
base_model=info.base_type,
|
model_type = info.model_type,
|
||||||
model_type=info.model_type,
|
model_attributes = attributes,
|
||||||
model_attributes=attributes,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
def _install_url(self, url: str) -> AddModelResult:
|
def _install_url(self, url: str)->AddModelResult:
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
location = download_with_resume(url, Path(staging))
|
location = download_with_resume(url,Path(staging))
|
||||||
if not location:
|
if not location:
|
||||||
logger.error(f"Unable to download {url}. Skipping.")
|
logger.error(f'Unable to download {url}. Skipping.')
|
||||||
info = ModelProbe().heuristic_probe(location)
|
info = ModelProbe().heuristic_probe(location)
|
||||||
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
models_path = shutil.move(location,dest)
|
||||||
models_path = shutil.move(location, dest)
|
|
||||||
|
|
||||||
# staged version will be garbage-collected at this time
|
# staged version will be garbage-collected at this time
|
||||||
return self._install_path(Path(models_path), info)
|
return self._install_path(Path(models_path), info)
|
||||||
|
|
||||||
def _install_repo(self, repo_id: str) -> AddModelResult:
|
def _install_repo(self, repo_id: str)->AddModelResult:
|
||||||
hinfo = HfApi().model_info(repo_id)
|
hinfo = HfApi().model_info(repo_id)
|
||||||
|
|
||||||
# we try to figure out how to download this most economically
|
# we try to figure out how to download this most economically
|
||||||
@@ -303,51 +280,42 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if "model_index.json" in files and "unet/model.onnx" not in files:
|
if 'model_index.json' in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
elif "unet/model.onnx" in files:
|
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
|
||||||
else:
|
else:
|
||||||
for suffix in ["safetensors", "bin"]:
|
for suffix in ['safetensors','bin']:
|
||||||
if f"pytorch_lora_weights.{suffix}" in files:
|
if f'pytorch_lora_weights.{suffix}' in files:
|
||||||
location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA
|
location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA
|
||||||
break
|
break
|
||||||
elif (
|
elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone
|
||||||
self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files
|
files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}']
|
||||||
): # vae, controlnet or some other standalone
|
|
||||||
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
break
|
break
|
||||||
elif f"diffusion_pytorch_model.{suffix}" in files:
|
elif f'diffusion_pytorch_model.{suffix}' in files:
|
||||||
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
files = ['config.json', f'diffusion_pytorch_model.{suffix}']
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
break
|
break
|
||||||
elif f"learned_embeds.{suffix}" in files:
|
elif f'learned_embeds.{suffix}' in files:
|
||||||
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging)
|
||||||
break
|
break
|
||||||
if not location:
|
if not location:
|
||||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.')
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
||||||
if not info:
|
if not info:
|
||||||
logger.warning(f"Could not probe {location}. Skipping install.")
|
logger.warning(f'Could not probe {location}. Skipping install.')
|
||||||
return {}
|
return {}
|
||||||
dest = (
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location)
|
||||||
self.config.models_path
|
|
||||||
/ info.base_type.value
|
|
||||||
/ info.model_type.value
|
|
||||||
/ self._get_model_name(repo_id, location)
|
|
||||||
)
|
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
shutil.rmtree(dest)
|
shutil.rmtree(dest)
|
||||||
shutil.copytree(location, dest)
|
shutil.copytree(location,dest)
|
||||||
return self._install_path(dest, info)
|
return self._install_path(dest, info)
|
||||||
|
|
||||||
def _get_model_name(self, path_name: str, location: Path) -> str:
|
def _get_model_name(self,path_name: str, location: Path)->str:
|
||||||
"""
|
'''
|
||||||
Calculate a name for the model - primitive implementation.
|
Calculate a name for the model - primitive implementation.
|
||||||
"""
|
'''
|
||||||
if key := self.reverse_paths.get(path_name):
|
if key := self.reverse_paths.get(path_name):
|
||||||
(name, base, mtype) = ModelManager.parse_key(key)
|
(name, base, mtype) = ModelManager.parse_key(key)
|
||||||
return name
|
return name
|
||||||
@@ -356,108 +324,99 @@ class ModelInstall(object):
|
|||||||
else:
|
else:
|
||||||
return location.stem
|
return location.stem
|
||||||
|
|
||||||
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
|
def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict:
|
||||||
model_name = path.name if path.is_dir() else path.stem
|
model_name = path.name if path.is_dir() else path.stem
|
||||||
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
|
description = f'{info.base_type.value} {info.model_type.value} model {model_name}'
|
||||||
if key := self.reverse_paths.get(self.current_id):
|
if key := self.reverse_paths.get(self.current_id):
|
||||||
if key in self.datasets:
|
if key in self.datasets:
|
||||||
description = self.datasets[key].get("description") or description
|
description = self.datasets[key].get('description') or description
|
||||||
|
|
||||||
rel_path = self.relative_to_root(path, self.config.models_path)
|
rel_path = self.relative_to_root(path)
|
||||||
|
|
||||||
attributes = dict(
|
attributes = dict(
|
||||||
path=str(rel_path),
|
path = str(rel_path),
|
||||||
description=str(description),
|
description = str(description),
|
||||||
model_format=info.format,
|
model_format = info.format,
|
||||||
)
|
|
||||||
legacy_conf = None
|
|
||||||
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
|
||||||
attributes.update(
|
|
||||||
dict(
|
|
||||||
variant=info.variant_type,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if info.format == "checkpoint":
|
legacy_conf = None
|
||||||
|
if info.model_type == ModelType.Main:
|
||||||
|
attributes.update(dict(variant = info.variant_type,))
|
||||||
|
if info.format=="checkpoint":
|
||||||
try:
|
try:
|
||||||
possible_conf = path.with_suffix(".yaml")
|
possible_conf = path.with_suffix('.yaml')
|
||||||
if possible_conf.exists():
|
if possible_conf.exists():
|
||||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||||
elif info.base_type == BaseModelType.StableDiffusion2:
|
elif info.base_type == BaseModelType.StableDiffusion2:
|
||||||
legacy_conf = Path(
|
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type])
|
||||||
self.config.legacy_conf_dir,
|
|
||||||
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
legacy_conf = Path(
|
legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type])
|
||||||
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
|
|
||||||
)
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
|
legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess
|
||||||
|
|
||||||
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
|
if info.model_type == ModelType.ControlNet and info.format=="checkpoint":
|
||||||
possible_conf = path.with_suffix(".yaml")
|
possible_conf = path.with_suffix('.yaml')
|
||||||
if possible_conf.exists():
|
if possible_conf.exists():
|
||||||
legacy_conf = str(self.relative_to_root(possible_conf))
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
||||||
|
|
||||||
if legacy_conf:
|
if legacy_conf:
|
||||||
attributes.update(dict(config=str(legacy_conf)))
|
attributes.update(
|
||||||
|
dict(
|
||||||
|
config = str(legacy_conf)
|
||||||
|
)
|
||||||
|
)
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
def relative_to_root(self, path: Path)->Path:
|
||||||
root = root or self.config.root_path
|
root = self.config.root_path
|
||||||
if path.is_relative_to(root):
|
if path.is_relative_to(root):
|
||||||
return path.relative_to(root)
|
return path.relative_to(root)
|
||||||
else:
|
else:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path:
|
def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path:
|
||||||
"""
|
'''
|
||||||
This retrieves a StableDiffusion model from cache or remote and then
|
This retrieves a StableDiffusion model from cache or remote and then
|
||||||
does a save_pretrained() to the indicated staging area.
|
does a save_pretrained() to the indicated staging area.
|
||||||
"""
|
'''
|
||||||
_, name = repo_id.split("/")
|
_,name = repo_id.split("/")
|
||||||
revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"]
|
revisions = ['fp16','main'] if self.config.precision=='float16' else ['main']
|
||||||
model = None
|
model = None
|
||||||
for revision in revisions:
|
for revision in revisions:
|
||||||
try:
|
try:
|
||||||
model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None)
|
model = DiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None)
|
||||||
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
except: # most errors are due to fp16 not being present. Fix this to catch other errors
|
||||||
pass
|
pass
|
||||||
if model:
|
if model:
|
||||||
break
|
break
|
||||||
if not model:
|
if not model:
|
||||||
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.')
|
||||||
return None
|
return None
|
||||||
model.save_pretrained(staging / name, safe_serialization=True)
|
model.save_pretrained(staging / name, safe_serialization=True)
|
||||||
return staging / name
|
return staging / name
|
||||||
|
|
||||||
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path) -> Path:
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path:
|
||||||
_, name = repo_id.split("/")
|
_,name = repo_id.split("/")
|
||||||
location = staging / name
|
location = staging / name
|
||||||
paths = list()
|
paths = list()
|
||||||
for filename in files:
|
for filename in files:
|
||||||
filePath = Path(filename)
|
p = hf_download_with_resume(repo_id,
|
||||||
p = hf_download_with_resume(
|
model_dir=location,
|
||||||
repo_id,
|
model_name=filename,
|
||||||
model_dir=location / filePath.parent,
|
access_token = self.access_token
|
||||||
model_name=filePath.name,
|
)
|
||||||
access_token=self.access_token,
|
|
||||||
subfolder=filePath.parent,
|
|
||||||
)
|
|
||||||
if p:
|
if p:
|
||||||
paths.append(p)
|
paths.append(p)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Could not download {filename} from {repo_id}.")
|
logger.warning(f'Could not download {filename} from {repo_id}.')
|
||||||
|
|
||||||
return location if len(paths) > 0 else None
|
return location if len(paths)>0 else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _reverse_paths(cls, datasets) -> dict:
|
def _reverse_paths(cls,datasets)->dict:
|
||||||
"""
|
'''
|
||||||
Reverse mapping from repo_id/path to destination name.
|
Reverse mapping from repo_id/path to destination name.
|
||||||
"""
|
'''
|
||||||
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
|
return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()}
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def yes_or_no(prompt: str, default_yes=True):
|
def yes_or_no(prompt: str, default_yes=True):
|
||||||
@@ -468,11 +427,12 @@ def yes_or_no(prompt: str, default_yes=True):
|
|||||||
else:
|
else:
|
||||||
return response[0] in ("y", "Y")
|
return response[0] in ("y", "Y")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
def hf_download_from_pretrained(
|
||||||
logger = InvokeAILogger.getLogger("InvokeAI")
|
model_class: object, model_name: str, destination: Path, **kwargs
|
||||||
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
):
|
||||||
|
logger = InvokeAILogger.getLogger('InvokeAI')
|
||||||
|
logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage())
|
||||||
|
|
||||||
model = model_class.from_pretrained(
|
model = model_class.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
@@ -482,20 +442,18 @@ def hf_download_from_pretrained(model_class: object, model_name: str, destinatio
|
|||||||
model.save_pretrained(destination, safe_serialization=True)
|
model.save_pretrained(destination, safe_serialization=True)
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def hf_download_with_resume(
|
def hf_download_with_resume(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
model_dir: str,
|
model_dir: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_dest: Path = None,
|
model_dest: Path = None,
|
||||||
access_token: str = None,
|
access_token: str = None,
|
||||||
subfolder: str = None,
|
|
||||||
) -> Path:
|
) -> Path:
|
||||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
url = hf_hub_url(repo_id, model_name)
|
||||||
|
|
||||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||||
open_mode = "wb"
|
open_mode = "wb"
|
||||||
@@ -509,7 +467,9 @@ def hf_download_with_resume(
|
|||||||
resp = requests.get(url, headers=header, stream=True)
|
resp = requests.get(url, headers=header, stream=True)
|
||||||
total = int(resp.headers.get("content-length", 0))
|
total = int(resp.headers.get("content-length", 0))
|
||||||
|
|
||||||
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
|
if (
|
||||||
|
resp.status_code == 416
|
||||||
|
): # "range not satisfiable", which means nothing to return
|
||||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||||
return model_dest
|
return model_dest
|
||||||
elif resp.status_code == 404:
|
elif resp.status_code == 404:
|
||||||
@@ -538,3 +498,5 @@ def hf_download_with_resume(
|
|||||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
return model_dest
|
return model_dest
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,13 +3,6 @@ Initialization file for invokeai.backend.model_management
|
|||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .lora import ModelPatcher, ONNXModelPatcher
|
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException, DuplicateModelException
|
||||||
from .models import (
|
|
||||||
BaseModelType,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
ModelVariantType,
|
|
||||||
ModelNotFoundException,
|
|
||||||
DuplicateModelException,
|
|
||||||
)
|
|
||||||
from .model_merge import ModelMerger, MergeInterpolationMethod
|
from .model_merge import ModelMerger, MergeInterpolationMethod
|
||||||
|
|
||||||
|
|||||||
@@ -56,14 +56,16 @@ from diffusers.schedulers import (
|
|||||||
)
|
)
|
||||||
from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available
|
from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available
|
||||||
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
||||||
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||||
|
LDMBertConfig, LDMBertModel
|
||||||
|
)
|
||||||
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig, MODEL_CORE
|
||||||
|
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from .models import BaseModelType, ModelVariantType
|
from .models import BaseModelType, ModelVariantType
|
||||||
@@ -81,8 +83,7 @@ if is_accelerate_available():
|
|||||||
from accelerate.utils import set_module_tensor_to_device
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
|
||||||
logger = InvokeAILogger.getLogger(__name__)
|
logger = InvokeAILogger.getLogger(__name__)
|
||||||
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().models_path / "core/convert"
|
CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().root_path / MODEL_CORE / "convert"
|
||||||
|
|
||||||
|
|
||||||
def shave_segments(path, n_shave_prefix_segments=1):
|
def shave_segments(path, n_shave_prefix_segments=1):
|
||||||
"""
|
"""
|
||||||
@@ -508,7 +509,9 @@ def convert_ldm_unet_checkpoint(
|
|||||||
|
|
||||||
paths = renew_resnet_paths(resnets)
|
paths = renew_resnet_paths(resnets)
|
||||||
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
||||||
assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config)
|
assign_to_checkpoint(
|
||||||
|
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
||||||
|
)
|
||||||
|
|
||||||
if len(attentions):
|
if len(attentions):
|
||||||
paths = renew_attention_paths(attentions)
|
paths = renew_attention_paths(attentions)
|
||||||
@@ -793,7 +796,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
|||||||
|
|
||||||
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
||||||
if text_encoder is None:
|
if text_encoder is None:
|
||||||
config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / 'clip-vit-large-patch14')
|
||||||
|
|
||||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||||
with ctx():
|
with ctx():
|
||||||
@@ -1005,9 +1008,7 @@ def stable_unclip_image_encoder(original_config):
|
|||||||
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
||||||
feature_extractor = CLIPImageProcessor()
|
feature_extractor = CLIPImageProcessor()
|
||||||
# InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur
|
# InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur
|
||||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||||
CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
||||||
@@ -1070,7 +1071,7 @@ def convert_controlnet_checkpoint(
|
|||||||
extract_ema,
|
extract_ema,
|
||||||
use_linear_projection=None,
|
use_linear_projection=None,
|
||||||
cross_attention_dim=None,
|
cross_attention_dim=None,
|
||||||
precision: Optional[torch.dtype] = None,
|
precision: torch.dtype=torch.float32,
|
||||||
):
|
):
|
||||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||||
@@ -1078,9 +1079,9 @@ def convert_controlnet_checkpoint(
|
|||||||
ctrlnet_config.pop("sample_size")
|
ctrlnet_config.pop("sample_size")
|
||||||
original_config = ctrlnet_config.copy()
|
original_config = ctrlnet_config.copy()
|
||||||
|
|
||||||
ctrlnet_config.pop("addition_embed_type")
|
ctrlnet_config.pop('addition_embed_type')
|
||||||
ctrlnet_config.pop("addition_time_embed_dim")
|
ctrlnet_config.pop('addition_time_embed_dim')
|
||||||
ctrlnet_config.pop("transformer_layers_per_block")
|
ctrlnet_config.pop('transformer_layers_per_block')
|
||||||
|
|
||||||
if use_linear_projection is not None:
|
if use_linear_projection is not None:
|
||||||
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
||||||
@@ -1110,7 +1111,7 @@ def convert_controlnet_checkpoint(
|
|||||||
|
|
||||||
return controlnet.to(precision)
|
return controlnet.to(precision)
|
||||||
|
|
||||||
|
# TO DO - PASS PRECISION
|
||||||
def download_from_original_stable_diffusion_ckpt(
|
def download_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
model_version: BaseModelType,
|
model_version: BaseModelType,
|
||||||
@@ -1120,7 +1121,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
prediction_type: str = None,
|
prediction_type: str = None,
|
||||||
model_type: str = None,
|
model_type: str = None,
|
||||||
extract_ema: bool = False,
|
extract_ema: bool = False,
|
||||||
precision: Optional[torch.dtype] = None,
|
precision: torch.dtype = torch.float32,
|
||||||
scheduler_type: str = "pndm",
|
scheduler_type: str = "pndm",
|
||||||
num_in_channels: Optional[int] = None,
|
num_in_channels: Optional[int] = None,
|
||||||
upcast_attention: Optional[bool] = None,
|
upcast_attention: Optional[bool] = None,
|
||||||
@@ -1193,8 +1194,6 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer)
|
||||||
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
|
to use. If this parameter is `None`, the function will load a new instance of [CLIPTokenizer] by itself, if
|
||||||
needed.
|
needed.
|
||||||
precision (`torch.dtype`, *optional*, defauts to `None`):
|
|
||||||
If not provided the precision will be set to the precision of the original file.
|
|
||||||
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -1251,11 +1250,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
while "state_dict" in checkpoint:
|
while "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}")
|
logger.debug(f'model_type = {model_type}; original_config_file = {original_config_file}')
|
||||||
|
|
||||||
precision_probing_key = "model.diffusion_model.input_blocks.0.0.bias"
|
|
||||||
logger.debug(f"original checkpoint precision == {checkpoint[precision_probing_key].dtype}")
|
|
||||||
precision = precision or checkpoint[precision_probing_key].dtype
|
|
||||||
|
|
||||||
if original_config_file is None:
|
if original_config_file is None:
|
||||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
@@ -1263,9 +1258,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||||
|
|
||||||
# model_type = "v1"
|
# model_type = "v1"
|
||||||
config_url = (
|
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
|
||||||
)
|
|
||||||
|
|
||||||
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
||||||
# model_type = "v2"
|
# model_type = "v2"
|
||||||
@@ -1284,13 +1277,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
original_config_file = BytesIO(requests.get(config_url).content)
|
original_config_file = BytesIO(requests.get(config_url).content)
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
if original_config["model"]["params"].get("use_ema") is not None:
|
if model_version == BaseModelType.StableDiffusion2 and original_config["model"]["params"]["parameterization"] == "v":
|
||||||
extract_ema = original_config["model"]["params"]["use_ema"]
|
|
||||||
|
|
||||||
if (
|
|
||||||
model_version == BaseModelType.StableDiffusion2
|
|
||||||
and original_config["model"]["params"].get("parameterization") == "v"
|
|
||||||
):
|
|
||||||
prediction_type = "v_prediction"
|
prediction_type = "v_prediction"
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
image_size = 768
|
image_size = 768
|
||||||
@@ -1449,13 +1436,13 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
config_kwargs = {"subfolder": "text_encoder"}
|
config_kwargs = {"subfolder": "text_encoder"}
|
||||||
|
|
||||||
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer")
|
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / 'stable-diffusion-2-clip', subfolder="tokenizer")
|
||||||
|
|
||||||
if stable_unclip is None:
|
if stable_unclip is None:
|
||||||
if controlnet:
|
if controlnet:
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model.to(precision),
|
text_encoder=text_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@@ -1467,7 +1454,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
else:
|
else:
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model.to(precision),
|
text_encoder=text_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@@ -1492,8 +1479,8 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
image_noising_scheduler=image_noising_scheduler,
|
image_noising_scheduler=image_noising_scheduler,
|
||||||
# regular denoising components
|
# regular denoising components
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_model.to(precision),
|
text_encoder=text_model,
|
||||||
unet=unet.to(precision),
|
unet=unet,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
# vae
|
# vae
|
||||||
vae=vae,
|
vae=vae,
|
||||||
@@ -1504,9 +1491,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior")
|
prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior")
|
||||||
|
|
||||||
prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
||||||
prior_text_model = CLIPTextModelWithProjection.from_pretrained(
|
prior_text_model = CLIPTextModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
||||||
CONVERT_MODEL_ROOT / "clip-vit-large-patch14"
|
|
||||||
)
|
|
||||||
|
|
||||||
prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler")
|
prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler")
|
||||||
prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
|
prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config)
|
||||||
@@ -1548,19 +1533,11 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
text_model = convert_ldm_clip_checkpoint(
|
text_model = convert_ldm_clip_checkpoint(
|
||||||
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
|
checkpoint, local_files_only=local_files_only, text_encoder=text_encoder
|
||||||
)
|
)
|
||||||
tokenizer = (
|
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") if tokenizer is None else tokenizer
|
||||||
CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14")
|
|
||||||
if tokenizer is None
|
|
||||||
else tokenizer
|
|
||||||
)
|
|
||||||
|
|
||||||
if load_safety_checker:
|
if load_safety_checker:
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker")
|
||||||
CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker"
|
feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker")
|
||||||
)
|
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
||||||
CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
safety_checker = None
|
safety_checker = None
|
||||||
feature_extractor = None
|
feature_extractor = None
|
||||||
@@ -1568,7 +1545,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
if controlnet:
|
if controlnet:
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model.to(precision),
|
text_encoder=text_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
controlnet=controlnet,
|
controlnet=controlnet,
|
||||||
@@ -1579,7 +1556,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
else:
|
else:
|
||||||
pipe = pipeline_class(
|
pipe = pipeline_class(
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_model.to(precision),
|
text_encoder=text_model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@@ -1600,11 +1577,11 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
pipe = StableDiffusionXLPipeline(
|
pipe = StableDiffusionXLPipeline (
|
||||||
vae=vae.to(precision),
|
vae=vae.to(precision),
|
||||||
text_encoder=text_encoder.to(precision),
|
text_encoder=text_encoder,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder_2=text_encoder_2.to(precision),
|
text_encoder_2=text_encoder_2,
|
||||||
tokenizer_2=tokenizer_2,
|
tokenizer_2=tokenizer_2,
|
||||||
unet=unet.to(precision),
|
unet=unet.to(precision),
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
@@ -1647,7 +1624,7 @@ def download_controlnet_from_original_ckpt(
|
|||||||
original_config_file: str,
|
original_config_file: str,
|
||||||
image_size: int = 512,
|
image_size: int = 512,
|
||||||
extract_ema: bool = False,
|
extract_ema: bool = False,
|
||||||
precision: Optional[torch.dtype] = None,
|
precision: torch.dtype = torch.float32,
|
||||||
num_in_channels: Optional[int] = None,
|
num_in_channels: Optional[int] = None,
|
||||||
upcast_attention: Optional[bool] = None,
|
upcast_attention: Optional[bool] = None,
|
||||||
device: str = None,
|
device: str = None,
|
||||||
@@ -1688,12 +1665,6 @@ def download_controlnet_from_original_ckpt(
|
|||||||
while "state_dict" in checkpoint:
|
while "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
# use original precision
|
|
||||||
precision_probing_key = "input_blocks.0.0.bias"
|
|
||||||
ckpt_precision = checkpoint[precision_probing_key].dtype
|
|
||||||
logger.debug(f"original controlnet precision = {ckpt_precision}")
|
|
||||||
precision = precision or ckpt_precision
|
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
if num_in_channels is not None:
|
if num_in_channels is not None:
|
||||||
@@ -1713,24 +1684,26 @@ def download_controlnet_from_original_ckpt(
|
|||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
return controlnet.to(precision)
|
return controlnet
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
|
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL:
|
||||||
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
|
vae_config = create_vae_diffusers_config(
|
||||||
|
vae_config, image_size=image_size
|
||||||
|
)
|
||||||
|
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||||
|
checkpoint, vae_config
|
||||||
|
)
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
return vae
|
return vae
|
||||||
|
|
||||||
|
|
||||||
def convert_ckpt_to_diffusers(
|
def convert_ckpt_to_diffusers(
|
||||||
checkpoint_path: Union[str, Path],
|
checkpoint_path: Union[str, Path],
|
||||||
dump_path: Union[str, Path],
|
dump_path: Union[str, Path],
|
||||||
use_safetensors: bool = True,
|
use_safetensors: bool=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
|
||||||
@@ -1744,11 +1717,10 @@ def convert_ckpt_to_diffusers(
|
|||||||
safe_serialization=use_safetensors and is_safetensors_available(),
|
safe_serialization=use_safetensors and is_safetensors_available(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_controlnet_to_diffusers(
|
def convert_controlnet_to_diffusers(
|
||||||
checkpoint_path: Union[str, Path],
|
checkpoint_path: Union[str, Path],
|
||||||
dump_path: Union[str, Path],
|
dump_path: Union[str, Path],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
Takes all the arguments of download_controlnet_from_original_ckpt(),
|
||||||
|
|||||||
@@ -6,31 +6,19 @@ from typing import Optional, Dict, Tuple, Any, Union, List
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
|
||||||
from torch.utils.hooks import RemovableHandle
|
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
|
||||||
from transformers import CLIPTextModel
|
|
||||||
from onnx import numpy_helper
|
|
||||||
from onnxruntime import OrtValue
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
# TODO: rename and split this file
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
class LoRALayerBase:
|
||||||
# rank: Optional[int]
|
#rank: Optional[int]
|
||||||
# alpha: Optional[float]
|
#alpha: Optional[float]
|
||||||
# bias: Optional[torch.Tensor]
|
#bias: Optional[torch.Tensor]
|
||||||
# layer_key: str
|
#layer_key: str
|
||||||
|
|
||||||
# @property
|
#@property
|
||||||
# def scale(self):
|
#def scale(self):
|
||||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -43,7 +31,11 @@ class LoRALayerBase:
|
|||||||
else:
|
else:
|
||||||
self.alpha = None
|
self.alpha = None
|
||||||
|
|
||||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
if (
|
||||||
|
"bias_indices" in values
|
||||||
|
and "bias_values" in values
|
||||||
|
and "bias_size" in values
|
||||||
|
):
|
||||||
self.bias = torch.sparse_coo_tensor(
|
self.bias = torch.sparse_coo_tensor(
|
||||||
values["bias_indices"],
|
values["bias_indices"],
|
||||||
values["bias_values"],
|
values["bias_values"],
|
||||||
@@ -53,13 +45,13 @@ class LoRALayerBase:
|
|||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
self.rank = None # set in layer implementation
|
self.rank = None # set in layer implementation
|
||||||
self.layer_key = layer_key
|
self.layer_key = layer_key
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
module: torch.nn.Module,
|
module: torch.nn.Module,
|
||||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||||
multiplier: float,
|
multiplier: float,
|
||||||
):
|
):
|
||||||
if type(module) == torch.nn.Conv2d:
|
if type(module) == torch.nn.Conv2d:
|
||||||
@@ -79,16 +71,12 @@ class LoRALayerBase:
|
|||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
bias = self.bias if self.bias is not None else 0
|
||||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
return (
|
return op(
|
||||||
op(
|
*input_h,
|
||||||
*input_h,
|
(weight + bias).view(module.weight.shape),
|
||||||
(weight + bias).view(module.weight.shape),
|
None,
|
||||||
None,
|
**extra_args,
|
||||||
**extra_args,
|
) * multiplier * scale
|
||||||
)
|
|
||||||
* multiplier
|
|
||||||
* scale
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -111,9 +99,9 @@ class LoRALayerBase:
|
|||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
# up: torch.Tensor
|
#up: torch.Tensor
|
||||||
# mid: Optional[torch.Tensor]
|
#mid: Optional[torch.Tensor]
|
||||||
# down: torch.Tensor
|
#down: torch.Tensor
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -163,12 +151,12 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
class LoHALayer(LoRALayerBase):
|
||||||
# w1_a: torch.Tensor
|
#w1_a: torch.Tensor
|
||||||
# w1_b: torch.Tensor
|
#w1_b: torch.Tensor
|
||||||
# w2_a: torch.Tensor
|
#w2_a: torch.Tensor
|
||||||
# w2_b: torch.Tensor
|
#w2_b: torch.Tensor
|
||||||
# t1: Optional[torch.Tensor] = None
|
#t1: Optional[torch.Tensor] = None
|
||||||
# t2: Optional[torch.Tensor] = None
|
#t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -199,8 +187,12 @@ class LoHALayer(LoRALayerBase):
|
|||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
rebuild1 = torch.einsum(
|
||||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
|
||||||
|
)
|
||||||
|
rebuild2 = torch.einsum(
|
||||||
|
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
|
||||||
|
)
|
||||||
weight = rebuild1 * rebuild2
|
weight = rebuild1 * rebuild2
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
@@ -231,13 +223,13 @@ class LoHALayer(LoRALayerBase):
|
|||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
class LoKRLayer(LoRALayerBase):
|
||||||
# w1: Optional[torch.Tensor] = None
|
#w1: Optional[torch.Tensor] = None
|
||||||
# w1_a: Optional[torch.Tensor] = None
|
#w1_a: Optional[torch.Tensor] = None
|
||||||
# w1_b: Optional[torch.Tensor] = None
|
#w1_b: Optional[torch.Tensor] = None
|
||||||
# w2: Optional[torch.Tensor] = None
|
#w2: Optional[torch.Tensor] = None
|
||||||
# w2_a: Optional[torch.Tensor] = None
|
#w2_a: Optional[torch.Tensor] = None
|
||||||
# w2_b: Optional[torch.Tensor] = None
|
#w2_b: Optional[torch.Tensor] = None
|
||||||
# t2: Optional[torch.Tensor] = None
|
#t2: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -274,7 +266,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
elif "lokr_w2_b" in values:
|
elif "lokr_w2_b" in values:
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
self.rank = values["lokr_w2_b"].shape[0]
|
||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
w1 = self.w1
|
w1 = self.w1
|
||||||
@@ -286,7 +278,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
if self.t2 is None:
|
if self.t2 is None:
|
||||||
w2 = self.w2_a @ self.w2_b
|
w2 = self.w2_a @ self.w2_b
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
@@ -325,7 +317,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel: # (torch.nn.Module):
|
class LoRAModel: #(torch.nn.Module):
|
||||||
_name: str
|
_name: str
|
||||||
layers: Dict[str, LoRALayer]
|
layers: Dict[str, LoRALayer]
|
||||||
_device: torch.device
|
_device: torch.device
|
||||||
@@ -388,7 +380,7 @@ class LoRAModel: # (torch.nn.Module):
|
|||||||
model = cls(
|
model = cls(
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
name=file_path.stem, # TODO:
|
name=file_path.stem, # TODO:
|
||||||
layers=dict(),
|
layers=dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -400,6 +392,7 @@ class LoRAModel: # (torch.nn.Module):
|
|||||||
state_dict = cls._group_state(state_dict)
|
state_dict = cls._group_state(state_dict)
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
for layer_key, values in state_dict.items():
|
||||||
|
|
||||||
# lora and locon
|
# lora and locon
|
||||||
if "lora_down.weight" in values:
|
if "lora_down.weight" in values:
|
||||||
layer = LoRALayer(layer_key, values)
|
layer = LoRALayer(layer_key, values)
|
||||||
@@ -414,7 +407,9 @@ class LoRAModel: # (torch.nn.Module):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# TODO: diff/ia3/... format
|
# TODO: diff/ia3/... format
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
print(
|
||||||
|
f">> Encountered unknown lora layer module in {model.name}: {layer_key}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
# lower memory consumption by removing already parsed layer values
|
||||||
@@ -448,10 +443,9 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
|||||||
# unmodified unet
|
# unmodified unet
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename smth like ModelPatcher and add TI method?
|
# TODO: rename smth like ModelPatcher and add TI method?
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
assert "." not in lora_key
|
assert "." not in lora_key
|
||||||
@@ -461,7 +455,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
module = model
|
module = model
|
||||||
module_key = ""
|
module_key = ""
|
||||||
key_parts = lora_key[len(prefix) :].split("_")
|
key_parts = lora_key[len(prefix):].split('_')
|
||||||
|
|
||||||
submodule_name = key_parts.pop(0)
|
submodule_name = key_parts.pop(0)
|
||||||
|
|
||||||
@@ -483,6 +477,7 @@ class ModelPatcher:
|
|||||||
applied_loras: List[Tuple[LoRAModel, float]],
|
applied_loras: List[Tuple[LoRAModel, float]],
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
):
|
):
|
||||||
|
|
||||||
def lora_forward(module, input_h, output):
|
def lora_forward(module, input_h, output):
|
||||||
if len(applied_loras) == 0:
|
if len(applied_loras) == 0:
|
||||||
return output
|
return output
|
||||||
@@ -496,6 +491,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return lora_forward
|
return lora_forward
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora_unet(
|
def apply_lora_unet(
|
||||||
@@ -506,6 +502,7 @@ class ModelPatcher:
|
|||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora_text_encoder(
|
def apply_lora_text_encoder(
|
||||||
@@ -516,6 +513,7 @@ class ModelPatcher:
|
|||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora(
|
def apply_lora(
|
||||||
@@ -528,7 +526,7 @@ class ModelPatcher:
|
|||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
# assert lora.device.type == "cpu"
|
#assert lora.device.type == "cpu"
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
if not layer_key.startswith(prefix):
|
if not layer_key.startswith(prefix):
|
||||||
continue
|
continue
|
||||||
@@ -538,7 +536,7 @@ class ModelPatcher:
|
|||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
# enable autocast to calc fp16 loras on cpu
|
# enable autocast to calc fp16 loras on cpu
|
||||||
# with torch.autocast(device_type="cpu"):
|
#with torch.autocast(device_type="cpu"):
|
||||||
layer.to(dtype=torch.float32)
|
layer.to(dtype=torch.float32)
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
||||||
@@ -549,13 +547,14 @@ class ModelPatcher:
|
|||||||
|
|
||||||
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
||||||
|
|
||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for module_key, weight in original_weights.items():
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
model.get_submodule(module_key).weight.copy_(weight)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ti(
|
def apply_ti(
|
||||||
@@ -603,9 +602,7 @@ class ModelPatcher:
|
|||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
model_embeddings.weight.data[token_id] = embedding.to(
|
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||||
device=text_encoder.device, dtype=text_encoder.dtype
|
|
||||||
)
|
|
||||||
ti_tokens.append(token_id)
|
ti_tokens.append(token_id)
|
||||||
|
|
||||||
if len(ti_tokens) > 1:
|
if len(ti_tokens) > 1:
|
||||||
@@ -617,6 +614,7 @@ class ModelPatcher:
|
|||||||
if init_tokens_count and new_tokens_added:
|
if init_tokens_count and new_tokens_added:
|
||||||
text_encoder.resize_token_embeddings(init_tokens_count)
|
text_encoder.resize_token_embeddings(init_tokens_count)
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_clip_skip(
|
def apply_clip_skip(
|
||||||
@@ -635,10 +633,9 @@ class ModelPatcher:
|
|||||||
while len(skipped_layers) > 0:
|
while len(skipped_layers) > 0:
|
||||||
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionModel:
|
class TextualInversionModel:
|
||||||
name: str
|
name: str
|
||||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
@@ -650,11 +647,8 @@ class TextualInversionModel:
|
|||||||
if not isinstance(file_path, Path):
|
if not isinstance(file_path, Path):
|
||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
result = cls() # TODO:
|
result = cls() # TODO:
|
||||||
if file_path.name == "learned_embeds.bin":
|
result.name = file_path.stem # TODO:
|
||||||
result.name = file_path.parent.name
|
|
||||||
else:
|
|
||||||
result.name = file_path.stem
|
|
||||||
|
|
||||||
if file_path.suffix == ".safetensors":
|
if file_path.suffix == ".safetensors":
|
||||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||||
@@ -665,9 +659,7 @@ class TextualInversionModel:
|
|||||||
# difference mostly in metadata
|
# difference mostly in metadata
|
||||||
if "string_to_param" in state_dict:
|
if "string_to_param" in state_dict:
|
||||||
if len(state_dict["string_to_param"]) > 1:
|
if len(state_dict["string_to_param"]) > 1:
|
||||||
print(
|
print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.")
|
||||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
|
|
||||||
)
|
|
||||||
|
|
||||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||||
|
|
||||||
@@ -696,7 +688,10 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
self.pad_tokens = dict()
|
self.pad_tokens = dict()
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
def expand_textual_inversion_token_ids_if_necessary(
|
||||||
|
self, token_ids: list[int]
|
||||||
|
) -> list[int]:
|
||||||
|
|
||||||
if len(self.pad_tokens) == 0:
|
if len(self.pad_tokens) == 0:
|
||||||
return token_ids
|
return token_ids
|
||||||
|
|
||||||
@@ -713,185 +708,3 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
return new_token_ids
|
return new_token_ids
|
||||||
|
|
||||||
|
|
||||||
class ONNXModelPatcher:
|
|
||||||
from .models.base import IAIOnnxRuntimeModel, OnnxRuntimeModel
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_unet(
|
|
||||||
cls,
|
|
||||||
unet: OnnxRuntimeModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
):
|
|
||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder: OnnxRuntimeModel,
|
|
||||||
loras: List[Tuple[LoRAModel, float]],
|
|
||||||
):
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
# based on
|
|
||||||
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_lora(
|
|
||||||
cls,
|
|
||||||
model: IAIOnnxRuntimeModel,
|
|
||||||
loras: List[Tuple[LoraModel, float]],
|
|
||||||
prefix: str,
|
|
||||||
):
|
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
|
||||||
|
|
||||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
|
||||||
|
|
||||||
orig_weights = dict()
|
|
||||||
|
|
||||||
try:
|
|
||||||
blended_loras = dict()
|
|
||||||
|
|
||||||
for lora, lora_weight in loras:
|
|
||||||
for layer_key, layer in lora.layers.items():
|
|
||||||
if not layer_key.startswith(prefix):
|
|
||||||
continue
|
|
||||||
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
layer_key = layer_key.replace(prefix, "")
|
|
||||||
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
|
||||||
if layer_key is blended_loras:
|
|
||||||
blended_loras[layer_key] += layer_weight
|
|
||||||
else:
|
|
||||||
blended_loras[layer_key] = layer_weight
|
|
||||||
|
|
||||||
node_names = dict()
|
|
||||||
for node in model.nodes.values():
|
|
||||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
|
||||||
|
|
||||||
for layer_key, lora_weight in blended_loras.items():
|
|
||||||
conv_key = layer_key + "_Conv"
|
|
||||||
gemm_key = layer_key + "_Gemm"
|
|
||||||
matmul_key = layer_key + "_MatMul"
|
|
||||||
|
|
||||||
if conv_key in node_names or gemm_key in node_names:
|
|
||||||
if conv_key in node_names:
|
|
||||||
conv_node = model.nodes[node_names[conv_key]]
|
|
||||||
else:
|
|
||||||
conv_node = model.nodes[node_names[gemm_key]]
|
|
||||||
|
|
||||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
|
||||||
orig_weight = model.tensors[weight_name]
|
|
||||||
|
|
||||||
if orig_weight.shape[-2:] == (1, 1):
|
|
||||||
if lora_weight.shape[-2:] == (1, 1):
|
|
||||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
|
||||||
else:
|
|
||||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
|
||||||
|
|
||||||
new_weight = np.expand_dims(new_weight, (2, 3))
|
|
||||||
else:
|
|
||||||
if orig_weight.shape != lora_weight.shape:
|
|
||||||
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
|
||||||
else:
|
|
||||||
new_weight = orig_weight + lora_weight
|
|
||||||
|
|
||||||
orig_weights[weight_name] = orig_weight
|
|
||||||
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
|
||||||
|
|
||||||
elif matmul_key in node_names:
|
|
||||||
weight_node = model.nodes[node_names[matmul_key]]
|
|
||||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
|
||||||
|
|
||||||
orig_weight = model.tensors[matmul_name]
|
|
||||||
new_weight = orig_weight + lora_weight.transpose()
|
|
||||||
|
|
||||||
orig_weights[matmul_name] = orig_weight
|
|
||||||
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# warn? err?
|
|
||||||
pass
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# restore original weights
|
|
||||||
for name, orig_weight in orig_weights.items():
|
|
||||||
model.tensors[name] = orig_weight
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_ti(
|
|
||||||
cls,
|
|
||||||
tokenizer: CLIPTokenizer,
|
|
||||||
text_encoder: IAIOnnxRuntimeModel,
|
|
||||||
ti_list: List[Any],
|
|
||||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
|
||||||
from .models.base import IAIOnnxRuntimeModel
|
|
||||||
|
|
||||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
|
||||||
|
|
||||||
orig_embeddings = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
|
||||||
|
|
||||||
def _get_trigger(ti, index):
|
|
||||||
trigger = ti.name
|
|
||||||
if index > 0:
|
|
||||||
trigger += f"-!pad-{i}"
|
|
||||||
return f"<{trigger}>"
|
|
||||||
|
|
||||||
# modify tokenizer
|
|
||||||
new_tokens_added = 0
|
|
||||||
for ti in ti_list:
|
|
||||||
for i in range(ti.embedding.shape[0]):
|
|
||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
|
||||||
|
|
||||||
# modify text_encoder
|
|
||||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
|
||||||
|
|
||||||
embeddings = np.concatenate(
|
|
||||||
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
|
||||||
axis=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
for ti in ti_list:
|
|
||||||
ti_tokens = []
|
|
||||||
for i in range(ti.embedding.shape[0]):
|
|
||||||
embedding = ti.embedding[i].detach().numpy()
|
|
||||||
trigger = _get_trigger(ti, i)
|
|
||||||
|
|
||||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
|
||||||
if token_id == ti_tokenizer.unk_token_id:
|
|
||||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
|
||||||
|
|
||||||
if embeddings[token_id].shape != embedding.shape:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings[token_id] = embedding
|
|
||||||
ti_tokens.append(token_id)
|
|
||||||
|
|
||||||
if len(ti_tokens) > 1:
|
|
||||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
|
||||||
|
|
||||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
|
|
||||||
orig_embeddings.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
yield ti_tokenizer, ti_manager
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# restore
|
|
||||||
if orig_embeddings is not None:
|
|
||||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
|
||||||
|
|||||||
@@ -37,22 +37,19 @@ from .models import BaseModelType, ModelType, SubModelType, ModelBase
|
|||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
|
|
||||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
DEFAULT_MAX_VRAM_CACHE_SIZE= 2.75
|
||||||
|
|
||||||
# actual size of a gig
|
# actual size of a gig
|
||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
"Forward declaration"
|
"Forward declaration"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class _CacheRecord:
|
class _CacheRecord:
|
||||||
size: int
|
size: int
|
||||||
model: Any
|
model: Any
|
||||||
@@ -87,17 +84,17 @@ class _CacheRecord:
|
|||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
||||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
max_vram_cache_size: float=DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||||
execution_device: torch.device = torch.device("cuda"),
|
execution_device: torch.device=torch.device('cuda'),
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
storage_device: torch.device=torch.device('cpu'),
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype=torch.float16,
|
||||||
sequential_offload: bool = False,
|
sequential_offload: bool=False,
|
||||||
lazy_offloading: bool = True,
|
lazy_offloading: bool=True,
|
||||||
sha_chunksize: int = 16777216,
|
sha_chunksize: int = 16777216,
|
||||||
logger: types.ModuleType = logger,
|
logger: types.ModuleType = logger
|
||||||
):
|
):
|
||||||
"""
|
'''
|
||||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
@@ -105,16 +102,16 @@ class ModelCache(object):
|
|||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||||
"""
|
'''
|
||||||
self.model_infos: Dict[str, ModelBase] = dict()
|
self.model_infos: Dict[str, ModelBase] = dict()
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
self.precision: torch.dtype = precision
|
self.precision: torch.dtype=precision
|
||||||
self.max_cache_size: float = max_cache_size
|
self.max_cache_size: float=max_cache_size
|
||||||
self.max_vram_cache_size: float = max_vram_cache_size
|
self.max_vram_cache_size: float=max_vram_cache_size
|
||||||
self.execution_device: torch.device = execution_device
|
self.execution_device: torch.device=execution_device
|
||||||
self.storage_device: torch.device = storage_device
|
self.storage_device: torch.device=storage_device
|
||||||
self.sha_chunksize = sha_chunksize
|
self.sha_chunksize=sha_chunksize
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
self._cached_models = dict()
|
self._cached_models = dict()
|
||||||
@@ -127,6 +124,7 @@ class ModelCache(object):
|
|||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
key = f"{model_path}:{base_model}:{model_type}"
|
key = f"{model_path}:{base_model}:{model_type}"
|
||||||
if submodel_type:
|
if submodel_type:
|
||||||
key += f":{submodel_type}"
|
key += f":{submodel_type}"
|
||||||
@@ -165,6 +163,7 @@ class ModelCache(object):
|
|||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
gpu_load: bool = True,
|
gpu_load: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
||||||
if not isinstance(model_path, Path):
|
if not isinstance(model_path, Path):
|
||||||
model_path = Path(model_path)
|
model_path = Path(model_path)
|
||||||
|
|
||||||
@@ -187,9 +186,7 @@ class ModelCache(object):
|
|||||||
# TODO: lock for no copies on simultaneous calls?
|
# TODO: lock for no copies on simultaneous calls?
|
||||||
cache_entry = self._cached_models.get(key, None)
|
cache_entry = self._cached_models.get(key, None)
|
||||||
if cache_entry is None:
|
if cache_entry is None:
|
||||||
self.logger.info(
|
self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}')
|
||||||
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
# there is sufficient room to load the requested model
|
# there is sufficient room to load the requested model
|
||||||
@@ -199,7 +196,7 @@ class ModelCache(object):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||||
if mem_used := model_info.get_size(submodel):
|
if mem_used := model_info.get_size(submodel):
|
||||||
self.logger.debug(f"CPU RAM used for load: {(mem_used/GIG):.2f} GB")
|
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||||
|
|
||||||
cache_entry = _CacheRecord(self, model, mem_used)
|
cache_entry = _CacheRecord(self, model, mem_used)
|
||||||
self._cached_models[key] = cache_entry
|
self._cached_models[key] = cache_entry
|
||||||
@@ -212,13 +209,13 @@ class ModelCache(object):
|
|||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||||
"""
|
'''
|
||||||
:param cache: The model_cache object
|
:param cache: The model_cache object
|
||||||
:param key: The key of the model to lock in GPU
|
:param key: The key of the model to lock in GPU
|
||||||
:param model: The model to lock
|
:param model: The model to lock
|
||||||
:param gpu_load: True if load into gpu
|
:param gpu_load: True if load into gpu
|
||||||
:param size_needed: Size of the model to load
|
:param size_needed: Size of the model to load
|
||||||
"""
|
'''
|
||||||
self.gpu_load = gpu_load
|
self.gpu_load = gpu_load
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.key = key
|
self.key = key
|
||||||
@@ -227,7 +224,7 @@ class ModelCache(object):
|
|||||||
self.cache_entry = self.cache._cached_models[self.key]
|
self.cache_entry = self.cache._cached_models[self.key]
|
||||||
|
|
||||||
def __enter__(self) -> Any:
|
def __enter__(self) -> Any:
|
||||||
if not hasattr(self.model, "to"):
|
if not hasattr(self.model, 'to'):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this
|
# NOTE that the model has to have the to() method in order for this
|
||||||
@@ -237,21 +234,22 @@ class ModelCache(object):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if self.cache.lazy_offloading:
|
if self.cache.lazy_offloading:
|
||||||
self.cache._offload_unlocked_models(self.size_needed)
|
self.cache._offload_unlocked_models(self.size_needed)
|
||||||
|
|
||||||
if self.model.device != self.cache.execution_device:
|
if self.model.device != self.cache.execution_device:
|
||||||
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
|
self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}')
|
||||||
with VRAMUsage() as mem:
|
with VRAMUsage() as mem:
|
||||||
self.model.to(self.cache.execution_device) # move into GPU
|
self.model.to(self.cache.execution_device) # move into GPU
|
||||||
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
|
self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB')
|
||||||
|
|
||||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
|
||||||
self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
except:
|
except:
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
# TODO: not fully understand
|
# TODO: not fully understand
|
||||||
# in the event that the caller wants the model in RAM, we
|
# in the event that the caller wants the model in RAM, we
|
||||||
# move it into CPU if it is in GPU and not locked
|
# move it into CPU if it is in GPU and not locked
|
||||||
@@ -261,7 +259,7 @@ class ModelCache(object):
|
|||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
def __exit__(self, type, value, traceback):
|
||||||
if not hasattr(self.model, "to"):
|
if not hasattr(self.model, 'to'):
|
||||||
return
|
return
|
||||||
|
|
||||||
self.cache_entry.unlock()
|
self.cache_entry.unlock()
|
||||||
@@ -279,11 +277,11 @@ class ModelCache(object):
|
|||||||
self,
|
self,
|
||||||
model_path: Union[str, Path],
|
model_path: Union[str, Path],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
'''
|
||||||
Given the HF repo id or path to a model on disk, returns a unique
|
Given the HF repo id or path to a model on disk, returns a unique
|
||||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||||
:param model_path: Path to model file/directory on disk.
|
:param model_path: Path to model file/directory on disk.
|
||||||
"""
|
'''
|
||||||
return self._local_model_hash(model_path)
|
return self._local_model_hash(model_path)
|
||||||
|
|
||||||
def cache_size(self) -> float:
|
def cache_size(self) -> float:
|
||||||
@@ -292,7 +290,7 @@ class ModelCache(object):
|
|||||||
return current_cache_size / GIG
|
return current_cache_size / GIG
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
def _has_cuda(self) -> bool:
|
||||||
return self.execution_device.type == "cuda"
|
return self.execution_device.type == 'cuda'
|
||||||
|
|
||||||
def _print_cuda_stats(self):
|
def _print_cuda_stats(self):
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
@@ -308,21 +306,18 @@ class ModelCache(object):
|
|||||||
if model_info.locked:
|
if model_info.locked:
|
||||||
locked_models += 1
|
locked_models += 1
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
|
||||||
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_cache_room(self, model_size):
|
def _make_cache_room(self, model_size):
|
||||||
# calculate how much memory this model will require
|
# calculate how much memory this model will require
|
||||||
# multiplier = 2 if self.precision==torch.float32 else 1
|
#multiplier = 2 if self.precision==torch.float32 else 1
|
||||||
bytes_needed = model_size
|
bytes_needed = model_size
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||||
current_size = sum([m.size for m in self._cached_models.values()])
|
current_size = sum([m.size for m in self._cached_models.values()])
|
||||||
|
|
||||||
if current_size + bytes_needed > maximum_size:
|
if current_size + bytes_needed > maximum_size:
|
||||||
self.logger.debug(
|
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
||||||
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
@@ -344,7 +339,7 @@ class ModelCache(object):
|
|||||||
with suppress(RuntimeError):
|
with suppress(RuntimeError):
|
||||||
referrer.clear()
|
referrer.clear()
|
||||||
cleared = True
|
cleared = True
|
||||||
# break
|
#break
|
||||||
|
|
||||||
# repeat if referrers changes(due to frame clear), else exit loop
|
# repeat if referrers changes(due to frame clear), else exit loop
|
||||||
if cleared:
|
if cleared:
|
||||||
@@ -353,18 +348,13 @@ class ModelCache(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||||
self.logger.debug(
|
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
|
||||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2 refs:
|
# 2 refs:
|
||||||
# 1 from cache_entry
|
# 1 from cache_entry
|
||||||
# 1 from getrefcount function
|
# 1 from getrefcount function
|
||||||
# 1 from onnx runtime object
|
if not cache_entry.locked and refs <= 2:
|
||||||
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
||||||
self.logger.debug(
|
|
||||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
|
||||||
)
|
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
del self._cache_stack[pos]
|
del self._cache_stack[pos]
|
||||||
del self._cached_models[model_key]
|
del self._cached_models[model_key]
|
||||||
@@ -378,20 +368,20 @@ class ModelCache(object):
|
|||||||
|
|
||||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
def _offload_unlocked_models(self, size_needed: int = 0):
|
def _offload_unlocked_models(self, size_needed: int=0):
|
||||||
reserved = self.max_vram_cache_size * GIG
|
reserved = self.max_vram_cache_size * GIG
|
||||||
vram_in_use = torch.cuda.memory_allocated()
|
vram_in_use = torch.cuda.memory_allocated()
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
||||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x:x[1].size):
|
||||||
if vram_in_use <= reserved:
|
if vram_in_use <= reserved:
|
||||||
break
|
break
|
||||||
if not cache_entry.locked and cache_entry.loaded:
|
if not cache_entry.locked and cache_entry.loaded:
|
||||||
self.logger.debug(f"Offloading {model_key} from {self.execution_device} into {self.storage_device}")
|
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
|
||||||
with VRAMUsage() as mem:
|
with VRAMUsage() as mem:
|
||||||
cache_entry.model.to(self.storage_device)
|
cache_entry.model.to(self.storage_device)
|
||||||
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
|
self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB')
|
||||||
vram_in_use += mem.vram_used # note vram_used is negative
|
vram_in_use += mem.vram_used # note vram_used is negative
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB')
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -406,8 +396,10 @@ class ModelCache(object):
|
|||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
self.logger.debug(f"computing hash of model {path.name}")
|
self.logger.debug(f'computing hash of model {path.name}')
|
||||||
for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")):
|
for file in list(path.rglob("*.ckpt")) \
|
||||||
|
+ list(path.rglob("*.safetensors")) \
|
||||||
|
+ list(path.rglob("*.pth")):
|
||||||
with open(file, "rb") as f:
|
with open(file, "rb") as f:
|
||||||
while chunk := f.read(self.sha_chunksize):
|
while chunk := f.read(self.sha_chunksize):
|
||||||
sha.update(chunk)
|
sha.update(chunk)
|
||||||
@@ -416,7 +408,6 @@ class ModelCache(object):
|
|||||||
f.write(hash)
|
f.write(hash)
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
|
|
||||||
class VRAMUsage(object):
|
class VRAMUsage(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.vram = None
|
self.vram = None
|
||||||
|
|||||||
@@ -249,26 +249,20 @@ from invokeai.backend.util import CUDA_DEVICE, Chdir
|
|||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
from .model_search import ModelSearch
|
from .model_search import ModelSearch
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType,
|
BaseModelType, ModelType, SubModelType,
|
||||||
ModelType,
|
ModelError, SchedulerPredictionType, MODEL_CLASSES,
|
||||||
SubModelType,
|
|
||||||
ModelError,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
MODEL_CLASSES,
|
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
ModelNotFoundException,
|
ModelNotFoundException, InvalidModelException,
|
||||||
InvalidModelException,
|
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
# reduce confusion.
|
# reduce confusion.
|
||||||
CONFIG_FILE_VERSION = "3.0.0"
|
CONFIG_FILE_VERSION='3.0.0'
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelInfo:
|
class ModelInfo():
|
||||||
context: ModelLocker
|
context: ModelLocker
|
||||||
name: str
|
name: str
|
||||||
base_model: BaseModelType
|
base_model: BaseModelType
|
||||||
@@ -276,29 +270,25 @@ class ModelInfo:
|
|||||||
hash: str
|
hash: str
|
||||||
location: Union[Path, str]
|
location: Union[Path, str]
|
||||||
precision: torch.dtype
|
precision: torch.dtype
|
||||||
_cache: Optional[ModelCache] = None
|
_cache: ModelCache = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self.context.__enter__()
|
return self.context.__enter__()
|
||||||
|
|
||||||
def __exit__(self, *args, **kwargs):
|
def __exit__(self,*args, **kwargs):
|
||||||
self.context.__exit__(*args, **kwargs)
|
self.context.__exit__(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AddModelResult(BaseModel):
|
class AddModelResult(BaseModel):
|
||||||
name: str = Field(description="The name of the model after installation")
|
name: str = Field(description="The name of the model after installation")
|
||||||
model_type: ModelType = Field(description="The type of model")
|
model_type: ModelType = Field(description="The type of model")
|
||||||
base_model: BaseModelType = Field(description="The base model")
|
base_model: BaseModelType = Field(description="The base model")
|
||||||
config: ModelConfigBase = Field(description="The configuration of the model")
|
config: ModelConfigBase = Field(description="The configuration of the model")
|
||||||
|
|
||||||
|
|
||||||
MAX_CACHE_SIZE = 6.0 # GB
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
|
|
||||||
class ConfigMeta(BaseModel):
|
class ConfigMeta(BaseModel):
|
||||||
version: str
|
version: str
|
||||||
|
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
High-level interface to model management.
|
High-level interface to model management.
|
||||||
@@ -325,12 +315,12 @@ class ModelManager(object):
|
|||||||
if isinstance(config, (str, Path)):
|
if isinstance(config, (str, Path)):
|
||||||
self.config_path = Path(config)
|
self.config_path = Path(config)
|
||||||
if not self.config_path.exists():
|
if not self.config_path.exists():
|
||||||
logger.warning(f"The file {self.config_path} was not found. Initializing a new file")
|
logger.warning(f'The file {self.config_path} was not found. Initializing a new file')
|
||||||
self.initialize_model_config(self.config_path)
|
self.initialize_model_config(self.config_path)
|
||||||
config = OmegaConf.load(self.config_path)
|
config = OmegaConf.load(self.config_path)
|
||||||
|
|
||||||
elif not isinstance(config, DictConfig):
|
elif not isinstance(config, DictConfig):
|
||||||
raise ValueError("config argument must be an OmegaConf object, a Path or a string")
|
raise ValueError('config argument must be an OmegaConf object, a Path or a string')
|
||||||
|
|
||||||
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
self.config_meta = ConfigMeta(**config.pop("__metadata__"))
|
||||||
# TODO: metadata not found
|
# TODO: metadata not found
|
||||||
@@ -340,11 +330,11 @@ class ModelManager(object):
|
|||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_cache_size=max_cache_size,
|
max_cache_size=max_cache_size,
|
||||||
max_vram_cache_size=self.app_config.max_vram_cache_size,
|
max_vram_cache_size = self.app_config.max_vram_cache_size,
|
||||||
execution_device=device_type,
|
execution_device = device_type,
|
||||||
precision=precision,
|
precision = precision,
|
||||||
sequential_offload=sequential_offload,
|
sequential_offload = sequential_offload,
|
||||||
logger=logger,
|
logger = logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._read_models(config)
|
self._read_models(config)
|
||||||
@@ -358,7 +348,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
self.models = dict()
|
self.models = dict()
|
||||||
for model_key, model_config in config.items():
|
for model_key, model_config in config.items():
|
||||||
if model_key.startswith("_"):
|
if model_key.startswith('_'):
|
||||||
continue
|
continue
|
||||||
model_name, base_model, model_type = self.parse_key(model_key)
|
model_name, base_model, model_type = self.parse_key(model_key)
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@@ -401,15 +391,11 @@ class ModelManager(object):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
# In 3.11, the behavior of (str,enum) when interpolated into a
|
return f"{base_model}/{model_type}/{model_name}"
|
||||||
# string has changed. The next two lines are defensive.
|
|
||||||
base_model = BaseModelType(base_model)
|
|
||||||
model_type = ModelType(model_type)
|
|
||||||
return f"{base_model.value}/{model_type.value}/{model_name}"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]:
|
||||||
base_model_str, model_type_str, model_name = model_key.split("/", 2)
|
base_model_str, model_type_str, model_name = model_key.split('/', 2)
|
||||||
try:
|
try:
|
||||||
model_type = ModelType(model_type_str)
|
model_type = ModelType(model_type_str)
|
||||||
except:
|
except:
|
||||||
@@ -423,21 +409,25 @@ class ModelManager(object):
|
|||||||
return (model_name, base_model, model_type)
|
return (model_name, base_model, model_type)
|
||||||
|
|
||||||
def _get_model_cache_path(self, model_path):
|
def _get_model_cache_path(self, model_path):
|
||||||
return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
|
return self.app_config.models_path / ".cache" / hashlib.md5(str(model_path).encode()).hexdigest()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize_model_config(cls, config_path: Path):
|
def initialize_model_config(cls, config_path: Path):
|
||||||
"""Create empty config file"""
|
"""Create empty config file"""
|
||||||
with open(config_path, "w") as yaml_file:
|
with open(config_path,'w') as yaml_file:
|
||||||
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
yaml_file.write(yaml.dump({'__metadata__':
|
||||||
|
{'version':'3.0.0'}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None
|
||||||
) -> ModelInfo:
|
)->ModelInfo:
|
||||||
"""Given a model named identified in models.yaml, return
|
"""Given a model named identified in models.yaml, return
|
||||||
an ModelInfo object describing it.
|
an ModelInfo object describing it.
|
||||||
:param model_name: symbolic name of the model in models.yaml
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
@@ -456,12 +446,12 @@ class ModelManager(object):
|
|||||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f"Model not found - {model_key}")
|
||||||
|
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
model_path = self.resolve_model_path(model_config.path)
|
model_path = self.app_config.root_path / model_config.path
|
||||||
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
self.models[model_key].error = ModelError.NotFound
|
self.models[model_key].error = ModelError.NotFound
|
||||||
raise Exception(f'Files for model "{model_key}" not found')
|
raise Exception(f"Files for model \"{model_key}\" not found")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
@@ -472,7 +462,7 @@ class ModelManager(object):
|
|||||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||||
override_path = getattr(model_config, submodel_type)
|
override_path = getattr(model_config, submodel_type)
|
||||||
if override_path:
|
if override_path:
|
||||||
model_path = self.resolve_path(override_path)
|
model_path = self.app_config.root_path / override_path
|
||||||
model_type = submodel_type
|
model_type = submodel_type
|
||||||
submodel_type = None
|
submodel_type = None
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
@@ -483,7 +473,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
model_path = model_class.convert_if_required(
|
model_path = model_class.convert_if_required(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_path=str(model_path), # TODO: refactor str/Path types logic
|
model_path=str(model_path), # TODO: refactor str/Path types logic
|
||||||
output_path=dst_convert_path,
|
output_path=dst_convert_path,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
)
|
)
|
||||||
@@ -500,17 +490,17 @@ class ModelManager(object):
|
|||||||
self.cache_keys[model_key] = set()
|
self.cache_keys[model_key] = set()
|
||||||
self.cache_keys[model_key].add(model_context.key)
|
self.cache_keys[model_key].add(model_context.key)
|
||||||
|
|
||||||
model_hash = "<NO_HASH>" # TODO:
|
model_hash = "<NO_HASH>" # TODO:
|
||||||
|
|
||||||
return ModelInfo(
|
return ModelInfo(
|
||||||
context=model_context,
|
context = model_context,
|
||||||
name=model_name,
|
name = model_name,
|
||||||
base_model=base_model,
|
base_model = base_model,
|
||||||
type=submodel_type or model_type,
|
type = submodel_type or model_type,
|
||||||
hash=model_hash,
|
hash = model_hash,
|
||||||
location=model_path, # TODO:
|
location = model_path, # TODO:
|
||||||
precision=self.cache.precision,
|
precision = self.cache.precision,
|
||||||
_cache=self.cache,
|
_cache = self.cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
def model_info(
|
def model_info(
|
||||||
@@ -526,7 +516,7 @@ class ModelManager(object):
|
|||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
return self.models[model_key].dict(exclude_defaults=True)
|
return self.models[model_key].dict(exclude_defaults=True)
|
||||||
else:
|
else:
|
||||||
return None # TODO: None or empty dict on not found
|
return None # TODO: None or empty dict on not found
|
||||||
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
"""
|
"""
|
||||||
@@ -536,16 +526,16 @@ class ModelManager(object):
|
|||||||
return [(self.parse_key(x)) for x in self.models.keys()]
|
return [(self.parse_key(x)) for x in self.models.keys()]
|
||||||
|
|
||||||
def list_model(
|
def list_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Returns a dict describing one installed model, using
|
Returns a dict describing one installed model, using
|
||||||
the combined format of the list_models() method.
|
the combined format of the list_models() method.
|
||||||
"""
|
"""
|
||||||
models = self.list_models(base_model, model_type, model_name)
|
models = self.list_models(base_model,model_type,model_name)
|
||||||
return models[0] if models else None
|
return models[0] if models else None
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
@@ -558,17 +548,13 @@ class ModelManager(object):
|
|||||||
Return a list of models.
|
Return a list of models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_keys = (
|
model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold)
|
||||||
[self.create_key(model_name, base_model, model_type)]
|
|
||||||
if model_name
|
|
||||||
else sorted(self.models, key=str.casefold)
|
|
||||||
)
|
|
||||||
models = []
|
models = []
|
||||||
for model_key in model_keys:
|
for model_key in model_keys:
|
||||||
model_config = self.models.get(model_key)
|
model_config = self.models.get(model_key)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
self.logger.error(f"Unknown model {model_name}")
|
self.logger.error(f'Unknown model {model_name}')
|
||||||
raise ModelNotFoundException(f"Unknown model {model_name}")
|
raise ModelNotFoundException(f'Unknown model {model_name}')
|
||||||
|
|
||||||
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
if base_model is not None and cur_base_model != base_model:
|
if base_model is not None and cur_base_model != base_model:
|
||||||
@@ -585,8 +571,8 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# expose paths as absolute to help web UI
|
# expose paths as absolute to help web UI
|
||||||
if path := model_dict.get("path"):
|
if path := model_dict.get('path'):
|
||||||
model_dict["path"] = str(self.resolve_model_path(path))
|
model_dict['path'] = str(self.app_config.root_path / path)
|
||||||
models.append(model_dict)
|
models.append(model_dict)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
@@ -623,7 +609,7 @@ class ModelManager(object):
|
|||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
# if model inside invoke models folder - delete files
|
# if model inside invoke models folder - delete files
|
||||||
model_path = self.resolve_model_path(model_cfg.path)
|
model_path = self.app_config.root_path / model_cfg.path
|
||||||
cache_path = self._get_model_cache_path(model_path)
|
cache_path = self._get_model_cache_path(model_path)
|
||||||
if cache_path.exists():
|
if cache_path.exists():
|
||||||
rmtree(str(cache_path))
|
rmtree(str(cache_path))
|
||||||
@@ -654,15 +640,16 @@ class ModelManager(object):
|
|||||||
The returned dict has the same format as the dict returned by
|
The returned dict has the same format as the dict returned by
|
||||||
model_info().
|
model_info().
|
||||||
"""
|
"""
|
||||||
# relativize paths as they go in - this makes it easier to move the models directory around
|
# relativize paths as they go in - this makes it easier to move the root directory around
|
||||||
if path := model_attributes.get("path"):
|
if path := model_attributes.get('path'):
|
||||||
model_attributes["path"] = str(self.relative_model_path(Path(path)))
|
if Path(path).is_relative_to(self.app_config.root_path):
|
||||||
|
model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path))
|
||||||
|
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
model_config = model_class.create_config(**model_attributes)
|
model_config = model_class.create_config(**model_attributes)
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
|
|
||||||
if model_key in self.models and not clobber:
|
if model_key in self.models and not clobber:
|
||||||
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
raise Exception(f'Attempt to overwrite existing model definition "{model_key}"')
|
||||||
|
|
||||||
old_model = self.models.pop(model_key, None)
|
old_model = self.models.pop(model_key, None)
|
||||||
@@ -670,7 +657,7 @@ class ModelManager(object):
|
|||||||
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
# TODO: if path changed and old_model.path inside models folder should we delete this too?
|
||||||
|
|
||||||
# remove conversion cache as config changed
|
# remove conversion cache as config changed
|
||||||
old_model_path = self.resolve_model_path(old_model.path)
|
old_model_path = self.app_config.root_path / old_model.path
|
||||||
old_model_cache = self._get_model_cache_path(old_model_path)
|
old_model_cache = self._get_model_cache_path(old_model_path)
|
||||||
if old_model_cache.exists():
|
if old_model_cache.exists():
|
||||||
if old_model_cache.is_dir():
|
if old_model_cache.is_dir():
|
||||||
@@ -688,23 +675,23 @@ class ModelManager(object):
|
|||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
return AddModelResult(
|
return AddModelResult(
|
||||||
name=model_name,
|
name = model_name,
|
||||||
model_type=model_type,
|
model_type = model_type,
|
||||||
base_model=base_model,
|
base_model = base_model,
|
||||||
config=model_config,
|
config = model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
def rename_model(
|
def rename_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
new_name: str = None,
|
new_name: str = None,
|
||||||
new_base: BaseModelType = None,
|
new_base: BaseModelType = None,
|
||||||
):
|
):
|
||||||
"""
|
'''
|
||||||
Rename or rebase a model.
|
Rename or rebase a model.
|
||||||
"""
|
'''
|
||||||
if new_name is None and new_base is None:
|
if new_name is None and new_base is None:
|
||||||
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.")
|
||||||
return
|
return
|
||||||
@@ -714,7 +701,7 @@ class ModelManager(object):
|
|||||||
if not model_cfg:
|
if not model_cfg:
|
||||||
raise ModelNotFoundException(f"Unknown model: {model_key}")
|
raise ModelNotFoundException(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
old_path = self.resolve_model_path(model_cfg.path)
|
old_path = self.app_config.root_path / model_cfg.path
|
||||||
new_name = new_name or model_name
|
new_name = new_name or model_name
|
||||||
new_base = new_base or base_model
|
new_base = new_base or base_model
|
||||||
new_key = self.create_key(new_name, new_base, model_type)
|
new_key = self.create_key(new_name, new_base, model_type)
|
||||||
@@ -723,15 +710,9 @@ class ModelManager(object):
|
|||||||
|
|
||||||
# if this is a model file/directory that we manage ourselves, we need to move it
|
# if this is a model file/directory that we manage ourselves, we need to move it
|
||||||
if old_path.is_relative_to(self.app_config.models_path):
|
if old_path.is_relative_to(self.app_config.models_path):
|
||||||
new_path = self.resolve_model_path(
|
new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name
|
||||||
Path(
|
|
||||||
BaseModelType(new_base).value,
|
|
||||||
ModelType(model_type).value,
|
|
||||||
new_name,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
move(old_path, new_path)
|
move(old_path, new_path)
|
||||||
model_cfg.path = str(new_path.relative_to(self.app_config.models_path))
|
model_cfg.path = str(new_path.relative_to(self.app_config.root_path))
|
||||||
|
|
||||||
# clean up caches
|
# clean up caches
|
||||||
old_model_cache = self._get_model_cache_path(old_path)
|
old_model_cache = self._get_model_cache_path(old_path)
|
||||||
@@ -745,18 +726,18 @@ class ModelManager(object):
|
|||||||
for cache_id in cache_ids:
|
for cache_id in cache_ids:
|
||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
self.models.pop(model_key, None) # delete
|
self.models.pop(model_key, None) # delete
|
||||||
self.models[new_key] = model_cfg
|
self.models[new_key] = model_cfg
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
def convert_model(
|
def convert_model (
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: Union[ModelType.Main, ModelType.Vae],
|
model_type: Union[ModelType.Main,ModelType.Vae],
|
||||||
dest_directory: Optional[Path] = None,
|
dest_directory: Optional[Path]=None,
|
||||||
) -> AddModelResult:
|
) -> AddModelResult:
|
||||||
"""
|
'''
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
version and deleting the original checkpoint file if it is in the models
|
version and deleting the original checkpoint file if it is in the models
|
||||||
directory.
|
directory.
|
||||||
@@ -765,7 +746,7 @@ class ModelManager(object):
|
|||||||
:param model_type: Type of model ['vae' or 'main']
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
|
||||||
This will raise a ValueError unless the model is a checkpoint.
|
This will raise a ValueError unless the model is a checkpoint.
|
||||||
"""
|
'''
|
||||||
info = self.model_info(model_name, base_model, model_type)
|
info = self.model_info(model_name, base_model, model_type)
|
||||||
if info["model_format"] != "checkpoint":
|
if info["model_format"] != "checkpoint":
|
||||||
raise ValueError(f"not a checkpoint format model: {model_name}")
|
raise ValueError(f"not a checkpoint format model: {model_name}")
|
||||||
@@ -773,32 +754,27 @@ class ModelManager(object):
|
|||||||
# We are taking advantage of a side effect of get_model() that converts check points
|
# We are taking advantage of a side effect of get_model() that converts check points
|
||||||
# into cached diffusers directories stored at `location`. It doesn't matter
|
# into cached diffusers directories stored at `location`. It doesn't matter
|
||||||
# what submodeltype we request here, so we get the smallest.
|
# what submodeltype we request here, so we get the smallest.
|
||||||
submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {}
|
submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {}
|
||||||
model = self.get_model(
|
model = self.get_model(model_name,
|
||||||
model_name,
|
base_model,
|
||||||
base_model,
|
model_type,
|
||||||
model_type,
|
**submodel,
|
||||||
**submodel,
|
)
|
||||||
)
|
checkpoint_path = self.app_config.root_path / info["path"]
|
||||||
checkpoint_path = self.resolve_model_path(info["path"])
|
old_diffusers_path = self.app_config.models_path / model.location
|
||||||
old_diffusers_path = self.resolve_model_path(model.location)
|
new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name
|
||||||
new_diffusers_path = (
|
|
||||||
dest_directory or self.app_config.models_path / base_model.value / model_type.value
|
|
||||||
) / model_name
|
|
||||||
if new_diffusers_path.exists():
|
if new_diffusers_path.exists():
|
||||||
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
raise ValueError(f"A diffusers model already exists at {new_diffusers_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
move(old_diffusers_path, new_diffusers_path)
|
move(old_diffusers_path,new_diffusers_path)
|
||||||
info["model_format"] = "diffusers"
|
info["model_format"] = "diffusers"
|
||||||
info["path"] = (
|
info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path))
|
||||||
str(new_diffusers_path)
|
info.pop('config')
|
||||||
if dest_directory
|
|
||||||
else str(new_diffusers_path.relative_to(self.app_config.models_path))
|
|
||||||
)
|
|
||||||
info.pop("config")
|
|
||||||
|
|
||||||
result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True)
|
result = self.add_model(model_name, base_model, model_type,
|
||||||
|
model_attributes = info,
|
||||||
|
clobber=True)
|
||||||
except:
|
except:
|
||||||
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
# something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error!
|
||||||
rmtree(new_diffusers_path)
|
rmtree(new_diffusers_path)
|
||||||
@@ -809,15 +785,6 @@ class ModelManager(object):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def resolve_model_path(self, path: Union[Path, str]) -> Path:
|
|
||||||
"""return relative paths based on configured models_path"""
|
|
||||||
return self.app_config.models_path / path
|
|
||||||
|
|
||||||
def relative_model_path(self, model_path: Path) -> Path:
|
|
||||||
if model_path.is_relative_to(self.app_config.models_path):
|
|
||||||
model_path = model_path.relative_to(self.app_config.models_path)
|
|
||||||
return model_path
|
|
||||||
|
|
||||||
def search_models(self, search_folder):
|
def search_models(self, search_folder):
|
||||||
self.logger.info(f"Finding Models In: {search_folder}")
|
self.logger.info(f"Finding Models In: {search_folder}")
|
||||||
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
models_folder_ckpt = Path(search_folder).glob("**/*.ckpt")
|
||||||
@@ -831,12 +798,15 @@ class ModelManager(object):
|
|||||||
found_models = []
|
found_models = []
|
||||||
for file in files:
|
for file in files:
|
||||||
location = str(file.resolve()).replace("\\", "/")
|
location = str(file.resolve()).replace("\\", "/")
|
||||||
if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location:
|
if (
|
||||||
|
"model.safetensors" not in location
|
||||||
|
and "diffusion_pytorch_model.safetensors" not in location
|
||||||
|
):
|
||||||
found_models.append({"name": file.stem, "location": location})
|
found_models.append({"name": file.stem, "location": location})
|
||||||
|
|
||||||
return search_folder, found_models
|
return search_folder, found_models
|
||||||
|
|
||||||
def commit(self, conf_file: Path = None) -> None:
|
def commit(self, conf_file: Path=None) -> None:
|
||||||
"""
|
"""
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
"""
|
"""
|
||||||
@@ -854,7 +824,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
yaml_str = OmegaConf.to_yaml(data_to_save)
|
yaml_str = OmegaConf.to_yaml(data_to_save)
|
||||||
config_file_path = conf_file or self.config_path
|
config_file_path = conf_file or self.config_path
|
||||||
assert config_file_path is not None, "no config file path to write to"
|
assert config_file_path is not None,'no config file path to write to'
|
||||||
config_file_path = self.app_config.root_path / config_file_path
|
config_file_path = self.app_config.root_path / config_file_path
|
||||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||||
try:
|
try:
|
||||||
@@ -887,21 +857,15 @@ class ModelManager(object):
|
|||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
loaded_files = set()
|
loaded_files = set()
|
||||||
new_models_found = False
|
new_models_found = False
|
||||||
|
|
||||||
self.logger.info(f"Scanning {self.app_config.models_path} for new models")
|
self.logger.info(f'Scanning {self.app_config.models_path} for new models')
|
||||||
with Chdir(self.app_config.models_path):
|
with Chdir(self.app_config.root_path):
|
||||||
for model_key, model_config in list(self.models.items()):
|
for model_key, model_config in list(self.models.items()):
|
||||||
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
|
||||||
|
model_path = self.app_config.root_path.absolute() / model_config.path
|
||||||
# Patch for relative path bug in older models.yaml - paths should not
|
|
||||||
# be starting with a hard-coded 'models'. This will also fix up
|
|
||||||
# models.yaml when committed.
|
|
||||||
if model_config.path.startswith("models"):
|
|
||||||
model_config.path = str(Path(*Path(model_config.path).parts[1:]))
|
|
||||||
|
|
||||||
model_path = self.resolve_model_path(model_config.path).absolute()
|
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||||
if model_class.save_to_config:
|
if model_class.save_to_config:
|
||||||
@@ -920,13 +884,13 @@ class ModelManager(object):
|
|||||||
if model_type is not None and cur_model_type != model_type:
|
if model_type is not None and cur_model_type != model_type:
|
||||||
continue
|
continue
|
||||||
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
model_class = MODEL_CLASSES[cur_base_model][cur_model_type]
|
||||||
models_dir = self.resolve_model_path(Path(cur_base_model.value, cur_model_type.value))
|
models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value
|
||||||
|
|
||||||
if not models_dir.exists():
|
if not models_dir.exists():
|
||||||
continue # TODO: or create all folders?
|
continue # TODO: or create all folders?
|
||||||
|
|
||||||
for model_path in models_dir.iterdir():
|
for model_path in models_dir.iterdir():
|
||||||
if model_path not in loaded_files: # TODO: check
|
if model_path not in loaded_files: # TODO: check
|
||||||
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
model_name = model_path.name if model_path.is_dir() else model_path.stem
|
||||||
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
|
model_key = self.create_key(model_name, cur_base_model, cur_model_type)
|
||||||
|
|
||||||
@@ -934,7 +898,9 @@ class ModelManager(object):
|
|||||||
if model_key in self.models:
|
if model_key in self.models:
|
||||||
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
raise DuplicateModelException(f"Model with key {model_key} added twice")
|
||||||
|
|
||||||
model_path = self.relative_model_path(model_path)
|
if model_path.is_relative_to(self.app_config.root_path):
|
||||||
|
model_path = model_path.relative_to(self.app_config.root_path)
|
||||||
|
|
||||||
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
model_config: ModelConfigBase = model_class.probe_config(str(model_path))
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
new_models_found = True
|
new_models_found = True
|
||||||
@@ -945,14 +911,16 @@ class ModelManager(object):
|
|||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
self.logger.warning(e)
|
self.logger.warning(e)
|
||||||
|
|
||||||
imported_models = self.scan_autoimport_directory()
|
imported_models = self.autoimport()
|
||||||
|
|
||||||
if (new_models_found or imported_models) and self.config_path:
|
if (new_models_found or imported_models) and self.config_path:
|
||||||
self.commit()
|
self.commit()
|
||||||
|
|
||||||
def scan_autoimport_directory(self) -> Dict[str, AddModelResult]:
|
|
||||||
"""
|
def autoimport(self)->Dict[str, AddModelResult]:
|
||||||
|
'''
|
||||||
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
Scan the autoimport directory (if defined) and import new models, delete defunct models.
|
||||||
"""
|
'''
|
||||||
# avoid circular import
|
# avoid circular import
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
from invokeai.frontend.install.model_install import ask_user_for_prediction_type
|
||||||
@@ -971,9 +939,7 @@ class ModelManager(object):
|
|||||||
self.new_models_found.update(self.installer.heuristic_import(model))
|
self.new_models_found.update(self.installer.heuristic_import(model))
|
||||||
|
|
||||||
def on_search_completed(self):
|
def on_search_completed(self):
|
||||||
self.logger.info(
|
self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models')
|
||||||
f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models"
|
|
||||||
)
|
|
||||||
|
|
||||||
def models_found(self):
|
def models_found(self):
|
||||||
return self.new_models_found
|
return self.new_models_found
|
||||||
@@ -983,37 +949,31 @@ class ModelManager(object):
|
|||||||
# LS: hacky
|
# LS: hacky
|
||||||
# Patch in the SD VAE from core so that it is available for use by the UI
|
# Patch in the SD VAE from core so that it is available for use by the UI
|
||||||
try:
|
try:
|
||||||
self.heuristic_import({self.resolve_model_path("core/convert/sd-vae-ft-mse")})
|
self.heuristic_import({config.root_path / 'models/core/convert/sd-vae-ft-mse'})
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
installer = ModelInstall(
|
installer = ModelInstall(config = self.app_config,
|
||||||
config=self.app_config,
|
model_manager = self,
|
||||||
model_manager=self,
|
prediction_type_helper = ask_user_for_prediction_type,
|
||||||
prediction_type_helper=ask_user_for_prediction_type,
|
)
|
||||||
)
|
known_paths = {config.root_path / x['path'] for x in self.list_models()}
|
||||||
known_paths = {self.resolve_model_path(x["path"]) for x in self.list_models()}
|
directories = {config.root_path / x for x in [config.autoimport_dir,
|
||||||
directories = {
|
config.lora_dir,
|
||||||
config.root_path / x
|
config.embedding_dir,
|
||||||
for x in [
|
config.controlnet_dir,
|
||||||
config.autoimport_dir,
|
] if x
|
||||||
config.lora_dir,
|
}
|
||||||
config.embedding_dir,
|
|
||||||
config.controlnet_dir,
|
|
||||||
]
|
|
||||||
if x
|
|
||||||
}
|
|
||||||
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer)
|
||||||
scanner.search()
|
scanner.search()
|
||||||
|
|
||||||
return scanner.models_found()
|
return scanner.models_found()
|
||||||
|
|
||||||
def heuristic_import(
|
def heuristic_import(self,
|
||||||
self,
|
items_to_import: Set[str],
|
||||||
items_to_import: Set[str],
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
|
)->Dict[str, AddModelResult]:
|
||||||
) -> Dict[str, AddModelResult]:
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
@@ -1032,15 +992,14 @@ class ModelManager(object):
|
|||||||
May return the following exceptions:
|
May return the following exceptions:
|
||||||
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
|
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
|
||||||
- ValueError - a corresponding model already exists
|
- ValueError - a corresponding model already exists
|
||||||
"""
|
'''
|
||||||
# avoid circular import here
|
# avoid circular import here
|
||||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||||
|
|
||||||
successfully_installed = dict()
|
successfully_installed = dict()
|
||||||
|
|
||||||
installer = ModelInstall(
|
installer = ModelInstall(config = self.app_config,
|
||||||
config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self
|
prediction_type_helper = prediction_type_helper,
|
||||||
)
|
model_manager = self)
|
||||||
for thing in items_to_import:
|
for thing in items_to_import:
|
||||||
installed = installer.heuristic_import(thing)
|
installed = installer.heuristic_import(thing)
|
||||||
successfully_installed.update(installed)
|
successfully_installed.update(installed)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user