mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-18 07:27:57 -05:00
Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ad0016c2d | ||
|
|
c2a3c66e49 | ||
|
|
c0a0d20935 | ||
|
|
028d8d8ead | ||
|
|
657095d2e2 | ||
|
|
1c47dc997e | ||
|
|
a3de6b6165 | ||
|
|
e57f0ff055 | ||
|
|
feee4c49a2 | ||
|
|
42e052d6f2 | ||
|
|
b03e429b26 | ||
|
|
7399909029 | ||
|
|
c8aaf5e76b | ||
|
|
0cdf7a7048 | ||
|
|
41985487d3 | ||
|
|
41d5a17114 | ||
|
|
14f9d5b6bc | ||
|
|
eec4bdb038 | ||
|
|
f3dd44044a | ||
|
|
61a22eb8cb | ||
|
|
03ca83fe13 | ||
|
|
8f1e25c387 | ||
|
|
29cf4bc002 | ||
|
|
9428642806 | ||
|
|
8620572524 | ||
|
|
f44c7e824d | ||
|
|
c5b8bde285 | ||
|
|
4c86a7ecbf | ||
|
|
b9f9d1c152 | ||
|
|
7567ee2adf | ||
|
|
0e632dbc5c | ||
|
|
49191709a0 | ||
|
|
3af7fc26fa | ||
|
|
a36a627f83 | ||
|
|
b31c71f302 | ||
|
|
5302d4890f | ||
|
|
766b752572 | ||
|
|
7feae5e5ce | ||
|
|
26730ca702 | ||
|
|
1e2c7c51b5 | ||
|
|
da2b6815ac | ||
|
|
68d14de3ee | ||
|
|
38991ffc35 | ||
|
|
f345c0fabc | ||
|
|
ca23b5337e | ||
|
|
35910d3952 | ||
|
|
6f1dcf385b | ||
|
|
84c9ecc83f | ||
|
|
52aa839b7e | ||
|
|
316ed1d478 | ||
|
|
3519e8ae39 | ||
|
|
82f645c7a1 | ||
|
|
cc36cfb617 | ||
|
|
ded8a84284 | ||
|
|
94771ea626 | ||
|
|
51d661023e | ||
|
|
d215829b91 | ||
|
|
fad6c67f01 | ||
|
|
f366640d46 | ||
|
|
36a3fba8cb |
11
.github/workflows/build-container.yml
vendored
11
.github/workflows/build-container.yml
vendored
@@ -76,9 +76,6 @@ jobs:
|
||||
latest=${{ matrix.gpu-driver == 'cuda' && github.ref == 'refs/heads/main' }}
|
||||
suffix=-${{ matrix.gpu-driver }},onlatest=false
|
||||
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
@@ -103,7 +100,7 @@ jobs:
|
||||
push: ${{ github.ref == 'refs/heads/main' || github.ref_type == 'tag' || github.event.inputs.push-to-registry }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: |
|
||||
type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
# cache-from: |
|
||||
# type=gha,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
# type=gha,scope=main-${{ matrix.gpu-driver }}
|
||||
# cache-to: type=gha,mode=max,scope=${{ github.ref_name }}-${{ matrix.gpu-driver }}
|
||||
|
||||
138
docs/RELEASE.md
138
docs/RELEASE.md
@@ -1,41 +1,50 @@
|
||||
# Release Process
|
||||
|
||||
The app is published in twice, in different build formats.
|
||||
The Invoke application is published as a python package on [PyPI]. This includes both a source distribution and built distribution (a wheel).
|
||||
|
||||
- A [PyPI] distribution. This includes both a source distribution and built distribution (a wheel). Users install with `pip install invokeai`. The updater uses this build.
|
||||
- An installer on the [InvokeAI Releases Page]. This is a zip file with install scripts and a wheel. This is only used for new installs.
|
||||
Most users install it with the [Launcher](https://github.com/invoke-ai/launcher/), others with `pip`.
|
||||
|
||||
The launcher uses GitHub as the source of truth for available releases.
|
||||
|
||||
## Broad Strokes
|
||||
|
||||
- Merge all changes and bump the version in the codebase.
|
||||
- Tag the release commit.
|
||||
- Wait for the release workflow to complete.
|
||||
- Approve the PyPI publish jobs.
|
||||
- Write GH release notes.
|
||||
|
||||
## General Prep
|
||||
|
||||
Make a developer call-out for PRs to merge. Merge and test things out.
|
||||
|
||||
While the release workflow does not include end-to-end tests, it does pause before publishing so you can download and test the final build.
|
||||
Make a developer call-out for PRs to merge. Merge and test things out. Bump the version by editing `invokeai/version/invokeai_version.py`.
|
||||
|
||||
## Release Workflow
|
||||
|
||||
The `release.yml` workflow runs a number of jobs to handle code checks, tests, build and publish on PyPI.
|
||||
|
||||
It is triggered on **tag push**, when the tag matches `v*`. It doesn't matter if you've prepped a release branch like `release/v3.5.0` or are releasing from `main` - it works the same.
|
||||
|
||||
> Because commits are reference-counted, it is safe to create a release branch, tag it, let the workflow run, then delete the branch. So long as the tag exists, that commit will exist.
|
||||
It is triggered on **tag push**, when the tag matches `v*`.
|
||||
|
||||
### Triggering the Workflow
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow.
|
||||
Ensure all commits that should be in the release are merged, and you have pulled them locally.
|
||||
|
||||
The release may also be dispatched [manually].
|
||||
Double-check that you have checked out the commit that will represent the release (typically the latest commit on `main`).
|
||||
|
||||
Run `make tag-release` to tag the current commit and kick off the workflow. You will be prompted to provide a message - use the version specifier.
|
||||
|
||||
If this version's tag already exists for some reason (maybe you had to make a last minute change), the script will overwrite it.
|
||||
|
||||
> In case you cannot use the Make target, the release may also be dispatched [manually] via GH.
|
||||
|
||||
### Workflow Jobs and Process
|
||||
|
||||
The workflow consists of a number of concurrently-run jobs, and two final publish jobs.
|
||||
The workflow consists of a number of concurrently-run checks and tests, then two final publish jobs.
|
||||
|
||||
The publish jobs require manual approval and are only run if the other jobs succeed.
|
||||
|
||||
#### `check-version` Job
|
||||
|
||||
This job checks that the git ref matches the app version. It matches the ref against the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
|
||||
When the workflow is triggered by tag push, the ref is the tag. If the workflow is run manually, the ref is the target selected from the **Use workflow from** dropdown.
|
||||
This job ensures that the `invokeai` python package version specifier matches the tag for the release. The version specifier is pulled from the `__version__` variable in `invokeai/version/invokeai_version.py`.
|
||||
|
||||
This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
@@ -43,62 +52,52 @@ This job uses [samuelcolvin/check-python-version].
|
||||
|
||||
#### Check and Test Jobs
|
||||
|
||||
Next, these jobs run and must pass. They are the same jobs that are run for every PR.
|
||||
|
||||
- **`python-tests`**: runs `pytest` on matrix of platforms
|
||||
- **`python-checks`**: runs `ruff` (format and lint)
|
||||
- **`frontend-tests`**: runs `vitest`
|
||||
- **`frontend-checks`**: runs `prettier` (format), `eslint` (lint), `dpdm` (circular refs), `tsc` (static type check) and `knip` (unused imports)
|
||||
|
||||
> **TODO** We should add `mypy` or `pyright` to the **`check-python`** job.
|
||||
|
||||
> **TODO** We should add an end-to-end test job that generates an image.
|
||||
- **`typegen-checks`**: ensures the frontend and backend types are synced
|
||||
|
||||
#### `build-installer` Job
|
||||
|
||||
This sets up both python and frontend dependencies and builds the python package. Internally, this runs `installer/create_installer.sh` and uploads two artifacts:
|
||||
|
||||
- **`dist`**: the python distribution, to be published on PyPI
|
||||
- **`InvokeAI-installer-${VERSION}.zip`**: the installer to be included in the GitHub release
|
||||
- **`InvokeAI-installer-${VERSION}.zip`**: the legacy install scripts
|
||||
|
||||
You don't need to download either of these files.
|
||||
|
||||
> The legacy install scripts are no longer used, but we haven't updated the workflow to skip building them.
|
||||
|
||||
#### Sanity Check & Smoke Test
|
||||
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval. Time to test the installer.
|
||||
At this point, the release workflow pauses as the remaining publish jobs require approval.
|
||||
|
||||
Because the installer pulls from PyPI, and we haven't published to PyPI yet, you will need to install from the wheel:
|
||||
It's possible to test the python package before it gets published to PyPI. We've never had problems with it, so it's not necessary to do this.
|
||||
|
||||
- Download and unzip `dist.zip` and the installer from the **Summary** tab of the workflow
|
||||
- Run the installer script using the `--wheel` CLI arg, pointing at the wheel:
|
||||
But, if you want to be extra-super careful, here's how to test it:
|
||||
|
||||
```sh
|
||||
./install.sh --wheel ../InvokeAI-4.0.0rc6-py3-none-any.whl
|
||||
```
|
||||
|
||||
- Install to a temporary directory so you get the new user experience
|
||||
- Download a model and generate
|
||||
|
||||
> The same wheel file is bundled in the installer and in the `dist` artifact, which is uploaded to PyPI. You should end up with the exactly the same installation as if the installer got the wheel from PyPI.
|
||||
- Download the `dist.zip` build artifact from the `build-installer` job
|
||||
- Unzip it and find the wheel file
|
||||
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/) - but instead of installing from PyPI, install from the wheel
|
||||
- Test the app
|
||||
|
||||
##### Something isn't right
|
||||
|
||||
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?).
|
||||
|
||||
Now you can start from the top:
|
||||
|
||||
- Fix the issues and PR the fixes per usual
|
||||
- Get the PR approved and merged per usual
|
||||
- Switch to `main` and pull in the fixes
|
||||
- Run `make tag-release` to move the tag to `HEAD` (which has the fixes) and kick off the release workflow again
|
||||
- Re-do the sanity check
|
||||
If testing reveals any issues, no worries. Cancel the workflow, which will cancel the pending publish jobs (you didn't approve them prematurely, right?) and start over.
|
||||
|
||||
#### PyPI Publish Jobs
|
||||
|
||||
The publish jobs will run if any of the previous jobs fail.
|
||||
The publish jobs will not run if any of the previous jobs fail.
|
||||
|
||||
They use [GitHub environments], which are configured as [trusted publishers] on PyPI.
|
||||
|
||||
Both jobs require a maintainer to approve them from the workflow's **Summary** tab.
|
||||
Both jobs require a @hipsterusername or @psychedelicious to approve them from the workflow's **Summary** tab.
|
||||
|
||||
- Click the **Review deployments** button
|
||||
- Select the environment (either `testpypi` or `pypi`)
|
||||
- Select the environment (either `testpypi` or `pypi` - typically you select both)
|
||||
- Click **Approve and deploy**
|
||||
|
||||
> **If the version already exists on PyPI, the publish jobs will fail.** PyPI only allows a given version to be published once - you cannot change it. If version published on PyPI has a problem, you'll need to "fail forward" by bumping the app version and publishing a followup release.
|
||||
@@ -113,46 +112,33 @@ If there are no incidents, contact @hipsterusername or @lstein, who have owner a
|
||||
|
||||
Publishes the distribution on the [Test PyPI] index, using the `testpypi` GitHub environment.
|
||||
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release.
|
||||
This job is not required for the production PyPI publish, but included just in case you want to test the PyPI release for some reason:
|
||||
|
||||
If approved and successful, you could try out the test release like this:
|
||||
|
||||
```sh
|
||||
# Create a new virtual environment
|
||||
python -m venv ~/.test-invokeai-dist --prompt test-invokeai-dist
|
||||
# Install the distribution from Test PyPI
|
||||
pip install --index-url https://test.pypi.org/simple/ invokeai
|
||||
# Run and test the app
|
||||
invokeai-web
|
||||
# Cleanup
|
||||
deactivate
|
||||
rm -rf ~/.test-invokeai-dist
|
||||
```
|
||||
- Approve this publish job without approving the prod publish
|
||||
- Let it finish
|
||||
- Create a fresh Invoke install by following the [manual install guide](https://invoke-ai.github.io/InvokeAI/installation/manual/), making sure to use the Test PyPI index URL: `https://test.pypi.org/simple/`
|
||||
- Test the app
|
||||
|
||||
#### `publish-pypi` Job
|
||||
|
||||
Publishes the distribution on the production PyPI index, using the `pypi` GitHub environment.
|
||||
|
||||
## Publish the GitHub Release with installer
|
||||
It's a good idea to wait to approve and run this job until you have the release notes ready!
|
||||
|
||||
Once the release is published to PyPI, it's time to publish the GitHub release.
|
||||
## Prep and publish the GitHub Release
|
||||
|
||||
1. [Draft a new release] on GitHub, choosing the tag that triggered the release.
|
||||
1. Write the release notes, describing important changes. The **Generate release notes** button automatically inserts the changelog and new contributors, and you can copy/paste the intro from previous releases.
|
||||
1. Use `scripts/get_external_contributions.py` to get a list of external contributions to shout out in the release notes.
|
||||
1. Upload the zip file created in **`build`** job into the Assets section of the release notes.
|
||||
1. Check **Set as a pre-release** if it's a pre-release.
|
||||
1. Check **Create a discussion for this release**.
|
||||
1. Publish the release.
|
||||
1. Announce the release in Discord.
|
||||
|
||||
> **TODO** Workflows can create a GitHub release from a template and upload release assets. One popular action to handle this is [ncipollo/release-action]. A future enhancement to the release process could set this up.
|
||||
|
||||
## Manual Build
|
||||
|
||||
The `build installer` workflow can be dispatched manually. This is useful to test the installer for a given branch or tag.
|
||||
|
||||
No checks are run, it just builds.
|
||||
2. The **Generate release notes** button automatically inserts the changelog and new contributors. Make sure to select the correct tags for this release and the last stable release. GH often selects the wrong tags - do this manually.
|
||||
3. Write the release notes, describing important changes. Contributions from community members should be shouted out. Use the GH-generated changelog to see all contributors. If there are Weblate translation updates, open that PR and shout out every person who contributed a translation.
|
||||
4. Check **Set as a pre-release** if it's a pre-release.
|
||||
5. Approve and wait for the `publish-pypi` job to finish if you haven't already.
|
||||
6. Publish the GH release.
|
||||
7. Post the release in Discord in the [releases](https://discord.com/channels/1020123559063990373/1149260708098359327) channel with abbreviated notes. For example:
|
||||
> Invoke v5.7.0 (stable): <https://github.com/invoke-ai/InvokeAI/releases/tag/v5.7.0>
|
||||
>
|
||||
> It's a pretty big one - Form Builder, Metadata Nodes (thanks @SkunkWorxDark!), and much more.
|
||||
8. Right click the message in releases and copy the link to it. Then, post that link in the [new-release-discussion](https://discord.com/channels/1020123559063990373/1149506274971631688) channel. For example:
|
||||
> Invoke v5.7.0 (stable): <https://discord.com/channels/1020123559063990373/1149260708098359327/1344521744916021248>
|
||||
|
||||
## Manual Release
|
||||
|
||||
@@ -160,12 +146,10 @@ The `release` workflow can be dispatched manually. You must dispatch the workflo
|
||||
|
||||
This functionality is available as a fallback in case something goes wonky. Typically, releases should be triggered via tag push as described above.
|
||||
|
||||
[InvokeAI Releases Page]: https://github.com/invoke-ai/InvokeAI/releases
|
||||
[PyPI]: https://pypi.org/
|
||||
[Draft a new release]: https://github.com/invoke-ai/InvokeAI/releases/new
|
||||
[Test PyPI]: https://test.pypi.org/
|
||||
[version specifier]: https://packaging.python.org/en/latest/specifications/version-specifiers/
|
||||
[ncipollo/release-action]: https://github.com/ncipollo/release-action
|
||||
[GitHub environments]: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment
|
||||
[trusted publishers]: https://docs.pypi.org/trusted-publishers/
|
||||
[samuelcolvin/check-python-version]: https://github.com/samuelcolvin/check-python-version
|
||||
|
||||
@@ -31,6 +31,7 @@ It is possible to fine-tune the settings for best performance or if you still ge
|
||||
Low-VRAM mode involves 4 features, each of which can be configured or fine-tuned:
|
||||
|
||||
- Partial model loading (`enable_partial_loading`)
|
||||
- PyTorch CUDA allocator config (`pytorch_cuda_alloc_conf`)
|
||||
- Dynamic RAM and VRAM cache sizes (`max_cache_ram_gb`, `max_cache_vram_gb`)
|
||||
- Working memory (`device_working_mem_gb`)
|
||||
- Keeping a RAM weight copy (`keep_ram_copy_of_weights`)
|
||||
@@ -51,6 +52,16 @@ As described above, you can enable partial model loading by adding this line to
|
||||
enable_partial_loading: true
|
||||
```
|
||||
|
||||
### PyTorch CUDA allocator config
|
||||
|
||||
The PyTorch CUDA allocator's behavior can be configured using the `pytorch_cuda_alloc_conf` config. Tuning the allocator configuration can help to reduce the peak reserved VRAM. The optimal configuration is dependent on many factors (e.g. device type, VRAM, CUDA driver version, etc.), but switching from PyTorch's native allocator to using CUDA's built-in allocator works well on many systems. To try this, add the following line to your `invokeai.yaml` file:
|
||||
|
||||
```yaml
|
||||
pytorch_cuda_alloc_conf: "backend:cudaMallocAsync"
|
||||
```
|
||||
|
||||
A more complete explanation of the available configuration options is [here](https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf).
|
||||
|
||||
### Dynamic RAM and VRAM cache sizes
|
||||
|
||||
Loading models from disk is slow and can be a major bottleneck for performance. Invoke uses two model caches - RAM and VRAM - to reduce loading from disk to a minimum.
|
||||
@@ -75,24 +86,26 @@ But, if your GPU has enough VRAM to hold models fully, you might get a perf boos
|
||||
# As an example, if your system has 32GB of RAM and no other heavy processes, setting the `max_cache_ram_gb` to 28GB
|
||||
# might be a good value to achieve aggressive model caching.
|
||||
max_cache_ram_gb: 28
|
||||
|
||||
# The default max cache VRAM size is adjusted dynamically based on the amount of available VRAM (taking into
|
||||
# consideration the VRAM used by other processes).
|
||||
# You can override the default value by setting `max_cache_vram_gb`. Note that this value takes precedence over the
|
||||
# `device_working_mem_gb`.
|
||||
# It is recommended to set the VRAM cache size to be as large as possible while leaving enough room for the working
|
||||
# memory of the tasks you will be doing. For example, on a 24GB GPU that will be running unquantized FLUX without any
|
||||
# auxiliary models, 18GB might be a good value.
|
||||
max_cache_vram_gb: 18
|
||||
# You can override the default value by setting `max_cache_vram_gb`.
|
||||
# CAUTION: Most users should not manually set this value. See warning below.
|
||||
max_cache_vram_gb: 16
|
||||
```
|
||||
|
||||
!!! tip "Max safe value for `max_cache_vram_gb`"
|
||||
!!! warning "Max safe value for `max_cache_vram_gb`"
|
||||
|
||||
To determine the max safe value for `max_cache_vram_gb`, subtract `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
|
||||
Most users should not manually configure the `max_cache_vram_gb`. This configuration value takes precedence over the `device_working_mem_gb` and any operations that explicitly reserve additional working memory (e.g. VAE decode). As such, manually configuring it increases the likelihood of encountering out-of-memory errors.
|
||||
|
||||
For users who wish to configure `max_cache_vram_gb`, the max safe value can be determined by subtracting `device_working_mem_gb` from your GPU's VRAM. As described below, the default for `device_working_mem_gb` is 3GB.
|
||||
|
||||
For example, if you have a 12GB GPU, the max safe value for `max_cache_vram_gb` is `12GB - 3GB = 9GB`.
|
||||
|
||||
If you had increased `device_working_mem_gb` to 4GB, then the max safe value for `max_cache_vram_gb` is `12GB - 4GB = 8GB`.
|
||||
|
||||
Most users who override `max_cache_vram_gb` are doing so because they wish to use significantly less VRAM, and should be setting `max_cache_vram_gb` to a value significantly less than the 'max safe value'.
|
||||
|
||||
### Working memory
|
||||
|
||||
Invoke cannot use _all_ of your VRAM for model caching and loading. It requires some VRAM to use as working memory for various operations.
|
||||
|
||||
@@ -48,7 +48,9 @@ async def enqueue_batch(
|
||||
) -> EnqueueBatchResult:
|
||||
"""Processes a batch and enqueues the output graphs for execution."""
|
||||
|
||||
return ApiDependencies.invoker.services.session_queue.enqueue_batch(queue_id=queue_id, batch=batch, prepend=prepend)
|
||||
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
|
||||
queue_id=queue_id, batch=batch, prepend=prepend
|
||||
)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
@@ -15,11 +11,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
@@ -36,39 +28,15 @@ from invokeai.app.api.routers import (
|
||||
workflows,
|
||||
)
|
||||
from invokeai.app.api.sockets import SocketIO
|
||||
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
# fix for windows mimetypes registry entries being borked
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
torch_device_name = TorchDevice.get_torch_device_name()
|
||||
logger.info(f"Using torch device: {torch_device_name}")
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
# We may change the port if the default is in use, this global variable is used to store the port so that we can log
|
||||
# the correct port when the server starts in the lifespan handler.
|
||||
port = app_config.port
|
||||
|
||||
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
|
||||
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
|
||||
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
|
||||
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
@@ -77,7 +45,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Log the server address when it starts - in case the network log level is not high enough to see the startup log
|
||||
proto = "https" if app_config.ssl_certfile else "http"
|
||||
msg = f"Invoke running on {proto}://{app_config.host}:{port} (Press CTRL+C to quit)"
|
||||
msg = f"Invoke running on {proto}://{app_config.host}:{app_config.port} (Press CTRL+C to quit)"
|
||||
|
||||
# Logging this way ignores the logger's log level and _always_ logs the message
|
||||
record = logger.makeRecord(
|
||||
@@ -192,73 +160,3 @@ except RuntimeError:
|
||||
app.mount(
|
||||
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
|
||||
) # docs favicon is in here
|
||||
|
||||
|
||||
def check_cudnn(logger: logging.Logger) -> None:
|
||||
"""Check for cuDNN issues that could be causing degraded performance."""
|
||||
if torch.backends.cudnn.is_available():
|
||||
try:
|
||||
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
|
||||
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
logger.info(f"cuDNN version: {cudnn_version}")
|
||||
except RuntimeError as e:
|
||||
logger.warning(
|
||||
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
|
||||
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
|
||||
f"system. Full error message:\n{e}"
|
||||
)
|
||||
|
||||
|
||||
def invoke_api() -> None:
|
||||
def find_port(port: int) -> int:
|
||||
"""Find a port not in use starting at given port"""
|
||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||
# https://github.com/WaylonWalker
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
if s.connect_ex(("localhost", port)) == 0:
|
||||
return find_port(port=port + 1)
|
||||
else:
|
||||
return port
|
||||
|
||||
if app_config.dev_reload:
|
||||
try:
|
||||
import jurigged
|
||||
except ImportError as e:
|
||||
logger.error(
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.',
|
||||
exc_info=e,
|
||||
)
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
global port
|
||||
port = find_port(app_config.port)
|
||||
if port != app_config.port:
|
||||
logger.warn(f"Port {app_config.port} in use, using port {port}")
|
||||
|
||||
check_cudnn(logger)
|
||||
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
port=port,
|
||||
loop="asyncio",
|
||||
log_level=app_config.log_level_network,
|
||||
ssl_certfile=app_config.ssl_certfile,
|
||||
ssl_keyfile=app_config.ssl_keyfile,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
|
||||
uvicorn_logger.handlers.clear()
|
||||
for hdlr in logger.handlers:
|
||||
uvicorn_logger.addHandler(hdlr)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_api()
|
||||
|
||||
@@ -41,16 +41,11 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision).
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1090 # Determined experimentally.
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add a 20% buffer to the working memory estimate to be safe.
|
||||
working_memory = working_memory * 1.2
|
||||
return int(working_memory)
|
||||
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
|
||||
@@ -60,7 +60,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision). This estimate is accurate for both SD1 and SDXL.
|
||||
element_size = 4 if self.fp32 else 2
|
||||
scaling_constant = 960 # Determined experimentally.
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
|
||||
if use_tiling:
|
||||
tile_size = self.tile_size
|
||||
@@ -84,9 +84,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
|
||||
working_memory += 250 * 2**20
|
||||
|
||||
# We add 20% to the working memory estimate to be safe.
|
||||
working_memory = int(working_memory * 1.2)
|
||||
return working_memory
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
|
||||
@@ -361,7 +361,7 @@ class MetadataToIntegerInvocation(BaseInvocation, WithMetadata):
|
||||
title="Metadata To Float",
|
||||
tags=["metadata"],
|
||||
category="metadata",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class MetadataToFloatInvocation(BaseInvocation, WithMetadata):
|
||||
@@ -377,7 +377,7 @@ class MetadataToFloatInvocation(BaseInvocation, WithMetadata):
|
||||
description=FieldDescriptions.metadata_item_label,
|
||||
input=Input.Direct,
|
||||
)
|
||||
default_value: int = InputField(description="The default float to use if not found in the metadata")
|
||||
default_value: float = InputField(description="The default float to use if not found in the metadata")
|
||||
|
||||
_validate_custom_label = model_validator(mode="after")(validate_custom_label)
|
||||
|
||||
|
||||
@@ -43,16 +43,11 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoencoderKL) -> int:
|
||||
"""Estimate the working memory required by the invocation in bytes."""
|
||||
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
|
||||
# element size (precision).
|
||||
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
|
||||
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
|
||||
element_size = next(vae.parameters()).element_size()
|
||||
scaling_constant = 1230 # Determined experimentally.
|
||||
scaling_constant = 2200 # Determined experimentally.
|
||||
working_memory = out_h * out_w * element_size * scaling_constant
|
||||
|
||||
# We add a 20% buffer to the working memory estimate to be safe.
|
||||
working_memory = working_memory * 1.2
|
||||
return int(working_memory)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -1,12 +1,82 @@
|
||||
"""This is a wrapper around the main app entrypoint, to allow for CLI args to be parsed before running the app."""
|
||||
import uvicorn
|
||||
|
||||
from invokeai.app.invocations.load_custom_nodes import load_custom_nodes
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
|
||||
def get_app():
|
||||
"""Import the app and event loop. We wrap this in a function to more explicitly control when it happens, because
|
||||
importing from api_app does a bunch of stuff - it's more like calling a function than importing a module.
|
||||
"""
|
||||
from invokeai.app.api_app import app, loop
|
||||
|
||||
return app, loop
|
||||
|
||||
|
||||
def run_app() -> None:
|
||||
# Before doing _anything_, parse CLI args!
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
"""The main entrypoint for the app."""
|
||||
# Parse the CLI arguments.
|
||||
InvokeAIArgs.parse_args()
|
||||
|
||||
from invokeai.app.api_app import invoke_api
|
||||
# Load config.
|
||||
app_config = get_config()
|
||||
|
||||
invoke_api()
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
|
||||
# Configure the torch CUDA memory allocator.
|
||||
# NOTE: It is important that this happens before torch is imported.
|
||||
if app_config.pytorch_cuda_alloc_conf:
|
||||
configure_torch_cuda_allocator(app_config.pytorch_cuda_alloc_conf, logger)
|
||||
|
||||
# Import from startup_utils here to avoid importing torch before configure_torch_cuda_allocator() is called.
|
||||
from invokeai.app.util.startup_utils import (
|
||||
apply_monkeypatches,
|
||||
check_cudnn,
|
||||
enable_dev_reload,
|
||||
find_open_port,
|
||||
register_mime_types,
|
||||
)
|
||||
|
||||
# Find an open port, and modify the config accordingly.
|
||||
orig_config_port = app_config.port
|
||||
app_config.port = find_open_port(app_config.port)
|
||||
if orig_config_port != app_config.port:
|
||||
logger.warning(f"Port {orig_config_port} is already in use. Using port {app_config.port}.")
|
||||
|
||||
# Miscellaneous startup tasks.
|
||||
apply_monkeypatches()
|
||||
register_mime_types()
|
||||
if app_config.dev_reload:
|
||||
enable_dev_reload()
|
||||
check_cudnn(logger)
|
||||
|
||||
# Initialize the app and event loop.
|
||||
app, loop = get_app()
|
||||
|
||||
# Load custom nodes. This must be done after importing the Graph class, which itself imports all modules from the
|
||||
# invocations module. The ordering here is implicit, but important - we want to load custom nodes after all the
|
||||
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
|
||||
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path)
|
||||
|
||||
# Start the server.
|
||||
config = uvicorn.Config(
|
||||
app=app,
|
||||
host=app_config.host,
|
||||
port=app_config.port,
|
||||
loop="asyncio",
|
||||
log_level=app_config.log_level_network,
|
||||
ssl_certfile=app_config.ssl_certfile,
|
||||
ssl_keyfile=app_config.ssl_keyfile,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# replace uvicorn's loggers with InvokeAI's for consistent appearance
|
||||
uvicorn_logger = InvokeAILogger.get_logger("uvicorn")
|
||||
uvicorn_logger.handlers.clear()
|
||||
for hdlr in logger.handlers:
|
||||
uvicorn_logger.addHandler(hdlr)
|
||||
|
||||
loop.run_until_complete(server.serve())
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||
@@ -13,15 +12,9 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def add_image_to_board(
|
||||
self,
|
||||
@@ -29,8 +22,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO board_images (board_id, image_name)
|
||||
VALUES (?, ?)
|
||||
@@ -42,16 +35,14 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def remove_image_from_board(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM board_images
|
||||
WHERE image_name = ?;
|
||||
@@ -62,8 +53,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
@@ -72,33 +61,27 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
# TODO: this isn't paginated yet?
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
ORDER BY board_images.updated_at DESC;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images WHERE 1=1;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
@@ -107,98 +90,79 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
params: list[str | bool] = []
|
||||
|
||||
params: list[str | bool] = []
|
||||
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
# Base query is a join between images and board_images
|
||||
stmt = """
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
params.append(board_id)
|
||||
params.append(board_id)
|
||||
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
# Add the category filter
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
stmt += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
params.append(c)
|
||||
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
# Add the is_intermediate filter
|
||||
if is_intermediate is not None:
|
||||
stmt += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
params.append(is_intermediate)
|
||||
params.append(is_intermediate)
|
||||
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
# Put a ring on it
|
||||
stmt += ";"
|
||||
|
||||
# Execute the query
|
||||
self._cursor.execute(stmt, params)
|
||||
# Execute the query
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(stmt, params)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
return image_names
|
||||
|
||||
def get_board_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT board_id
|
||||
FROM board_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
(image_name,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM board_images
|
||||
INNER JOIN images ON board_images.image_name = images.image_name
|
||||
WHERE images.is_intermediate = FALSE
|
||||
AND board_images.board_id = ?;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
return count
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
(board_id,),
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
return count
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Union, cast
|
||||
|
||||
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
||||
@@ -19,20 +18,14 @@ from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM boards
|
||||
WHERE board_id = ?;
|
||||
@@ -40,14 +33,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
(board_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
except Exception as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -55,8 +43,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
board_id = uuid_string()
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO boards (board_id, board_name)
|
||||
VALUES (?, ?);
|
||||
@@ -67,8 +55,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get(
|
||||
@@ -76,8 +62,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
board_id: str,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM boards
|
||||
@@ -86,12 +72,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return BoardRecord(**dict(result))
|
||||
@@ -102,11 +85,10 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
changes: BoardChanges,
|
||||
) -> BoardRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a board
|
||||
if changes.board_name is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET board_name = ?
|
||||
@@ -117,7 +99,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
|
||||
# Change the cover image of a board
|
||||
if changes.cover_image_name is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET cover_image_name = ?
|
||||
@@ -128,7 +110,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
|
||||
# Change the archived status of a board
|
||||
if changes.archived is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE boards
|
||||
SET archived = ?
|
||||
@@ -141,8 +123,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise BoardRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
@@ -153,11 +133,10 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build base query
|
||||
base_query = """
|
||||
# Build base query
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
@@ -165,81 +144,67 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
# Determine archived filter condition
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
self._cursor.execute(final_query, (limit, offset))
|
||||
# Execute query to fetch boards
|
||||
cursor.execute(final_query, (limit, offset))
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
# Determine count query
|
||||
if include_archived:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards;
|
||||
"""
|
||||
else:
|
||||
count_query = """
|
||||
else:
|
||||
count_query = """
|
||||
SELECT COUNT(*)
|
||||
FROM boards
|
||||
WHERE archived = 0;
|
||||
"""
|
||||
|
||||
# Execute count query
|
||||
self._cursor.execute(count_query)
|
||||
# Execute count query
|
||||
cursor.execute(count_query)
|
||||
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
cursor = self._conn.cursor()
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
base_query = """
|
||||
else:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
self._cursor.execute(final_query)
|
||||
cursor.execute(final_query)
|
||||
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
boards = [deserialize_board_record(dict(r)) for r in result]
|
||||
|
||||
return boards
|
||||
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
return boards
|
||||
|
||||
@@ -91,6 +91,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
|
||||
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
|
||||
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
@@ -169,6 +170,9 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
|
||||
lazy_offload: bool = Field(default=True, description="DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.")
|
||||
|
||||
# PyTorch Memory Allocator
|
||||
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
@@ -22,21 +21,14 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS} FROM images
|
||||
WHERE image_name = ?;
|
||||
@@ -44,12 +36,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -58,9 +47,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata FROM images
|
||||
WHERE image_name = ?;
|
||||
@@ -68,7 +56,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
if not result:
|
||||
raise ImageRecordNotFoundException
|
||||
@@ -77,10 +65,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
|
||||
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordNotFoundException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -88,10 +73,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
changes: ImageRecordChanges,
|
||||
) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
cursor = self._conn.cursor()
|
||||
# Change the category of the image
|
||||
if changes.image_category is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET image_category = ?
|
||||
@@ -102,7 +87,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
# Change the session associated with the image
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET session_id = ?
|
||||
@@ -113,7 +98,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.is_intermediate is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
@@ -124,7 +109,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
# Change the image's `starred`` state
|
||||
if changes.starred is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE images
|
||||
SET starred = ?
|
||||
@@ -137,8 +122,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
@@ -152,110 +135,104 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
# Manually build two queries - one for the count, one for the records
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if categories is not None:
|
||||
# Convert the enum values to unique list of strings
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND images.metadata LIKE ?
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
# Build the list of images, deserializing each row
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
# board_id of "none" is reserved for images without a board
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND images.metadata LIKE ?
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
if starred_first:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
else:
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Final images query with pagination
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
# Add the pagination parameters
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
self._cursor.execute(count_query, count_params)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
# Set up and execute the count query, without pagination
|
||||
count_query += query_conditions + ";"
|
||||
count_params = query_params.copy()
|
||||
cursor.execute(count_query, count_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE image_name = ?;
|
||||
@@ -266,58 +243,48 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
try:
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
self._lock.acquire()
|
||||
placeholders = ",".join("?" for _ in image_names)
|
||||
|
||||
# Construct the SQLite query with the placeholders
|
||||
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
|
||||
|
||||
# Execute the query with the list of IDs as parameters
|
||||
self._cursor.execute(query, image_names)
|
||||
cursor.execute(query, image_names)
|
||||
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_intermediates_count(self) -> int:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, self._cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
return count
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*) FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
self._conn.commit()
|
||||
return count
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT image_name FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
"""
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [r[0] for r in result]
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM images
|
||||
WHERE is_intermediate = TRUE;
|
||||
@@ -328,8 +295,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordDeleteException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -346,8 +311,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
metadata: Optional[str] = None,
|
||||
) -> datetime:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO images (
|
||||
image_name,
|
||||
@@ -380,7 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT created_at
|
||||
FROM images
|
||||
@@ -389,34 +354,30 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
(image_name,),
|
||||
)
|
||||
|
||||
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
|
||||
created_at = datetime.fromisoformat(cursor.fetchone()[0])
|
||||
|
||||
return created_at
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise ImageRecordSaveException from e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT images.*
|
||||
FROM images
|
||||
JOIN board_images ON images.image_name = board_images.image_name
|
||||
WHERE board_images.board_id = ?
|
||||
AND images.is_intermediate = FALSE
|
||||
ORDER BY images.starred DESC, images.created_at DESC
|
||||
LIMIT 1;
|
||||
""",
|
||||
(board_id,),
|
||||
)
|
||||
|
||||
result = cast(Optional[sqlite3.Row], cursor.fetchone())
|
||||
|
||||
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
|
||||
finally:
|
||||
self._lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
|
||||
@@ -78,7 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._cursor = db.conn.cursor()
|
||||
self._logger = logger
|
||||
|
||||
@property
|
||||
@@ -96,38 +95,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
|
||||
"""
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO models (
|
||||
id,
|
||||
config
|
||||
)
|
||||
VALUES (?,?);
|
||||
""",
|
||||
(
|
||||
config.key,
|
||||
config.model_dump_json(),
|
||||
),
|
||||
)
|
||||
self._db.conn.commit()
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(config.key)
|
||||
|
||||
@@ -139,21 +138,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
@@ -164,23 +163,23 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, key),
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
try:
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
""",
|
||||
(json_serialized, key),
|
||||
)
|
||||
if cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
@@ -192,33 +191,33 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = self._cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
rows = cursor.fetchone()
|
||||
if not rows:
|
||||
raise UnknownModelException("model not found")
|
||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||
return model
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
@@ -227,16 +226,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
count = 0
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
|
||||
def search_by_attr(
|
||||
@@ -284,17 +282,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = self._cursor.fetchall()
|
||||
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config, strftime('%s',updated_at)
|
||||
FROM models
|
||||
{where}
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
|
||||
""",
|
||||
tuple(bindings),
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# Parse the model configs.
|
||||
results: list[AnyModelConfig] = []
|
||||
@@ -313,34 +312,28 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [
|
||||
ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in self._cursor.fetchall()
|
||||
]
|
||||
cursor = self._db.conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def list_models(
|
||||
@@ -356,33 +349,32 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
with self._db.lock:
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
cursor = self._db.conn.cursor()
|
||||
|
||||
# query2: fetch key fields
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = self._cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(
|
||||
page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items
|
||||
)
|
||||
# Lock so that the database isn't updated while we're doing the two queries.
|
||||
# query1: get the total number of model configs
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT config
|
||||
FROM models
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
""",
|
||||
(
|
||||
per_page,
|
||||
page * per_page,
|
||||
),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
items = [ModelSummary.model_validate(dict(x)) for x in rows]
|
||||
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Any, Coroutine, Optional
|
||||
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
@@ -33,7 +33,7 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> Coroutine[Any, Any, EnqueueBatchResult]:
|
||||
"""Enqueues all permutations of a batch for execution."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from pydantic_core import to_jsonable_python
|
||||
@@ -37,9 +37,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
class SqliteSessionQueue(SessionQueueBase):
|
||||
__invoker: Invoker
|
||||
__conn: sqlite3.Connection
|
||||
__cursor: sqlite3.Cursor
|
||||
__lock: threading.RLock
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
@@ -55,9 +52,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self.__lock = db.lock
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
self._conn = db.conn
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
@@ -65,8 +60,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -74,14 +69,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items"""
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
@@ -91,11 +85,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, self.__cursor.fetchone()[0])
|
||||
return cast(int, cursor.fetchone()[0])
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue"""
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
@@ -105,12 +100,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
|
||||
return cast(Union[int, None], cursor.fetchone()[0]) or 0
|
||||
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
|
||||
|
||||
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
@@ -132,19 +129,17 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if requested_count > enqueued_count:
|
||||
values_to_insert = values_to_insert[:max_new_queue_items]
|
||||
|
||||
self.__cursor.executemany(
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
@@ -156,25 +151,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -182,52 +171,40 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'pending'
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
created_at ASC
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
@@ -241,8 +218,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
|
||||
@@ -250,12 +227,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(status, error_type, error_message, error_traceback, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
@@ -263,48 +238,36 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, cursor.fetchone()[0]) == 0
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -312,8 +275,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
@@ -321,17 +284,16 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
@@ -341,8 +303,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
OR status = 'canceled'
|
||||
)
|
||||
"""
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -350,8 +311,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
@@ -359,12 +320,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
@@ -393,8 +352,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
@@ -405,7 +364,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -413,8 +372,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -422,20 +381,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
@@ -445,7 +402,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = (queue_id, destination)
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -453,8 +410,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -462,20 +419,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
params,
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByDestinationResult(canceled=count)
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id is ?
|
||||
@@ -484,7 +439,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id]
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -492,8 +447,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -501,7 +456,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
@@ -509,21 +464,19 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
current_queue_item, batch_status, queue_status
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id == ?
|
||||
AND status == 'pending'
|
||||
"""
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -531,8 +484,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
count = cursor.fetchone()[0]
|
||||
cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -540,43 +493,35 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelAllExceptCurrentResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
|
||||
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
|
||||
# during execution.
|
||||
session_json = session.model_dump_json(warnings=False, exclude_none=True)
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET session = ?
|
||||
@@ -584,12 +529,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(session_json, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def list_queue_items(
|
||||
@@ -600,83 +543,71 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
try:
|
||||
item_id = cursor
|
||||
self.__lock.acquire()
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
status,
|
||||
priority,
|
||||
field_values,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(status)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
cursor_ = self._conn.cursor()
|
||||
item_id = cursor
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
status,
|
||||
priority,
|
||||
field_values,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
created_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
started_at,
|
||||
session_id,
|
||||
batch_id,
|
||||
queue_id,
|
||||
origin,
|
||||
destination
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
"""
|
||||
params: list[Union[str, int]] = [queue_id]
|
||||
|
||||
if status is not None:
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
AND status = ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
self.__cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
items.pop()
|
||||
has_more = True
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
params.append(status)
|
||||
|
||||
if item_id is not None:
|
||||
query += """--sql
|
||||
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
||||
"""
|
||||
params.extend([priority, priority, item_id])
|
||||
|
||||
query += """--sql
|
||||
ORDER BY
|
||||
priority DESC,
|
||||
item_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
cursor_.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], cursor_.fetchall())
|
||||
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
# remove the extra item
|
||||
items.pop()
|
||||
has_more = True
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(row[1] for row in counts_result)
|
||||
@@ -695,29 +626,23 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*), origin, destination
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
origin = result[0]["origin"] if result else None
|
||||
destination = result[0]["destination"] if result else None
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
@@ -733,24 +658,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
|
||||
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE queue_id = ?
|
||||
AND destination = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, destination),
|
||||
)
|
||||
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
total = sum(row[1] for row in counts_result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
||||
@@ -769,8 +688,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
"""Retries the given queue items"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
|
||||
cursor = self._conn.cursor()
|
||||
values_to_insert: list[tuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
@@ -813,7 +731,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
# TODO(psyche): Handle max queue size?
|
||||
|
||||
self.__cursor.executemany(
|
||||
cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
@@ -821,12 +739,10 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
values_to_insert,
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
retry_result = RetryItemsResult(
|
||||
queue_id=queue_id,
|
||||
retried_item_ids=retried_item_ids,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
@@ -38,14 +37,20 @@ class SqliteDatabase:
|
||||
self.logger.info(f"Initializing database at {self.db_path}")
|
||||
|
||||
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self.verbose:
|
||||
self.conn.set_trace_callback(self.logger.debug)
|
||||
|
||||
# Enable foreign key constraints
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
# Enable Write-Ahead Logging (WAL) mode for better concurrency
|
||||
self.conn.execute("PRAGMA journal_mode = WAL;")
|
||||
|
||||
# Set a busy timeout to prevent database lockups during writes
|
||||
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
@@ -53,15 +58,14 @@ class SqliteDatabase:
|
||||
# No need to clean in-memory database
|
||||
if not self.db_path:
|
||||
return
|
||||
with self.lock:
|
||||
try:
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
try:
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
|
||||
@@ -43,46 +43,45 @@ class SqliteMigrator:
|
||||
|
||||
def run_migrations(self) -> bool:
|
||||
"""Migrates the database to the latest version."""
|
||||
with self._db.lock:
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
# This throws if there is a problem.
|
||||
self._migration_set.validate_migration_chain()
|
||||
cursor = self._db.conn.cursor()
|
||||
self._create_migrations_table(cursor=cursor)
|
||||
|
||||
if self._migration_set.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return False
|
||||
if self._migration_set.count == 0:
|
||||
self._logger.debug("No migrations registered")
|
||||
return False
|
||||
|
||||
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return False
|
||||
if self._get_current_version(cursor=cursor) == self._migration_set.latest_version:
|
||||
self._logger.debug("Database is up to date, no migrations to run")
|
||||
return False
|
||||
|
||||
self._logger.info("Database update needed")
|
||||
self._logger.info("Database update needed")
|
||||
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
# Make a backup of the db if it needs to be updated and is a file db
|
||||
if self._db.db_path is not None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
|
||||
self._logger.info(f"Backing up database to {str(self._backup_path)}")
|
||||
# Use SQLite to do the backup
|
||||
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
|
||||
self._db.conn.backup(backup_conn)
|
||||
else:
|
||||
self._logger.info("Using in-memory database, no backup needed")
|
||||
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
next_migration = self._migration_set.get(self._get_current_version(cursor))
|
||||
self._logger.info("Database updated successfully")
|
||||
return True
|
||||
next_migration = self._migration_set.get(from_version=self._get_current_version(cursor))
|
||||
while next_migration is not None:
|
||||
self._run_migration(next_migration)
|
||||
next_migration = self._migration_set.get(self._get_current_version(cursor))
|
||||
self._logger.info("Database updated successfully")
|
||||
return True
|
||||
|
||||
def _run_migration(self, migration: Migration) -> None:
|
||||
"""Runs a single migration."""
|
||||
try:
|
||||
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
|
||||
# exception is raised.
|
||||
with self._db.lock, self._db.conn as conn:
|
||||
with self._db.conn as conn:
|
||||
cursor = conn.cursor()
|
||||
if self._get_current_version(cursor) != migration.from_version:
|
||||
raise MigrationError(
|
||||
@@ -108,27 +107,26 @@ class SqliteMigrator:
|
||||
|
||||
def _create_migrations_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""Creates the migrations table for the database, if one does not already exist."""
|
||||
with self._db.lock:
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if cursor.fetchone() is not None:
|
||||
return
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
cursor.connection.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
cursor.connection.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
try:
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='migrations';")
|
||||
if cursor.fetchone() is not None:
|
||||
return
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE migrations (
|
||||
version INTEGER PRIMARY KEY,
|
||||
migrated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW'))
|
||||
);
|
||||
"""
|
||||
)
|
||||
cursor.execute("INSERT INTO migrations (version) VALUES (0);")
|
||||
cursor.connection.commit()
|
||||
self._logger.debug("Created migrations table")
|
||||
except sqlite3.Error as e:
|
||||
msg = f"Problem creating migrations table: {e}"
|
||||
self._logger.error(msg)
|
||||
cursor.connection.rollback()
|
||||
raise MigrationError(msg) from e
|
||||
|
||||
@classmethod
|
||||
def _get_current_version(cls, cursor: sqlite3.Cursor) -> int:
|
||||
|
||||
@@ -17,9 +17,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -27,31 +25,25 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
"""Gets a style preset by ID."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM style_presets
|
||||
WHERE id = ?;
|
||||
""",
|
||||
(style_preset_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return StylePresetRecordDTO.from_dict(dict(row))
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
@@ -72,18 +64,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
style_preset_ids = []
|
||||
try:
|
||||
self._lock.acquire()
|
||||
cursor = self._conn.cursor()
|
||||
for style_preset in style_presets:
|
||||
style_preset_id = uuid_string()
|
||||
style_preset_ids.append(style_preset_id)
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO style_presets (
|
||||
id,
|
||||
@@ -104,17 +94,15 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
return None
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
cursor = self._conn.cursor()
|
||||
# Change the name of a style preset
|
||||
if changes.name is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET name = ?
|
||||
@@ -125,7 +113,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
|
||||
# Change the preset data for a style preset
|
||||
if changes.preset_data is not None:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE style_presets
|
||||
SET preset_data = ?
|
||||
@@ -138,14 +126,12 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from style_presets
|
||||
WHERE id = ?;
|
||||
@@ -156,46 +142,38 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
main_query = """
|
||||
SELECT
|
||||
*
|
||||
FROM style_presets
|
||||
"""
|
||||
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
if type is not None:
|
||||
main_query += "WHERE type = ? "
|
||||
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
main_query += "ORDER BY LOWER(name) ASC"
|
||||
|
||||
if type is not None:
|
||||
self._cursor.execute(main_query, (type,))
|
||||
else:
|
||||
self._cursor.execute(main_query)
|
||||
cursor = self._conn.cursor()
|
||||
if type is not None:
|
||||
cursor.execute(main_query, (type,))
|
||||
else:
|
||||
cursor.execute(main_query)
|
||||
|
||||
rows = self._cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
rows = cursor.fetchall()
|
||||
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
|
||||
|
||||
return style_presets
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return style_presets
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database. Internal use only."""
|
||||
|
||||
# First delete all existing default style presets
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM style_presets
|
||||
WHERE type = "default";
|
||||
@@ -205,10 +183,8 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
# Next, parse and create the default style presets
|
||||
with self._lock, open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
for preset in presets:
|
||||
style_preset = StylePresetWithoutId.model_validate(preset)
|
||||
|
||||
@@ -23,9 +23,7 @@ from invokeai.app.util.misc import uuid_string
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
@@ -33,42 +31,36 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow by ID. Updates the opened_at column."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowRecordDTO.from_dict(dict(row))
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowRecordDTO.from_dict(dict(row))
|
||||
|
||||
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||
try:
|
||||
# Only user workflows may be created by this method
|
||||
assert workflow.meta.category is WorkflowCategory.User
|
||||
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR IGNORE INTO workflow_library (
|
||||
workflow_id,
|
||||
@@ -82,14 +74,12 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow_with_id.id)
|
||||
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
SET workflow = ?
|
||||
@@ -101,14 +91,12 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow.id)
|
||||
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
WHERE workflow_id = ? AND category = 'user';
|
||||
@@ -119,8 +107,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(
|
||||
@@ -132,66 +118,60 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
per_page: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
assert category in WorkflowCategory
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at
|
||||
FROM workflow_library
|
||||
WHERE category = ?
|
||||
"""
|
||||
main_params: list[int | str] = [category.value]
|
||||
count_params: list[int | str] = [category.value]
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
assert category in WorkflowCategory
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at
|
||||
FROM workflow_library
|
||||
WHERE category = ?
|
||||
"""
|
||||
main_params: list[int | str] = [category.value]
|
||||
count_params: list[int | str] = [category.value]
|
||||
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
main_query += " AND name LIKE ? OR description LIKE ? "
|
||||
count_query += " AND name LIKE ? OR description LIKE ?;"
|
||||
main_params.extend([wildcard_query, wildcard_query])
|
||||
count_params.extend([wildcard_query, wildcard_query])
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
main_query += " AND name LIKE ? OR description LIKE ? "
|
||||
count_query += " AND name LIKE ? OR description LIKE ?;"
|
||||
main_params.extend([wildcard_query, wildcard_query])
|
||||
count_params.extend([wildcard_query, wildcard_query])
|
||||
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value}"
|
||||
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
if per_page:
|
||||
main_query += " LIMIT ? OFFSET ?"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
|
||||
self._cursor.execute(main_query, main_params)
|
||||
rows = self._cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(main_query, main_params)
|
||||
rows = cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
|
||||
self._cursor.execute(count_query, count_params)
|
||||
total = self._cursor.fetchone()[0]
|
||||
cursor.execute(count_query, count_params)
|
||||
total = cursor.fetchone()[0]
|
||||
|
||||
if per_page:
|
||||
pages = total // per_page + (total % per_page > 0)
|
||||
else:
|
||||
pages = 1 # If no pagination, there is only one page
|
||||
if per_page:
|
||||
pages = total // per_page + (total % per_page > 0)
|
||||
else:
|
||||
pages = 1 # If no pagination, there is only one page
|
||||
|
||||
return PaginatedResults(
|
||||
items=workflows,
|
||||
page=page,
|
||||
per_page=per_page if per_page else total,
|
||||
pages=pages,
|
||||
total=total,
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return PaginatedResults(
|
||||
items=workflows,
|
||||
page=page,
|
||||
per_page=per_page if per_page else total,
|
||||
pages=pages,
|
||||
total=total,
|
||||
)
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database. Internal use only."""
|
||||
@@ -207,7 +187,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
"""
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
workflows: list[Workflow] = []
|
||||
workflows_dir = Path(__file__).parent / Path("default_workflows")
|
||||
workflow_paths = workflows_dir.glob("*.json")
|
||||
@@ -218,14 +197,15 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
workflows.append(workflow)
|
||||
# Only default workflows may be managed by this method
|
||||
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
|
||||
self._cursor.execute(
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM workflow_library
|
||||
WHERE category = 'default';
|
||||
"""
|
||||
)
|
||||
for w in workflows:
|
||||
self._cursor.execute(
|
||||
cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR REPLACE INTO workflow_library (
|
||||
workflow_id,
|
||||
@@ -239,5 +219,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
64
invokeai/app/util/startup_utils.py
Normal file
64
invokeai/app/util/startup_utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def find_open_port(port: int) -> int:
|
||||
"""Find a port not in use starting at given port"""
|
||||
# Taken from https://waylonwalker.com/python-find-available-port/, thanks Waylon!
|
||||
# https://github.com/WaylonWalker
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
if s.connect_ex(("localhost", port)) == 0:
|
||||
return find_open_port(port=port + 1)
|
||||
else:
|
||||
return port
|
||||
|
||||
|
||||
def check_cudnn(logger: logging.Logger) -> None:
|
||||
"""Check for cuDNN issues that could be causing degraded performance."""
|
||||
if torch.backends.cudnn.is_available():
|
||||
try:
|
||||
# Note: At the time of writing (torch 2.2.1), torch.backends.cudnn.version() only raises an error the first
|
||||
# time it is called. Subsequent calls will return the version number without complaining about a mismatch.
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
logger.info(f"cuDNN version: {cudnn_version}")
|
||||
except RuntimeError as e:
|
||||
logger.warning(
|
||||
"Encountered a cuDNN version issue. This may result in degraded performance. This issue is usually "
|
||||
"caused by an incompatible cuDNN version installed in your python environment, or on the host "
|
||||
f"system. Full error message:\n{e}"
|
||||
)
|
||||
|
||||
|
||||
def enable_dev_reload() -> None:
|
||||
"""Enable hot reloading on python file changes during development."""
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
try:
|
||||
import jurigged
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
|
||||
) from e
|
||||
else:
|
||||
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
|
||||
|
||||
|
||||
def apply_monkeypatches() -> None:
|
||||
"""Apply monkeypatches to fix issues with third-party libraries."""
|
||||
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
|
||||
|
||||
def register_mime_types() -> None:
|
||||
"""Register additional mime types for windows."""
|
||||
# Fix for windows mimetypes registry entries being borked.
|
||||
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
42
invokeai/app/util/torch_cuda_allocator.py
Normal file
42
invokeai/app/util/torch_cuda_allocator.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
def configure_torch_cuda_allocator(pytorch_cuda_alloc_conf: str, logger: logging.Logger | None = None):
|
||||
"""Configure the PyTorch CUDA memory allocator. See
|
||||
https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf for supported
|
||||
configurations.
|
||||
"""
|
||||
|
||||
# Raise if the PYTORCH_CUDA_ALLOC_CONF environment variable is already set.
|
||||
prev_cuda_alloc_conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None)
|
||||
if prev_cuda_alloc_conf is not None:
|
||||
raise RuntimeError(
|
||||
f"Attempted to configure the PyTorch CUDA memory allocator, but PYTORCH_CUDA_ALLOC_CONF is already set to "
|
||||
f"'{prev_cuda_alloc_conf}'."
|
||||
)
|
||||
|
||||
# Configure the PyTorch CUDA memory allocator.
|
||||
# NOTE: It is important that this happens before torch is imported.
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = pytorch_cuda_alloc_conf
|
||||
|
||||
import torch
|
||||
|
||||
# Relevant docs: https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
"Attempted to configure the PyTorch CUDA memory allocator, but no CUDA devices are available."
|
||||
)
|
||||
|
||||
# Verify that the torch allocator was properly configured.
|
||||
allocator_backend = torch.cuda.get_allocator_backend()
|
||||
expected_backend = "cudaMallocAsync" if "cudaMallocAsync" in pytorch_cuda_alloc_conf else "native"
|
||||
if allocator_backend != expected_backend:
|
||||
raise RuntimeError(
|
||||
f"Failed to configure the PyTorch CUDA memory allocator. Expected backend: '{expected_backend}', but got "
|
||||
f"'{allocator_backend}'. Verify that 1) the pytorch_cuda_alloc_conf is set correctly, and 2) that torch is "
|
||||
"not imported before calling configure_torch_cuda_allocator()."
|
||||
)
|
||||
|
||||
if logger is not None:
|
||||
logger.info(f"PyTorch CUDA memory allocator: {torch.cuda.get_allocator_backend()}")
|
||||
@@ -75,6 +75,8 @@
|
||||
"idb-keyval": "^6.2.1",
|
||||
"jsondiffpatch": "^0.6.0",
|
||||
"konva": "^9.3.15",
|
||||
"linkify-react": "^4.2.0",
|
||||
"linkifyjs": "^4.2.0",
|
||||
"lodash-es": "^4.17.21",
|
||||
"lru-cache": "^11.0.1",
|
||||
"mtwist": "^1.0.2",
|
||||
|
||||
20
invokeai/frontend/web/pnpm-lock.yaml
generated
20
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -74,6 +74,12 @@ dependencies:
|
||||
konva:
|
||||
specifier: ^9.3.15
|
||||
version: 9.3.15
|
||||
linkify-react:
|
||||
specifier: ^4.2.0
|
||||
version: 4.2.0(linkifyjs@4.2.0)(react@18.3.1)
|
||||
linkifyjs:
|
||||
specifier: ^4.2.0
|
||||
version: 4.2.0
|
||||
lodash-es:
|
||||
specifier: ^4.17.21
|
||||
version: 4.17.21
|
||||
@@ -6714,6 +6720,20 @@ packages:
|
||||
resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==}
|
||||
dev: false
|
||||
|
||||
/linkify-react@4.2.0(linkifyjs@4.2.0)(react@18.3.1):
|
||||
resolution: {integrity: sha512-dIcDGo+n4FP2FPIHDcqB7cUE+omkcEgQJpc7sNNP4+XZ9FUhFAkKjGnHMzsZM+B4yF93sK166z9K5cKTe/JpzA==}
|
||||
peerDependencies:
|
||||
linkifyjs: ^4.0.0
|
||||
react: '>= 15.0.0'
|
||||
dependencies:
|
||||
linkifyjs: 4.2.0
|
||||
react: 18.3.1
|
||||
dev: false
|
||||
|
||||
/linkifyjs@4.2.0:
|
||||
resolution: {integrity: sha512-pCj3PrQyATaoTYKHrgWRF3SJwsm61udVh+vuls/Rl6SptiDhgE7ziUIudAedRY9QEfynmM7/RmLEfPUyw1HPCw==}
|
||||
dev: false
|
||||
|
||||
/liqe@3.8.0:
|
||||
resolution: {integrity: sha512-cZ1rDx4XzxONBTskSPBp7/KwJ9qbUdF8EPnY4VjKXwHF1Krz9lgnlMTh1G7kd+KtPYvUte1mhuZeQSnk7KiSBg==}
|
||||
engines: {node: '>=12.0'}
|
||||
|
||||
@@ -921,6 +921,7 @@
|
||||
"currentImage": "Current Image",
|
||||
"currentImageDescription": "Displays the current image in the Node Editor",
|
||||
"downloadWorkflow": "Download Workflow JSON",
|
||||
"downloadWorkflowError": "Error downloading workflow",
|
||||
"edge": "Edge",
|
||||
"edit": "Edit",
|
||||
"editMode": "Edit in Workflow Editor",
|
||||
@@ -1725,6 +1726,10 @@
|
||||
"layout": "Layout",
|
||||
"row": "Row",
|
||||
"column": "Column",
|
||||
"container": "Container",
|
||||
"heading": "Heading",
|
||||
"text": "Text",
|
||||
"divider": "Divider",
|
||||
"nodeField": "Node Field",
|
||||
"zoomToNode": "Zoom to Node",
|
||||
"nodeFieldTooltip": "To add a node field, click the small plus sign button on the field in the Workflow Editor, or drag the field by its name into the form.",
|
||||
|
||||
@@ -98,7 +98,22 @@
|
||||
"close": "Fermer",
|
||||
"clipboard": "Presse-papier",
|
||||
"loadingModel": "Chargement du modèle",
|
||||
"generating": "En Génération"
|
||||
"generating": "En Génération",
|
||||
"warnings": "Alertes",
|
||||
"layout": "Disposition",
|
||||
"row": "Ligne",
|
||||
"column": "Colonne",
|
||||
"start": "Commencer",
|
||||
"board": "Planche",
|
||||
"count": "Quantité",
|
||||
"step": "Étape",
|
||||
"end": "Fin",
|
||||
"min": "Min",
|
||||
"max": "Max",
|
||||
"values": "Valeurs",
|
||||
"resetToDefaults": "Réinitialiser par défaut",
|
||||
"seed": "Graine",
|
||||
"combinatorial": "Combinatoire"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Taille de l'image",
|
||||
@@ -165,7 +180,9 @@
|
||||
"imagesSettings": "Paramètres des images de la galerie",
|
||||
"assetsTab": "Fichiers que vous avez importés pour vos projets.",
|
||||
"imagesTab": "Images que vous avez créées et enregistrées dans Invoke.",
|
||||
"boardsSettings": "Paramètres des planches"
|
||||
"boardsSettings": "Paramètres des planches",
|
||||
"assets": "Ressources",
|
||||
"images": "Images"
|
||||
},
|
||||
"modelManager": {
|
||||
"modelManager": "Gestionnaire de modèle",
|
||||
@@ -289,7 +306,7 @@
|
||||
"usingDefaultSettings": "Utilisation des paramètres par défaut du modèle",
|
||||
"defaultSettingsOutOfSync": "Certain paramètres ne correspondent pas aux valeurs par défaut du modèle :",
|
||||
"restoreDefaultSettings": "Cliquez pour utiliser les paramètres par défaut du modèle.",
|
||||
"hfForbiddenErrorMessage": "Nous vous recommandons de visiter la page du modèle sur HuggingFace.com. Le propriétaire peut exiger l'acceptation des conditions pour pouvoir télécharger.",
|
||||
"hfForbiddenErrorMessage": "Nous vous recommandons de visiter la page du modèle. Le propriétaire peut exiger l'acceptation des conditions pour pouvoir télécharger.",
|
||||
"hfTokenRequired": "Vous essayez de télécharger un modèle qui nécessite un token HuggingFace valide.",
|
||||
"clipLEmbed": "CLIP-L Embed",
|
||||
"hfTokenSaved": "Token HF enregistré",
|
||||
@@ -303,7 +320,10 @@
|
||||
"hfForbidden": "Vous n'avez pas accès à ce modèle HF.",
|
||||
"hfTokenInvalidErrorMessage2": "Mettre à jour dans le ",
|
||||
"controlLora": "Controle LoRA",
|
||||
"urlUnauthorizedErrorMessage2": "Découvrir comment ici."
|
||||
"urlUnauthorizedErrorMessage2": "Découvrir comment ici.",
|
||||
"urlUnauthorizedErrorMessage": "Vous devrez peut-être configurer un jeton API pour accéder à ce modèle.",
|
||||
"urlForbidden": "Vous n'avez pas accès à ce modèle",
|
||||
"urlForbiddenErrorMessage": "Vous devrez peut-être demander l'autorisation du site qui distribue le modèle."
|
||||
},
|
||||
"parameters": {
|
||||
"images": "Images",
|
||||
@@ -345,19 +365,31 @@
|
||||
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), la hauteur de la bounding box est {{height}}",
|
||||
"fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), la hauteur de la bounding box est {{height}}",
|
||||
"noFLUXVAEModelSelected": "Aucun modèle VAE sélectionné pour la génération FLUX",
|
||||
"canvasIsTransforming": "La Toile se transforme",
|
||||
"canvasIsRasterizing": "La Toile se rastérise",
|
||||
"canvasIsTransforming": "La Toile est occupée (en transformation)",
|
||||
"canvasIsRasterizing": "La Toile est occupée (en rastérisation)",
|
||||
"noCLIPEmbedModelSelected": "Aucun modèle CLIP Embed sélectionné pour la génération FLUX",
|
||||
"canvasIsFiltering": "La Toile est en train de filtrer",
|
||||
"canvasIsFiltering": "La Toile est occupée (en filtration)",
|
||||
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), la largeur de la bounding box est {{width}}",
|
||||
"noT5EncoderModelSelected": "Aucun modèle T5 Encoder sélectionné pour la génération FLUX",
|
||||
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), la largeur de la bounding box mise à l'échelle est {{width}}",
|
||||
"canvasIsCompositing": "La toile est en train de composer",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}} : trop peu d'éléments, minimum {{minItems}}",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}} : trop d'éléments, maximum {{maxItems}}",
|
||||
"canvasIsCompositing": "La Toile est occupée (en composition)",
|
||||
"collectionTooFewItems": "trop peu d'éléments, minimum {{minItems}}",
|
||||
"collectionTooManyItems": "trop d'éléments, maximum {{maxItems}}",
|
||||
"canvasIsSelectingObject": "La toile est occupée (sélection d'objet)",
|
||||
"emptyBatches": "lots vides",
|
||||
"batchNodeNotConnected": "Noeud de lots non connecté : {{label}}"
|
||||
"batchNodeNotConnected": "Noeud de lots non connecté : {{label}}",
|
||||
"fluxModelMultipleControlLoRAs": "Vous ne pouvez utiliser qu'un seul Control LoRA à la fois",
|
||||
"collectionNumberLTMin": "{{value}} < {{minimum}} (incl. min)",
|
||||
"collectionNumberGTMax": "{{value}} > {{maximum}} (incl. max)",
|
||||
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (max exc)",
|
||||
"batchNodeEmptyCollection": "Certains nœuds de lot ont des collections vides",
|
||||
"batchNodeCollectionSizeMismatch": "Non-concordance de taille de collection sur le lot {{batchGroupId}}",
|
||||
"collectionStringTooLong": "trop long, max {{maxLength}}",
|
||||
"collectionNumberNotMultipleOf": "{{value}} n'est pas un multiple de {{multipleOf}}",
|
||||
"collectionEmpty": "collection vide",
|
||||
"collectionStringTooShort": "trop court, min {{minLength}}",
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (min exc)",
|
||||
"batchNodeCollectionSizeMismatchNoGroupId": "Taille de collection de groupe par lot non conforme"
|
||||
},
|
||||
"negativePromptPlaceholder": "Prompt Négatif",
|
||||
"positivePromptPlaceholder": "Prompt Positif",
|
||||
@@ -501,7 +533,13 @@
|
||||
"uploadFailedInvalidUploadDesc_withCount_one": "Doit être au maximum une image PNG ou JPEG.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_many": "Doit être au maximum {{count}} images PNG ou JPEG.",
|
||||
"uploadFailedInvalidUploadDesc_withCount_other": "Doit être au maximum {{count}} images PNG ou JPEG.",
|
||||
"addedToUncategorized": "Ajouté aux ressources de la planche $t(boards.uncategorized)"
|
||||
"addedToUncategorized": "Ajouté aux ressources de la planche $t(boards.uncategorized)",
|
||||
"pasteSuccess": "Collé à {{destination}}",
|
||||
"pasteFailed": "Échec du collage",
|
||||
"outOfMemoryErrorDescLocal": "Suivez notre <LinkComponent>guide Low VRAM</LinkComponent> pour réduire les OOMs.",
|
||||
"unableToCopy": "Incapable de Copier",
|
||||
"unableToCopyDesc": "Votre navigateur ne prend pas en charge l'accès au presse-papiers. Les utilisateurs de Firefox peuvent peut-être résoudre ce problème en suivant ",
|
||||
"unableToCopyDesc_theseSteps": "ces étapes"
|
||||
},
|
||||
"accessibility": {
|
||||
"uploadImage": "Importer une image",
|
||||
@@ -659,7 +697,14 @@
|
||||
"iterations_many": "Itérations",
|
||||
"iterations_other": "Itérations",
|
||||
"back": "fin",
|
||||
"batchSize": "Taille de lot"
|
||||
"batchSize": "Taille de lot",
|
||||
"retryFailed": "Problème de nouvelle tentative de l'élément",
|
||||
"retrySucceeded": "Élément Retenté",
|
||||
"retryItem": "Réessayer l'élement",
|
||||
"cancelAllExceptCurrentQueueItemAlertDialog": "Annuler tous les éléments de la file d'attente, sauf celui en cours, arrêtera les éléments en attente mais permettra à celui en cours de se terminer.",
|
||||
"cancelAllExceptCurrentQueueItemAlertDialog2": "Êtes-vous sûr de vouloir annuler tous les éléments en attente dans la file d'attente ?",
|
||||
"cancelAllExceptCurrentTooltip": "Annuler tout sauf l'élément actuel",
|
||||
"confirm": "Confirmer"
|
||||
},
|
||||
"prompt": {
|
||||
"noMatchingTriggers": "Pas de déclancheurs correspondants",
|
||||
@@ -1031,7 +1076,9 @@
|
||||
"controlNetWeight": {
|
||||
"heading": "Poids",
|
||||
"paragraphs": [
|
||||
"Poids du Control Adapter. Un poids plus élevé aura un impact plus important sur l'image finale."
|
||||
"Poids du Control Adapter. Un poids plus élevé aura un impact plus important sur l'image finale.",
|
||||
"• Poids plus élevé (.75-2) : Crée un impact plus significatif sur le résultat final.",
|
||||
"• Poids inférieur (0-.75) : Crée un impact plus faible sur le résultat final."
|
||||
]
|
||||
},
|
||||
"compositingMaskAdjustments": {
|
||||
@@ -1076,8 +1123,9 @@
|
||||
"controlNetBeginEnd": {
|
||||
"heading": "Pourcentage de début / de fin d'étape",
|
||||
"paragraphs": [
|
||||
"La partie du processus de débruitage à laquelle le Control Adapter sera appliqué.",
|
||||
"En général, les Control Adapter appliqués au début du processus guident la composition, tandis que les Control Adapter appliqués à la fin guident les détails."
|
||||
"Ce paramètre détérmine quelle portion du processus de débruitage (génération) utilisera cette couche comme guide.",
|
||||
"En général, les Control Adapter appliqués au début du processus guident la composition, tandis que les Control Adapter appliqués à la fin guident les détails.",
|
||||
"• Étape de fin (%): Spécifie quand arrêter d'appliquer le guide de cette couche et revenir aux guides généraux du modèle et aux autres paramètres."
|
||||
]
|
||||
},
|
||||
"controlNetControlMode": {
|
||||
@@ -1442,7 +1490,8 @@
|
||||
"showDynamicPrompts": "Afficher les Prompts dynamiques",
|
||||
"dynamicPrompts": "Prompts Dynamiques",
|
||||
"promptsPreview": "Prévisualisation des Prompts",
|
||||
"loading": "Génération des Pompts Dynamiques..."
|
||||
"loading": "Génération des Pompts Dynamiques...",
|
||||
"promptsToGenerate": "Prompts à générer"
|
||||
},
|
||||
"metadata": {
|
||||
"positivePrompt": "Prompt Positif",
|
||||
@@ -1653,7 +1702,22 @@
|
||||
"internalDesc": "Cette invocation est utilisée internalement par Invoke. En fonction des mises à jours il est possible que des changements y soit effectués ou qu'elle soit supprimé sans prévention.",
|
||||
"splitOn": "Diviser sur",
|
||||
"generatorNoValues": "vide",
|
||||
"addItem": "Ajouter un élément"
|
||||
"addItem": "Ajouter un élément",
|
||||
"specialDesc": "Cette invocation nécessite un traitement spécial dans l'application. Par exemple, les nœuds Batch sont utilisés pour mettre en file d'attente plusieurs graphes à partir d'un seul workflow.",
|
||||
"unableToUpdateNode": "La mise à jour du nœud a échoué : nœud {{node}} de type {{type}} (peut nécessiter la suppression et la recréation).",
|
||||
"deletedMissingNodeFieldFormElement": "Champ de formulaire manquant supprimé : nœud {{nodeId}} champ {{fieldName}}",
|
||||
"nodeName": "Nom du nœud",
|
||||
"description": "Description",
|
||||
"loadWorkflowDesc": "Charger le workflow ?",
|
||||
"missingSourceOrTargetNode": "Nœud source ou cible manquant",
|
||||
"generatorImagesCategory": "Catégorie",
|
||||
"generatorImagesFromBoard": "Images de la Planche",
|
||||
"missingSourceOrTargetHandle": "Manque de gestionnaire source ou cible",
|
||||
"loadingTemplates": "Chargement de {{name}}",
|
||||
"loadWorkflowDesc2": "Votre workflow actuel contient des modifications non enregistrées.",
|
||||
"generatorImages_one": "{{count}} image",
|
||||
"generatorImages_many": "{{count}} images",
|
||||
"generatorImages_other": "{{count}} images"
|
||||
},
|
||||
"models": {
|
||||
"noMatchingModels": "Aucun modèle correspondant",
|
||||
@@ -1712,13 +1776,41 @@
|
||||
"deleteWorkflow2": "Êtes-vous sûr de vouloir supprimer ce Workflow ? Cette action ne peut pas être annulé.",
|
||||
"download": "Télécharger",
|
||||
"copyShareLinkForWorkflow": "Copier le lien de partage pour le Workflow",
|
||||
"delete": "Supprimer"
|
||||
"delete": "Supprimer",
|
||||
"builder": {
|
||||
"component": "Composant",
|
||||
"numberInput": "Entrée de nombre",
|
||||
"slider": "Curseur",
|
||||
"both": "Les deux",
|
||||
"singleLine": "Ligne unique",
|
||||
"multiLine": "Multi Ligne",
|
||||
"headingPlaceholder": "En-tête vide",
|
||||
"emptyRootPlaceholderEditMode": "Faites glisser un élément de formulaire ou un champ de nœud ici pour commencer.",
|
||||
"emptyRootPlaceholderViewMode": "Cliquez sur Modifier pour commencer à créer un formulaire pour ce workflow.",
|
||||
"containerPlaceholder": "Conteneur Vide",
|
||||
"row": "Ligne",
|
||||
"column": "Colonne",
|
||||
"layout": "Mise en page",
|
||||
"nodeField": "Champ de nœud",
|
||||
"zoomToNode": "Zoomer sur le nœud",
|
||||
"nodeFieldTooltip": "Pour ajouter un champ de nœud, cliquez sur le petit bouton plus sur le champ dans l'Éditeur de Workflow, ou faites glisser le champ par son nom dans le formulaire.",
|
||||
"addToForm": "Ajouter au formulaire",
|
||||
"label": "Étiquette",
|
||||
"textPlaceholder": "Texte vide",
|
||||
"builder": "Constructeur de Formulaire",
|
||||
"resetAllNodeFields": "Réinitialiser tous les champs de nœud",
|
||||
"deleteAllElements": "Supprimer tous les éléments de formulaire",
|
||||
"workflowBuilderAlphaWarning": "Le constructeur de workflow est actuellement en version alpha. Il peut y avoir des changements majeurs avant la version stable.",
|
||||
"showDescription": "Afficher la description"
|
||||
},
|
||||
"openLibrary": "Ouvrir la Bibliothèque"
|
||||
},
|
||||
"whatsNew": {
|
||||
"whatsNewInInvoke": "Quoi de neuf dans Invoke",
|
||||
"watchRecentReleaseVideos": "Regarder les vidéos des dernières versions",
|
||||
"items": [
|
||||
"<StrongComponent>FLUX Guidage Régional (bêta)</StrongComponent> : Notre version bêta de FLUX Guidage Régional est en ligne pour le contrôle des prompt régionaux."
|
||||
"<StrongComponent>FLUX Guidage Régional (bêta)</StrongComponent> : Notre version bêta de FLUX Guidage Régional est en ligne pour le contrôle des prompt régionaux.",
|
||||
"Autres améliorations : mise en file d'attente par lots plus rapide, meilleur redimensionnement, sélecteur de couleurs amélioré et nœuds de métadonnées."
|
||||
],
|
||||
"readReleaseNotes": "Notes de version",
|
||||
"watchUiUpdatesOverview": "Aperçu des mises à jour de l'interface utilisateur"
|
||||
@@ -1832,7 +1924,49 @@
|
||||
"cancel": "Annuler",
|
||||
"advanced": "Avancé",
|
||||
"processingLayerWith": "Calque de traitement avec le filtre {{type}}.",
|
||||
"forMoreControl": "Pour plus de contrôle, cliquez sur Avancé ci-dessous."
|
||||
"forMoreControl": "Pour plus de contrôle, cliquez sur Avancé ci-dessous.",
|
||||
"adjust_image": {
|
||||
"b": "B (LAB)",
|
||||
"blue": "Bleu (RGBA)",
|
||||
"alpha": "Alpha (RGBA)",
|
||||
"magenta": "Magenta (CMJN)",
|
||||
"yellow": "Jaune (CMJN)",
|
||||
"cb": "Cb (YCbCr)",
|
||||
"cr": "Cr (YCbCr)",
|
||||
"cyan": "Cyan (CMJN)",
|
||||
"label": "Ajuster l'image",
|
||||
"description": "Ajuste le canal sélectionné d'une image.",
|
||||
"channel": "Canal",
|
||||
"value_setting": "Valeur",
|
||||
"scale_values": "Valeurs d'échelle",
|
||||
"red": "Rouge (RGBA)",
|
||||
"green": "Vert (RGBA)",
|
||||
"black": "Noir (CMJN)",
|
||||
"hue": "Teinte (HSV)",
|
||||
"saturation": "Saturation (HSV)",
|
||||
"value": "Valeur (HSV)",
|
||||
"luminosity": "Luminosité (LAB)",
|
||||
"a": "A (LAB)",
|
||||
"y": "Y (YCbCr)"
|
||||
},
|
||||
"img_blur": {
|
||||
"label": "Flou de l'image",
|
||||
"blur_type": "Type de flou",
|
||||
"box_type": "Boîte",
|
||||
"description": "Floute la couche sélectionnée.",
|
||||
"blur_radius": "Rayon",
|
||||
"gaussian_type": "Gaussien"
|
||||
},
|
||||
"img_noise": {
|
||||
"label": "Image de bruit",
|
||||
"description": "Ajoute du bruit à la couche sélectionnée.",
|
||||
"gaussian_type": "Gaussien",
|
||||
"size": "Taille du bruit",
|
||||
"noise_amount": "Quantité",
|
||||
"noise_type": "Type de bruit",
|
||||
"salt_and_pepper_type": "Sel et Poivre",
|
||||
"noise_color": "Bruit coloré"
|
||||
}
|
||||
},
|
||||
"canvasContextMenu": {
|
||||
"saveToGalleryGroup": "Enregistrer dans la galerie",
|
||||
@@ -1846,7 +1980,10 @@
|
||||
"newGlobalReferenceImage": "Nouvelle image de référence globale",
|
||||
"newControlLayer": "Nouveau couche de contrôle",
|
||||
"newInpaintMask": "Nouveau Masque Inpaint",
|
||||
"newRegionalGuidance": "Nouveau Guide Régional"
|
||||
"newRegionalGuidance": "Nouveau Guide Régional",
|
||||
"copyToClipboard": "Copier dans le presse-papiers",
|
||||
"copyBboxToClipboard": "Copier Bbox dans le presse-papiers",
|
||||
"copyCanvasToClipboard": "Copier la Toile dans le presse-papiers"
|
||||
},
|
||||
"bookmark": "Marque-page pour Changement Rapide",
|
||||
"saveLayerToAssets": "Enregistrer la couche dans les ressources",
|
||||
@@ -2012,7 +2149,10 @@
|
||||
"ipAdapterMethod": "Méthode d'IP Adapter",
|
||||
"full": "Complet",
|
||||
"style": "Style uniquement",
|
||||
"composition": "Composition uniquement"
|
||||
"composition": "Composition uniquement",
|
||||
"fullDesc": "Applique le style visuel (couleurs, textures) et la composition (mise en page, structure).",
|
||||
"styleDesc": "Applique un style visuel (couleurs, textures) sans tenir compte de sa mise en page.",
|
||||
"compositionDesc": "Réplique la mise en page et la structure tout en ignorant le style de la référence."
|
||||
},
|
||||
"fitBboxToLayers": "Ajuster la bounding box aux calques",
|
||||
"regionIsEmpty": "La zone sélectionnée est vide",
|
||||
@@ -2095,7 +2235,40 @@
|
||||
"asRasterLayerResize": "En tant que $t(controlLayers.rasterLayer) (Redimensionner)",
|
||||
"asControlLayer": "En tant que $t(controlLayers.controlLayer)",
|
||||
"asControlLayerResize": "En $t(controlLayers.controlLayer) (Redimensionner)",
|
||||
"newSession": "Nouvelle session"
|
||||
"newSession": "Nouvelle session",
|
||||
"warnings": {
|
||||
"controlAdapterIncompatibleBaseModel": "modèle de base de la couche de contrôle incompatible",
|
||||
"controlAdapterNoControl": "aucun contrôle sélectionné/dessiné",
|
||||
"rgNoPromptsOrIPAdapters": "pas de textes d'instructions ni d'images de référence",
|
||||
"rgAutoNegativeNotSupported": "Auto-négatif non pris en charge pour le modèle de base sélectionné",
|
||||
"rgNoRegion": "aucune région dessinée",
|
||||
"ipAdapterNoModelSelected": "aucun modèle d'image de référence sélectionné",
|
||||
"rgReferenceImagesNotSupported": "Les images de référence régionales ne sont pas prises en charge pour le modèle de base sélectionné",
|
||||
"problemsFound": "Problèmes trouvés",
|
||||
"unsupportedModel": "couche non prise en charge pour le modèle de base sélectionné",
|
||||
"rgNegativePromptNotSupported": "Prompt négatif non pris en charge pour le modèle de base sélectionné",
|
||||
"ipAdapterIncompatibleBaseModel": "modèle de base d'image de référence incompatible",
|
||||
"controlAdapterNoModelSelected": "aucun modèle de couche de contrôle sélectionné",
|
||||
"ipAdapterNoImageSelected": "Aucune image de référence sélectionnée."
|
||||
},
|
||||
"pasteTo": "Coller vers",
|
||||
"pasteToAssets": "Ressources",
|
||||
"pasteToAssetsDesc": "Coller dans les ressources",
|
||||
"pasteToBbox": "Bbox",
|
||||
"regionCopiedToClipboard": "{{region}} Copié dans le presse-papiers",
|
||||
"copyRegionError": "Erreur de copie {{region}}",
|
||||
"pasteToCanvas": "Toile",
|
||||
"errors": {
|
||||
"unableToFindImage": "Impossible de trouver l'image",
|
||||
"unableToLoadImage": "Impossible de charger l'image"
|
||||
},
|
||||
"referenceImageRegional": "Image de référence (régionale)",
|
||||
"pasteToBboxDesc": "Nouvelle couche (dans Bbox)",
|
||||
"pasteToCanvasDesc": "Nouvelle couche (dans la Toile)",
|
||||
"useImage": "Utiliser l'image",
|
||||
"pastedTo": "Collé à {{destination}}",
|
||||
"referenceImageEmptyState": "<UploadButton>Séléctionner une image</UploadButton> ou faites glisser une image depuis la <GalleryButton>galerie</GalleryButton> sur cette couche pour commencer.",
|
||||
"referenceImageGlobal": "Image de référence (Globale)"
|
||||
},
|
||||
"upscaling": {
|
||||
"exceedsMaxSizeDetails": "La limite maximale d'agrandissement est de {{maxUpscaleDimension}}x{{maxUpscaleDimension}} pixels. Veuillez essayer une image plus petite ou réduire votre sélection d'échelle.",
|
||||
@@ -2175,7 +2348,8 @@
|
||||
"queue": "File d'attente",
|
||||
"events": "Événements",
|
||||
"metadata": "Métadonnées",
|
||||
"gallery": "Galerie"
|
||||
"gallery": "Galerie",
|
||||
"dnd": "Glisser et déposer"
|
||||
},
|
||||
"logLevel": {
|
||||
"trace": "Trace",
|
||||
@@ -2192,7 +2366,8 @@
|
||||
"toGetStarted": "Pour commencer, saisissez un prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement dans la <StrongComponent>Galerie</StrongComponent> ou de les modifier sur la <StrongComponent>Toile</StrongComponent>.",
|
||||
"gettingStartedSeries": "Vous souhaitez plus de conseils ? Consultez notre <LinkComponent>Série de démarrage</LinkComponent> pour des astuces sur l'exploitation du plein potentiel de l'Invoke Studio.",
|
||||
"noModelsInstalled": "Il semble qu'aucun modèle ne soit installé",
|
||||
"toGetStartedLocal": "Pour commencer, assurez-vous de télécharger ou d'importer des modèles nécessaires pour exécuter Invoke. Ensuite, saisissez le prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement sur <StrongComponent>Galerie</StrongComponent> ou les modifier sur la <StrongComponent>Toile</StrongComponent>."
|
||||
"toGetStartedLocal": "Pour commencer, assurez-vous de télécharger ou d'importer des modèles nécessaires pour exécuter Invoke. Ensuite, saisissez le prompt dans la boîte et cliquez sur <StrongComponent>Invoke</StrongComponent> pour générer votre première image. Sélectionnez un template de prompt pour améliorer les résultats. Vous pouvez choisir de sauvegarder vos images directement sur <StrongComponent>Galerie</StrongComponent> ou les modifier sur la <StrongComponent>Toile</StrongComponent>.",
|
||||
"lowVRAMMode": "Pour de meilleures performances, suivez notre <LinkComponent>guide Low VRAM</LinkComponent>."
|
||||
},
|
||||
"upsell": {
|
||||
"shareAccess": "Partager l'accès",
|
||||
@@ -2240,7 +2415,8 @@
|
||||
"description": "Introduction à l'ajout d'images de référence et IP Adapters globaux."
|
||||
},
|
||||
"howDoIUseInpaintMasks": {
|
||||
"title": "Comment utiliser les masques d'inpainting ?"
|
||||
"title": "Comment utiliser les masques d'inpainting ?",
|
||||
"description": "Comment appliquer des masques de retourche pour la correction et la variation d'image."
|
||||
},
|
||||
"creatingYourFirstImage": {
|
||||
"title": "Créer votre première image",
|
||||
@@ -2260,5 +2436,10 @@
|
||||
"studioSessionsDesc2": "Rejoignez notre <DiscordLink /> pour participer aux sessions en direct et poser vos questions. Les sessions sont ajoutée dans la playlist la semaine suivante.",
|
||||
"supportVideos": "Vidéos d'assistance",
|
||||
"controlCanvas": "Contrôler la toile"
|
||||
},
|
||||
"modelCache": {
|
||||
"clear": "Effacer le cache du modèle",
|
||||
"clearSucceeded": "Cache du modèle effacée",
|
||||
"clearFailed": "Problème de nettoyage du cache du modèle"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,7 +105,11 @@
|
||||
"resetToDefaults": "Ripristina le impostazioni predefinite",
|
||||
"seed": "Seme",
|
||||
"combinatorial": "Combinatorio",
|
||||
"count": "Quantità"
|
||||
"count": "Quantità",
|
||||
"board": "Bacheca",
|
||||
"layout": "Schema",
|
||||
"row": "Riga",
|
||||
"column": "Colonna"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -173,7 +177,8 @@
|
||||
"assetsTab": "File che hai caricato per usarli nei tuoi progetti.",
|
||||
"boardsSettings": "Impostazioni Bacheche",
|
||||
"imagesSettings": "Impostazioni Immagini Galleria",
|
||||
"assets": "Risorse"
|
||||
"assets": "Risorse",
|
||||
"images": "Immagini"
|
||||
},
|
||||
"hotkeys": {
|
||||
"searchHotkeys": "Cerca tasti di scelta rapida",
|
||||
@@ -703,7 +708,8 @@
|
||||
"collectionNumberNotMultipleOf": "{{value}} non è multiplo di {{multipleOf}}",
|
||||
"collectionNumberLTMin": "{{value}} < {{minimum}} (incr min)",
|
||||
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (excl max)",
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)"
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)",
|
||||
"collectionEmpty": "raccolta vuota"
|
||||
},
|
||||
"useCpuNoise": "Usa la CPU per generare rumore",
|
||||
"iterations": "Iterazioni",
|
||||
@@ -920,8 +926,8 @@
|
||||
"boolean": "Booleani",
|
||||
"node": "Nodo",
|
||||
"collection": "Raccolta",
|
||||
"cannotConnectInputToInput": "Impossibile collegare Input a Input",
|
||||
"cannotConnectOutputToOutput": "Impossibile collegare Output ad Output",
|
||||
"cannotConnectInputToInput": "Impossibile collegare ingresso a ingresso",
|
||||
"cannotConnectOutputToOutput": "Impossibile collegare uscita ad uscita",
|
||||
"cannotConnectToSelf": "Impossibile connettersi a se stesso",
|
||||
"mismatchedVersion": "Nodo non valido: il nodo {{node}} di tipo {{type}} ha una versione non corrispondente (provare ad aggiornare?)",
|
||||
"loadingNodes": "Caricamento nodi...",
|
||||
@@ -1014,7 +1020,21 @@
|
||||
"generatorNRandomValues_one": "{{count}} valore casuale",
|
||||
"generatorNRandomValues_many": "{{count}} valori casuali",
|
||||
"generatorNRandomValues_other": "{{count}} valori casuali",
|
||||
"arithmeticSequence": "Sequenza aritmetica"
|
||||
"arithmeticSequence": "Sequenza aritmetica",
|
||||
"nodeName": "Nome del nodo",
|
||||
"loadWorkflowDesc": "Caricare il flusso di lavoro?",
|
||||
"loadWorkflowDesc2": "Il flusso di lavoro corrente presenta modifiche non salvate.",
|
||||
"downloadWorkflowError": "Errore durante lo scaricamento del flusso di lavoro",
|
||||
"deletedMissingNodeFieldFormElement": "Campo modulo mancante eliminato: nodo {{nodeId}} campo {{fieldName}}",
|
||||
"loadingTemplates": "Caricamento {{name}}",
|
||||
"unableToUpdateNode": "Aggiornamento del nodo non riuscito: nodo {{node}} di tipo {{type}} (potrebbe essere necessario eliminarlo e ricrearlo)",
|
||||
"description": "Descrizione",
|
||||
"generatorImagesCategory": "Categoria",
|
||||
"generatorImages_one": "{{count}} immagine",
|
||||
"generatorImages_many": "{{count}} immagini",
|
||||
"generatorImages_other": "{{count}} immagini",
|
||||
"generatorImagesFromBoard": "Immagini dalla Bacheca",
|
||||
"missingSourceOrTargetNode": "Nodo sorgente o di destinazione mancante"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||
@@ -1140,7 +1160,10 @@
|
||||
"cancelAllExceptCurrentQueueItemAlertDialog2": "Vuoi davvero annullare tutti gli elementi in coda in sospeso?",
|
||||
"confirm": "Conferma",
|
||||
"cancelAllExceptCurrentQueueItemAlertDialog": "L'annullamento di tutti gli elementi della coda, eccetto quello corrente, interromperà gli elementi in sospeso ma consentirà il completamento di quello in corso.",
|
||||
"cancelAllExceptCurrentTooltip": "Annulla tutto tranne l'elemento corrente"
|
||||
"cancelAllExceptCurrentTooltip": "Annulla tutto tranne l'elemento corrente",
|
||||
"retrySucceeded": "Elemento rieseguito",
|
||||
"retryItem": "Riesegui elemento",
|
||||
"retryFailed": "Problema riesecuzione elemento"
|
||||
},
|
||||
"models": {
|
||||
"noMatchingModels": "Nessun modello corrispondente",
|
||||
@@ -1718,7 +1741,33 @@
|
||||
"download": "Scarica",
|
||||
"copyShareLink": "Copia Condividi Link",
|
||||
"copyShareLinkForWorkflow": "Copia Condividi Link del Flusso di lavoro",
|
||||
"delete": "Elimina"
|
||||
"delete": "Elimina",
|
||||
"openLibrary": "Apri la libreria",
|
||||
"builder": {
|
||||
"resetAllNodeFields": "Reimposta tutti i campi del nodo",
|
||||
"row": "Riga",
|
||||
"nodeField": "Campo del nodo",
|
||||
"slider": "Cursore",
|
||||
"emptyRootPlaceholderEditMode": "Per iniziare, trascina qui un elemento del modulo o un campo nodo.",
|
||||
"containerPlaceholder": "Contenitore vuoto",
|
||||
"headingPlaceholder": "Titolo vuoto",
|
||||
"column": "Colonna",
|
||||
"nodeFieldTooltip": "Per aggiungere un campo nodo, fare clic sul piccolo pulsante con il segno più sul campo nell'editor del flusso di lavoro, oppure trascinare il campo in base al suo nome nel modulo.",
|
||||
"label": "Etichetta",
|
||||
"deleteAllElements": "Elimina tutti gli elementi del modulo",
|
||||
"addToForm": "Aggiungi al Modulo",
|
||||
"layout": "Schema",
|
||||
"builder": "Generatore Modulo",
|
||||
"zoomToNode": "Zoom sul nodo",
|
||||
"component": "Componente",
|
||||
"showDescription": "Mostra Descrizione",
|
||||
"singleLine": "Linea singola",
|
||||
"multiLine": "Linea Multipla",
|
||||
"both": "Entrambi",
|
||||
"emptyRootPlaceholderViewMode": "Fare clic su Modifica per iniziare a creare un modulo per questo flusso di lavoro.",
|
||||
"textPlaceholder": "Testo vuoto",
|
||||
"workflowBuilderAlphaWarning": "Il generatore di flussi di lavoro è attualmente in versione alpha. Potrebbero esserci cambiamenti radicali prima della versione stabile."
|
||||
}
|
||||
},
|
||||
"accordions": {
|
||||
"compositing": {
|
||||
@@ -2276,12 +2325,8 @@
|
||||
"watchRecentReleaseVideos": "Guarda i video su questa versione",
|
||||
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
|
||||
"items": [
|
||||
"Impostazioni predefinite VRAM migliorate",
|
||||
"Cancellazione della cache del modello su richiesta",
|
||||
"Compatibilità estesa FLUX LoRA",
|
||||
"Filtro Regola Immagine su Tela",
|
||||
"Annulla tutto tranne l'elemento della coda corrente",
|
||||
"Copia da e incolla sulla Tela"
|
||||
"Editor del flusso di lavoro: nuovo generatore di moduli trascina-e-rilascia per una creazione più facile del flusso di lavoro.",
|
||||
"Altri miglioramenti: messa in coda dei lotti più rapida, migliore ampliamento, selettore colore migliorato e nodi metadati."
|
||||
]
|
||||
},
|
||||
"system": {
|
||||
|
||||
@@ -329,7 +329,13 @@
|
||||
"redo": {
|
||||
"title": "やり直し"
|
||||
},
|
||||
"title": "ワークフロー"
|
||||
"title": "ワークフロー",
|
||||
"pasteSelection": {
|
||||
"title": "ペースト"
|
||||
},
|
||||
"copySelection": {
|
||||
"title": "コピー"
|
||||
}
|
||||
},
|
||||
"app": {
|
||||
"toggleLeftPanel": {
|
||||
@@ -390,7 +396,10 @@
|
||||
"desc": "カーソルをポジティブプロンプト欄に移動します。"
|
||||
}
|
||||
},
|
||||
"hotkeys": "ホットキー"
|
||||
"hotkeys": "ホットキー",
|
||||
"gallery": {
|
||||
"title": "ギャラリー"
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
"modelManager": "モデルマネージャ",
|
||||
@@ -452,7 +461,8 @@
|
||||
"loraModels": "LoRA",
|
||||
"edit": "編集",
|
||||
"install": "インストール",
|
||||
"huggingFacePlaceholder": "owner/model-name"
|
||||
"huggingFacePlaceholder": "owner/model-name",
|
||||
"variant": "Variant"
|
||||
},
|
||||
"parameters": {
|
||||
"images": "画像",
|
||||
@@ -507,7 +517,8 @@
|
||||
"resetWebUIDesc2": "もしギャラリーに画像が表示されないなど、何か問題が発生した場合はGitHubにissueを提出する前にリセットを試してください。",
|
||||
"resetComplete": "WebUIはリセットされました。",
|
||||
"ui": "ユーザーインターフェイス",
|
||||
"beta": "ベータ"
|
||||
"beta": "ベータ",
|
||||
"developer": "開発者"
|
||||
},
|
||||
"toast": {
|
||||
"uploadFailed": "アップロード失敗",
|
||||
@@ -556,7 +567,8 @@
|
||||
"negativePrompt": "ネガティブプロンプト",
|
||||
"generationMode": "生成モード",
|
||||
"vae": "VAE",
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)"
|
||||
"cfgRescaleMultiplier": "$t(parameters.cfgRescaleMultiplier)",
|
||||
"canvasV2Metadata": "キャンバス"
|
||||
},
|
||||
"queue": {
|
||||
"queueEmpty": "キューが空です",
|
||||
@@ -690,7 +702,8 @@
|
||||
"notes": "ノート",
|
||||
"workflow": "ワークフロー",
|
||||
"workflowName": "名前",
|
||||
"workflowNotes": "ノート"
|
||||
"workflowNotes": "ノート",
|
||||
"enum": "Enum"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "自動追加するボード",
|
||||
@@ -823,6 +836,15 @@
|
||||
},
|
||||
"lora": {
|
||||
"heading": "LoRA"
|
||||
},
|
||||
"loraWeight": {
|
||||
"heading": "重み"
|
||||
},
|
||||
"patchmatchDownScaleSize": {
|
||||
"heading": "Downscale"
|
||||
},
|
||||
"controlNetWeight": {
|
||||
"heading": "重み"
|
||||
}
|
||||
},
|
||||
"accordions": {
|
||||
@@ -865,7 +887,8 @@
|
||||
"queue": "キュー",
|
||||
"canvas": "キャンバス",
|
||||
"workflows": "ワークフロー",
|
||||
"models": "モデル"
|
||||
"models": "モデル",
|
||||
"gallery": "ギャラリー"
|
||||
}
|
||||
},
|
||||
"controlLayers": {
|
||||
@@ -880,7 +903,8 @@
|
||||
"bboxGroup": "バウンディングボックスから作成",
|
||||
"cropCanvasToBbox": "キャンバスをバウンディングボックスでクロップ",
|
||||
"newGlobalReferenceImage": "新規全域参照画像",
|
||||
"newRegionalReferenceImage": "新規領域参照画像"
|
||||
"newRegionalReferenceImage": "新規領域参照画像",
|
||||
"canvasGroup": "キャンバス"
|
||||
},
|
||||
"regionalGuidance": "領域ガイダンス",
|
||||
"globalReferenceImage": "全域参照画像",
|
||||
@@ -901,7 +925,8 @@
|
||||
"brush": "ブラシ",
|
||||
"rectangle": "矩形",
|
||||
"move": "移動",
|
||||
"eraser": "消しゴム"
|
||||
"eraser": "消しゴム",
|
||||
"bbox": "Bbox"
|
||||
},
|
||||
"saveCanvasToGallery": "キャンバスをギャラリーに保存",
|
||||
"saveBboxToGallery": "バウンディングボックスをギャラリーへ保存",
|
||||
@@ -919,7 +944,27 @@
|
||||
"canvas": "キャンバス",
|
||||
"fitBboxToLayers": "バウンディングボックスをレイヤーにフィット",
|
||||
"removeBookmark": "ブックマークを外す",
|
||||
"savedToGalleryOk": "ギャラリーに保存しました"
|
||||
"savedToGalleryOk": "ギャラリーに保存しました",
|
||||
"controlMode": {
|
||||
"prompt": "プロンプト"
|
||||
},
|
||||
"prompt": "プロンプト",
|
||||
"settings": {
|
||||
"snapToGrid": {
|
||||
"off": "オフ",
|
||||
"on": "オン"
|
||||
}
|
||||
},
|
||||
"filter": {
|
||||
"filter": "フィルター",
|
||||
"spandrel_filter": {
|
||||
"model": "モデル"
|
||||
},
|
||||
"apply": "適用",
|
||||
"reset": "リセット",
|
||||
"cancel": "キャンセル"
|
||||
},
|
||||
"weight": "重み"
|
||||
},
|
||||
"stylePresets": {
|
||||
"clearTemplateSelection": "選択したテンプレートをクリア",
|
||||
@@ -934,7 +979,10 @@
|
||||
"toggleViewMode": "表示モードを切り替え",
|
||||
"negativePromptColumn": "'negative_prompt'",
|
||||
"preview": "プレビュー",
|
||||
"nameColumn": "'name'"
|
||||
"nameColumn": "'name'",
|
||||
"type": "タイプ",
|
||||
"private": "プライベート",
|
||||
"name": "名称"
|
||||
},
|
||||
"upscaling": {
|
||||
"upscaleModel": "アップスケールモデル",
|
||||
@@ -946,7 +994,8 @@
|
||||
"denoisingStrength": "ノイズ除去強度",
|
||||
"scheduler": "スケジューラー",
|
||||
"loading": "ロード中...",
|
||||
"steps": "ステップ"
|
||||
"steps": "ステップ",
|
||||
"refiner": "Refiner"
|
||||
},
|
||||
"modelCache": {
|
||||
"clear": "モデルキャッシュを消去",
|
||||
@@ -958,5 +1007,23 @@
|
||||
"ascending": "昇順",
|
||||
"name": "名前",
|
||||
"descending": "降順"
|
||||
},
|
||||
"system": {
|
||||
"logNamespaces": {
|
||||
"system": "システム",
|
||||
"gallery": "ギャラリー",
|
||||
"workflows": "ワークフロー",
|
||||
"models": "モデル",
|
||||
"canvas": "キャンバス",
|
||||
"metadata": "メタデータ",
|
||||
"queue": "キュー"
|
||||
},
|
||||
"logLevel": {
|
||||
"debug": "Debug",
|
||||
"info": "Info",
|
||||
"error": "Error",
|
||||
"fatal": "Fatal",
|
||||
"warn": "Warn"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,7 +118,8 @@
|
||||
"compareHelp2": "Nhấn <Kbd>M</Kbd> để tuần hoàn trong chế độ so sánh.",
|
||||
"boardsSettings": "Thiết Lập Bảng",
|
||||
"imagesSettings": "Cài Đặt Ảnh Trong Thư Viện Ảnh",
|
||||
"assets": "Tài Nguyên"
|
||||
"assets": "Tài Nguyên",
|
||||
"images": "Hình Ảnh"
|
||||
},
|
||||
"common": {
|
||||
"ipAdapter": "IP Adapter",
|
||||
@@ -233,7 +234,8 @@
|
||||
"combinatorial": "Tổ Hợp",
|
||||
"column": "Cột",
|
||||
"layout": "Bố Cục",
|
||||
"row": "Hàng"
|
||||
"row": "Hàng",
|
||||
"board": "Bảng"
|
||||
},
|
||||
"prompt": {
|
||||
"addPromptTrigger": "Thêm Prompt Trigger",
|
||||
@@ -873,7 +875,7 @@
|
||||
"removeLinearView": "Xoá Khỏi Chế Độ Xem Tuyến Tính",
|
||||
"unknownErrorValidatingWorkflow": "Lỗi không rõ khi xác thực workflow",
|
||||
"unableToLoadWorkflow": "Không Thể Tải Workflow",
|
||||
"workflowSettings": "Cài Đặt Trình Biên Tập Viên Workflow",
|
||||
"workflowSettings": "Cài Đặt Biên Tập Workflow",
|
||||
"workflowVersion": "Phiên Bản",
|
||||
"unableToGetWorkflowVersion": "Không thể tìm phiên bản của lược đồ workflow",
|
||||
"collection": "Đa tài nguyên",
|
||||
@@ -1010,7 +1012,11 @@
|
||||
"loadWorkflowDesc2": "Workflow hiện tại của bạn có những điều chỉnh chưa được lưu.",
|
||||
"loadingTemplates": "Đang Tải {{name}}",
|
||||
"nodeName": "Tên Node",
|
||||
"unableToUpdateNode": "Cập nhật node thất bại: node {{node}} thuộc dạng {{type}} (có thể cần xóa và tạo lại)"
|
||||
"unableToUpdateNode": "Cập nhật node thất bại: node {{node}} thuộc dạng {{type}} (có thể cần xóa và tạo lại)",
|
||||
"downloadWorkflowError": "Lỗi tải xuống workflow",
|
||||
"generatorImagesFromBoard": "Ảnh Từ Bảng",
|
||||
"generatorImagesCategory": "Phân Loại",
|
||||
"generatorImages_other": "{{count}} ảnh"
|
||||
},
|
||||
"popovers": {
|
||||
"paramCFGRescaleMultiplier": {
|
||||
@@ -1494,7 +1500,9 @@
|
||||
"batchNodeCollectionSizeMismatch": "Kích cỡ tài nguyên không phù hợp với Lô {{batchGroupId}}",
|
||||
"emptyBatches": "lô trống",
|
||||
"batchNodeNotConnected": "Node Hàng Loạt chưa được kết nối: {{label}}",
|
||||
"batchNodeEmptyCollection": "Một vài node hàng loạt có tài nguyên rỗng"
|
||||
"batchNodeEmptyCollection": "Một vài node hàng loạt có tài nguyên rỗng",
|
||||
"collectionEmpty": "tài nguyên trống",
|
||||
"batchNodeCollectionSizeMismatchNoGroupId": "tài nguyên theo nhóm có kích thước sai lệch"
|
||||
},
|
||||
"cfgScale": "Thang CFG",
|
||||
"useSeed": "Dùng Hạt Giống",
|
||||
@@ -2247,11 +2255,11 @@
|
||||
"workflowLibrary": "Thư Viện",
|
||||
"opened": "Ngày Mở",
|
||||
"deleteWorkflow": "Xoá Workflow",
|
||||
"workflowEditorMenu": "Menu Biên Tập Viên Workflow",
|
||||
"workflowEditorMenu": "Menu Biên Tập Workflow",
|
||||
"uploadAndSaveWorkflow": "Tải Lên Thư Viện",
|
||||
"openLibrary": "Mở Thư Viện",
|
||||
"builder": {
|
||||
"resetAllNodeFields": "Khởi Động Lại Tất Cả Vùng Cho Node",
|
||||
"resetAllNodeFields": "Tải Lại Các Vùng Cho Node",
|
||||
"builder": "Trình Tạo Vùng Nhập",
|
||||
"layout": "Bố Cục",
|
||||
"row": "Hàng",
|
||||
@@ -2271,10 +2279,10 @@
|
||||
"headingPlaceholder": "Đầu Dòng Trống",
|
||||
"textPlaceholder": "Mô Tả Trống",
|
||||
"column": "Cột",
|
||||
"deleteAllElements": "Xóa Tất Cả Thành Phần Vùng Nhập",
|
||||
"deleteAllElements": "Xóa Tất Cả Thành Phần",
|
||||
"nodeField": "Vùng Cho Node",
|
||||
"nodeFieldTooltip": "Để thêm vùng cho node, bấm vào dấu cộng nhỏ trên vùng trong Vùng Biên Tập Workflow, hoặc kéo vùng theo tên của nó vào vùng nhập.",
|
||||
"workflowBuilderAlphaWarning": "Trình tạo workflow đang trong giai đoạn alpha. Nó có thể xuất hiện những thay đổi đột ngột trước khi chính thức được phát hành."
|
||||
"nodeFieldTooltip": "Để thêm vùng cho node, bấm vào dấu cộng nhỏ trên vùng trong Trình Biên Tập Workflow, hoặc kéo vùng theo tên của nó vào vùng nhập.",
|
||||
"workflowBuilderAlphaWarning": "Trình tạo vùng nhập đang trong giai đoạn alpha. Nó có thể xuất hiện những thay đổi đột ngột trước khi chính thức được phát hành."
|
||||
}
|
||||
},
|
||||
"upscaling": {
|
||||
@@ -2310,12 +2318,8 @@
|
||||
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
|
||||
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
|
||||
"items": [
|
||||
"Cải thiện các thiết lập mặc định của VRAM",
|
||||
"Xoá bộ nhớ đệm của model theo yêu cầu",
|
||||
"Mở rộng khả năng tương thích LoRA trên FLUX",
|
||||
"Bộ lọc điều chỉnh ảnh trên Canvas",
|
||||
"Huỷ tất cả trừ mục đang xếp hàng hiện tại",
|
||||
"Sao chép và dán trên Canvas"
|
||||
"Trình Biên Tập Workflow: trình tạo vùng nhập dưới dạng kéo thả nhằm tạo dựng workflow dễ dàng hơn.",
|
||||
"Các nâng cấp khác: Xếp hàng tạo sinh theo nhóm nhanh hơn, upscale tốt hơn, trình chọn màu được cải thiện, và node chứa metadata."
|
||||
]
|
||||
},
|
||||
"upsell": {
|
||||
|
||||
@@ -3,6 +3,7 @@ import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
@@ -13,7 +14,6 @@ import { toast } from 'features/toast/toast';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
@@ -80,16 +80,15 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, enqueueMutationFixedCacheKeyOptions)
|
||||
);
|
||||
req.reset();
|
||||
|
||||
const enqueueResult = await withResultAsync(() => req.unwrap());
|
||||
|
||||
if (enqueueResult.isErr()) {
|
||||
log.error({ error: serializeError(enqueueResult.error) }, 'Failed to enqueue batch');
|
||||
return;
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
|
||||
log.debug({ batchConfig: prepareBatchResult.value } as JsonObject, 'Enqueued batch');
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
@@ -7,9 +9,12 @@ import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { groupBy } from 'lodash-es';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
@@ -101,6 +106,9 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
@@ -19,6 +24,9 @@ export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening)
|
||||
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
|
||||
17
invokeai/frontend/web/src/common/components/linkify.ts
Normal file
17
invokeai/frontend/web/src/common/components/linkify.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import type { Opts as LinkifyOpts } from 'linkifyjs';
|
||||
|
||||
export const linkifySx: SystemStyleObject = {
|
||||
a: {
|
||||
fontWeight: 'semibold',
|
||||
},
|
||||
'a:hover': {
|
||||
textDecoration: 'underline',
|
||||
},
|
||||
};
|
||||
|
||||
export const linkifyOptions: LinkifyOpts = {
|
||||
target: '_blank',
|
||||
rel: 'noopener noreferrer',
|
||||
validate: (value) => /^https?:\/\//.test(value),
|
||||
};
|
||||
@@ -128,7 +128,11 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
|
||||
getInputProps: getUploadInputProps,
|
||||
open: openUploader,
|
||||
} = useDropzone({
|
||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||
accept: {
|
||||
'image/png': ['.png'],
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||
'image/webp': ['.webp'],
|
||||
},
|
||||
onDropAccepted,
|
||||
onDropRejected,
|
||||
disabled: isDisabled,
|
||||
|
||||
@@ -22,8 +22,8 @@ import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
import type { UploadImageArg } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const ACCEPTED_IMAGE_TYPES = ['image/png', 'image/jpg', 'image/jpeg'];
|
||||
const ACCEPTED_FILE_EXTENSIONS = ['.png', '.jpg', '.jpeg'];
|
||||
const ACCEPTED_IMAGE_TYPES = ['image/png', 'image/jpg', 'image/jpeg', 'image/webp'];
|
||||
const ACCEPTED_FILE_EXTENSIONS = ['.png', '.jpg', '.jpeg', '.webp'];
|
||||
|
||||
// const MAX_IMAGE_SIZE = 4; //In MegaBytes
|
||||
// const sizeInMB = (sizeInBytes: number, decimalsNum = 2) => {
|
||||
|
||||
@@ -72,7 +72,11 @@ const ModelImageUpload = ({ model_key, model_image }: Props) => {
|
||||
}, [model_key, t, deleteModelImage]);
|
||||
|
||||
const { getInputProps, getRootProps } = useDropzone({
|
||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||
accept: {
|
||||
'image/png': ['.png'],
|
||||
'image/jpeg': ['.jpg', '.jpeg', '.png'],
|
||||
'image/webp': ['.webp'],
|
||||
},
|
||||
onDropAccepted,
|
||||
noDrag: true,
|
||||
multiple: false,
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { Flex, IconButton, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import AddNodeButton from 'features/nodes/components/flow/panels/TopPanel/AddNodeButton';
|
||||
import ClearFlowButton from 'features/nodes/components/flow/panels/TopPanel/ClearFlowButton';
|
||||
import SaveWorkflowButton from 'features/nodes/components/flow/panels/TopPanel/SaveWorkflowButton';
|
||||
import UpdateNodesButton from 'features/nodes/components/flow/panels/TopPanel/UpdateNodesButton';
|
||||
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { WorkflowName } from 'features/nodes/components/sidePanel/WorkflowName';
|
||||
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
|
||||
import WorkflowLibraryMenu from 'features/workflowLibrary/components/WorkflowLibraryMenu/WorkflowLibraryMenu';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap="2">
|
||||
@@ -22,7 +27,12 @@ const TopCenterPanel = () => {
|
||||
<Spacer />
|
||||
<ClearFlowButton />
|
||||
<SaveWorkflowButton />
|
||||
<WorkflowLibraryMenu />
|
||||
<IconButton
|
||||
pointerEvents="auto"
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
icon={<PiGearSixFill />}
|
||||
onClick={modal.setTrue}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import { selectWorkflowDescription } from 'features/nodes/store/workflowSlice';
|
||||
import Linkify from 'linkify-react';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const ActiveWorkflowDescription = memo(() => {
|
||||
@@ -11,8 +13,8 @@ export const ActiveWorkflowDescription = memo(() => {
|
||||
}
|
||||
|
||||
return (
|
||||
<Text color="base.300" fontStyle="italic" noOfLines={1} pb={2}>
|
||||
{description}
|
||||
<Text color="base.300" fontStyle="italic" pb={2} sx={linkifySx}>
|
||||
<Linkify options={linkifyOptions}>{description}</Linkify>
|
||||
</Text>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Flex, Spacer } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { NewWorkflowButton } from 'features/nodes/components/sidePanel/NewWorkflowButton';
|
||||
import { WorkflowListMenuTrigger } from 'features/nodes/components/sidePanel/WorkflowListMenu/WorkflowListMenuTrigger';
|
||||
import { WorkflowViewEditToggleButton } from 'features/nodes/components/sidePanel/WorkflowViewEditToggleButton';
|
||||
import { selectWorkflowMode } from 'features/nodes/store/workflowSlice';
|
||||
import { WorkflowLibraryMenu } from 'features/workflowLibrary/components/WorkflowLibraryMenu/WorkflowLibraryMenu';
|
||||
import { memo } from 'react';
|
||||
|
||||
import SaveWorkflowButton from './SaveWorkflowButton';
|
||||
@@ -17,7 +17,7 @@ export const ActiveWorkflowNameAndActions = memo(() => {
|
||||
<Spacer />
|
||||
{mode === 'edit' && <SaveWorkflowButton />}
|
||||
<WorkflowViewEditToggleButton />
|
||||
<NewWorkflowButton />
|
||||
<WorkflowLibraryMenu />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -6,7 +6,7 @@ import dateFormat, { masks } from 'dateformat';
|
||||
import { selectWorkflowId } from 'features/nodes/store/workflowSlice';
|
||||
import { useDeleteWorkflow } from 'features/workflowLibrary/components/DeleteLibraryWorkflowConfirmationAlertDialog';
|
||||
import { useLoadWorkflow } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
|
||||
import { useDownloadWorkflow } from 'features/workflowLibrary/hooks/useDownloadWorkflow';
|
||||
import { useDownloadWorkflowById } from 'features/workflowLibrary/hooks/useDownloadWorkflowById';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -30,7 +30,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
|
||||
}, []);
|
||||
|
||||
const workflowId = useAppSelector(selectWorkflowId);
|
||||
const downloadWorkflow = useDownloadWorkflow();
|
||||
const { downloadWorkflow, isLoading: isLoadingDownloadWorkflow } = useDownloadWorkflowById();
|
||||
const shareWorkflow = useShareWorkflow();
|
||||
const deleteWorkflow = useDeleteWorkflow();
|
||||
const loadWorkflow = useLoadWorkflow();
|
||||
@@ -71,9 +71,9 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
setIsHovered(false);
|
||||
downloadWorkflow();
|
||||
downloadWorkflow(workflow.workflow_id);
|
||||
},
|
||||
[downloadWorkflow]
|
||||
[downloadWorkflow, workflow.workflow_id]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -144,6 +144,7 @@ export const WorkflowListItem = ({ workflow }: { workflow: WorkflowRecordListIte
|
||||
aria-label={t('workflows.download')}
|
||||
onClick={handleClickDownload}
|
||||
icon={<PiDownloadSimpleBold />}
|
||||
isLoading={isLoadingDownloadWorkflow}
|
||||
/>
|
||||
</Tooltip>
|
||||
{!!projectUrl && workflow.workflow_id && workflow.category !== 'user' && (
|
||||
|
||||
@@ -15,6 +15,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { useWorkflowListMenu } from 'features/nodes/store/workflowListMenu';
|
||||
import { selectWorkflowName } from 'features/nodes/store/workflowSlice';
|
||||
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
|
||||
import UploadWorkflowButton from 'features/workflowLibrary/components/UploadWorkflowButton';
|
||||
import { useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -63,6 +64,7 @@ export const WorkflowListMenuTrigger = () => {
|
||||
<WorkflowSearch searchInputRef={searchInputRef} />
|
||||
<WorkflowSortControl />
|
||||
<UploadWorkflowButton />
|
||||
<NewWorkflowButton />
|
||||
</Flex>
|
||||
<Box position="relative" w="full" h="full">
|
||||
<ScrollableContent>
|
||||
|
||||
@@ -50,7 +50,7 @@ const ContainerElement = memo(({ id }: { id: string }) => {
|
||||
ContainerElement.displayName = 'ContainerElementComponent';
|
||||
|
||||
const containerViewModeSx: SystemStyleObject = {
|
||||
gap: 4,
|
||||
gap: 2,
|
||||
'&[data-self-layout="column"]': {
|
||||
flexDir: 'column',
|
||||
alignItems: 'stretch',
|
||||
@@ -197,7 +197,7 @@ const rootViewModeSx: SystemStyleObject = {
|
||||
borderRadius: 'base',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
gap: 4,
|
||||
gap: 2,
|
||||
display: 'flex',
|
||||
flex: 1,
|
||||
maxW: '768px',
|
||||
@@ -232,6 +232,7 @@ RootContainerElementViewMode.displayName = 'RootContainerElementViewMode';
|
||||
|
||||
const rootEditModeSx: SystemStyleObject = {
|
||||
...rootViewModeSx,
|
||||
gap: 4,
|
||||
'&[data-is-dragging-over="true"]': {
|
||||
opacity: 1,
|
||||
bg: 'base.850',
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import type { HeadingProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import Linkify from 'linkify-react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -9,13 +11,14 @@ const headingSx: SystemStyleObject = {
|
||||
'&[data-is-empty="true"]': {
|
||||
opacity: 0.3,
|
||||
},
|
||||
...linkifySx,
|
||||
};
|
||||
|
||||
export const HeadingElementContent = memo(({ content, ...rest }: { content: string } & HeadingProps) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Text sx={headingSx} data-is-empty={content === ''} {...rest}>
|
||||
{content || t('workflows.builder.headingPlaceholder')}
|
||||
<Linkify options={linkifyOptions}>{content || t('workflows.builder.headingPlaceholder')}</Linkify>
|
||||
</Text>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { FormHelperText, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { useInputFieldDescription } from 'features/nodes/hooks/useInputFieldDescription';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
|
||||
export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
@@ -36,7 +38,11 @@ export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeField
|
||||
});
|
||||
|
||||
if (!editable.isEditing) {
|
||||
return <FormHelperText onDoubleClick={editable.startEditing}>{editable.value}</FormHelperText>;
|
||||
return (
|
||||
<FormHelperText onDoubleClick={editable.startEditing} sx={linkifySx}>
|
||||
<Linkify options={linkifyOptions}>{editable.value}</Linkify>
|
||||
</FormHelperText>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, FormControl, FormHelperText } from '@invoke-ai/ui-library';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
|
||||
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
|
||||
import { NodeFieldElementLabel } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabel';
|
||||
@@ -7,6 +8,7 @@ import { useInputFieldDescription } from 'features/nodes/hooks/useInputFieldDesc
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import type { NodeFieldElement } from 'features/nodes/types/workflow';
|
||||
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
|
||||
import Linkify from 'linkify-react';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
@@ -18,6 +20,9 @@ const sx: SystemStyleObject = {
|
||||
flex: '1 1 0',
|
||||
minW: 32,
|
||||
},
|
||||
'&[data-with-description="false"]': {
|
||||
pb: 2,
|
||||
},
|
||||
};
|
||||
|
||||
export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement }) => {
|
||||
@@ -33,7 +38,13 @@ export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement })
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex id={id} className={NODE_FIELD_CLASS_NAME} sx={sx} data-parent-layout={containerCtx.layout}>
|
||||
<Flex
|
||||
id={id}
|
||||
className={NODE_FIELD_CLASS_NAME}
|
||||
sx={sx}
|
||||
data-parent-layout={containerCtx.layout}
|
||||
data-with-description={showDescription && !!_description}
|
||||
>
|
||||
<FormControl flex="1 1 0" orientation="vertical">
|
||||
<NodeFieldElementLabel el={el} />
|
||||
<Flex w="full" gap={4}>
|
||||
@@ -43,7 +54,11 @@ export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement })
|
||||
settings={data.settings}
|
||||
/>
|
||||
</Flex>
|
||||
{showDescription && _description && <FormHelperText>{_description}</FormHelperText>}
|
||||
{showDescription && _description && (
|
||||
<FormHelperText sx={linkifySx}>
|
||||
<Linkify options={linkifyOptions}>{_description}</Linkify>
|
||||
</FormHelperText>
|
||||
)}
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import type { SystemStyleObject, TextProps } from '@invoke-ai/ui-library';
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { linkifyOptions, linkifySx } from 'common/components/linkify';
|
||||
import Linkify from 'linkify-react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
@@ -9,13 +11,14 @@ const textSx: SystemStyleObject = {
|
||||
'&[data-is-empty="true"]': {
|
||||
opacity: 0.3,
|
||||
},
|
||||
...linkifySx,
|
||||
};
|
||||
|
||||
export const TextElementContent = memo(({ content, ...rest }: { content: string } & TextProps) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Text sx={textSx} data-is-empty={content === ''} {...rest}>
|
||||
{content || t('workflows.builder.textPlaceholder')}
|
||||
<Linkify options={linkifyOptions}>{content || t('workflows.builder.textPlaceholder')}</Linkify>
|
||||
</Text>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -14,8 +14,7 @@ import { $hasTemplates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectIsFormEmpty } from 'features/nodes/store/workflowSlice';
|
||||
import type { FormElement } from 'features/nodes/types/workflow';
|
||||
import { buildContainer, buildDivider, buildHeading, buildText } from 'features/nodes/types/workflow';
|
||||
import { startCase } from 'lodash-es';
|
||||
import type { RefObject } from 'react';
|
||||
import type { PropsWithChildren, RefObject } from 'react';
|
||||
import { memo, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
|
||||
@@ -39,10 +38,10 @@ export const WorkflowBuilder = memo(() => {
|
||||
<Flex justifyContent="center" w="full" h="full">
|
||||
<Flex flexDir="column" w="full" maxW="768px" gap={2}>
|
||||
<Flex w="full" alignItems="center" gap={2} pt={3}>
|
||||
<AddFormElementDndButton type="container" />
|
||||
<AddFormElementDndButton type="divider" />
|
||||
<AddFormElementDndButton type="heading" />
|
||||
<AddFormElementDndButton type="text" />
|
||||
<AddFormElementDndButton type="container">{t('workflows.builder.container')}</AddFormElementDndButton>
|
||||
<AddFormElementDndButton type="divider">{t('workflows.builder.divider')}</AddFormElementDndButton>
|
||||
<AddFormElementDndButton type="heading">{t('workflows.builder.heading')}</AddFormElementDndButton>
|
||||
<AddFormElementDndButton type="text">{t('workflows.builder.text')}</AddFormElementDndButton>
|
||||
<Button size="sm" variant="ghost" tooltip={t('workflows.builder.nodeFieldTooltip')}>
|
||||
{t('workflows.builder.nodeField')}
|
||||
</Button>
|
||||
@@ -130,14 +129,17 @@ const addFormElementButtonSx: SystemStyleObject = {
|
||||
_disabled: { borderStyle: 'dashed', opacity: 0.5 },
|
||||
};
|
||||
|
||||
const AddFormElementDndButton = ({ type }: { type: Parameters<typeof useAddFormElementDnd>[0] }) => {
|
||||
const AddFormElementDndButton = ({
|
||||
type,
|
||||
children,
|
||||
}: PropsWithChildren<{ type: Parameters<typeof useAddFormElementDnd>[0] }>) => {
|
||||
const draggableRef = useRef<HTMLDivElement>(null);
|
||||
const isDragging = useAddFormElementDnd(type, draggableRef);
|
||||
|
||||
return (
|
||||
// Must be as div for draggable to work correctly
|
||||
<Button as="div" ref={draggableRef} size="sm" isDisabled={isDragging} variant="outline" sx={addFormElementButtonSx}>
|
||||
{startCase(type)}
|
||||
{children}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -331,14 +331,25 @@ const buildInstanceTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
|
||||
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds a type guard for a specific field input template type.
|
||||
*
|
||||
* The output type guards are primarily used for determining which input component to render for fields in the
|
||||
* <InputFieldRenderer/> component.
|
||||
*
|
||||
* @param name The name of the field type.
|
||||
* @param cardinalities The allowed cardinalities for the field type. If omitted, all cardinalities are allowed.
|
||||
*
|
||||
* @returns A type guard for the specified field type.
|
||||
*/
|
||||
const buildTemplateTypeGuard =
|
||||
<T extends FieldInputTemplate>(name: string, cardinality?: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION') =>
|
||||
<T extends FieldInputTemplate>(name: string, cardinalities?: FieldType['cardinality'][]) =>
|
||||
(template: FieldInputTemplate): template is T => {
|
||||
if (template.type.name !== name) {
|
||||
return false;
|
||||
}
|
||||
if (cardinality) {
|
||||
return template.type.cardinality === cardinality;
|
||||
if (cardinalities) {
|
||||
return cardinalities.includes(template.type.cardinality);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
@@ -366,7 +377,10 @@ export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
||||
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
||||
export type IntegerFieldInputTemplate = z.infer<typeof zIntegerFieldInputTemplate>;
|
||||
export const isIntegerFieldInputInstance = buildInstanceTypeGuard(zIntegerFieldInputInstance);
|
||||
export const isIntegerFieldInputTemplate = buildTemplateTypeGuard<IntegerFieldInputTemplate>('IntegerField', 'SINGLE');
|
||||
export const isIntegerFieldInputTemplate = buildTemplateTypeGuard<IntegerFieldInputTemplate>('IntegerField', [
|
||||
'SINGLE',
|
||||
'SINGLE_OR_COLLECTION',
|
||||
]);
|
||||
// #endregion
|
||||
|
||||
// #region IntegerField Collection
|
||||
@@ -406,7 +420,7 @@ export type IntegerFieldCollectionInputTemplate = z.infer<typeof zIntegerFieldCo
|
||||
export const isIntegerFieldCollectionInputInstance = buildInstanceTypeGuard(zIntegerFieldCollectionInputInstance);
|
||||
export const isIntegerFieldCollectionInputTemplate = buildTemplateTypeGuard<IntegerFieldCollectionInputTemplate>(
|
||||
'IntegerField',
|
||||
'COLLECTION'
|
||||
['COLLECTION']
|
||||
);
|
||||
// #endregion
|
||||
|
||||
@@ -432,7 +446,10 @@ export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
|
||||
export const isFloatFieldInputInstance = buildInstanceTypeGuard(zFloatFieldInputInstance);
|
||||
export const isFloatFieldInputTemplate = buildTemplateTypeGuard<FloatFieldInputTemplate>('FloatField', 'SINGLE');
|
||||
export const isFloatFieldInputTemplate = buildTemplateTypeGuard<FloatFieldInputTemplate>('FloatField', [
|
||||
'SINGLE',
|
||||
'SINGLE_OR_COLLECTION',
|
||||
]);
|
||||
// #endregion
|
||||
|
||||
// #region FloatField Collection
|
||||
@@ -471,7 +488,7 @@ export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollec
|
||||
export const isFloatFieldCollectionInputInstance = buildInstanceTypeGuard(zFloatFieldCollectionInputInstance);
|
||||
export const isFloatFieldCollectionInputTemplate = buildTemplateTypeGuard<FloatFieldCollectionInputTemplate>(
|
||||
'FloatField',
|
||||
'COLLECTION'
|
||||
['COLLECTION']
|
||||
);
|
||||
// #endregion
|
||||
|
||||
@@ -504,7 +521,10 @@ export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
|
||||
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
|
||||
export const isStringFieldInputInstance = buildInstanceTypeGuard(zStringFieldInputInstance);
|
||||
export const isStringFieldInputTemplate = buildTemplateTypeGuard<StringFieldInputTemplate>('StringField', 'SINGLE');
|
||||
export const isStringFieldInputTemplate = buildTemplateTypeGuard<StringFieldInputTemplate>('StringField', [
|
||||
'SINGLE',
|
||||
'SINGLE_OR_COLLECTION',
|
||||
]);
|
||||
// #endregion
|
||||
|
||||
// #region StringField Collection
|
||||
@@ -550,7 +570,7 @@ export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldColl
|
||||
export const isStringFieldCollectionInputInstance = buildInstanceTypeGuard(zStringFieldCollectionInputInstance);
|
||||
export const isStringFieldCollectionInputTemplate = buildTemplateTypeGuard<StringFieldCollectionInputTemplate>(
|
||||
'StringField',
|
||||
'COLLECTION'
|
||||
['COLLECTION']
|
||||
);
|
||||
// #endregion
|
||||
|
||||
@@ -613,7 +633,10 @@ export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
|
||||
export const isImageFieldInputInstance = buildInstanceTypeGuard(zImageFieldInputInstance);
|
||||
export const isImageFieldInputTemplate = buildTemplateTypeGuard<ImageFieldInputTemplate>('ImageField', 'SINGLE');
|
||||
export const isImageFieldInputTemplate = buildTemplateTypeGuard<ImageFieldInputTemplate>('ImageField', [
|
||||
'SINGLE',
|
||||
'SINGLE_OR_COLLECTION',
|
||||
]);
|
||||
// #endregion
|
||||
|
||||
// #region ImageField Collection
|
||||
@@ -648,7 +671,7 @@ export type ImageFieldCollectionInputTemplate = z.infer<typeof zImageFieldCollec
|
||||
export const isImageFieldCollectionInputInstance = buildInstanceTypeGuard(zImageFieldCollectionInputInstance);
|
||||
export const isImageFieldCollectionInputTemplate = buildTemplateTypeGuard<ImageFieldCollectionInputTemplate>(
|
||||
'ImageField',
|
||||
'COLLECTION'
|
||||
['COLLECTION']
|
||||
);
|
||||
// #endregion
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ export const NewWorkflowButton = memo(() => {
|
||||
<IconButton
|
||||
onClick={onClickNewWorkflow}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
aria-label={t('nodes.newWorkflow')}
|
||||
tooltip={t('nodes.newWorkflow')}
|
||||
icon={<PiFilePlusBold />}
|
||||
@@ -1,12 +1,12 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useDownloadWorkflow } from 'features/workflowLibrary/hooks/useDownloadWorkflow';
|
||||
import { useDownloadCurrentlyLoadedWorkflow } from 'features/workflowLibrary/hooks/useDownloadCurrentlyLoadedWorkflow';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiDownloadSimpleBold } from 'react-icons/pi';
|
||||
|
||||
const DownloadWorkflowMenuItem = () => {
|
||||
const { t } = useTranslation();
|
||||
const downloadWorkflow = useDownloadWorkflow();
|
||||
const downloadWorkflow = useDownloadCurrentlyLoadedWorkflow();
|
||||
|
||||
return (
|
||||
<MenuItem as="button" icon={<PiDownloadSimpleBold />} onClick={downloadWorkflow}>
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useWorkflowEditorSettingsModal } from 'features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
const DownloadWorkflowMenuItem = () => {
|
||||
const { t } = useTranslation();
|
||||
const modal = useWorkflowEditorSettingsModal();
|
||||
|
||||
return (
|
||||
<MenuItem as="button" icon={<PiGearSixFill />} onClick={modal.setTrue}>
|
||||
{t('nodes.workflowSettings')}
|
||||
</MenuItem>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(DownloadWorkflowMenuItem);
|
||||
@@ -13,13 +13,12 @@ import LoadWorkflowFromGraphMenuItem from 'features/workflowLibrary/components/W
|
||||
import { NewWorkflowMenuItem } from 'features/workflowLibrary/components/WorkflowLibraryMenu/NewWorkflowMenuItem';
|
||||
import SaveWorkflowAsMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowAsMenuItem';
|
||||
import SaveWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem';
|
||||
import SettingsMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SettingsMenuItem';
|
||||
import UploadWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/UploadWorkflowMenuItem';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiDotsThreeOutlineFill } from 'react-icons/pi';
|
||||
|
||||
const WorkflowLibraryMenu = () => {
|
||||
export const WorkflowLibraryMenu = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
const shift = useShiftModifier();
|
||||
@@ -31,6 +30,8 @@ const WorkflowLibraryMenu = () => {
|
||||
aria-label={t('workflows.workflowEditorMenu')}
|
||||
icon={<PiDotsThreeOutlineFill />}
|
||||
pointerEvents="auto"
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
/>
|
||||
<MenuList pointerEvents="auto">
|
||||
<NewWorkflowMenuItem />
|
||||
@@ -39,13 +40,10 @@ const WorkflowLibraryMenu = () => {
|
||||
<SaveWorkflowMenuItem />
|
||||
<SaveWorkflowAsMenuItem />
|
||||
<DownloadWorkflowMenuItem />
|
||||
<MenuDivider />
|
||||
<SettingsMenuItem />
|
||||
{shift && <MenuDivider />}
|
||||
{shift && <LoadWorkflowFromGraphMenuItem />}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(WorkflowLibraryMenu);
|
||||
});
|
||||
WorkflowLibraryMenu.displayName = 'WorkflowLibraryMenu';
|
||||
|
||||
@@ -3,7 +3,7 @@ import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import { workflowDownloaded } from 'features/workflowLibrary/store/actions';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
export const useDownloadWorkflow = () => {
|
||||
export const useDownloadCurrentlyLoadedWorkflow = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const downloadWorkflow = useCallback(() => {
|
||||
@@ -0,0 +1,42 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { workflowDownloaded } from 'features/workflowLibrary/store/actions';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyGetWorkflowQuery } from 'services/api/endpoints/workflows';
|
||||
|
||||
export const useDownloadWorkflowById = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const [trigger, query] = useLazyGetWorkflowQuery();
|
||||
|
||||
const toastError = useCallback(() => {
|
||||
toast({ status: 'error', description: t('nodes.downloadWorkflowError') });
|
||||
}, [t]);
|
||||
|
||||
const downloadWorkflow = useCallback(
|
||||
async (workflowId: string) => {
|
||||
try {
|
||||
const { data } = await trigger(workflowId);
|
||||
if (!data) {
|
||||
toastError();
|
||||
return;
|
||||
}
|
||||
const { workflow } = data;
|
||||
const blob = new Blob([JSON.stringify(workflow, null, 2)]);
|
||||
const a = document.createElement('a');
|
||||
a.href = URL.createObjectURL(blob);
|
||||
a.download = `${workflow.name || 'My Workflow'}.json`;
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
a.remove();
|
||||
dispatch(workflowDownloaded());
|
||||
} catch {
|
||||
toastError();
|
||||
}
|
||||
},
|
||||
[dispatch, toastError, trigger]
|
||||
);
|
||||
|
||||
return { downloadWorkflow, isLoading: query.isLoading };
|
||||
};
|
||||
@@ -103,6 +103,7 @@ export const api = customCreateApi({
|
||||
reducerPath: 'api',
|
||||
tagTypes,
|
||||
endpoints: () => ({}),
|
||||
invalidationBehavior: 'immediately',
|
||||
});
|
||||
|
||||
function getCircularReplacer() {
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "5.7.0"
|
||||
__version__ = "5.7.2rc2"
|
||||
|
||||
13
tests/app/util/test_torch_cuda_allocator.py
Normal file
13
tests/app/util/test_torch_cuda_allocator.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.util.torch_cuda_allocator import configure_torch_cuda_allocator
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA device.")
|
||||
def test_configure_torch_cuda_allocator_raises_if_torch_is_already_imported():
|
||||
"""Test that configure_torch_cuda_allocator() raises a RuntimeError if torch is already imported."""
|
||||
import torch # noqa: F401
|
||||
|
||||
with pytest.raises(RuntimeError, match="Failed to configure the PyTorch CUDA memory allocator."):
|
||||
configure_torch_cuda_allocator("backend:cudaMallocAsync")
|
||||
Reference in New Issue
Block a user