Compare commits

..

89 Commits

Author SHA1 Message Date
psychedelicious
8a3848e7b6 chore(ui): update whats new copy 2025-05-22 14:25:02 +10:00
psychedelicious
3f8486b480 chore: bump version to v5.12.0 2025-05-22 14:25:02 +10:00
Hosted Weblate
b80be4f639 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-05-22 14:11:52 +10:00
Linos
adb3a849b9 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1910 of 1910 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-05-22 14:11:52 +10:00
Riccardo Giovanetti
798499fda6 translationBot(ui): update translation (Italian)
Currently translated at 98.9% (1889 of 1910 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.9% (1889 of 1910 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-05-22 14:11:52 +10:00
psychedelicious
02fc5a165c chore(ui): typegen 2025-05-22 13:50:15 +10:00
psychedelicious
b1b8edecfb fix(ui): minor ts issue 2025-05-22 13:50:15 +10:00
Mary Hipp
3cd8d48809 lint 2025-05-22 13:50:15 +10:00
Mary Hipp
f4672ad8c1 more cleanup 2025-05-22 13:50:15 +10:00
Mary Hipp
5a86490845 cleanup and refactor into hooks 2025-05-22 13:50:15 +10:00
Mary Hipp
27dc843046 Imagen4 working in UI 2025-05-22 13:50:15 +10:00
Mary Hipp
2f35d74902 backend updates 2025-05-22 13:50:15 +10:00
Kevin Turner
8bd52ed744 fix: improve gguf performance with torch.compile
pytorch 2.7 does not implement `set.__contains__`, so make this a list instead.

See https://github.com/pytorch/pytorch/issues/145761
2025-05-22 13:42:09 +10:00
psychedelicious
f3e2a3c384 gh: update CODEOWNERS
- Remove brandon
- Consolidate two entries for `invokeai/backend`
2025-05-22 13:37:24 +10:00
psychedelicious
ecc6e8a532 fix(nodes): transformers bug with SAM
Upstream bug in `transformers` breaks use of `AutoModelForMaskGeneration` class to load SAM models

Simple fix - directly load the model with `SamModel` class instead.

See upstream issue https://github.com/huggingface/transformers/issues/38228
2025-05-22 11:32:37 +10:00
Mary Hipp
9170576a38 make logic more straight forward 2025-05-21 10:52:04 -04:00
Mary Hipp
f26baa0341 use hook instead 2025-05-21 10:52:04 -04:00
psychedelicious
99dad953a4 chore: bump version to v5.12.0rc2 2025-05-20 14:50:03 +10:00
jazzhaiku
c39bcdffd3 Re-enable classification API as fallback (#8007)
## Summary

- Fallback to new classification API if legacy probe fails
- Method to read model metadata
- Created `StrippedModelOnDisk` class for testing
- Test to verify only a single config `matches` with a model

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-05-20 11:25:38 +10:00
Billy
32f2223237 Warning comment 2025-05-20 11:19:59 +10:00
Billy
6176941853 Warning comment 2025-05-20 11:19:59 +10:00
Billy
af41dc83f7 Make ruff happy 2025-05-20 11:19:59 +10:00
Billy
a17e771eba Re-enable classification API as fallback 2025-05-20 11:19:59 +10:00
psychedelicious
19ecdb196e chore: ruff 2025-05-20 10:47:02 +10:00
psychedelicious
15880e6ea7 fix(ui): invocation parsing for optional enum fields
For example:
```py
my_field: Literal["foo", "bar"] | None = InputField(default=None)
```

Previously, this would cause a field parsing error and prevent the app from loading.

Two fixes:
- This type annotation and resultant schema are now parsed correctly
- Error handling added to template building logic to prevent the hang at startup when an error does occur
2025-05-20 10:47:02 +10:00
psychedelicious
53ffa98662 chore(ui): typegen 2025-05-20 10:47:02 +10:00
psychedelicious
021a334240 fix(nodes): fix spots where default of None was provided for non-optional fields 2025-05-20 10:47:02 +10:00
psychedelicious
cfed293d48 fix(nodes): do not make invocation field defaults None when they are not provided 2025-05-20 10:47:02 +10:00
Mary Hipp
d36bc185c8 only use client side uploads if more than one image to retain metadata for single uploads 2025-05-20 08:03:00 +10:00
psychedelicious
7878203b03 chore(ui): update whats new copy 2025-05-19 23:28:40 +10:00
psychedelicious
3352220d39 chore: bump version to v5.12.0rc1 2025-05-19 23:28:40 +10:00
Riccardo Giovanetti
bcfb1e7e52 translationBot(ui): update translation (Italian)
Currently translated at 98.7% (1887 of 1910 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-05-19 23:23:07 +10:00
psychedelicious
e84b3c142c chore(ui): typegen 2025-05-19 13:50:04 +10:00
Kent Keirsey
22f637b647 ruff ruff 2025-05-19 13:50:04 +10:00
Kent Keirsey
5d192ab6e5 Fix SD precise in patcher. 2025-05-19 13:50:04 +10:00
Kent Keirsey
9273d1629e UX Copy Clean-up 2025-05-19 13:50:04 +10:00
Kent Keirsey
27a12f080b missing translation values 2025-05-19 13:50:04 +10:00
Kent Keirsey
3bfb497764 ruff fixes 2025-05-19 13:50:04 +10:00
Kent Keirsey
b849c7d382 ruff fix 2025-05-19 13:50:04 +10:00
Kent Keirsey
8d4120583d update schema pt 2 2025-05-19 13:50:04 +10:00
Kent Keirsey
402cdc7eda update schema 2025-05-19 13:50:04 +10:00
Kent Keirsey
b02ea1a898 Expanded styles & updated UI 2025-05-19 13:50:04 +10:00
Kent Keirsey
d709040f4b Matt3o base changes 2025-05-19 13:50:04 +10:00
psychedelicious
8a7a498da3 chore: update uv lock 2025-05-19 12:29:51 +10:00
psychedelicious
699736486b chore: bump torch to 2.7.0
- Update `pyproject.toml`
- Update `pins.json` so launcher installs latest CUDA 12.8 & ROCm 6.3
2025-05-19 12:29:51 +10:00
psychedelicious
37e790ae19 fix(app): address pydantic deprecation warning for accessing BaseModel.model_fields 2025-05-19 12:22:59 +10:00
David Burnett
6c0bd7d150 fix import ordering, remove code I reverted that the resync added back 2025-05-19 11:16:23 +10:00
David Burnett
99e154d773 fix picky ruff issue 2025-05-19 11:16:23 +10:00
David Burnett
e4e43ae126 fix missing bracket 2025-05-19 11:16:23 +10:00
David Burnett
a07fac6180 raise exected exception when attempting to change dtype 2025-05-19 11:16:23 +10:00
David Burnett
93d4b00082 Add to overload for GGMLTensor, so calling to on the model moves the quantized data as well 2025-05-19 11:16:23 +10:00
David Burnett
8abcc99ced add check for state_dict, required to load TI's 2025-05-19 11:16:23 +10:00
David Burnett
73ab4b8895 fix offload device 2025-05-19 11:16:23 +10:00
David Burnett
86719f2065 revert to overload due to failing tests, use Torch futures instead 2025-05-19 11:16:23 +10:00
David Burnett
5271fc1cac fix picky ruff issue 2025-05-19 11:16:23 +10:00
David Burnett
96ff7d9093 fix missing bracket 2025-05-19 11:16:23 +10:00
David Burnett
6f73d9e9c6 raise exected exception when attempting to change dtype 2025-05-19 11:16:23 +10:00
David Burnett
29b406a84b Add to overload for GGMLTensor, so calling to on the model moves the quantized data as well 2025-05-19 11:16:23 +10:00
psychedelicious
2b1e4b88d3 tests: add new service to mocks 2025-05-19 10:29:07 +10:00
psychedelicious
0f0085a776 chore(ui): typegen 2025-05-19 10:29:07 +10:00
psychedelicious
ea28ed8261 chore: ruff 2025-05-19 10:29:07 +10:00
Lucian Hardy
c0e6327d3a chore(ui): Refactor RelatedModels.tsx
Major cleanup of RelatedModels.tsx for improved readability, structure, and maintainability.
Dried out repetitive logic
Consolidated model type sorting into reusable helpers
Added disallowed model type relationships to prevent broken connections (e.g. VAE ↔ LoRA)
- Aware this introduces a new constraint—open to feedback (see PR comment)
Some naming and types may still need refinement; happy to revisit
2025-05-19 10:29:07 +10:00
Lucian Hardy
459491e402 chore(backend): Removed unused model_relationship methods
removed unused AnyModelConfig related methods,
removed unused get_related_model_key_count method.
2025-05-19 10:29:07 +10:00
Lucian Hardy
a4cddfa47d feat(ui): model relationship management
Adds full support for managing model-to-model relationships in the UI and backend.

Introduces RelatedModels subpanel for linking and unlinking models in model management.
 - Adds REST API routes for adding, removing, and retrieving model relationships.
 - New database migration: creates model_relationships table for bidirectional links.
 - New service layer (model_relationships) for relationship management.
 - Updated frontend: Related models float to top of LoRA/Main grouped model comboboxes for quick access.
     - Added 'Show Only Related' toggle badge to MainModelPicker filter bar

**Amended commit to remove changes to ParamMainModelSelect.tsx and MainModelPicker.tsx to avoid conflict with upstream deletion/ rewrite**
2025-05-19 10:29:07 +10:00
jazzhaiku
9a822bcfe8 Jazzhaiku/stats (#8006)
## Summary

- Modify stats reset to be on a per session basis, rather than a "full
reset", to allow for parallel session execution
- Add "aider" to gitignore

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-05-16 07:51:23 +10:00
psychedelicious
5f12b9185f feat(mm): add cache_snapshot to model cache clear callback 2025-05-15 16:06:47 +10:00
psychedelicious
d958d2e5a0 feat(mm): iterate on cache callbacks API 2025-05-15 14:37:22 +10:00
psychedelicious
823ca214e6 feat(mm): iterate on cache callbacks API 2025-05-15 13:28:51 +10:00
psychedelicious
a33da450fd feat(mm): support cache callbacks 2025-05-15 11:23:58 +10:00
Billy
8b5f4d190c Restore Schema 2025-05-15 10:38:01 +10:00
Billy
f1f3b7965a Schema 2025-05-15 10:26:45 +10:00
Billy
987be3507c Merge branch 'main' into jazzhaiku/stats 2025-05-15 10:22:56 +10:00
Billy
1f4090fe0e Reset invocation stats on per session basis 2025-05-15 10:19:05 +10:00
Billy
029e2d2c46 Add aider to gitignore 2025-05-15 10:18:42 +10:00
Riku
7722f479e8 translationBot(ui): update translation (German)
Currently translated at 64.9% (1236 of 1902 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Linos
3ad4072183 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1904 of 1904 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1902 of 1902 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Hosted Weblate
6dfb9a1906 translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
RyoKoba
ad2924350d translationBot(ui): update translation (Japanese)
Currently translated at 67.1% (1279 of 1904 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 64.9% (1231 of 1895 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 60.2% (1141 of 1895 strings)

translationBot(ui): update translation (Japanese)

Currently translated at 56.7% (1075 of 1895 strings)

Co-authored-by: RyoKoba <kobayashi_ryo@cyberagent.co.jp>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/ja/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Linos
3bf51ee0c2 translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1896 of 1896 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1895 of 1895 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1886 of 1886 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Hosted Weblate
fce5051dcc translationBot(ui): update translation files
Updated by "Remove blank strings" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
Riccardo Giovanetti
446d8818b9 translationBot(ui): update translation (Italian)
Currently translated at 98.8% (1883 of 1904 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1882 of 1903 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1881 of 1902 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1878 of 1899 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1874 of 1895 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1873 of 1895 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.8% (1864 of 1886 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-05-14 10:32:24 +10:00
psychedelicious
1566e29c19 feat(nodes): tidy some type annotations in baseinvocation 2025-05-14 06:55:15 +10:00
psychedelicious
6a2e35f2c4 feat(nodes): store original field annotation & FieldInfo in invocations 2025-05-14 06:55:15 +10:00
psychedelicious
b6d58774f4 feat(nodes): improved error messages for invalid defaults 2025-05-14 06:55:15 +10:00
psychedelicious
758f94d3c6 chore(ui): typegen 2025-05-14 06:55:15 +10:00
psychedelicious
9df0871754 fix(nodes): do not provide invalid defaults for batch nodes 2025-05-14 06:55:15 +10:00
psychedelicious
3011150a3a feat(nodes): validate default values for all fields
This prevents issues where the node is defined with an invalid default value, which would guarantee an error during a ser/de roundtrip.

- Upstream issue requesting this functionality be built-in to pydantic: https://github.com/pydantic/pydantic/issues/8722
- Upstream PR that implements the functionality: https://github.com/pydantic/pydantic-core/pull/1593
2025-05-14 06:55:15 +10:00
psychedelicious
05aa1fce71 chore(ui): typegen 2025-05-14 06:55:15 +10:00
psychedelicious
df81f3274a feat(nodes): improved pydantic type annotation massaging
When we do our field type overrides to allow invocations to be instantiated without all required fields, we were not modifying the annotation of the field but did set the default value of the field to `None`.

This results in an error when doing a ser/de round trip. Here's what we end up doing:

```py
from pydantic import BaseModel, Field

class MyModel(BaseModel):
    foo: str = Field(default=None)
```

And here is a simple round-trip, which should not error but which does:

```py
MyModel(**MyModel().model_dump())
# ValidationError: 1 validation error for MyModel
# foo
#   Input should be a valid string [type=string_type, input_value=None, input_type=NoneType]
#     For further information visit https://errors.pydantic.dev/2.11/v/string_type
```

To fix this, we now check every incoming field and update its annotation to match its default value. In other words, when we override the default field value to `None`, we make its type annotation `<original type> | None`.

This prevents the error during deserialization.

This slightly alters the schema for all invocations and outputs - the values of all fields without default values are now typed as `<original type> | None`, reflecting the overrides.

This means the autogenerated types for fields have also changed for fields without defaults:

```ts
// Old
image?: components["schemas"]["ImageField"];

// New
image?: components["schemas"]["ImageField"] | null;
```

This does not break anything on the frontend.
2025-05-14 06:55:15 +10:00
99 changed files with 2849 additions and 584 deletions

5
.github/CODEOWNERS vendored
View File

@@ -6,7 +6,7 @@
/mkdocs.yml @lstein @blessedcoolant @hipsterusername @psychedelicious
# nodes
/invokeai/app/ @blessedcoolant @psychedelicious @brandonrising @hipsterusername @jazzhaiku
/invokeai/app/ @blessedcoolant @psychedelicious @hipsterusername @jazzhaiku
# installation and configuration
/pyproject.toml @lstein @blessedcoolant @hipsterusername
@@ -19,10 +19,9 @@
# web ui
/invokeai/frontend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
/invokeai/backend @blessedcoolant @psychedelicious @lstein @maryhipp @hipsterusername
# generation, model management, postprocessing
/invokeai/backend @lstein @blessedcoolant @brandonrising @hipsterusername @jazzhaiku
/invokeai/backend @lstein @blessedcoolant @hipsterusername @jazzhaiku @psychedelicious @maryhipp
# front ends
/invokeai/frontend/CLI @lstein @hipsterusername

1
.gitignore vendored
View File

@@ -188,3 +188,4 @@ installer/install.sh
installer/update.bat
installer/update.sh
installer/InvokeAI-Installer/
.aider*

View File

@@ -23,6 +23,10 @@ from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import (
SqliteModelRelationshipRecordStorage,
)
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
from invokeai.app.services.names.names_default import SimpleNameService
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
@@ -136,6 +140,8 @@ class ApiDependencies:
download_queue=download_queue_service,
events=events,
)
model_relationships = ModelRelationshipsService()
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
@@ -161,6 +167,8 @@ class ApiDependencies:
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
model_relationships=model_relationships,
model_relationship_records=model_relationship_records,
download_queue=download_queue_service,
names=names,
performance_statistics=performance_statistics,

View File

@@ -0,0 +1,215 @@
"""FastAPI route for model relationship records."""
from typing import List
from fastapi import APIRouter, Body, HTTPException, Path, status
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
model_relationships_router = APIRouter(prefix="/v1/model_relationships", tags=["model_relationships"])
# === Schemas ===
class ModelRelationshipCreateRequest(BaseModel):
model_key_1: str = Field(
...,
description="The key of the first model in the relationship",
examples=[
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
"ac32b914-10ab-496e-a24a-3068724b9c35",
"d944abfd-c7c3-42e2-a4ff-da640b29b8b4",
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
"12345678-90ab-cdef-1234-567890abcdef",
"fedcba98-7654-3210-fedc-ba9876543210",
],
)
model_key_2: str = Field(
...,
description="The key of the second model in the relationship",
examples=[
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
"f0c3da4e-d9ff-42b5-a45c-23be75c887c9",
"38170dd8-f1e5-431e-866c-2c81f1277fcc",
"c57fea2d-7646-424c-b9ad-c0ba60fc68be",
"10f7807b-ab54-46a9-ab03-600e88c630a1",
"f6c1d267-cf87-4ee0-bee0-37e791eacab7",
],
)
class ModelRelationshipBatchRequest(BaseModel):
model_keys: List[str] = Field(
...,
description="List of model keys to fetch related models for",
examples=[
[
"aa3b247f-90c9-4416-bfcd-aeaa57a5339e",
"ac32b914-10ab-496e-a24a-3068724b9c35",
],
[
"b1c2d3e4-f5a6-7890-abcd-ef1234567890",
"12345678-90ab-cdef-1234-567890abcdef",
"fedcba98-7654-3210-fedc-ba9876543210",
],
[
"3bb7c0eb-b6c8-469c-ad8c-4d69c06075e4",
],
],
)
# === Routes ===
@model_relationships_router.get(
"/i/{model_key}",
operation_id="get_related_models",
response_model=list[str],
responses={
200: {
"description": "A list of related model keys was retrieved successfully",
"content": {
"application/json": {
"example": [
"15e9eb28-8cfe-47c9-b610-37907a79fc3c",
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
]
}
},
},
404: {"description": "The specified model could not be found"},
422: {"description": "Validation error"},
},
)
async def get_related_models(
model_key: str = Path(..., description="The key of the model to get relationships for"),
) -> list[str]:
"""
Get a list of model keys related to a given model.
"""
try:
return ApiDependencies.invoker.services.model_relationships.get_related_model_keys(model_key)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.post(
"/",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "The relationship was successfully created"},
400: {"description": "Invalid model keys or self-referential relationship"},
409: {"description": "The relationship already exists"},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Add Model Relationship",
description="Creates a **bidirectional** relationship between two models, allowing each to reference the other as related.",
)
async def add_model_relationship(
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to relate"),
) -> None:
"""
Add a relationship between two models.
Relationships are bidirectional and will be accessible from both models.
- Raises 400 if keys are invalid or identical.
- Raises 409 if the relationship already exists.
"""
try:
if req.model_key_1 == req.model_key_2:
raise HTTPException(status_code=400, detail="Cannot relate a model to itself.")
ApiDependencies.invoker.services.model_relationships.add_model_relationship(
req.model_key_1,
req.model_key_2,
)
except ValueError as e:
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.delete(
"/",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "The relationship was successfully removed"},
400: {"description": "Invalid model keys or self-referential relationship"},
404: {"description": "The relationship does not exist"},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Remove Model Relationship",
description="Removes a **bidirectional** relationship between two models. The relationship must already exist.",
)
async def remove_model_relationship(
req: ModelRelationshipCreateRequest = Body(..., description="The model keys to disconnect"),
) -> None:
"""
Removes a bidirectional relationship between two model keys.
- Raises 400 if attempting to unlink a model from itself.
- Raises 404 if the relationship was not found.
"""
try:
if req.model_key_1 == req.model_key_2:
raise HTTPException(status_code=400, detail="Cannot unlink a model from itself.")
ApiDependencies.invoker.services.model_relationships.remove_model_relationship(
req.model_key_1,
req.model_key_2,
)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@model_relationships_router.post(
"/batch",
operation_id="get_related_models_batch",
response_model=List[str],
responses={
200: {
"description": "Related model keys retrieved successfully",
"content": {
"application/json": {
"example": [
"ca562b14-995e-4a42-90c1-9528f1a5921d",
"cc0c2b8a-c62e-41d6-878e-cc74dde5ca8f",
"18ca7649-6a9e-47d5-bc17-41ab1e8cec81",
"7c12d1b2-0ef9-4bec-ba55-797b2d8f2ee1",
"c382eaa3-0e28-4ab0-9446-408667699aeb",
"71272e82-0e5f-46d5-bca9-9a61f4bd8a82",
"a5d7cd49-1b98-4534-a475-aeee4ccf5fa2",
]
}
},
},
422: {"description": "Validation error"},
500: {"description": "Internal server error"},
},
summary="Get Related Model Keys (Batch)",
description="Retrieves all **unique related model keys** for a list of given models. This is useful for contextual suggestions or filtering.",
)
async def get_related_models_batch(
req: ModelRelationshipBatchRequest = Body(..., description="Model keys to check for related connections"),
) -> list[str]:
"""
Accepts multiple model keys and returns a flat list of all unique related keys.
Useful when working with multiple selections in the UI or cross-model comparisons.
"""
try:
all_related: set[str] = set()
for key in req.model_keys:
related = ApiDependencies.invoker.services.model_relationships.get_related_model_keys(key)
all_related.update(related)
return list(all_related)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -22,6 +22,7 @@ from invokeai.app.api.routers import (
download_queue,
images,
model_manager,
model_relationships,
session_queue,
style_presets,
utilities,
@@ -125,6 +126,7 @@ app.include_router(download_queue.download_queue_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(model_relationships.model_relationships_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")

View File

@@ -5,6 +5,8 @@ from __future__ import annotations
import inspect
import re
import sys
import types
import typing
import warnings
from abc import ABC, abstractmethod
from enum import Enum
@@ -20,8 +22,10 @@ from typing import (
Literal,
Optional,
Type,
TypedDict,
TypeVar,
Union,
cast,
)
import semver
@@ -104,6 +108,11 @@ class UIConfigBase(BaseModel):
)
class OriginalModelField(TypedDict):
annotation: Any
field_info: FieldInfo
class BaseInvocationOutput(BaseModel):
"""
Base class for all invocation outputs.
@@ -132,6 +141,9 @@ class BaseInvocationOutput(BaseModel):
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
return cls.model_fields["type"].default
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
"""The original model fields, before any modifications were made by the @invocation_output decorator."""
model_config = ConfigDict(
protected_namespaces=(),
validate_assignment=True,
@@ -165,7 +177,7 @@ class BaseInvocation(ABC, BaseModel):
return cls.model_fields["type"].default
@classmethod
def get_output_annotation(cls) -> BaseInvocationOutput:
def get_output_annotation(cls) -> Type[BaseInvocationOutput]:
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
return signature(cls.invoke).return_annotation
@@ -197,7 +209,7 @@ class BaseInvocation(ABC, BaseModel):
Internal invoke method, calls `invoke()` after some prep.
Handles optional fields that are required to call `invoke()` and invocation cache.
"""
for field_name, field in self.model_fields.items():
for field_name, field in type(self).model_fields.items():
if not field.json_schema_extra or callable(field.json_schema_extra):
# something has gone terribly awry, we should always have this and it should be a dict
continue
@@ -212,9 +224,9 @@ class BaseInvocation(ABC, BaseModel):
setattr(self, field_name, orig_default)
if orig_required and orig_default is PydanticUndefined and getattr(self, field_name) is None:
if input_ == Input.Connection:
raise RequiredConnectionException(self.model_fields["type"].default, field_name)
raise RequiredConnectionException(type(self).model_fields["type"].default, field_name)
elif input_ == Input.Any:
raise MissingInputException(self.model_fields["type"].default, field_name)
raise MissingInputException(type(self).model_fields["type"].default, field_name)
# skip node cache codepath if it's disabled
if services.configuration.node_cache_size == 0:
@@ -264,6 +276,9 @@ class BaseInvocation(ABC, BaseModel):
coerce_numbers_to_str=True,
)
_original_model_fields: ClassVar[dict[str, OriginalModelField]] = {}
"""The original model fields, before any modifications were made by the @invocation decorator."""
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
@@ -489,6 +504,48 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
return None
class NoDefaultSentinel:
pass
def validate_field_default(
cls_name: str, field_name: str, invocation_type: str, annotation: Any, field_info: FieldInfo
) -> None:
"""Validates the default value of a field against its pydantic field definition."""
assert isinstance(field_info.json_schema_extra, dict), "json_schema_extra is not a dict"
# By the time we are doing this, we've already done some pydantic magic by overriding the original default value.
# We store the original default value in the json_schema_extra dict, so we can validate it here.
orig_default = field_info.json_schema_extra.get("orig_default", NoDefaultSentinel)
if orig_default is NoDefaultSentinel:
return
# To validate the default value, we can create a temporary pydantic model with the field we are validating as its
# only field. Then validate the default value against this temporary model.
TempDefaultValidator = cast(BaseModel, create_model(cls_name, **{field_name: (annotation, field_info)}))
try:
TempDefaultValidator.model_validate({field_name: orig_default})
except Exception as e:
raise InvalidFieldError(
f'Default value for field "{field_name}" on invocation "{invocation_type}" is invalid, {e}'
) from e
def is_optional(annotation: Any) -> bool:
"""
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
"""
origin = typing.get_origin(annotation)
# PEP 604 unions (int|None) have origin types.UnionType
is_union = origin is typing.Union or origin is types.UnionType
if not is_union:
return False
return any(arg is type(None) for arg in typing.get_args(annotation))
def invocation(
invocation_type: str,
title: Optional[str] = None,
@@ -523,6 +580,24 @@ def invocation(
validate_fields(cls.model_fields, invocation_type)
fields: dict[str, tuple[Any, FieldInfo]] = {}
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
validate_field_default(cls.__name__, field_name, invocation_type, annotation, field_info)
if field_info.default is None and not is_optional(annotation):
annotation = annotation | None
fields[field_name] = (annotation, field_info)
# Add OpenAPI schema extras
uiconfig: dict[str, Any] = {}
uiconfig["title"] = title
@@ -557,11 +632,17 @@ def invocation(
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
invocation_type_annotation = Literal[invocation_type] # type: ignore
invocation_type_field = Field(
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
invocation_type_annotation = Literal[invocation_type]
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
# This cast makes the type annotation match the class's true type.
invocation_type_field_info = cast(
FieldInfo,
Field(title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
)
fields["type"] = (invocation_type_annotation, invocation_type_field_info)
# Validate the `invoke()` method is implemented
if "invoke" in cls.__abstractmethods__:
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
@@ -583,17 +664,12 @@ def invocation(
)
docstring = cls.__doc__
cls = create_model(
cls.__qualname__,
__base__=cls,
__module__=cls.__module__,
type=(invocation_type_annotation, invocation_type_field),
)
cls.__doc__ = docstring
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields) # type: ignore
new_class.__doc__ = docstring
InvocationRegistry.register_invocation(cls)
InvocationRegistry.register_invocation(new_class)
return cls
return new_class
return wrapper
@@ -618,23 +694,39 @@ def invocation_output(
validate_fields(cls.model_fields, output_type)
fields: dict[str, tuple[Any, FieldInfo]] = {}
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
if field_info.default is not PydanticUndefined and is_optional(annotation):
annotation = annotation | None
fields[field_name] = (annotation, field_info)
# Add the output type to the model.
output_type_annotation = Literal[output_type] # type: ignore
output_type_field = Field(
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
output_type_annotation = Literal[output_type]
# Field() returns an instance of FieldInfo, but thanks to a pydantic implementation detail, it is _typed_ as Any.
# This cast makes the type annotation match the class's true type.
output_type_field_info = cast(
FieldInfo,
Field(title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}),
)
fields["type"] = (output_type_annotation, output_type_field_info)
docstring = cls.__doc__
cls = create_model(
cls.__qualname__,
__base__=cls,
__module__=cls.__module__,
type=(output_type_annotation, output_type_field),
)
cls.__doc__ = docstring
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
new_class.__doc__ = docstring
InvocationRegistry.register_output(cls)
InvocationRegistry.register_output(new_class)
return cls
return new_class
return wrapper

View File

@@ -64,7 +64,6 @@ class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[],
min_length=1,
description="The images to batch over",
)
@@ -120,7 +119,6 @@ class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(
default=[],
min_length=1,
description="The strings to batch over",
)
@@ -176,7 +174,6 @@ class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
integers: list[int] = InputField(
default=[],
min_length=1,
description="The integers to batch over",
)
@@ -230,7 +227,6 @@ class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
floats: list[float] = InputField(
default=[],
min_length=1,
description="The floats to batch over",
)

View File

@@ -274,12 +274,12 @@ class InvokeAdjustImageHuePlusInvocation(BaseInvocation, WithMetadata, WithBoard
title="Enhance Image",
tags=["enhance", "image"],
category="image",
version="1.2.0",
version="1.2.1",
)
class InvokeImageEnhanceInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Applies processing from PIL's ImageEnhance module. Originally created by @dwringer"""
image: ImageField = InputField(default=None, description="The image for which to apply processing")
image: ImageField = InputField(description="The image for which to apply processing")
invert: bool = InputField(default=False, description="Whether to invert the image colors")
color: float = InputField(ge=0, default=1.0, description="Color enhancement factor")
contrast: float = InputField(ge=0, default=1.0, description="Contrast enhancement factor")

View File

@@ -42,12 +42,12 @@ class GradientMaskOutput(BaseInvocationOutput):
title="Create Gradient Mask",
tags=["mask", "denoise"],
category="latents",
version="1.2.0",
version="1.2.1",
)
class CreateGradientMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
mask: ImageField = InputField(description="Image which will be masked", ui_order=1)
edge_radius: int = InputField(
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
)

View File

@@ -608,6 +608,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
mask=mask,
method=single_ip_adapter.method,
)
)

View File

@@ -62,6 +62,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
Imagen3Model = "Imagen3ModelField"
Imagen4Model = "Imagen4ModelField"
ChatGPT4oModel = "ChatGPT4oModelField"
# endregion
@@ -400,8 +401,8 @@ class InputFieldJSONSchemaExtra(BaseModel):
"""
input: Input
orig_required: bool
field_kind: FieldKind
orig_required: bool = True
default: Optional[Any] = None
orig_default: Optional[Any] = None
ui_hidden: bool = False
@@ -498,7 +499,7 @@ def InputField(
input: Input = Input.Any,
ui_type: Optional[UIType] = None,
ui_component: Optional[UIComponent] = None,
ui_hidden: bool = False,
ui_hidden: Optional[bool] = None,
ui_order: Optional[int] = None,
ui_choice_labels: Optional[dict[str, str]] = None,
) -> Any:
@@ -534,15 +535,20 @@ def InputField(
json_schema_extra_ = InputFieldJSONSchemaExtra(
input=input,
ui_type=ui_type,
ui_component=ui_component,
ui_hidden=ui_hidden,
ui_order=ui_order,
ui_choice_labels=ui_choice_labels,
field_kind=FieldKind.Input,
orig_required=True,
)
if ui_type is not None:
json_schema_extra_.ui_type = ui_type
if ui_component is not None:
json_schema_extra_.ui_component = ui_component
if ui_hidden is not None:
json_schema_extra_.ui_hidden = ui_hidden
if ui_order is not None:
json_schema_extra_.ui_order = ui_order
if ui_choice_labels is not None:
json_schema_extra_.ui_choice_labels = ui_choice_labels
"""
There is a conflict between the typing of invocation definitions and the typing of an invocation's
`invoke()` function.
@@ -614,7 +620,7 @@ def InputField(
return Field(
**provided_args,
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
json_schema_extra=json_schema_extra_.model_dump(exclude_unset=True),
)

View File

@@ -21,14 +21,14 @@ class IdealSizeOutput(BaseInvocationOutput):
"ideal_size",
title="Ideal Size - SD1.5, SDXL",
tags=["latents", "math", "ideal_size"],
version="1.0.5",
version="1.0.6",
)
class IdealSizeInvocation(BaseInvocation):
"""Calculates the ideal size for generation to avoid duplication"""
width: int = InputField(default=1024, description="Final image width")
height: int = InputField(default=576, description="Final image height")
unet: UNetField = InputField(default=None, description=FieldDescriptions.unet)
unet: UNetField = InputField(description=FieldDescriptions.unet)
multiplier: float = InputField(
default=1.0,
description="Amount to multiply the model's dimensions by when calculating the ideal size (may result in "

View File

@@ -975,13 +975,13 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Canvas Paste Back",
tags=["image", "combine"],
category="image",
version="1.0.0",
version="1.0.1",
)
class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combines two images by using the mask provided. Intended for use on the Unified Canvas."""
source_image: ImageField = InputField(description="The source image")
target_image: ImageField = InputField(default=None, description="The target image")
target_image: ImageField = InputField(description="The target image")
mask: ImageField = InputField(
description="The mask to use when pasting",
)

View File

@@ -31,6 +31,7 @@ class IPAdapterField(BaseModel):
image_encoder_model: ModelIdentifierField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the IP-Adapter.")
target_blocks: List[str] = Field(default=[], description="The IP Adapter blocks to apply")
method: str = Field(default="full", description="Weight apply method")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
@@ -94,7 +95,7 @@ class IPAdapterInvocation(BaseInvocation):
weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight"
)
method: Literal["full", "style", "composition"] = InputField(
method: Literal["full", "style", "composition", "style_strong", "style_precise"] = InputField(
default="full", description="The method to apply the IP-Adapter"
)
begin_step_percent: float = InputField(
@@ -147,6 +148,38 @@ class IPAdapterInvocation(BaseInvocation):
target_blocks = ["down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "style_precise":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.1", "down_blocks.2", "mid_block"]
elif ip_adapter_info.base == "sdxl":
target_blocks = ["up_blocks.0.attentions.1", "down_blocks.2.attentions.1"]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "style_strong":
if ip_adapter_info.base == "sd-1":
target_blocks = ["up_blocks.0", "up_blocks.1", "up_blocks.2", "down_blocks.0", "down_blocks.1"]
elif ip_adapter_info.base == "sdxl":
target_blocks = [
"up_blocks.0.attentions.1",
"up_blocks.1.attentions.1",
"up_blocks.2.attentions.1",
"up_blocks.0.attentions.2",
"up_blocks.1.attentions.2",
"up_blocks.2.attentions.2",
"up_blocks.0.attentions.0",
"up_blocks.1.attentions.0",
"up_blocks.2.attentions.0",
"down_blocks.0.attentions.0",
"down_blocks.0.attentions.1",
"down_blocks.0.attentions.2",
"down_blocks.1.attentions.0",
"down_blocks.1.attentions.1",
"down_blocks.1.attentions.2",
"down_blocks.2.attentions.0",
"down_blocks.2.attentions.2",
]
else:
raise ValueError(f"Unsupported IP-Adapter base type: '{ip_adapter_info.base}'.")
elif self.method == "full":
target_blocks = ["block"]
else:
@@ -162,6 +195,7 @@ class IPAdapterInvocation(BaseInvocation):
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
mask=self.mask,
method=self.method,
),
)

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers import AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
@@ -104,14 +104,13 @@ class SegmentAnythingInvocation(BaseInvocation):
@staticmethod
def _load_sam_model(model_path: Path):
sam_model = AutoModelForMaskGeneration.from_pretrained(
sam_model = SamModel.from_pretrained(
model_path,
local_files_only=True,
# TODO(ryand): Setting the torch_dtype here doesn't work. Investigate whether fp16 is supported by the
# model, and figure out how to make it work in the pipeline.
# torch_dtype=TorchDevice.choose_torch_dtype(),
)
assert isinstance(sam_model, SamModel)
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
assert isinstance(sam_processor, SamProcessor)

View File

@@ -27,6 +27,10 @@ if TYPE_CHECKING:
from invokeai.app.services.invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from invokeai.app.services.model_images.model_images_base import ModelImageFileStorageBase
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
from invokeai.app.services.names.names_base import NameServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
@@ -54,6 +58,8 @@ class InvocationServices:
logger: "Logger",
model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase",
model_relationships: "ModelRelationshipsServiceABC",
model_relationship_records: "ModelRelationshipRecordStorageBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
session_queue: "SessionQueueBase",
@@ -81,6 +87,8 @@ class InvocationServices:
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.model_relationships = model_relationships
self.model_relationship_records = model_relationship_records
self.download_queue = download_queue
self.performance_statistics = performance_statistics
self.session_queue = session_queue

View File

@@ -60,7 +60,7 @@ class InvocationStatsServiceBase(ABC):
pass
@abstractmethod
def reset_stats(self):
def reset_stats(self, graph_execution_state_id: str) -> None:
"""Reset all stored statistics."""
pass

View File

@@ -73,9 +73,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
)
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
def reset_stats(self):
self._stats = {}
self._cache_stats = {}
def reset_stats(self, graph_execution_state_id: str) -> None:
self._stats.pop(graph_execution_state_id, None)
self._cache_stats.pop(graph_execution_state_id, None)
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)

View File

@@ -38,6 +38,7 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
@@ -646,14 +647,18 @@ class ModelInstallService(ModelInstallServiceBase):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo)
# New model probe API is disabled pending resolution of issue caused by a change of the ordering of checks.
# See commit message for details.
# try:
# return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
# except InvalidModelConfigException:
# return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

View File

@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
class ModelRelationshipRecordStorageBase(ABC):
"""Abstract base class for model-to-model relationship record storage."""
@abstractmethod
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Creates a relationship between two models by keys."""
pass
@abstractmethod
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Removes a relationship between two models by keys."""
pass
@abstractmethod
def get_related_model_keys(self, model_key: str) -> list[str]:
"""Gets all models keys related to a given model key."""
pass
@abstractmethod
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
"""Get related model keys for multiple models given a list of keys."""
pass

View File

@@ -0,0 +1,66 @@
import sqlite3
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._conn = db.conn
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_related_model_keys(self, model_key: str) -> list[str]:
cursor = self._conn.cursor()
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
return [row[0] for row in cursor.fetchall()]
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
cursor = self._conn.cursor()
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
return [row[0] for row in cursor.fetchall()]

View File

@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod
class ModelRelationshipsServiceABC(ABC):
"""High-level service for managing model-to-model relationships."""
@abstractmethod
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Creates a relationship between two models keys."""
pass
@abstractmethod
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
"""Removes a relationship between two models keys."""
pass
@abstractmethod
def get_related_model_keys(self, model_key: str) -> list[str]:
"""Gets all models keys related to a given model key."""
pass
@abstractmethod
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
"""Get related model keys for multiple models."""
pass

View File

@@ -0,0 +1,9 @@
from datetime import datetime
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class ModelRelationship(BaseModelExcludeNull):
model_key_1: str
model_key_2: str
created_at: datetime

View File

@@ -0,0 +1,31 @@
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_relationships.model_relationships_base import ModelRelationshipsServiceABC
from invokeai.backend.model_manager.config import AnyModelConfig
class ModelRelationshipsService(ModelRelationshipsServiceABC):
__invoker: Invoker
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
self.__invoker.services.model_relationship_records.add_model_relationship(model_key_1, model_key_2)
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
self.__invoker.services.model_relationship_records.remove_model_relationship(model_key_1, model_key_2)
def get_related_model_keys(self, model_key: str) -> list[str]:
return self.__invoker.services.model_relationship_records.get_related_model_keys(model_key)
def add_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None:
self.add_model_relationship(model_1.key, model_2.key)
def remove_relationship_from_models(self, model_1: AnyModelConfig, model_2: AnyModelConfig) -> None:
self.remove_model_relationship(model_1.key, model_2.key)
def get_related_keys_from_model(self, model: AnyModelConfig) -> list[str]:
return self.get_related_model_keys(model.key)
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
return self.__invoker.services.model_relationship_records.get_related_model_keys_batch(model_keys)

View File

@@ -210,7 +210,7 @@ class DefaultSessionRunner(SessionRunnerBase):
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
self._services.performance_statistics.reset_stats(queue_item.session.id)
for callback in self._on_after_run_session_callbacks:
callback(queue_item=queue_item)

View File

@@ -148,7 +148,7 @@ class Batch(BaseModel):
node = cast(BaseInvocation, graph.get_node(batch_data.node_path))
except NodeNotFoundError:
raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph")
if batch_data.field_name not in node.model_fields:
if batch_data.field_name not in type(node).model_fields:
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return values

View File

@@ -424,7 +424,7 @@ class Graph(BaseModel):
)
# input fields are on the node
if edge.destination.field not in destination_node.model_fields:
if edge.destination.field not in type(destination_node).model_fields:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)

View File

@@ -22,6 +22,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_16 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_17 import build_migration_17
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_18 import build_migration_18
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -61,6 +62,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_17())
migrator.register_migration(build_migration_18())
migrator.register_migration(build_migration_19(app_config=config))
migrator.register_migration(build_migration_20())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,37 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration20Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
cursor.execute(
"""
-- many-to-many relationship table for models
CREATE TABLE IF NOT EXISTS model_relationships (
-- model_key_1 and model_key_2 are the same as the key(primary key) in the models table
model_key_1 TEXT NOT NULL,
model_key_2 TEXT NOT NULL,
created_at TEXT DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
PRIMARY KEY (model_key_1, model_key_2),
-- model_key_1 < model_key_2, to ensure uniqueness and prevent duplicates
FOREIGN KEY (model_key_1) REFERENCES models(id) ON DELETE CASCADE,
FOREIGN KEY (model_key_2) REFERENCES models(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""
-- Creates an index to keep performance equal when searching for model_key_1 or model_key_2
CREATE INDEX IF NOT EXISTS keyx_model_relationships_model_key_2
ON model_relationships(model_key_2)
"""
)
def build_migration_20() -> Migration:
return Migration(
from_version=19,
to_version=20,
callback=Migration20Callback(),
)

View File

@@ -146,33 +146,35 @@ class ModelConfigBase(ABC, BaseModel):
)
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
_USING_LEGACY_PROBE: ClassVar[set] = set()
_USING_CLASSIFY_API: ClassVar[set] = set()
USING_LEGACY_PROBE: ClassVar[set] = set()
USING_CLASSIFY_API: ClassVar[set] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
else:
ModelConfigBase._USING_CLASSIFY_API.add(cls)
ModelConfigBase.USING_CLASSIFY_API.add(cls)
@staticmethod
def all_config_classes():
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
concrete = {cls for cls in subclasses if not isabstract(cls)}
return concrete
@staticmethod
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
if isinstance(mod, Path | str):
mod = ModelOnDisk(mod, hash_algo)
candidates = ModelConfigBase.USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:

View File

@@ -2,6 +2,8 @@ from typing import Any
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
@@ -76,7 +78,15 @@ class CachedModelOnlyFullLoad:
for k, v in self._cpu_state_dict.items():
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
self._model.to(self._compute_device)
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
self._model.to(self._compute_device)
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
else:
self._model.to(self._compute_device)
self._is_in_vram = True
return self._total_bytes
@@ -92,7 +102,15 @@ class CachedModelOnlyFullLoad:
if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
self._model.to(self._offload_device)
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
self._model.to(self._offload_device)
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
else:
self._model.to(self._offload_device)
self._is_in_vram = False
return self._total_bytes

View File

@@ -2,9 +2,10 @@ import gc
import logging
import threading
import time
from dataclasses import dataclass
from functools import wraps
from logging import Logger
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Protocol
import psutil
import torch
@@ -54,6 +55,39 @@ def synchronized(method: Callable[..., Any]) -> Callable[..., Any]:
return wrapper
@dataclass
class CacheEntrySnapshot:
cache_key: str
total_bytes: int
current_vram_bytes: int
class CacheMissCallback(Protocol):
def __call__(
self,
model_key: str,
cache_snapshot: dict[str, CacheEntrySnapshot],
) -> None: ...
class CacheHitCallback(Protocol):
def __call__(
self,
model_key: str,
cache_snapshot: dict[str, CacheEntrySnapshot],
) -> None: ...
class CacheModelsClearedCallback(Protocol):
def __call__(
self,
models_cleared: int,
bytes_requested: int,
bytes_freed: int,
cache_snapshot: dict[str, CacheEntrySnapshot],
) -> None: ...
class ModelCache:
"""A cache for managing models in memory.
@@ -144,6 +178,34 @@ class ModelCache:
# - Requests to empty the cache from a separate thread
self._lock = threading.RLock()
self._on_cache_hit_callbacks: set[CacheHitCallback] = set()
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
self._on_cache_hit_callbacks.add(cb)
def unsubscribe() -> None:
self._on_cache_hit_callbacks.discard(cb)
return unsubscribe
def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
self._on_cache_miss_callbacks.add(cb)
def unsubscribe() -> None:
self._on_cache_miss_callbacks.discard(cb)
return unsubscribe
def on_cache_models_cleared(self, cb: CacheModelsClearedCallback) -> Callable[[], None]:
self._on_cache_models_cleared_callbacks.add(cb)
def unsubscribe() -> None:
self._on_cache_models_cleared_callbacks.discard(cb)
return unsubscribe
@property
@synchronized
def stats(self) -> Optional[CacheStats]:
@@ -195,6 +257,20 @@ class ModelCache:
f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size / MB:.2f}MB)"
)
@synchronized
def _get_cache_snapshot(self) -> dict[str, CacheEntrySnapshot]:
overview: dict[str, CacheEntrySnapshot] = {}
for cache_key, cache_entry in self._cached_models.items():
total_bytes = cache_entry.cached_model.total_bytes()
current_vram_bytes = cache_entry.cached_model.cur_vram_bytes()
overview[cache_key] = CacheEntrySnapshot(
cache_key=cache_key,
total_bytes=total_bytes,
current_vram_bytes=current_vram_bytes,
)
return overview
@synchronized
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
"""Retrieve a model from the cache.
@@ -208,6 +284,8 @@ class ModelCache:
if self.stats:
self.stats.hits += 1
else:
for cb in self._on_cache_miss_callbacks:
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
if self.stats:
self.stats.misses += 1
self._logger.debug(f"Cache miss: {key}")
@@ -229,6 +307,8 @@ class ModelCache:
self._cache_stack.append(key)
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
for cb in self._on_cache_hit_callbacks:
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
return cache_entry
@synchronized
@@ -649,6 +729,13 @@ class ModelCache:
# immediately when their reference count hits 0.
if self.stats:
self.stats.cleared = models_cleared
for cb in self._on_cache_models_cleared_callbacks:
cb(
models_cleared=models_cleared,
bytes_requested=bytes_needed,
bytes_freed=ram_bytes_freed,
cache_snapshot=self._get_cache_snapshot(),
)
gc.collect()
TorchDevice.empty_cache()

View File

@@ -4,6 +4,7 @@ from typing import Any, Optional, TypeAlias
import safetensors.torch
import torch
from picklescan.scanner import scan_file_path
from safetensors import safe_open
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
@@ -35,12 +36,21 @@ class ModelOnDisk:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self) -> set[Path]:
def weight_files(self) -> set[Path]:
if self.path.is_file():
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
try:
with safe_open(self.path, framework="pt", device="cpu") as f:
metadata = f.metadata()
assert isinstance(metadata, dict)
return metadata
except Exception:
return {}
def repo_variant(self) -> Optional[ModelRepoVariant]:
if self.path.is_file():
return None
@@ -64,18 +74,7 @@ class ModelOnDisk:
if path in sd_cache:
return sd_cache[path]
if not path:
components = list(self.component_paths())
match components:
case []:
raise ValueError("No weight files found for this model")
case [p]:
path = p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
path = self.resolve_weight_file(path)
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
@@ -94,3 +93,18 @@ class ModelOnDisk:
state_dict = checkpoint.get("state_dict", checkpoint)
sd_cache[path] = state_dict
return state_dict
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
if not path:
weight_files = list(self.weight_files())
match weight_files:
case []:
raise ValueError("No weight files found for this model")
case [p]:
return p
case ps if len(ps) >= 2:
raise ValueError(
f"Multiple weight files found for this model: {ps}. "
f"Please specify the intended file using the 'path' argument"
)
return path

View File

@@ -27,6 +27,7 @@ class BaseModelType(str, Enum):
Flux = "flux"
CogView4 = "cogview4"
Imagen3 = "imagen3"
Imagen4 = "imagen4"
ChatGPT4o = "chatgpt-4o"

View File

@@ -5,7 +5,8 @@ from typing import Callable, Optional, Union
import gguf
import torch
TORCH_COMPATIBLE_QTYPES = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
# should not be a Set until this is resolved: https://github.com/pytorch/pytorch/issues/145761
TORCH_COMPATIBLE_QTYPES = [None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
# K Quants #
QK_K = 256

View File

@@ -371,7 +371,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if use_ip_adapter or use_regional_prompting:
ip_adapters: Optional[List[UNetIPAdapterData]] = (
[{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data]
[
{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks, "method": ipa.method}
for ipa in ip_adapter_data
]
if use_ip_adapter
else None
)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
@@ -104,15 +104,29 @@ class IPAdapterConditioningInfo:
@dataclass
class IPAdapterData:
"""Data class for IP-Adapter configuration.
Attributes:
ip_adapter_model: The IP-Adapter model to use.
ip_adapter_conditioning: The IP-Adapter conditioning data.
mask: The mask to apply to the IP-Adapter conditioning.
target_blocks: List of target attention block names to apply IP-Adapter to.
negative_blocks: List of target attention block names that should use negative attention.
weight: The weight to apply to the IP-Adapter conditioning.
begin_step_percent: The percentage of steps at which to start applying the IP-Adapter.
end_step_percent: The percentage of steps at which to stop applying the IP-Adapter.
method: The method to use for applying the IP-Adapter ('full', 'style', 'composition').
"""
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
target_blocks: List[str]
# Either a single weight applied to all steps, or a list of weights for each step.
negative_blocks: List[str] = field(default_factory=list)
weight: Union[float, List[float]] = 1.0
begin_step_percent: float = 0.0
end_step_percent: float = 1.0
method: str = "full"
def scale_for_step(self, step_index: int, total_steps: int) -> float:
first_adapter_step = math.floor(self.begin_step_percent * total_steps)

View File

@@ -14,6 +14,7 @@ from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import Reg
class IPAdapterAttentionWeights:
ip_adapter_weights: IPAttentionProcessorWeights
skip: bool
negative: bool
class CustomAttnProcessor2_0(AttnProcessor2_0):
@@ -162,6 +163,10 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
if not self._ip_adapter_attention_weights[ipa_index].skip:
# apply the IP-Adapter weights to the negative embeds
if self._ip_adapter_attention_weights[ipa_index].negative:
ip_hidden_states = torch.cat([ip_hidden_states[1], ip_hidden_states[0] * 0], dim=0)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)

View File

@@ -12,7 +12,8 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import (
class UNetIPAdapterData(TypedDict):
ip_adapter: IPAdapter
target_blocks: List[str]
target_blocks: List[str] # Blocks where IP-Adapter should be applied
method: str # Style or other method type
class UNetAttentionPatcher:
@@ -39,12 +40,18 @@ class UNetAttentionPatcher:
for ip_adapter in self._ip_adapters:
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True
negative = False
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
negative = ip_adapter["method"] == "style_precise" and (
block == "down_blocks.2.attentions.1"
or block == "down_blocks.2"
or block == "mid_block"
)
break
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip
ip_adapter_weights=ip_adapter_weights, skip=skip, negative=negative
)
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)

View File

@@ -119,7 +119,17 @@
"error_withCount_other": "{{count}} Fehler",
"value": "Wert",
"label": "Label",
"systemInformation": "Systeminformationen"
"systemInformation": "Systeminformationen",
"search": "Suche",
"clear": "Zurücksetzen",
"fullView": "Vollansicht",
"compactView": "Kompaktansicht",
"options_withCount_one": "{{count}} Option",
"options_withCount_other": "{{count}} Optionen",
"noOptions": "Keine Optionen",
"noMatches": "Keine Treffer",
"model_withCount_one": "{{count}} Modell",
"model_withCount_other": "{{count}} Modelle"
},
"gallery": {
"galleryImageSize": "Bildgröße",

View File

@@ -846,6 +846,8 @@
"predictionType": "Prediction Type",
"prune": "Prune",
"pruneTooltip": "Prune finished imports from queue",
"relatedModels": "Related Models",
"showOnlyRelatedModels": "Related",
"repo_id": "Repo ID",
"repoVariant": "Repo Variant",
"scanFolder": "Scan Folder",
@@ -1330,7 +1332,7 @@
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
"unableToCopyDesc_theseSteps": "these steps",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagen3IncompatibleGenerationMode": "Google Imagen3 supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
@@ -2040,10 +2042,14 @@
"ipAdapterMethod": "Mode",
"full": "Style and Composition",
"fullDesc": "Applies visual style (colors, textures) & composition (layout, structure).",
"style": "Style Only",
"styleDesc": "Applies visual style (colors, textures) without considering its layout.",
"style": "Style (Simple)",
"styleDesc": "Applies visual style (colors, textures) without considering its layout. Previously called Style Only.",
"composition": "Composition Only",
"compositionDesc": "Replicates layout & structure while ignoring the reference's style."
"compositionDesc": "Replicates layout & structure while ignoring the reference's style.",
"styleStrong": "Style (Strong)",
"styleStrongDesc": "Applies a strong visual style, with a slightly reduced composition influence.",
"stylePrecise": "Style (Precise)",
"stylePreciseDesc": "Applies a precise visual style, eliminating subject influence."
},
"fluxReduxImageInfluence": {
"imageInfluence": "Image Influence",
@@ -2413,8 +2419,9 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"CogView4: Support for CogView4 models in Canvas and Workflows.",
"Updated Dependencies: Invoke now runs on the latest version of its dependencies, including Python 3.12 and Pytorch 2.6.0."
"Nvidia 50xx GPUs: Invoke uses PyTorch 2.7.0, which is required for these GPUs.",
"Model Relationships: Link LoRAs to main models, and the LoRAs will show up first in the list.",
"IP Adapter: New Style (Strong) and Style (Precise) methods for SDXL and SD1.5 models."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@@ -116,7 +116,19 @@
"error_withCount_other": "{{count}} errori",
"value": "Valore",
"label": "Etichetta",
"systemInformation": "Informazioni di sistema"
"systemInformation": "Informazioni di sistema",
"noMatches": "Nessuna corrispondenza",
"noOptions": "Nessuna opzione",
"model_withCount_one": "{{count}} modello",
"model_withCount_many": "{{count}} modelli",
"model_withCount_other": "{{count}} modelli",
"options_withCount_one": "{{count}} opzione",
"options_withCount_many": "{{count}} opzioni",
"options_withCount_other": "{{count}} opzioni",
"search": "Cerca",
"clear": "Cancella",
"compactView": "Vista compatta",
"fullView": "Vista completa"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -637,7 +649,14 @@
"urlForbidden": "Non hai accesso a questo modello",
"urlForbiddenErrorMessage": "Potrebbe essere necessario richiedere l'autorizzazione al sito che distribuisce il modello.",
"urlUnauthorizedErrorMessage": "Potrebbe essere necessario configurare un gettone API per accedere a questo modello.",
"fileSize": "Dimensione del file"
"fileSize": "Dimensione del file",
"filterModels": "Filtra i modelli",
"modelPickerFallbackNoModelsInstalled": "Nessun modello installato.",
"modelPickerFallbackNoModelsInstalled2": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare i modelli.",
"manageModels": "Gestione modelli",
"hfTokenReset": "Ripristino del gettone HF",
"relatedModels": "Modelli correlati",
"showOnlyRelatedModels": "Correlati"
},
"parameters": {
"images": "Immagini",
@@ -719,7 +738,11 @@
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (excl min)",
"collectionEmpty": "raccolta vuota",
"batchNodeCollectionSizeMismatchNoGroupId": "Dimensione della raccolta di gruppo nel Lotto non corrisponde",
"modelIncompatibleBboxWidth": "La larghezza del riquadro di delimitazione è {{width}} ma {{model}} richiede multipli di {{multiple}}"
"modelIncompatibleBboxWidth": "La larghezza del riquadro di delimitazione è {{width}} ma {{model}} richiede multipli di {{multiple}}",
"modelIncompatibleBboxHeight": "L'altezza del riquadro è {{height}} ma {{model}} richiede multipli di {{multiple}}",
"modelIncompatibleScaledBboxWidth": "La larghezza scalata del riquadro è {{width}} ma {{model}} richiede multipli di {{multiple}}",
"modelIncompatibleScaledBboxHeight": "L'altezza scalata del riquadro è {{height}} ma {{model}} richiede multipli di {{multiple}}",
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Accedi alle impostazioni del tuo account per effettuare l'upgrade."
},
"useCpuNoise": "Usa la CPU per generare rumore",
"iterations": "Iterazioni",
@@ -746,7 +769,8 @@
"sendToCanvas": "Invia alla Tela",
"coherenceMinDenoise": "Min rid. rumore",
"recallMetadata": "Richiama i metadati",
"disabledNoRasterContent": "Disabilitato (nessun contenuto Raster)"
"disabledNoRasterContent": "Disabilitato (nessun contenuto Raster)",
"modelDisabledForTrial": "La generazione con {{modelName}} non è disponibile per gli account di prova. Visita le <LinkComponent>impostazioni account</LinkComponent> per effettuare l'upgrade."
},
"settings": {
"models": "Modelli",
@@ -855,7 +879,11 @@
"unableToCopy": "Impossibile copiare",
"unableToCopyDesc": "Il tuo browser non supporta l'accesso agli appunti. Gli utenti di Firefox potrebbero risolvere il problema seguendo ",
"unableToCopyDesc_theseSteps": "questi passaggi",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill non è compatibile con Testo a Immagine o Immagine a Immagine. Per queste attività, utilizzare altri modelli FLUX."
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill non è compatibile con Testo a Immagine o Immagine a Immagine. Per queste attività, utilizzare altri modelli FLUX.",
"problemUnpublishingWorkflow": "Problema durante l'annullamento della pubblicazione del flusso di lavoro",
"problemUnpublishingWorkflowDescription": "Si è verificato un problema durante l'annullamento della pubblicazione del flusso di lavoro. Riprova.",
"workflowUnpublished": "Flusso di lavoro non pubblicato",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supporta solo la conversione da testo a immagine e da immagine a immagine. Utilizza altri modelli per le attività di Inpainting e Outpainting."
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@@ -1049,7 +1077,8 @@
"unknownField_withName": "Campo \"{{name}}\" sconosciuto",
"missingField_withName": "Campo \"{{name}}\" mancante",
"unknownFieldEditWorkflowToFix_withName": "Il flusso di lavoro contiene un campo \"{{name}}\" sconosciuto .\nModifica il flusso di lavoro per risolvere il problema.",
"unexpectedField_withName": "Campo \"{{name}}\" inaspettato"
"unexpectedField_withName": "Campo \"{{name}}\" inaspettato",
"missingSourceOrTargetHandle": "Identificatore del nodo sorgente o di destinazione mancante"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@@ -1178,7 +1207,8 @@
"cancelAllExceptCurrentTooltip": "Annulla tutto tranne l'elemento corrente",
"retrySucceeded": "Elemento rieseguito",
"retryItem": "Riesegui elemento",
"retryFailed": "Problema riesecuzione elemento"
"retryFailed": "Problema riesecuzione elemento",
"credits": "Crediti"
},
"models": {
"noMatchingModels": "Nessun modello corrispondente",
@@ -1821,7 +1851,10 @@
"publishingValidationRunInProgress": "È in corso la convalida della pubblicazione.",
"publishedWorkflowsLocked": "I flussi di lavoro pubblicati sono bloccati e non possono essere modificati o eseguiti. Annulla la pubblicazione del flusso di lavoro o salva una copia per modificare o eseguire questo flusso di lavoro.",
"warningWorkflowHasNoPublishableInputFields": "Nessun campo di ingresso pubblicabile selezionato: il flusso di lavoro pubblicato verrà eseguito solo con i valori predefiniti",
"publishInProgress": "Pubblicazione in corso"
"publishInProgress": "Pubblicazione in corso",
"selectingOutputNode": "Selezione del nodo di uscita",
"selectingOutputNodeDesc": "Fare clic su un nodo per selezionarlo come nodo di uscita del flusso di lavoro.",
"errorWorkflowHasUnpublishableNodes": "Il flusso di lavoro ha nodi di estrazione lotto, generatore o metadati"
},
"loadMore": "Carica altro",
"searchPlaceholder": "Cerca per nome, descrizione o etichetta",
@@ -1971,12 +2004,16 @@
"stagingOnCanvas": "Genera immagini nella",
"ipAdapterMethod": {
"full": "Stile e Composizione",
"style": "Solo Stile",
"style": "Stile (semplice)",
"composition": "Solo Composizione",
"ipAdapterMethod": "Modalità",
"fullDesc": "Applica lo stile visivo (colori, texture) e la composizione (disposizione, struttura).",
"styleDesc": "Applica lo stile visivo (colori, texture) senza considerare la disposizione.",
"compositionDesc": "Replica disposizione e struttura ignorando lo stile di riferimento."
"styleDesc": "Applica lo stile visivo (colori, texture) senza considerare la disposizione. Precedentemente chiamato \"Solo stile\".",
"compositionDesc": "Replica disposizione e struttura ignorando lo stile di riferimento.",
"styleStrong": "Stile (forte)",
"styleStrongDesc": "Applica uno stile visivo forte, con un'influenza sulla composizione leggermente ridotta.",
"stylePrecise": "Stile (preciso)",
"stylePreciseDesc": "Applica uno stile visivo preciso, eliminando l'influenza del soggetto."
},
"showingType": "Mostra {{type}}",
"dynamicGrid": "Griglia dinamica",
@@ -2299,6 +2336,14 @@
"errors": {
"unableToFindImage": "Impossibile trovare l'immagine",
"unableToLoadImage": "Impossibile caricare l'immagine"
},
"fluxReduxImageInfluence": {
"high": "Alta",
"low": "Basso",
"imageInfluence": "Influenza dell'immagine",
"lowest": "Il più basso",
"medium": "Medio",
"highest": "La più alta"
}
},
"ui": {
@@ -2399,8 +2444,8 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Flussi di lavoro: supporto per menu a discesa di stringhe personalizzate nel Generatore di Flussi di lavoro.",
"FLUX: supporto per FLUX Fill in Flussi di lavoro e Tela."
"GPU Nvidia 50xx: Invoke utilizza PyTorch 2.7.0, necessario per queste GPU.",
"Relazioni tra modelli: collega i LoRA ai modelli principali e i LoRA verranno visualizzati per primi nell'elenco."
]
},
"system": {

View File

@@ -118,7 +118,15 @@
"value": "値",
"label": "ラベル",
"saveChanges": "変更を保存",
"error_withCount_other": "{{count}} 個のエラー"
"error_withCount_other": "{{count}} 個のエラー",
"noMatches": "合致しません",
"model_withCount_other": "{{count}}個のモデル",
"noOptions": "オプションがありません",
"search": "検索",
"clear": "クリア",
"compactView": "コンパクトビュー",
"fullView": "フルビュー",
"options_withCount_other": "{{count}}個のオプション"
},
"gallery": {
"galleryImageSize": "画像のサイズ",
@@ -583,7 +591,7 @@
"deleteModelImage": "モデル画像を削除",
"hfTokenInvalid": "ハギングフェイストークンが無効または見つかりません",
"hfForbiddenErrorMessage": "リポジトリにアクセスすることを勧めます.所有者はダウンロードにあたり利用規約への同意を要求する場合があります.",
"noModelsInstalled": "インストールされているモデルなし",
"noModelsInstalled": "インストールされているモデルがありません",
"pathToConfig": "設定へのパス",
"noModelsInstalledDesc1": "モデルを一緒にインストール",
"pruneTooltip": "完了したインポートをキューから削除",
@@ -639,7 +647,12 @@
"urlUnauthorizedErrorMessage": "このモデルにアクセスするためにAPIトークンを構成する必要があるかもしれません.",
"urlUnauthorizedErrorMessage2": "ここでどうやるか学びます.",
"inplaceInstall": "定位置にインストール",
"fileSize": "ファイルサイズ"
"fileSize": "ファイルサイズ",
"modelPickerFallbackNoModelsInstalled2": "<LinkComponent>モデルマネージャー</LinkComponent> にアクセスしてモデルをインストールしてください.",
"filterModels": "フィルターモデル",
"modelPickerFallbackNoModelsInstalled": "モデルがインストールされていません.",
"manageModels": "モデル管理",
"hfTokenReset": "ハギングフェイストークンリセット"
},
"parameters": {
"images": "画像",
@@ -684,7 +697,28 @@
"collectionNumberGTMax": "{{value}} > {{maximum}} (最大増加)",
"missingNodeTemplate": "ノードテンプレートの欠落",
"batchNodeNotConnected": "バッチノードが: {{label}}につながっていない",
"collectionNumberLTMin": "{{value}} < {{minimum}} (最小増加)"
"collectionNumberLTMin": "{{value}} < {{minimum}} (最小増加)",
"fluxModelIncompatibleScaledBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), スケーリングされたbboxの高さは{{height}}です",
"fluxModelMultipleControlLoRAs": "コントロールLoRAは1度に1つしか使用できません",
"noPrompts": "プロンプトが生成されません",
"noNodesInGraph": "グラフにノードがありません",
"noCLIPEmbedModelSelected": "FLUX生成にCLIPエンベッドモデルが選択されていません",
"canvasIsFiltering": "キャンバスがビジー状態(フィルタリング)",
"canvasIsCompositing": "キャンバスがビジー状態(合成)",
"systemDisconnected": "システムが切断されました",
"fluxModelIncompatibleScaledBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), 拡大縮小されたbboxの幅は{{width}}です",
"canvasIsTransforming": "キャンバスがビジー状態(変換)",
"canvasIsRasterizing": "キャンバスがビジー状態(ラスタライズ)",
"modelIncompatibleBboxHeight": "Bboxの高さは{{height}}ですが,{{model}}は{{multiple}}の倍数が必要です",
"modelIncompatibleScaledBboxHeight": "bboxの高さは{{height}}ですが,{{model}}は{{multiple}}の倍数を必要です",
"modelIncompatibleBboxWidth": "Bboxの幅は{{width}}ですが, {{model}}は{{multiple}}の倍数が必要です",
"modelIncompatibleScaledBboxWidth": "bboxの幅は{{width}}ですが,{{model}}は{{multiple}}の倍数が必要です",
"canvasIsSelectingObject": "キャンバスがビジー状態(オブジェクトの選択)",
"fluxModelIncompatibleBboxWidth": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), bboxの幅は{{width}}です",
"fluxModelIncompatibleBboxHeight": "$t(parameters.invoke.fluxRequiresDimensionsToBeMultipleOf16), bboxの高さは{{height}}です",
"noFLUXVAEModelSelected": "FLUX生成にVAEモデルが選択されていません",
"noT5EncoderModelSelected": "FLUX生成にT5エンコーダモデルが選択されていません",
"modelDisabledForTrial": "{{modelName}} を使用した生成はトライアルアカウントではご利用いただけません.アカウント設定にアクセスしてアップグレードしてください。"
},
"aspect": "縦横比",
"lockAspectRatio": "縦横比を固定",
@@ -716,7 +750,24 @@
"cfgRescaleMultiplier": "CFGリスケール倍率",
"clipSkip": "クリップスキップ",
"guidance": "ガイダンス",
"infillMethod": "充填法"
"infillMethod": "充填法",
"patchmatchDownScaleSize": "ダウンスケール",
"boxBlur": "ボックスぼかし",
"remixImage": "リミックス画像",
"processImage": "プロセス画像",
"useCpuNoise": "CPUイズの使用",
"staged": "ステージ",
"perlinNoise": "パーリン・ノイズ(グラデーションノイズ)",
"imageActions": "画像処理",
"gaussianBlur": "ガウスぼかし",
"noiseThreshold": "ノイズの閾値",
"maskBlur": "マスクぼかし",
"seamlessYAxis": "シームレスなY軸",
"optimizedImageToImage": "イメージ to イメージの最適化",
"symmetry": "左右対称",
"seamlessXAxis": "シームレスなX軸",
"sendToCanvas": "キャンバスに送る",
"modelDisabledForTrial": "{{modelName}} を使用した生成はトライアルアカウントではご利用いただけません.アップグレードするには,<LinkComponent>アカウント設定</LinkComponent> にアクセスしてください."
},
"settings": {
"models": "モデル",
@@ -728,16 +779,100 @@
"resetComplete": "WebUIはリセットされました。",
"ui": "ユーザーインターフェイス",
"beta": "ベータ",
"developer": "開発者"
"developer": "開発者",
"antialiasProgressImages": "アンチエイリアスの経過画像",
"enableInformationalPopovers": "情報ポップオーバーを有効にする",
"enableModelDescriptions": "ドロップダウンでモデルの説明を有効にする",
"confirmOnNewSession": "新しいセッションで確認する",
"informationalPopoversDisabled": "情報ポップオーバーが無効になっています",
"informationalPopoversDisabledDesc": "情報ポップオーバーが無効になっています.設定で有効にしてください.",
"enableNSFWChecker": "NSFWチェッカーを有効にする",
"enableInvisibleWatermark": "目に見えない透かしを有効にする",
"enableHighlightFocusedRegions": "重点領域を強調表示",
"clearIntermediatesDesc1": "中間物をクリアすると、キャンバスとコントロールネットの状態がリセットされます.",
"showProgressInViewer": "ビューアで進行状況画像を表示する",
"modelDescriptionsDisabled": "ドロップダウンのモデル説明が無効になっています",
"modelDescriptionsDisabledDesc": "ドロップダウンのモデル説明が無効になっています.設定で有効にしてください.",
"clearIntermediatesDisabled": "中間物をクリアするにはキューが空でなければなりません",
"clearIntermediatesDesc2": "中間画像は生成時に生成される副産物であり、ギャラリーに表示される結果画像とは異なります.中間画像を削除するとディスク容量が解放されます.",
"intermediatesClearedFailed": "中間物をクリアする問題",
"reloadingIn": "リロード中",
"clearIntermediatesDesc3": "ギャラリー画像は削除されません.",
"clearIntermediates": "中間物をクリア",
"clearIntermediatesWithCount_other": "{{count}} 個の中間物をクリア",
"intermediatesCleared_other": "{{count}}個の中間物がクリアされました",
"general": "一般",
"generation": "生成",
"showDetailedInvocationProgress": "進捗状況の詳細を表示"
},
"toast": {
"uploadFailed": "アップロード失敗",
"imageCopied": "画像をコピー",
"imageUploadFailed": "画像のアップロードに失敗しました",
"uploadFailedInvalidUploadDesc": "画像はPNGかJPGである必要があります",
"uploadFailedInvalidUploadDesc": "画像はPNGかJPGかWEBPである必要があります .",
"sentToUpscale": "アップスケーラーに転送しました",
"imageUploaded": "画像をアップロードしました",
"serverError": "サーバーエラー"
"serverError": "サーバーエラー",
"prunedQueue": "キューを破棄",
"workflowDeleted": "ワークフローが削除されました",
"unableToLoadStylePreset": "スタイルプリセットをロードできません",
"loadedWithWarnings": "ワークフローが警告付きでロードされました",
"parameters": "パラメーター",
"parameterSet": "パラメーターが呼び出されました",
"pasteSuccess": "{{destination}} に貼り付けました",
"imagesWillBeAddedTo": "アップロードされた画像はボード {{boardName}} のアセットに追加されます.",
"layerCopiedToClipboard": "レイヤーがクリップボードにコピーされました",
"pasteFailed": "貼り付け失敗",
"imageSavingFailed": "画像保存に失敗しました",
"importSuccessful": "インポートが成功しました",
"problemDownloadingImage": "画像をダウンロードできません",
"modelAddedSimple": "モデルがキューに追加されました",
"uploadFailedInvalidUploadDesc_withCount_other": "PNG、JPEG、または WEBP 画像は最大 1 つにする必要があります.",
"outOfMemoryErrorDesc": "現在の生成設定はシステム容量を超えています.設定を調整してもう一度お試しください.",
"parametersSet": "パラメーターが呼び出されました",
"modelImportCanceled": "モデルのインポートがキャンセルされました",
"problemRetrievingWorkflow": "ワークフローを取得した問題",
"problemUnpublishingWorkflow": "取り消されたワークフローの問題",
"parametersNotSet": "パラメーターが呼び出されていません",
"problemCopyingImage": "画像をコピーできません",
"baseModelChanged": "ベースモデルが変更されました",
"baseModelChangedCleared_other": "{{count}} 個の互換性のないサブモデルをクリア,または無効にしました",
"canceled": "処理がキャンセルされました",
"connected": "サーバーに接続されました",
"linkCopied": "リンクがコピーされました",
"unableToLoadImage": "画像をロードできません",
"unableToLoadImageMetadata": "画像のメタデータをロードできません",
"imageSaved": "画像が保存されました",
"importFailed": "インポートに失敗しました",
"invalidUpload": "無効なアップロードです",
"outOfMemoryError": "メモリ不足エラー",
"parameterSetDesc": "{{parameter}}を呼び出し",
"errorCopied": "エラーがコピーされました",
"sentToCanvas": "キャンバスに送信",
"setControlImage": "コントロール画像としてセット",
"workflowLoaded": "ワークフローがロードされました",
"unableToCopy": "コピーできません",
"unableToCopyDesc": "あなたのブラウザはクリップボードアクセスをサポートしていません.Firefoxユーザーの場合は、以下の手順で修正できる可能性があります. ",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fillは、テキストから画像へ、または画像から画像へ変換機能と互換性がありません.これらのタスクには、他のFLUXモデルをご利用ください.",
"problemUnpublishingWorkflowDescription": "取り下げられたワークフローの問題がありました.もう一度試してください.",
"workflowUnpublished": "ワークフローが取り消されました",
"sessionRef": "セッション: {{sessionId}}",
"somethingWentWrong": "問題が発生しました",
"unableToCopyDesc_theseSteps": "これらのステップ数",
"stylePresetLoaded": "スタイルプリセットがロードされました",
"parameterNotSetDescWithMessage": "{{parameter}}: {{message}}を呼び出せません",
"problemCopyingLayer": "レイヤーをコピーできません",
"problemSavingLayer": "レイヤー保存ができません",
"setNodeField": "ノードフィールドとしてセット",
"layerSavedToAssets": "レイヤーがアセットに保存されました",
"outOfMemoryErrorDescLocal": "OOM を削減するには、<LinkComponent>低 VRAM ガイド</LinkComponent> に従ってください.",
"parameterNotSet": "パラメーターが呼び出されていません",
"addedToBoard": "{{name}} 個の資産をボードに追加しました",
"addedToUncategorized": "$t(boards.uncategorized)個のアセットがボードに追加されました",
"problemDeletingWorkflow": "ワークフローが削除された問題",
"imageNotLoadedDesc": "画像を見つけられません",
"parameterNotSetDesc": "{{parameter}}を呼び出せません",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4oは,テキストから画像への生成と画像から画像への生成のみをサポートしています.インペインティングおよび,アウトペインティングタスクには他のモデルを使用してください."
},
"accessibility": {
"invokeProgressBar": "進捗バー",
@@ -862,7 +997,8 @@
"batchSize": "バッチサイズ",
"retryFailed": "項目のリトライに問題があります",
"cancelAllExceptCurrentQueueItemAlertDialog": "現在の項目を除くすべてのキュー項目をキャンセルすると、保留中の項目は停止しますが、進行中の項目は完了します。",
"retrySucceeded": "項目がリトライされました"
"retrySucceeded": "項目がリトライされました",
"credits": "クレジット"
},
"models": {
"noMatchingModels": "一致するモデルがありません",
@@ -1114,22 +1250,42 @@
]
},
"regionalGuidanceAndReferenceImage": {
"heading": "領域ガイダンスと領域参照画像"
"heading": "領域ガイダンスと領域参照画像",
"paragraphs": [
"領域ガイダンスの場合は,ブラシを使用して,グローバルプロンプトの要素が表示される場所をガイドします.",
"領域参照画像の場合は,ブラシを使用して特定の領域に参照画像を適用します."
]
},
"regionalReferenceImage": {
"heading": "領域参照画像"
"heading": "領域参照画像",
"paragraphs": [
"特定の領域に参照画像を適用するためのブラシ."
]
},
"paramScheduler": {
"heading": "スケジューラー"
"heading": "スケジューラー",
"paragraphs": [
"スケジューラーは生成中のプロセスで使用されます.",
"各スケジューラは、画像にノイズを反復的に追加する方法や、モデルの出力に基づいてサンプルを更新する方法を定義します."
]
},
"regionalGuidance": {
"heading": "領域ガイダンス"
"heading": "領域ガイダンス",
"paragraphs": [
"グローバルプロンプトの要素が表示される場所をガイドするブラシ."
]
},
"rasterLayer": {
"heading": "ラスターレイヤー"
"heading": "ラスターレイヤー",
"paragraphs": [
"画像生成中に使用される,キャンバスのピクセルベースのコンテンツ."
]
},
"globalReferenceImage": {
"heading": "全域参照画像"
"heading": "全域参照画像",
"paragraphs": [
"参照画像を適用して,生成全体に影響を及ぼします."
]
},
"paramUpscaleMethod": {
"heading": "アップスケール手法"
@@ -1153,7 +1309,10 @@
"heading": "スケジューラー"
},
"compositingCoherenceMode": {
"heading": "モード"
"heading": "モード",
"paragraphs": [
"新しく生成されたマスク領域と,一貫性のある画像を作成するために使用される方法."
]
},
"paramModel": {
"heading": "モデル"
@@ -1165,7 +1324,10 @@
"heading": "ステップ"
},
"ipAdapterMethod": {
"heading": "モード"
"heading": "モード",
"paragraphs": [
"モードは参照画像が生成プロセスをどのようにガイドするかを定義します."
]
},
"paramSeed": {
"heading": "シード"
@@ -1174,7 +1336,10 @@
"heading": "生成回数"
},
"controlNet": {
"heading": "ControlNet"
"heading": "ControlNet",
"paragraphs": [
"コントロールネットは生成プロセスへのガイダンスを提供し,選択したモデルに応じて制御された構成,構造,またはスタイルを持つ画像の作成に役立ちます."
]
},
"paramWidth": {
"heading": "幅"
@@ -1189,7 +1354,109 @@
"heading": "Downscale"
},
"controlNetWeight": {
"heading": "重み"
"heading": "重み",
"paragraphs": [
"レイヤーが生成プロセスにどの程度影響を与えるかを調整します",
"• 高いウエイト (.75-2): 最終結果にさらに大きな影響を及ぼします.",
"• 低いウエイト (0-.75): 最終結果への影響が小さくなります."
]
},
"paramNegativeConditioning": {
"paragraphs": [
"生成プロセスでは、ネガティブプロンプトに含まれる概念を回避します.これを使用して、出力から特定の性質やオブジェクトを除外します.",
"強制された構文と埋め込みをサポート."
],
"heading": "ネガティブプロンプト"
},
"clipSkip": {
"paragraphs": [
"スキップする CLIP モデルのレイヤー数.",
"特定のモデルは、CLIP Skip と併用するとより適しています."
],
"heading": "クリップスキップ"
},
"compositingMaskBlur": {
"heading": "マスクぼかし",
"paragraphs": [
"マスクのぼかし半径."
]
},
"paramPositiveConditioning": {
"paragraphs": [
"生成プロセスをガイドします.任意の単語やフレーズを使用できます.",
"強制とダイナミックプロンプトの構文と埋め込み."
],
"heading": "ポジティブプロンプト"
},
"compositingMaskAdjustments": {
"heading": "マスク調整",
"paragraphs": [
"マスクを調整する."
]
},
"compositingCoherenceMinDenoise": {
"paragraphs": [
"コヒーレンスモードの最小ノイズ除去強度",
"インペインティングまたはアウトペインティング時のコヒーレンス領域の最小ノイズ除去強度"
],
"heading": "最小ノイズ除去"
},
"compositingCoherencePass": {
"paragraphs": [
"2 回目のノイズ除去は,インペイント/アウトペイントされた画像の合成に役立ちます."
],
"heading": "コヒーレンスパス"
},
"controlNetBeginEnd": {
"paragraphs": [
"この設定は,ノイズ除去 (生成) プロセスのどの部分にこのレイヤーからのガイダンスが組み込まれるかを決定します.",
"• 開始ステップ (%): 生成プロセス中にこのレイヤーからのガイダンスの適用を開始するタイミングを指定します.",
"• 終了ステップ (%): このレイヤーのガイダンスの適用を停止し,モデルやその他の設定からの一般的なガイダンスを元に戻すタイミングを指定します."
],
"heading": "開始/終了ステップの割合"
},
"compositingCoherenceEdgeSize": {
"heading": "エッジサイズ",
"paragraphs": [
"コヒーレンスパスのエッジサイズ."
]
},
"compositingBlurMethod": {
"paragraphs": [
"マスクされた領域に適用されるぼかし方法."
],
"heading": "ぼかし方法"
},
"inpainting": {
"heading": "インペインティング",
"paragraphs": [
"ノイズ除去の強度に応じて,変更する領域を制御します."
]
},
"dynamicPrompts": {
"heading": "ダイナミックプロンプト",
"paragraphs": [
"ダイナミック プロンプトは,単一のプロンプトを複数のプロンプトに解析します.",
"基本的な構文は「{赤|緑|青}のボール」です.これにより,「赤いボール」「緑のボール」「青いボール」という3つのプロンプトが生成されます."
]
},
"controlNetResizeMode": {
"heading": "リサイズモード",
"paragraphs": [
"コントロールアダプタの入力画像サイズを出力生成サイズに適合させるメソッド."
]
},
"controlNetProcessor": {
"heading": "プロセッサー",
"paragraphs": [
"入力画像を処理する生成プロセスをガイドするメソッド.プロセッサによって,生成される画像に異なる効果やスタイルが与えられます。"
]
},
"controlNetControlMode": {
"heading": "コントロールモード",
"paragraphs": [
"プロンプトまたは コントロールネットのいずれかを重視します."
]
}
},
"accordions": {
@@ -1340,7 +1607,18 @@
"scheduler": "スケジューラー",
"loading": "ロード中...",
"steps": "ステップ",
"refiner": "Refiner"
"refiner": "Refiner",
"negStylePrompt": "ネガティブスタイルプロンプト",
"noModelsAvailable": "利用できるモデルがありません",
"posStylePrompt": "ポジティブスタイルプロンプト",
"cfgScale": "CFGスケール",
"concatPromptStyle": "リンキングプロンプトとスタイル",
"freePromptStyle": "手動スタイルプロンプト",
"posAestheticScore": "ポジティブ美的スコア",
"refinerSteps": "リファイナーステップ",
"refinerStart": "リファイナースタート",
"refinermodel": "リファイナーモデル",
"negAestheticScore": "ネガティブ美的スコア"
},
"modelCache": {
"clear": "モデルキャッシュを消去",
@@ -1370,5 +1648,20 @@
"fatal": "Fatal",
"warn": "Warn"
}
},
"dynamicPrompts": {
"promptsPreview": "プロンプトプレビュー",
"seedBehaviour": {
"label": "シードの挙動",
"perPromptLabel": "画像ごとのシード",
"perIterationLabel": "いてレーションごとのシード",
"perPromptDesc": "それぞれの画像に足して別のシードを使う",
"perIterationDesc": "それぞれのいてレーションに別のシードを使う"
},
"showDynamicPrompts": "ダイナミックプロンプトを表示する",
"promptsToGenerate": "生成するプロンプト",
"dynamicPrompts": "ダイナミックプロンプト",
"loading": "ダイナミックプロンプトを生成...",
"maxPrompts": "最大プロンプト"
}
}

View File

@@ -240,7 +240,15 @@
"error_withCount_other": "{{count}} lỗi",
"value": "Giá Trị",
"label": "Nhãn Tên",
"systemInformation": "Thông Tin Hệ Thống"
"systemInformation": "Thông Tin Hệ Thống",
"model_withCount_other": "{{count}} model",
"noOptions": "Không Có Lựa Chọn",
"noMatches": "Không Có Mục Phù Hợp",
"search": "Tìm Kiếm",
"clear": "Dọn Dẹp",
"compactView": "Chế Độ Xem Gọn",
"fullView": "Chế Độ Xem Đầy Đủ",
"options_withCount_other": "{{count}} thiết lập"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -321,7 +329,8 @@
"confirm": "Đồng Ý",
"retrySucceeded": "Mục Đã Thử Lại",
"retryFailed": "Có Vấn Đề Khi Thử Lại Mục",
"retryItem": "Thử Lại Mục"
"retryItem": "Thử Lại Mục",
"credits": "Nguồn"
},
"hotkeys": {
"canvas": {
@@ -775,7 +784,14 @@
"fluxRedux": "FLUX Redux",
"sigLip": "SigLIP",
"llavaOnevision": "LLaVA OneVision",
"fileSize": "Kích Thước Tệp"
"fileSize": "Kích Thước Tệp",
"filterModels": "Lọc Model",
"modelPickerFallbackNoModelsInstalled2": "Nhấp vào <LinkComponent>Trình Quản Lý Model</LinkComponent> để tải.",
"modelPickerFallbackNoModelsInstalled": "Không Có Sẵn Model.",
"manageModels": "Quản Lý Model",
"hfTokenReset": "Làm Mới HF Token",
"relatedModels": "Model Liên Quan",
"showOnlyRelatedModels": "Liên Quan"
},
"metadata": {
"guidance": "Hướng Dẫn",
@@ -1518,7 +1534,8 @@
"modelIncompatibleBboxWidth": "Chiều rộng hộp giới hạn là {{width}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
"modelIncompatibleBboxHeight": "Chiều dài hộp giới hạn là {{height}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
"modelIncompatibleScaledBboxHeight": "Chiều dài hộp giới hạn theo tỉ lệ là {{height}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
"modelIncompatibleScaledBboxWidth": "Chiều rộng hộp giới hạn theo tỉ lệ là {{width}} nhưng {{model}} yêu cầu bội số của {{multiple}}"
"modelIncompatibleScaledBboxWidth": "Chiều rộng hộp giới hạn theo tỉ lệ là {{width}} nhưng {{model}} yêu cầu bội số của {{multiple}}",
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần thiết lập tài khoản để nâng cấp."
},
"cfgScale": "Thang CFG",
"useSeed": "Dùng Hạt Giống",
@@ -1581,7 +1598,8 @@
"usePrompt": "Dùng Lệnh",
"upscaling": "Upscale",
"tileSize": "Kích Thước Khối",
"disabledNoRasterContent": "Đã Tắt (Không Có Nội Dung Dạng Raster)"
"disabledNoRasterContent": "Đã Tắt (Không Có Nội Dung Dạng Raster)",
"modelDisabledForTrial": "Tạo sinh với {{modelName}} là không thể với tài khoản trial. Vào phần <LinkComponent>thiết lập tài khoản</LinkComponent> để nâng cấp."
},
"dynamicPrompts": {
"seedBehaviour": {
@@ -1699,12 +1717,16 @@
"fitBboxToLayers": "Xếp Vừa Hộp Giới Hạn Vào Layer",
"ipAdapterMethod": {
"full": "Phong Cách Và Thành Phần",
"style": "Chỉ Lấy Phong Cách",
"style": "Phong Cách (Đơn Giản)",
"composition": "Chỉ Lấy Thành Phần",
"ipAdapterMethod": "Cách Thức",
"compositionDesc": "Áp dụng cách trình bày và bỏ qua phong cách mẫu.",
"fullDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) & thành phần (cách trình bày).",
"styleDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) và bỏ qua cách trình bày."
"styleDesc": "Áp dụng phong cách trực quan (màu, cấu tạo) và bỏ qua cách trình bày. Tên trước đây là Chỉ Lấy Phong Cách.",
"styleStrong": "Phong Cách (Mạnh Mẽ)",
"styleStrongDesc": "Áp dụng cách trình bày mạnh mẽ, với một chút giảm nhẹ ảnh hưởng lên thành phần.",
"stylePrecise": "Phong Cách (Chính Xác)",
"stylePreciseDesc": "Áp dụng cách trình bày chính xác, loại bỏ các chủ thể ảnh hưởng."
},
"deletePrompt": "Xoá Lệnh",
"rasterLayer": "Layer Dạng Raster",
@@ -2226,7 +2248,8 @@
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill không tương tích với Từ Ngữ Sang Hình Ảnh và Hình Ảnh Sang Hình Ảnh. Dùng model FLUX khác cho các tính năng này.",
"problemUnpublishingWorkflowDescription": "Có vấn đề khi ngừng đăng tải workflow. Vui lòng thử lại sau.",
"workflowUnpublished": "Workflow Đã Được Ngừng Đăng Tải",
"problemUnpublishingWorkflow": "Có Vấn Đề Khi Ngừng Đăng Tải Workflow"
"problemUnpublishingWorkflow": "Có Vấn Đề Khi Ngừng Đăng Tải Workflow",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o chỉ hỗ trợ Từ Ngữ Sang Hình Ảnh và Hình Ảnh Sang Hình Ảnh. Hãy dùng model khác cho các tác vụ Inpaint và Outpaint."
},
"ui": {
"tabs": {
@@ -2408,8 +2431,8 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Workflow: Hỗ trợ xâu ký tự thả xuống tùy chỉnh trong Trình Tạo Vùng Nhập.",
"FLUX: Hỗ trợ FLUX Fill trong Workflow và Canvas."
"Nvidia 50xx GPUs: Invoke sử dụng PyTorch 2.7.0, thứ tối quan trọng cho những GPU trên.",
"Mối Quan Hệ Model: Kết nối LoRA với model chính, và LoRA đó sẽ được hiển thị đầu danh sách."
]
},
"upsell": {

View File

@@ -11,6 +11,7 @@ import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildC
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
@@ -54,6 +55,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'imagen4':
return await buildImagen4Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
default:

View File

@@ -29,7 +29,7 @@ export type AppFeature =
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll'
| 'chatGPT4oModels';
| 'chatGPT4oHigh';
/**
* A disable-able Stable Diffusion feature
*/

View File

@@ -83,7 +83,7 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
}
} else {
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled) {
if (isClientSideUploadEnabled && files.length > 1) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
} else {
imageDTOs = await uploadImages(

View File

@@ -0,0 +1,92 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import type { GroupBase } from 'chakra-react-select';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { useTranslation } from 'react-i18next';
import type { AnyModelConfig } from 'services/api/types';
import { useGroupedModelCombobox } from './useGroupedModelCombobox';
import { useRelatedModelKeys } from './useRelatedModelKeys';
import { useSelectedModelKeys } from './useSelectedModelKeys';
type UseRelatedGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
isLoading?: boolean;
groupByType?: boolean;
};
// Custom hook to overlay the grouped model combobox with related models on top!
// Cleaner than hooking into useGroupedModelCombobox with a flag to enable/disable the related models
// Also allows for related models to be shown conditionally with some pretty simple logic if it ends up as a config flag.
type UseRelatedGroupedModelComboboxReturn = {
value: ComboboxOption | undefined | null;
options: GroupBase<ComboboxOption>[];
onChange: ComboboxOnChange;
placeholder: string;
noOptionsMessage: () => string;
};
export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
modelConfigs,
selectedModel,
onChange,
isLoading = false,
getIsDisabled,
groupByType,
}: UseRelatedGroupedModelComboboxArg<T>): UseRelatedGroupedModelComboboxReturn {
const { t } = useTranslation();
const selectedKeys = useSelectedModelKeys();
const relatedKeys = useRelatedModelKeys(selectedKeys);
// Base grouped options
const base = useGroupedModelCombobox({
modelConfigs,
selectedModel,
onChange,
getIsDisabled,
isLoading,
groupByType,
});
// If no related models selected, just return base
if (relatedKeys.size === 0) {
return base;
}
const relatedOptions: ComboboxOption[] = [];
const updatedGroups: GroupBase<ComboboxOption>[] = [];
for (const group of base.options) {
const remainingOptions: ComboboxOption[] = [];
for (const option of group.options) {
if (relatedKeys.has(option.value)) {
relatedOptions.push({ ...option, label: `* ${option.label}` });
} else {
remainingOptions.push(option);
}
}
if (remainingOptions.length > 0) {
updatedGroups.push({
label: group.label,
options: remainingOptions,
});
}
}
const finalOptions: GroupBase<ComboboxOption>[] =
relatedOptions.length > 0
? [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups]
: updatedGroups;
return {
...base,
options: finalOptions,
};
}

View File

@@ -0,0 +1,14 @@
import { useMemo } from 'react';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
/**
* Fetches related model keys for a given set of selected model keys.
* Returns a Set<string> for fast lookup.
*/
export const useRelatedModelKeys = (selectedKeys: Set<string>) => {
const { data: related = [] } = useGetRelatedModelIdsBatchQuery([...selectedKeys], {
skip: selectedKeys.size === 0,
});
return useMemo(() => new Set(related), [related]);
};

View File

@@ -0,0 +1,34 @@
import { useAppSelector } from 'app/store/storeHooks';
/**
* Gathers all currently selected model keys from parameters and loras.
* This includes the main model, VAE, refiner model, controlnet, and loras.
*/
export const useSelectedModelKeys = () => {
return useAppSelector((state) => {
const keys = new Set<string>();
const main = state.params.model;
const vae = state.params.vae;
const refiner = state.params.refinerModel;
const controlnet = state.params.controlLora;
const loras = state.loras.loras.map((l) => l.model);
if (main) {
keys.add(main.key);
}
if (vae) {
keys.add(vae.key);
}
if (refiner) {
keys.add(refiner.key);
}
if (controlnet) {
keys.add(controlnet.key);
}
for (const lora of loras) {
keys.add(lora.key);
}
return keys;
});
};

View File

@@ -30,6 +30,16 @@ export const IPAdapterMethod = memo(({ method, onChange }: Props) => {
value: 'style',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.styleDesc') : undefined,
},
{
label: t('controlLayers.ipAdapterMethod.styleStrong'),
value: 'style_strong',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.styleStrongDesc') : undefined,
},
{
label: t('controlLayers.ipAdapterMethod.stylePrecise'),
value: 'style_precise',
description: shouldShowModelDescriptions ? t('controlLayers.ipAdapterMethod.stylePreciseDesc') : undefined,
},
{
label: t('controlLayers.ipAdapterMethod.composition'),
value: 'composition',

View File

@@ -3,6 +3,7 @@ import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsImagen4,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import type { CanvasEntityType } from 'features/controlLayers/store/types';
@@ -14,24 +15,25 @@ export const useIsEntityTypeEnabled = (entityType: CanvasEntityType) => {
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isEntityTypeEnabled = useMemo<boolean>(() => {
switch (entityType) {
case 'reference_image':
return !isSD3 && !isCogView4 && !isImagen3;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4;
case 'regional_guidance':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'control_layer':
return !isSD3 && !isCogView4 && !isImagen3 && !isChatGPT4o;
return !isSD3 && !isCogView4 && !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'inpaint_mask':
return !isImagen3 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isChatGPT4o;
case 'raster_layer':
return !isImagen3 && !isChatGPT4o;
return !isImagen3 && !isImagen4 && !isChatGPT4o;
default:
assert<Equals<typeof entityType, never>>(false);
}
}, [entityType, isSD3, isCogView4, isImagen3, isChatGPT4o]);
}, [entityType, isSD3, isCogView4, isImagen3, isImagen4, isChatGPT4o]);
return isEntityTypeEnabled;
};

View File

@@ -8,6 +8,7 @@ import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect, Tool } from 'features/controlLayers/store/types';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import Konva from 'konva';
import { noop } from 'lodash-es';
import { atom } from 'nanostores';
@@ -235,7 +236,7 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
if (tool !== 'bbox') {
return NO_ANCHORS;
}
if (model?.base === 'imagen3' || model?.base === 'chatgpt-4o') {
if (model?.base && API_BASE_MODELS.includes(model.base)) {
// The bbox is not resizable in these modes
return NO_ANCHORS;
}

View File

@@ -32,6 +32,7 @@ import {
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
import { isEqual, merge } from 'lodash-es';
@@ -68,7 +69,7 @@ import type {
IPMethodV2,
T2IAdapterConfig,
} from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagen3AspectRatioID, isRenderableEntity } from './types';
import { getEntityIdentifier, isChatGPT4oAspectRatioID, isImagenAspectRatioID, isRenderableEntity } from './types';
import {
converters,
getControlLayerState,
@@ -1236,7 +1237,10 @@ export const canvasSlice = createSlice({
state.bbox.aspectRatio.id = id;
if (id === 'Free') {
state.bbox.aspectRatio.isLocked = false;
} else if (state.bbox.modelBase === 'imagen3' && isImagen3AspectRatioID(id)) {
} else if (
(state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'imagen4') &&
isImagenAspectRatioID(id)
) {
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '16:9') {
state.bbox.rect.width = 1408;
@@ -1742,7 +1746,7 @@ export const canvasSlice = createSlice({
const base = model?.base;
if (isMainModelBase(base) && state.bbox.modelBase !== base) {
state.bbox.modelBase = base;
if (base === 'imagen3' || base === 'chatgpt-4o') {
if (API_BASE_MODELS.includes(base)) {
state.bbox.aspectRatio.isLocked = true;
state.bbox.aspectRatio.value = 1;
state.bbox.aspectRatio.id = '1:1';
@@ -1881,7 +1885,7 @@ export const canvasPersistConfig: PersistConfig<CanvasState> = {
};
const syncScaledSize = (state: CanvasState) => {
if (state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'chatgpt-4o') {
if (API_BASE_MODELS.includes(state.bbox.modelBase)) {
// Imagen3 has fixed sizes. Scaled bbox is not supported.
return;
}

View File

@@ -381,6 +381,7 @@ export const selectIsFLUX = createParamsSelector((params) => params.model?.base
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
export const selectIsChatGTP4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectModel = createParamsSelector((params) => params.model);

View File

@@ -50,7 +50,7 @@ const zCLIPVisionModelV2 = z.enum(['ViT-H', 'ViT-G', 'ViT-L']);
export type CLIPVisionModelV2 = z.infer<typeof zCLIPVisionModelV2>;
export const isCLIPVisionModelV2 = (v: unknown): v is CLIPVisionModelV2 => zCLIPVisionModelV2.safeParse(v).success;
const zIPMethodV2 = z.enum(['full', 'style', 'composition']);
const zIPMethodV2 = z.enum(['full', 'style', 'composition', 'style_strong', 'style_precise']);
export type IPMethodV2 = z.infer<typeof zIPMethodV2>;
export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safeParse(v).success;
@@ -406,7 +406,7 @@ export type StagingAreaImage = {
export const zAspectRatioID = z.enum(['Free', '16:9', '3:2', '4:3', '1:1', '3:4', '2:3', '9:16']);
export const zImagen3AspectRatioID = z.enum(['16:9', '4:3', '1:1', '3:4', '9:16']);
export const isImagen3AspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
export const isImagenAspectRatioID = (v: unknown): v is z.infer<typeof zImagen3AspectRatioID> =>
zImagen3AspectRatioID.safeParse(v).success;
export const zChatGPT4oAspectRatioID = z.enum(['3:2', '1:1', '2:3']);

View File

@@ -109,7 +109,7 @@ export const FullscreenDropzone = memo(() => {
const autoAddBoardId = selectAutoAddBoardId(getState());
if (isClientSideUploadEnabled) {
if (isClientSideUploadEnabled && files.length > 1) {
for (const [i, file] of files.entries()) {
await clientSideUpload(file, i);
}

View File

@@ -3,7 +3,7 @@ import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useRelatedGroupedModelCombobox } from 'common/hooks/useRelatedGroupedModelCombobox';
import { loraAdded, selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
@@ -37,7 +37,7 @@ const LoRASelect = () => {
[dispatch]
);
const { options, onChange } = useGroupedModelCombobox({
const { options, onChange } = useRelatedGroupedModelCombobox({
modelConfigs,
getIsDisabled,
onChange: _onChange,

View File

@@ -17,6 +17,7 @@ export const BASE_COLOR_MAP: Record<BaseModelType, string> = {
flux: 'gold',
cogview4: 'red',
imagen3: 'pink',
imagen4: 'pink',
'chatgpt-4o': 'pink',
};

View File

@@ -11,6 +11,7 @@ import type { AnyModelConfig } from 'services/api/types';
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
import { ModelAttrView } from './ModelAttrView';
import { RelatedModels } from './RelatedModels';
type Props = {
modelConfig: AnyModelConfig;
@@ -83,6 +84,9 @@ export const ModelView = memo(({ modelConfig }: Props) => {
)}
</Box>
)}
<Box maxH="200px" overflowY="auto" layerStyle="second" borderRadius="base" p={4}>
<RelatedModels modelConfig={modelConfig} />
</Box>
</Flex>
</Flex>
);

View File

@@ -0,0 +1,351 @@
/**
* RelatedModels.tsx
*
* Panel for managing and displaying model-to-model relationships.
*
* Allows adding/removing bidirectional links between models, organized visually
* with color-coded tags, dividers between types, and sorted dropdown selection.
*/
import {
Box,
Button,
Combobox,
Divider,
Flex,
FormControl,
FormErrorMessage,
FormLabel,
Tag,
TagCloseButton,
TagLabel,
Tooltip,
} from '@invoke-ai/ui-library';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
import {
useAddModelRelationshipMutation,
useGetRelatedModelIdsQuery,
useRemoveModelRelationshipMutation,
} from 'services/api/endpoints/modelRelationships';
import { useGetModelConfigsQuery } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
type Props = {
modelConfig: AnyModelConfig;
};
type ModelGroup = {
type: string;
label: string;
color: string;
models: AnyModelConfig[];
};
// Determines if two models are compatible for relationship linking based on their base type.
//
// Models with a base of 'any' are considered universally compatible.
// This is a known flaw: 'any'-based links may allow relationships that are
// meaningless in practice and could bloat the database over time.
//
// TODO: In the future, refine this logic to more strictly validate
// relationships based on model types or actual usage patterns.
const isBaseCompatible = (a: AnyModelConfig, b: AnyModelConfig): boolean => {
if (a.base === 'any' || b.base === 'any') {
return true;
}
return a.base === b.base;
};
// Drying out and setting up for potential export
// Defines custom tag colors for model types in the UI.
//
// The default UI color scheme (mostly grey and orange) felt too flat,
// so this mapping provides a slightly more expressive color flow.
//
// Note: This is purely aesthetic. Safe to remove if project preferences change.
const getModelTagColor = (type: string): string => {
switch (type) {
case 'main':
case 'checkpoint':
return 'orange';
case 'lora':
case 'lycoris':
return 'purple';
case 'embedding':
case 'embedding_file':
return 'teal';
case 'vae':
return 'blue';
case 'controlnet':
case 'ip_adapter':
case 't2i_adapter':
return 'cyan';
case 'onnx':
case 'bnb_quantized_int8b':
case 'bnb_quantized_nf4b':
case 'gguf_quantized':
return 'pink';
case 't5_encoder':
case 'clip_embed':
case 'clip_vision':
case 'siglip':
return 'green';
default:
return 'base';
}
};
// Extracts model type from a label string (e.g., 'Base/LoRA' → 'lora')
const getTypeFromLabel = (label: string): string => label.split('/')[1]?.trim().toLowerCase() || '';
export const RelatedModels = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();
const [addModelRelationship, { isLoading: isAdding }] = useAddModelRelationshipMutation();
const [removeModelRelationship, { isLoading: isRemoving }] = useRemoveModelRelationshipMutation();
const isLoading = isAdding || isRemoving;
const [selectedKey, setSelectedKey] = useState('');
const { data: modelConfigs } = useGetModelConfigsQuery();
const { data: relatedModels = [] } = useGetRelatedModelIdsQuery(modelConfig.key);
const relatedIDs = useMemo(() => new Set(relatedModels), [relatedModels]);
// Defines model types to prioritize first in UI sorting.
// Types not listed here will appear afterward in default order.
const MODEL_TYPE_PRIORITY = useMemo(() => ['main', 'lora'], []);
// Defines disallowed connection types.
const DISALLOWED_RELATIONSHIPS = useMemo(
() =>
new Set([
'main|main',
'vae|vae',
'controlnet|controlnet',
'clip_vision|clip_vision',
'control_lora|control_lora',
'clip_embed|clip_embed',
'spandrel_image_to_image|spandrel_image_to_image',
'siglip|siglip',
'flux_redux|flux_redux',
]),
[]
);
// Drying out sorting
const prioritySort = useCallback(
(a: string, b: string): number => {
const aIndex = MODEL_TYPE_PRIORITY.indexOf(a);
const bIndex = MODEL_TYPE_PRIORITY.indexOf(b);
const aScore = aIndex === -1 ? 99 : aIndex;
const bScore = bIndex === -1 ? 99 : bIndex;
return aScore - bScore;
},
[MODEL_TYPE_PRIORITY]
);
//Get all modelConfigs that are not already related to the current model.
const availableModels = useMemo(() => {
if (!modelConfigs) {
return [];
}
const isDisallowedRelationship = (a: string, b: string): boolean =>
DISALLOWED_RELATIONSHIPS.has(`${a}|${b}`) || DISALLOWED_RELATIONSHIPS.has(`${b}|${a}`);
return Object.values(modelConfigs.entities).filter(
(m): m is AnyModelConfig =>
!!m &&
m.key !== modelConfig.key &&
!relatedIDs.has(m.key) &&
isBaseCompatible(modelConfig, m) &&
!isDisallowedRelationship(modelConfig.type, m.type)
);
}, [modelConfigs, modelConfig, relatedIDs, DISALLOWED_RELATIONSHIPS]);
// Tracks validation errors for current input (e.g., duplicate key or no selection).
const errors = useMemo(() => {
const errs: string[] = [];
if (!selectedKey) {
return errs;
}
if (relatedIDs.has(selectedKey)) {
errs.push('Item already promoted');
}
return errs;
}, [selectedKey, relatedIDs]);
// Handles linking a selected model to the current one via API.
const handleAdd = useCallback(async () => {
const target = availableModels.find((m) => m.key === selectedKey);
if (!target) {
return;
}
setSelectedKey('');
await addModelRelationship({ model_key_1: modelConfig.key, model_key_2: target.key });
}, [modelConfig, availableModels, addModelRelationship, selectedKey]);
const {
options,
onChange: comboboxOnChange,
placeholder,
noOptionsMessage,
} = useGroupedModelCombobox({
modelConfigs: availableModels,
selectedModel: null,
onChange: (model) => {
if (!model) {
return;
}
setSelectedKey(model.key);
},
groupByType: true,
});
// Finds the selected model's combobox option to control current dropdown state.
const selectedOption = useMemo(() => {
return options.flatMap((group) => group.options).find((o) => o.value === selectedKey) ?? null;
}, [selectedKey, options]);
const sortedOptions = useMemo(() => {
return [...options].sort((a, b) => prioritySort(getTypeFromLabel(a.label ?? ''), getTypeFromLabel(b.label ?? '')));
}, [options, prioritySort]);
const groupedModelConfigs = useMemo(() => {
if (!modelConfigs) {
return [];
}
const models = [...relatedModels].map((id) => modelConfigs.entities[id]).filter((m): m is AnyModelConfig => !!m);
models.sort((a, b) => prioritySort(a.type, b.type) || a.type.localeCompare(b.type) || a.name.localeCompare(b.name));
const groupsMap = new Map<string, ModelGroup>();
for (const model of models) {
if (!groupsMap.has(model.type)) {
groupsMap.set(model.type, {
type: model.type,
label: model.type.replace(/_/g, ' ').replace(/\b\w/g, (c) => c.toUpperCase()),
color: getModelTagColor(model.type),
models: [],
});
}
groupsMap.get(model.type)!.models.push(model);
}
return Array.from(groupsMap.values());
}, [modelConfigs, relatedModels, prioritySort]);
const removeHandlers = useMemo(() => {
const map = new Map<string, () => void>();
if (!modelConfigs) {
return map;
}
for (const group of groupedModelConfigs) {
for (const model of group.models) {
map.set(model.key, () => {
const target = modelConfigs.entities[model.key];
if (!target) {
return;
}
removeModelRelationship({
model_key_1: modelConfig.key,
model_key_2: model.key,
}).unwrap();
});
}
}
return map;
}, [groupedModelConfigs, modelConfig.key, modelConfigs, removeModelRelationship]);
return (
<Flex direction="column" gap="5" w="full">
<FormLabel>{t('modelManager.relatedModels')}</FormLabel>
<FormControl isInvalid={errors.length > 0}>
<Flex gap="3" alignItems="center" w="full">
<Combobox
value={selectedOption}
placeholder={placeholder}
options={sortedOptions}
onChange={comboboxOnChange}
noOptionsMessage={noOptionsMessage}
/>
<Button
leftIcon={<PiPlusBold />}
size="sm"
onClick={handleAdd}
isDisabled={!selectedKey || errors.length > 0}
isLoading={isLoading}
>
{t('common.add')}
</Button>
</Flex>
{errors.map((error) => (
<FormErrorMessage key={error}>{error}</FormErrorMessage>
))}
</FormControl>
<Box>
<Flex gap="2" flexWrap="wrap">
{groupedModelConfigs.map((group, i) => {
const withDivider = i < groupedModelConfigs.length - 1;
return (
<Box key={group.type} mb={4}>
<ModelTagGroup group={group} isLoading={isLoading} removeHandlers={removeHandlers} />
{withDivider && <Divider my={4} opacity={0.3} />}
</Box>
);
})}
</Flex>
</Box>
</Flex>
);
});
const ModelTag = ({
model,
onRemove,
isLoading,
}: {
model: AnyModelConfig;
onRemove: () => void;
isLoading: boolean;
}) => {
return (
<Tag py={2} px={4} bg={`${getModelTagColor(model.type)}.700`}>
<Tooltip label={`${model.type}: ${model.name}`} hasArrow>
<TagLabel maxWidth="50px" overflow="hidden" textOverflow="ellipsis" whiteSpace="nowrap">
{model.name}
</TagLabel>
</Tooltip>
<TagCloseButton onClick={onRemove} isDisabled={isLoading} />
</Tag>
);
};
const ModelTagGroup = ({
group,
isLoading,
removeHandlers,
}: {
group: ModelGroup;
isLoading: boolean;
removeHandlers: Map<string, () => void>;
}) => {
return (
<Flex gap="2" flexWrap="wrap" alignItems="center">
{group.models.map((model) => (
<ModelTag key={model.key} model={model} onRemove={removeHandlers.get(model.key)!} isLoading={isLoading} />
))}
</Flex>
);
};
RelatedModels.displayName = 'RelatedModels';

View File

@@ -7,6 +7,7 @@ import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flo
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { ImageGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageGeneratorFieldComponent';
import Imagen3ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen3ModelFieldInputComponent';
import Imagen4ModelFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/Imagen4ModelFieldInputComponent';
import { IntegerFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
@@ -63,6 +64,8 @@ import {
isImageGeneratorFieldInputTemplate,
isImagen3ModelFieldInputInstance,
isImagen3ModelFieldInputTemplate,
isImagen4ModelFieldInputInstance,
isImagen4ModelFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
@@ -407,6 +410,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <Imagen3ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isImagen4ModelFieldInputTemplate(template)) {
if (!isImagen4ModelFieldInputInstance(field)) {
return null;
}
return <Imagen4ModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isChatGPT4oModelFieldInputTemplate(template)) {
if (!isChatGPT4oModelFieldInputInstance(field)) {
return null;

View File

@@ -0,0 +1,46 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { ModelFieldCombobox } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelFieldCombobox';
import { fieldImagen4ModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useImagen4Models } from 'services/api/hooks/modelsByType';
import type { ApiModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const Imagen4ModelFieldInputComponent = (
props: FieldComponentProps<Imagen4ModelFieldInputInstance, Imagen4ModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useImagen4Models();
const onChange = useCallback(
(value: ApiModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldImagen4ModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<ModelFieldCombobox
value={field.value}
modelConfigs={modelConfigs}
isLoadingConfigs={isLoading}
onChange={onChange}
required={props.fieldTemplate.required}
/>
);
};
export default memo(Imagen4ModelFieldInputComponent);

View File

@@ -123,6 +123,7 @@ const NODE_TYPE_PUBLISH_DENYLIST = [
'metadata_to_t2i_adapters',
'google_imagen3_generate',
'google_imagen3_edit',
'google_imagen4_generate',
'chatgpt_create_image',
'chatgpt_edit_image',
];

View File

@@ -40,6 +40,7 @@ import type {
ImageFieldValue,
ImageGeneratorFieldValue,
Imagen3ModelFieldValue,
Imagen4ModelFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IntegerGeneratorFieldValue,
@@ -80,6 +81,7 @@ import {
zImageFieldValue,
zImageGeneratorFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
@@ -519,6 +521,9 @@ export const nodesSlice = createSlice({
fieldImagen3ModelValueChanged: (state, action: FieldValueAction<Imagen3ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen3ModelFieldValue);
},
fieldImagen4ModelValueChanged: (state, action: FieldValueAction<Imagen4ModelFieldValue>) => {
fieldValueReducer(state, action, zImagen4ModelFieldValue);
},
fieldChatGPT4oModelValueChanged: (state, action: FieldValueAction<ChatGPT4oModelFieldValue>) => {
fieldValueReducer(state, action, zChatGPT4oModelFieldValue);
},
@@ -690,6 +695,7 @@ export const {
fieldSigLipModelValueChanged,
fieldFluxReduxModelValueChanged,
fieldImagen3ModelValueChanged,
fieldImagen4ModelValueChanged,
fieldChatGPT4oModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,

View File

@@ -76,10 +76,21 @@ const zBaseModel = z.enum([
'flux',
'cogview4',
'imagen3',
'imagen4',
'chatgpt-4o',
]);
export type BaseModelType = z.infer<typeof zBaseModel>;
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'cogview4', 'imagen3', 'chatgpt-4o']);
export const zMainModelBase = z.enum([
'sd-1',
'sd-2',
'sd-3',
'sdxl',
'flux',
'cogview4',
'imagen3',
'imagen4',
'chatgpt-4o',
]);
export type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
const zModelType = z.enum([
@@ -147,7 +158,7 @@ export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zModelIdentifierField,
weight: z.number(),
method: z.enum(['full', 'style', 'composition']),
method: z.enum(['full', 'style', 'composition', 'style_strong', 'style_precise']),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
});

View File

@@ -252,6 +252,10 @@ const zImagen3ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen3ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zImagen4ModelFieldType = zFieldTypeBase.extend({
name: z.literal('Imagen4ModelField'),
originalType: zStatelessFieldType.optional(),
});
const zChatGPT4oModelFieldType = zFieldTypeBase.extend({
name: z.literal('ChatGPT4oModelField'),
originalType: zStatelessFieldType.optional(),
@@ -307,6 +311,7 @@ const zStatefulFieldType = z.union([
zSigLipModelFieldType,
zFluxReduxModelFieldType,
zImagen3ModelFieldType,
zImagen4ModelFieldType,
zChatGPT4oModelFieldType,
zColorFieldType,
zSchedulerFieldType,
@@ -347,6 +352,7 @@ const modelFieldTypeNames = [
zSigLipModelFieldType.shape.name.value,
zFluxReduxModelFieldType.shape.name.value,
zImagen3ModelFieldType.shape.name.value,
zImagen4ModelFieldType.shape.name.value,
zChatGPT4oModelFieldType.shape.name.value,
// Stateless model fields
'UNetField',
@@ -1207,6 +1213,24 @@ export const isImagen3ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen3ModelFieldInputTemplate>('Imagen3ModelField');
// #endregion
// #region Imagen4ModelField
export const zImagen4ModelFieldValue = zModelIdentifierField.optional();
const zImagen4ModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImagen4ModelFieldValue,
});
const zImagen4ModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zImagen4ModelFieldType,
originalType: zFieldType.optional(),
default: zImagen4ModelFieldValue,
});
export type Imagen4ModelFieldValue = z.infer<typeof zImagen4ModelFieldValue>;
export type Imagen4ModelFieldInputInstance = z.infer<typeof zImagen4ModelFieldInputInstance>;
export type Imagen4ModelFieldInputTemplate = z.infer<typeof zImagen4ModelFieldInputTemplate>;
export const isImagen4ModelFieldInputInstance = buildInstanceTypeGuard(zImagen4ModelFieldInputInstance);
export const isImagen4ModelFieldInputTemplate =
buildTemplateTypeGuard<Imagen4ModelFieldInputTemplate>('Imagen4ModelField');
// #endregion
// #region ChatGPT4oModelField
export const zChatGPT4oModelFieldValue = zModelIdentifierField.optional();
const zChatGPT4oModelFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1857,6 +1881,7 @@ export const zStatefulFieldValue = z.union([
zSigLipModelFieldValue,
zFluxReduxModelFieldValue,
zImagen3ModelFieldValue,
zImagen4ModelFieldValue,
zChatGPT4oModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
@@ -1949,6 +1974,7 @@ const zStatefulFieldInputTemplate = z.union([
zSigLipModelFieldInputTemplate,
zFluxReduxModelFieldInputTemplate,
zImagen3ModelFieldInputTemplate,
zImagen4ModelFieldInputTemplate,
zChatGPT4oModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,

View File

@@ -52,7 +52,11 @@ export const prepareLinearUIBatch = (arg: {
count: prompts.length * iterations,
// Imagen3's support for seeded generation is iffy, we are just not going too use it in linear UI generations.
start:
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
model.base === 'imagen3' || model.base === 'imagen4'
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: shouldRandomizeSeed
? undefined
: seed,
});
firstBatchDatumList.push({
@@ -74,7 +78,11 @@ export const prepareLinearUIBatch = (arg: {
count: iterations,
// Imagen3's support for seeded generation is iffy, we are just not going too use in in linear UI generations.
start:
model.base === 'imagen3' ? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX) : shouldRandomizeSeed ? undefined : seed,
model.base === 'imagen3' || model.base === 'imagen4'
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
: shouldRandomizeSeed
? undefined
: seed,
});
secondBatchDatumList.push({

View File

@@ -4,7 +4,7 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isImagen3AspectRatioID } from 'features/controlLayers/store/types';
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
@@ -24,7 +24,7 @@ export const buildImagen3Graph = async (state: RootState, manager: CanvasManager
const generationMode = await manager.compositor.getGenerationMode();
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagen3IncompatibleGenerationMode'));
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen3' }));
}
log.debug({ generationMode }, 'Building Imagen3 graph');
@@ -38,7 +38,7 @@ export const buildImagen3Graph = async (state: RootState, manager: CanvasManager
assert(model, 'No model found for Imagen3 graph');
assert(model.base === 'imagen3', 'Imagen3 graph requires Imagen3 model');
assert(isImagen3AspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character');
const is_intermediate = canvasSettings.sendToCanvas;

View File

@@ -0,0 +1,78 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { isImagenAspectRatioID } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
selectPresetModifiedPrompts,
} from 'features/nodes/util/graph/graphBuilderUtils';
import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { t } from 'i18next';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildImagen4Graph = async (state: RootState, manager: CanvasManager): Promise<GraphBuilderReturn> => {
const generationMode = await manager.compositor.getGenerationMode();
if (generationMode !== 'txt2img') {
throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen4' }));
}
log.debug({ generationMode }, 'Building Imagen4 graph');
const canvas = selectCanvasSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const { bbox } = canvas;
const { positivePrompt, negativePrompt } = selectPresetModifiedPrompts(state);
const model = selectMainModelConfig(state);
assert(model, 'No model found for Imagen4 graph');
assert(model.base === 'imagen4', 'Imagen4 graph requires Imagen4 model');
assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen4 does not support this aspect ratio');
assert(positivePrompt.length > 0, 'Imagen4 requires positive prompt to have at least one character');
const is_intermediate = canvasSettings.sendToCanvas;
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
if (generationMode === 'txt2img') {
const g = new Graph(getPrefixedId('imagen4_txt2img_graph'));
const imagen4 = g.addNode({
// @ts-expect-error: These nodes are not available in the OSS application
type: 'google_imagen4_generate_image',
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
model: zModelIdentifierField.parse(model),
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
aspect_ratio: bbox.aspectRatio.id,
enhance_prompt: true,
// When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed.
use_cache: false,
is_intermediate,
board,
});
g.upsertMetadata({
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
width: bbox.rect.width,
height: bbox.rect.height,
model: Graph.getModelMetadataField(model),
});
return {
g,
seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' },
positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' },
};
}
assert<Equals<typeof generationMode, never>>(false, 'Invalid generation mode for Imagen4');
};

View File

@@ -34,6 +34,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SigLipModelField: undefined,
FluxReduxModelField: undefined,
Imagen3ModelField: undefined,
Imagen4ModelField: undefined,
ChatGPT4oModelField: undefined,
FloatGeneratorField: undefined,
IntegerGeneratorField: undefined,

View File

@@ -23,6 +23,7 @@ import type {
ImageFieldInputTemplate,
ImageGeneratorFieldInputTemplate,
Imagen3ModelFieldInputTemplate,
Imagen4ModelFieldInputTemplate,
IntegerFieldCollectionInputTemplate,
IntegerFieldInputTemplate,
IntegerGeneratorFieldInputTemplate,
@@ -600,6 +601,18 @@ const buildImagen3ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen3Mode
return template;
};
const buildImagen4ModelFieldInputTemplate: FieldInputTemplateBuilder<Imagen4ModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: Imagen4ModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildChatGPT4oModelFieldInputTemplate: FieldInputTemplateBuilder<ChatGPT4oModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -682,7 +695,7 @@ const buildEnumFieldInputTemplate: FieldInputTemplateBuilder<EnumFieldInputTempl
if (filteredAnyOf.length !== 1 || !isSchemaObject(firstAnyOf)) {
options = [];
} else {
options = firstAnyOf.enum ?? [];
options = firstAnyOf.const ? [firstAnyOf.const] : (firstAnyOf.enum ?? []);
}
} else if (schemaObject.const) {
options = [schemaObject.const];
@@ -820,6 +833,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SigLipModelField: buildSigLipModelFieldInputTemplate,
FluxReduxModelField: buildFluxReduxModelFieldInputTemplate,
Imagen3ModelField: buildImagen3ModelFieldInputTemplate,
Imagen4ModelField: buildImagen4ModelFieldInputTemplate,
ChatGPT4oModelField: buildChatGPT4oModelFieldInputTemplate,
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,

View File

@@ -161,8 +161,15 @@ export const parseSchema = (
fieldType.batch = true;
}
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
inputsAccumulator[propertyName] = fieldInputTemplate;
try {
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
inputsAccumulator[propertyName] = fieldInputTemplate;
} catch {
log.error(
{ node: type, field: propertyName, schema: parseify(property) },
'Problem building input field template'
);
}
return inputsAccumulator;
},
@@ -226,9 +233,16 @@ export const parseSchema = (
fieldType.batch = true;
}
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
try {
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
outputsAccumulator[propertyName] = fieldOutputTemplate;
} catch {
log.error(
{ node: type, field: propertyName, schema: parseify(property) },
'Problem building output field template'
);
}
outputsAccumulator[propertyName] = fieldOutputTemplate;
return outputsAccumulator;
},
{} as Record<string, FieldOutputTemplate>

View File

@@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { bboxAspectRatioIdChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectIsChatGTP4o, selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsChatGTP4o, selectIsImagen3, selectIsImagen4 } from 'features/controlLayers/store/paramsSlice';
import { selectAspectRatioID } from 'features/controlLayers/store/selectors';
import {
isAspectRatioID,
@@ -23,10 +23,10 @@ export const BboxAspectRatioSelect = memo(() => {
const isStaging = useAppSelector(selectIsStaging);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isImagen4 = useAppSelector(selectIsImagen4);
const options = useMemo(() => {
// Imagen3 and ChatGPT4o have different aspect ratio options, and do not support freeform sizes
if (isImagen3) {
if (isImagen3 || isImagen4) {
return zImagen3AspectRatioID.options;
}
if (isChatGPT4o) {
@@ -34,7 +34,7 @@ export const BboxAspectRatioSelect = memo(() => {
}
// All other models
return zAspectRatioID.options;
}, [isImagen3, isChatGPT4o]);
}, [isImagen3, isChatGPT4o, isImagen4]);
const onChange = useCallback<ChangeEventHandler<HTMLSelectElement>>(
(e) => {

View File

@@ -1,10 +1,10 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectIsChatGTP4o, selectIsImagen3 } from 'features/controlLayers/store/paramsSlice';
import { useIsApiModel } from 'features/parameters/hooks/useIsApiModel';
export const useIsBboxSizeLocked = () => {
const isStaging = useAppSelector(selectIsStaging);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
return isImagen3 || isChatGPT4o || isStaging;
const isApiModel = useIsApiModel();
return isApiModel || isStaging;
};

View File

@@ -2,23 +2,18 @@ import { Flex, Link, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $accountSettingsLink } from 'app/store/nanostores/accountSettingsLink';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsChatGTP4o, selectModel } from 'features/controlLayers/store/paramsSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useMemo } from 'react';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { useIsModelDisabled } from 'features/parameters/hooks/useIsModelDisabled';
import { Trans, useTranslation } from 'react-i18next';
export const DisabledModelWarning = () => {
const { t } = useTranslation();
const model = useAppSelector(selectModel);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const areChatGPT4oModelsEnabled = useFeatureStatus('chatGPT4oModels');
const accountSettingsLink = useStore($accountSettingsLink);
const { isChatGPT4oHighModelDisabled } = useIsModelDisabled();
const isModelDisabled = useMemo(() => {
return isChatGPT4o && !areChatGPT4oModelsEnabled;
}, [isChatGPT4o, areChatGPT4oModelsEnabled]);
if (!isModelDisabled) {
if (!model || !isChatGPT4oHighModelDisabled(model)) {
return null;
}

View File

@@ -21,7 +21,7 @@ import { $installModelsTab } from 'features/modelManagerV2/subpanels/InstallMode
import { BASE_COLOR_MAP } from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelBaseBadge';
import ModelImage from 'features/modelManagerV2/subpanels/ModelManagerPanel/ModelImage';
import { NavigateToModelManagerButton } from 'features/parameters/components/MainModel/NavigateToModelManagerButton';
import { MODEL_TYPE_MAP, MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { API_BASE_MODELS, MODEL_TYPE_MAP, MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { filesize } from 'filesize';
import { memo, useCallback, useMemo, useRef } from 'react';
@@ -59,28 +59,28 @@ const NoOptionsFallback = memo(() => {
NoOptionsFallback.displayName = 'NoOptionsFallback';
const getGroupIDFromModelConfig = (modelConfig: AnyModelConfig): string => {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
if (API_BASE_MODELS.includes(modelConfig.base)) {
return 'api';
}
return modelConfig.base;
};
const getGroupNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
if (API_BASE_MODELS.includes(modelConfig.base)) {
return 'External API';
}
return MODEL_TYPE_MAP[modelConfig.base];
};
const getGroupShortNameFromModelConfig = (modelConfig: AnyModelConfig): string => {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
if (API_BASE_MODELS.includes(modelConfig.base)) {
return 'api';
}
return MODEL_TYPE_SHORT_MAP[modelConfig.base];
};
const getGroupColorSchemeFromModelConfig = (modelConfig: AnyModelConfig): string => {
if (modelConfig.base === 'chatgpt-4o' || modelConfig.base === 'imagen3') {
if (API_BASE_MODELS.includes(modelConfig.base)) {
return 'pink';
}
return BASE_COLOR_MAP[modelConfig.base];

View File

@@ -0,0 +1,10 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsChatGTP4o, selectIsImagen3, selectIsImagen4 } from 'features/controlLayers/store/paramsSlice';
export const useIsApiModel = () => {
const isImagen3 = useAppSelector(selectIsImagen3);
const isImagen4 = useAppSelector(selectIsImagen4);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
return isImagen3 || isImagen4 || isChatGPT4o;
};

View File

@@ -0,0 +1,16 @@
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useCallback } from 'react';
export const useIsModelDisabled = () => {
const isChatGPT4oHighEnabled = useFeatureStatus('chatGPT4oHigh');
const isChatGPT4oHighModelDisabled = useCallback(
(model: ParameterModel) => {
return model?.base === 'chatgpt-4o' && model.name.toLowerCase().includes('high') && !isChatGPT4oHighEnabled;
},
[isChatGPT4oHighEnabled]
);
return { isChatGPT4oHighModelDisabled };
};

View File

@@ -14,6 +14,7 @@ export const MODEL_TYPE_MAP: Record<BaseModelType, string> = {
flux: 'FLUX',
cogview4: 'CogView4',
imagen3: 'Imagen3',
imagen4: 'Imagen4',
'chatgpt-4o': 'ChatGPT 4o',
};
@@ -30,6 +31,7 @@ export const MODEL_TYPE_SHORT_MAP: Record<BaseModelType, string> = {
flux: 'FLUX',
cogview4: 'CogView4',
imagen3: 'Imagen3',
imagen4: 'Imagen4',
'chatgpt-4o': 'ChatGPT 4o',
};
@@ -73,6 +75,10 @@ export const CLIP_SKIP_MAP: Record<BaseModelType, { maxClip: number; markers: nu
maxClip: 0,
markers: [],
},
imagen4: {
maxClip: 0,
markers: [],
},
'chatgpt-4o': {
maxClip: 0,
markers: [],
@@ -114,3 +120,8 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
{ value: 'unipc', label: 'UniPC' },
{ value: 'unipc_k', label: 'UniPC Karras' },
];
/**
* List of base models that make API requests
*/
export const API_BASE_MODELS = ['imagen3', 'imagen4', 'chatgpt-4o'];

View File

@@ -19,6 +19,7 @@ export const getOptimalDimension = (base?: BaseModelType | null): number => {
case 'sd-3':
case 'cogview4':
case 'imagen3':
case 'imagen4':
case 'chatgpt-4o':
default:
return 1024;

View File

@@ -31,10 +31,11 @@ import type { WorkflowSettingsState } from 'features/nodes/store/workflowSetting
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isBatchNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
import { useIsModelDisabled } from 'features/parameters/hooks/useIsModelDisabled';
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { getGridSize } from 'features/parameters/util/optimalDimension';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectConfigSlice } from 'features/system/store/configSlice';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { TabName } from 'features/ui/store/uiTypes';
@@ -89,7 +90,7 @@ const debouncedUpdateReasons = debounce(
config: AppConfig,
store: AppStore,
isInPublishFlow: boolean,
areChatGPT4oModelsEnabled: boolean
isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean
) => {
if (tab === 'canvas') {
const model = selectMainModelConfig(store.getState());
@@ -104,7 +105,7 @@ const debouncedUpdateReasons = debounce(
canvasIsRasterizing,
canvasIsCompositing,
canvasIsSelectingObject,
areChatGPT4oModelsEnabled,
isChatGPT4oHighModelDisabled,
});
$reasonsWhyCannotEnqueue.set(reasons);
} else if (tab === 'workflows') {
@@ -152,7 +153,7 @@ export const useReadinessWatcher = () => {
const canvasIsSelectingObject = useStore(canvasManager?.stateApi.$isSegmenting ?? $true);
const canvasIsCompositing = useStore(canvasManager?.compositor.$isBusy ?? $true);
const isInPublishFlow = useStore($isInPublishFlow);
const areChatGPT4oModelsEnabled = useFeatureStatus('chatGPT4oModels');
const { isChatGPT4oHighModelDisabled } = useIsModelDisabled();
useEffect(() => {
debouncedUpdateReasons(
@@ -173,7 +174,7 @@ export const useReadinessWatcher = () => {
config,
store,
isInPublishFlow,
areChatGPT4oModelsEnabled
isChatGPT4oHighModelDisabled
);
}, [
store,
@@ -193,7 +194,7 @@ export const useReadinessWatcher = () => {
upscale,
workflowSettings,
isInPublishFlow,
areChatGPT4oModelsEnabled,
isChatGPT4oHighModelDisabled,
]);
};
@@ -341,7 +342,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
canvasIsRasterizing: boolean;
canvasIsCompositing: boolean;
canvasIsSelectingObject: boolean;
areChatGPT4oModelsEnabled: boolean;
isChatGPT4oHighModelDisabled: (model: ParameterModel) => boolean;
}) => {
const {
isConnected,
@@ -354,7 +355,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
canvasIsRasterizing,
canvasIsCompositing,
canvasIsSelectingObject,
areChatGPT4oModelsEnabled,
isChatGPT4oHighModelDisabled,
} = arg;
const { positivePrompt } = params;
const reasons: Reason[] = [];
@@ -487,7 +488,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
}
}
if (model?.base === 'chatgpt-4o' && !areChatGPT4oModelsEnabled) {
if (model && isChatGPT4oHighModelDisabled(model)) {
reasons.push({ content: i18n.t('parameters.invoke.modelDisabledForTrial', { modelName: model.name }) });
}

View File

@@ -4,13 +4,7 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsFLUX,
selectIsImagen3,
selectIsSD3,
} from 'features/controlLayers/store/paramsSlice';
import { selectIsCogView4, selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
@@ -20,6 +14,8 @@ import ParamSteps from 'features/parameters/components/Core/ParamSteps';
import { DisabledModelWarning } from 'features/parameters/components/MainModel/DisabledModelWarning';
import ParamUpscaleCFGScale from 'features/parameters/components/Upscale/ParamUpscaleCFGScale';
import ParamUpscaleScheduler from 'features/parameters/components/Upscale/ParamUpscaleScheduler';
import { useIsApiModel } from 'features/parameters/hooks/useIsApiModel';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { MainModelPicker } from 'features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
@@ -40,12 +36,8 @@ export const GenerationSettingsAccordion = memo(() => {
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
const isCogView4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
const isApiModel = useIsApiModel();
const isUpscaling = useMemo(() => {
return activeTabName === 'upscaling';
@@ -56,7 +48,7 @@ export const GenerationSettingsAccordion = memo(() => {
const enabledLoRAsCount = loras.loras.filter((l) => l.isEnabled).length;
const loraTabBadges = enabledLoRAsCount ? [`${enabledLoRAsCount} ${t('models.concepts')}`] : EMPTY_ARRAY;
const accordionBadges =
modelConfig?.base === 'imagen3' || modelConfig?.base === 'chatgpt-4o'
modelConfig && API_BASE_MODELS.includes(modelConfig.base)
? [modelConfig.name]
: modelConfig
? [modelConfig.name, modelConfig.base]

View File

@@ -3,13 +3,7 @@ import { Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-a
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectIsChatGTP4o,
selectIsFLUX,
selectIsImagen3,
selectIsSD3,
selectParamsSlice,
} from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectIsSD3, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectScaleMethod } from 'features/controlLayers/store/selectors';
import { ParamOptimizedDenoisingToggle } from 'features/parameters/components/Advanced/ParamOptimizedDenoisingToggle';
import BboxScaledHeight from 'features/parameters/components/Bbox/BboxScaledHeight';
@@ -17,9 +11,10 @@ import BboxScaledWidth from 'features/parameters/components/Bbox/BboxScaledWidth
import BboxScaleMethod from 'features/parameters/components/Bbox/BboxScaleMethod';
import { BboxSettings } from 'features/parameters/components/Bbox/BboxSettings';
import { ParamSeed } from 'features/parameters/components/Seed/ParamSeed';
import { useIsApiModel } from 'features/parameters/hooks/useIsApiModel';
import { useExpanderToggle } from 'features/settingsAccordions/hooks/useExpanderToggle';
import { useStandaloneAccordionToggle } from 'features/settingsAccordions/hooks/useStandaloneAccordionToggle';
import { memo, useMemo } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
const selectBadges = createMemoizedSelector([selectCanvasSlice, selectParamsSlice], (canvas, params) => {
@@ -65,12 +60,7 @@ export const ImageSettingsAccordion = memo(() => {
});
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
const isApiModel = useIsApiModel();
return (
<StandaloneAccordion

View File

@@ -2,13 +2,9 @@ import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import {
selectIsChatGTP4o,
selectIsCogView4,
selectIsImagen3,
selectIsSDXL,
} from 'features/controlLayers/store/paramsSlice';
import { selectIsCogView4, selectIsSDXL } from 'features/controlLayers/store/paramsSlice';
import { Prompts } from 'features/parameters/components/Prompts/Prompts';
import { useIsApiModel } from 'features/parameters/hooks/useIsApiModel';
import { AdvancedSettingsAccordion } from 'features/settingsAccordions/components/AdvancedSettingsAccordion/AdvancedSettingsAccordion';
import { CompositingSettingsAccordion } from 'features/settingsAccordions/components/CompositingSettingsAccordion/CompositingSettingsAccordion';
import { GenerationSettingsAccordion } from 'features/settingsAccordions/components/GenerationSettingsAccordion/GenerationSettingsAccordion';
@@ -19,7 +15,7 @@ import { StylePresetMenuTrigger } from 'features/stylePresets/components/StylePr
import { $isStylePresetsMenuOpen } from 'features/stylePresets/store/stylePresetSlice';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { memo } from 'react';
const overlayScrollbarsStyles: CSSProperties = {
height: '100%',
@@ -29,13 +25,9 @@ const overlayScrollbarsStyles: CSSProperties = {
const ParametersPanelTextToImage = () => {
const isSDXL = useAppSelector(selectIsSDXL);
const isCogview4 = useAppSelector(selectIsCogView4);
const isImagen3 = useAppSelector(selectIsImagen3);
const isChatGPT4o = useAppSelector(selectIsChatGTP4o);
const isStylePresetsMenuOpen = useStore($isStylePresetsMenuOpen);
const isApiModel = useMemo(() => {
return isImagen3 || isChatGPT4o;
}, [isImagen3, isChatGPT4o]);
const isApiModel = useIsApiModel();
return (
<Flex w="full" h="full" flexDir="column" gap={2}>

View File

@@ -0,0 +1,67 @@
/**
* modelRelationships.ts
*
* RTK Query API slice for managing model-to-model relationships.
*
* Endpoints provided:
* - Fetch related models for a single model
* - Add a relationship between two models
* - Remove a relationship between two models
* - Fetch related models for multiple models in batch
*
* Provides and invalidates cache tags for seamless UI updates after add/remove operations.
*/
import { api } from '..';
const REL_TAG = 'ModelRelationships'; // Needed for UI updates on relationship changes.
const modelRelationshipsApi = api.injectEndpoints({
endpoints: (build) => ({
getRelatedModelIds: build.query<string[], string>({
query: (model_key) => `/api/v1/model_relationships/i/${model_key}`,
providesTags: (result, error, model_key) => [{ type: REL_TAG, id: model_key }],
}),
addModelRelationship: build.mutation<void, { model_key_1: string; model_key_2: string }>({
query: (payload) => ({
url: `/api/v1/model_relationships/`,
method: 'POST',
body: payload,
}),
invalidatesTags: (result, error, { model_key_1, model_key_2 }) => [
{ type: REL_TAG, id: model_key_1 },
{ type: REL_TAG, id: model_key_2 },
],
}),
removeModelRelationship: build.mutation<void, { model_key_1: string; model_key_2: string }>({
query: (payload) => ({
url: `/api/v1/model_relationships/`,
method: 'DELETE',
body: payload,
}),
invalidatesTags: (result, error, { model_key_1, model_key_2 }) => [
{ type: REL_TAG, id: model_key_1 },
{ type: REL_TAG, id: model_key_2 },
],
}),
getRelatedModelIdsBatch: build.query<string[], string[]>({
query: (model_keys) => ({
url: `/api/v1/model_relationships/batch`,
method: 'POST',
body: { model_keys },
}),
providesTags: (result, error, model_keys) => model_keys.map((key) => ({ type: 'ModelRelationships', id: key })),
}),
}),
overrideExisting: false,
});
export const {
useGetRelatedModelIdsQuery,
useAddModelRelationshipMutation,
useRemoveModelRelationshipMutation,
useGetRelatedModelIdsBatchQuery,
} = modelRelationshipsApi;

View File

@@ -20,6 +20,7 @@ import {
isFluxReduxModelConfig,
isFluxVAEModelConfig,
isImagen3ModelConfig,
isImagen4ModelConfig,
isIPAdapterModelConfig,
isLLaVAModelConfig,
isLoRAModelConfig,
@@ -91,6 +92,7 @@ export const useRegionalReferenceImageModels = buildModelsHook(
);
export const useLLaVAModels = buildModelsHook(isLLaVAModelConfig);
export const useImagen3Models = buildModelsHook(isImagen3ModelConfig);
export const useImagen4Models = buildModelsHook(isImagen4ModelConfig);
export const useChatGPT4oModels = buildModelsHook(isChatGPT4oModelConfig);
// const buildModelsSelector =

View File

@@ -34,6 +34,7 @@ const tagTypes = [
'InvocationCacheStatus',
'ModelConfig',
'ModelInstalls',
'ModelRelationships',
'ModelScanFolderResults',
'T2IAdapterModel',
'MainModel',

File diff suppressed because it is too large Load Diff

View File

@@ -236,6 +236,10 @@ export const isImagen3ModelConfig = (config: AnyModelConfig): config is ApiModel
return config.type === 'main' && config.base === 'imagen3';
};
export const isImagen4ModelConfig = (config: AnyModelConfig): config is ApiModelConfig => {
return config.type === 'main' && config.base === 'imagen4';
};
export const isNonRefinerMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sdxl-refiner';
};

View File

@@ -1 +1 @@
__version__ = "5.11.0"
__version__ = "5.12.0"

View File

@@ -2,12 +2,12 @@
"python": "3.12",
"torchIndexUrl": {
"win32": {
"cuda": "https://download.pytorch.org/whl/cu126"
"cuda": "https://download.pytorch.org/whl/cu128"
},
"linux": {
"cpu": "https://download.pytorch.org/whl/cpu",
"rocm": "https://download.pytorch.org/whl/rocm6.2.4",
"cuda": "https://download.pytorch.org/whl/cu126"
"rocm": "https://download.pytorch.org/whl/rocm6.3",
"cuda": "https://download.pytorch.org/whl/cu128"
},
"darwin": {}
}

View File

@@ -47,7 +47,7 @@ dependencies = [
"safetensors",
"sentencepiece",
"spandrel",
"torch~=2.6.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
"torch~=2.7.0", # torch and related dependencies are loosely pinned, will respect requirement of `diffusers[torch]`
"torchsde", # diffusers needs this for SDE solvers, but it is not an explicit dep of diffusers
"torchvision",
"transformers",

View File

@@ -28,9 +28,9 @@ args = parser.parse_args()
def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
try:
return ModelConfigBase.classify(path, hash_algo)
except InvalidModelConfigException:
return ModelProbe.probe(path, hash_algo=hash_algo)
except InvalidModelConfigException:
return ModelConfigBase.classify(path, hash_algo)
for path in args.model_path:

View File

@@ -18,13 +18,16 @@ import json
import shutil
import sys
from pathlib import Path
from typing import Optional
import humanize
import torch
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk, StateDict
from invokeai.backend.model_manager.search import ModelSearch
METADATA_KEY = "metadata_key_for_stripped_models"
def strip(v):
match v:
@@ -57,9 +60,22 @@ def dress(v):
def load_stripped_model(path: Path, *args, **kwargs):
with open(path, "r") as f:
contents = json.load(f)
contents.pop(METADATA_KEY, None)
return dress(contents)
class StrippedModelOnDisk(ModelOnDisk):
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
path = self.resolve_weight_file(path)
return load_stripped_model(path)
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
path = self.resolve_weight_file(path)
with open(path, "r") as f:
contents = json.load(f)
return contents.get(METADATA_KEY, {})
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
original = ModelOnDisk(original_model_path)
if original.path.is_file():
@@ -69,11 +85,14 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
stripped = ModelOnDisk(stripped_model_path)
print(f"Created clone of {original.name} at {stripped.path}")
for component_path in stripped.component_paths():
for component_path in stripped.weight_files():
original_state_dict = stripped.load_state_dict(component_path)
stripped_state_dict = strip(original_state_dict) # type: ignore
metadata = stripped.metadata()
contents = {**stripped_state_dict, METADATA_KEY: metadata}
with open(component_path, "w") as f:
json.dump(stripped_state_dict, f, indent=4)
json.dump(contents, f, indent=4)
before_size = humanize.naturalsize(original.size())
after_size = humanize.naturalsize(stripped.size())

View File

@@ -0,0 +1,46 @@
from typing import Any, Literal, Optional, Union
import pytest
from pydantic import BaseModel
class TestModel(BaseModel):
foo: Literal["bar"] = "bar"
@pytest.mark.parametrize(
"input_type, expected",
[
(str, False),
(list[str], False),
(list[dict[str, Any]], False),
(list[None], False),
(list[dict[str, None]], False),
(Any, False),
(True, False),
(False, False),
(Union[str, False], False),
(Union[str, True], False),
(None, False),
(str | None, True),
(Union[str, None], True),
(Optional[str], True),
(str | int | None, True),
(None | str | int, True),
(Union[None, str], True),
(Optional[str], True),
(Optional[int], True),
(Optional[str], True),
(TestModel | None, True),
(Union[TestModel, None], True),
(Optional[TestModel], True),
],
)
def test_is_optional(input_type: Any, expected: bool) -> None:
"""
Test the is_optional function.
"""
from invokeai.app.invocations.baseinvocation import is_optional
result = is_optional(input_type)
assert result == expected, f"Expected {expected} but got {result} for input type {input_type}"

View File

@@ -65,6 +65,8 @@ def mock_services() -> InvocationServices:
style_preset_records=None, # type: ignore
style_preset_image_files=None, # type: ignore
workflow_thumbnails=None, # type: ignore
model_relationship_records=None, # type: ignore
model_relationships=None, # type: ignore
)

View File

@@ -29,6 +29,9 @@ from invokeai.backend.model_manager.legacy_probe import (
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util.logging import InvokeAILogger
from scripts.strip_models import StrippedModelOnDisk
logger = InvokeAILogger.get_logger(__file__)
@pytest.mark.parametrize(
@@ -156,7 +159,8 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
pass
try:
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
stripped_mod = StrippedModelOnDisk(path)
new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key)
except InvalidModelConfigException:
pass
@@ -165,10 +169,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
assert legacy_config.model_dump_json() == new_config.model_dump_json()
elif legacy_config:
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
assert type(legacy_config) in ModelConfigBase.USING_LEGACY_PROBE
elif new_config:
assert type(new_config) in ModelConfigBase._USING_CLASSIFY_API
assert type(new_config) in ModelConfigBase.USING_CLASSIFY_API
else:
raise ValueError(f"Both probe and classify failed to classify model at path {path}.")
@@ -177,7 +181,6 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
configs_with_tests.add(config_type)
untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests - {MinimalConfigExample}
logger = InvokeAILogger.get_logger(__file__)
logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}")
@@ -255,3 +258,13 @@ def test_any_model_config_includes_all_config_classes():
expected = set(ModelConfigBase.all_config_classes()) - {MinimalConfigExample}
assert extracted == expected
def test_config_uniquely_matches_model(datadir: Path):
model_paths = ModelSearch().search(datadir / "stripped_models")
for path in model_paths:
mod = StrippedModelOnDisk(path)
matches = {cls for cls in ModelConfigBase.USING_CLASSIFY_API if cls.matches(mod)}
assert len(matches) <= 1, f"Model at path {path} matches multiple config classes: {matches}"
if not matches:
logger.warning(f"Model at path {path} does not match any config classes using classify API.")

158
uv.lock generated
View File

@@ -1104,7 +1104,7 @@ requires-dist = [
{ name = "sentencepiece" },
{ name = "snakeviz", marker = "extra == 'dev'" },
{ name = "spandrel" },
{ name = "torch", specifier = "~=2.6.0" },
{ name = "torch", specifier = "~=2.7.0" },
{ name = "torchsde" },
{ name = "torchvision" },
{ name = "transformers" },
@@ -1794,69 +1794,81 @@ wheels = [
[[package]]
name = "nvidia-cublas-cu12"
version = "12.4.5.8"
version = "12.6.4.1"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 },
{ url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322 },
]
[[package]]
name = "nvidia-cuda-cupti-cu12"
version = "12.4.127"
version = "12.6.80"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 },
{ url = "https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132", size = 8917980 },
{ url = "https://files.pythonhosted.org/packages/a5/24/120ee57b218d9952c379d1e026c4479c9ece9997a4fb46303611ee48f038/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73", size = 8917972 },
]
[[package]]
name = "nvidia-cuda-nvrtc-cu12"
version = "12.4.127"
version = "12.6.77"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 },
{ url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380 },
]
[[package]]
name = "nvidia-cuda-runtime-cu12"
version = "12.4.127"
version = "12.6.77"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 },
{ url = "https://files.pythonhosted.org/packages/e1/23/e717c5ac26d26cf39a27fbc076240fad2e3b817e5889d671b67f4f9f49c5/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7", size = 897690 },
{ url = "https://files.pythonhosted.org/packages/f0/62/65c05e161eeddbafeca24dc461f47de550d9fa8a7e04eb213e32b55cfd99/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8", size = 897678 },
]
[[package]]
name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
version = "9.5.1.17"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
{ url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386 },
]
[[package]]
name = "nvidia-cufft-cu12"
version = "11.2.1.3"
version = "11.3.0.4"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 },
{ url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632 },
{ url = "https://files.pythonhosted.org/packages/60/de/99ec247a07ea40c969d904fc14f3a356b3e2a704121675b75c366b694ee1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca", size = 200221622 },
]
[[package]]
name = "nvidia-cufile-cu12"
version = "1.11.1.6"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b2/66/cc9876340ac68ae71b15c743ddb13f8b30d5244af344ec8322b449e35426/nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159", size = 1142103 },
]
[[package]]
name = "nvidia-curand-cu12"
version = "10.3.5.147"
version = "10.3.7.77"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 },
{ url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010 },
{ url = "https://files.pythonhosted.org/packages/4a/aa/2c7ff0b5ee02eaef890c0ce7d4f74bc30901871c5e45dee1ae6d0083cd80/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117", size = 56279000 },
]
[[package]]
name = "nvidia-cusolver-cu12"
version = "11.6.1.9"
version = "11.7.1.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
@@ -1864,50 +1876,53 @@ dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 },
{ url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790 },
{ url = "https://files.pythonhosted.org/packages/9f/81/baba53585da791d043c10084cf9553e074548408e04ae884cfe9193bd484/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6", size = 158229780 },
]
[[package]]
name = "nvidia-cusparse-cu12"
version = "12.3.1.170"
version = "12.5.4.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 },
{ url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367 },
{ url = "https://files.pythonhosted.org/packages/43/ac/64c4316ba163e8217a99680c7605f779accffc6a4bcd0c778c12948d3707/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f", size = 216561357 },
]
[[package]]
name = "nvidia-cusparselt-cu12"
version = "0.6.2"
version = "0.6.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751 },
{ url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796 },
]
[[package]]
name = "nvidia-nccl-cu12"
version = "2.21.5"
version = "2.26.2"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414 },
{ url = "https://files.pythonhosted.org/packages/67/ca/f42388aed0fddd64ade7493dbba36e1f534d4e6fdbdd355c6a90030ae028/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6", size = 201319755 },
]
[[package]]
name = "nvidia-nvjitlink-cu12"
version = "12.4.127"
version = "12.6.85"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 },
{ url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971 },
]
[[package]]
name = "nvidia-nvtx-cu12"
version = "12.4.127"
version = "12.6.77"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 },
{ url = "https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2", size = 89276 },
{ url = "https://files.pythonhosted.org/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265 },
]
[[package]]
@@ -3075,14 +3090,14 @@ wheels = [
[[package]]
name = "sympy"
version = "1.13.1"
version = "1.14.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mpmath" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ca/99/5a5b6f19ff9f083671ddf7b9632028436167cd3d33e11015754e41b249a4/sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f", size = 7533040 }
sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 },
{ url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353 },
]
[[package]]
@@ -3141,7 +3156,7 @@ wheels = [
[[package]]
name = "torch"
version = "2.6.0"
version = "2.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock" },
@@ -3154,6 +3169,7 @@ dependencies = [
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
@@ -3167,18 +3183,18 @@ dependencies = [
{ name = "typing-extensions" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/37/81/aa9ab58ec10264c1abe62c8b73f5086c3c558885d6beecebf699f0dbeaeb/torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961", size = 766685561 },
{ url = "https://files.pythonhosted.org/packages/86/86/e661e229df2f5bfc6eab4c97deb1286d598bbeff31ab0cdb99b3c0d53c6f/torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab", size = 95751887 },
{ url = "https://files.pythonhosted.org/packages/20/e0/5cb2f8493571f0a5a7273cd7078f191ac252a402b5fb9cb6091f14879109/torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341", size = 204165139 },
{ url = "https://files.pythonhosted.org/packages/e5/16/ea1b7842413a7b8a5aaa5e99e8eaf3da3183cc3ab345ad025a07ff636301/torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628", size = 66520221 },
{ url = "https://files.pythonhosted.org/packages/78/a9/97cbbc97002fff0de394a2da2cdfa859481fdca36996d7bd845d50aa9d8d/torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1", size = 766715424 },
{ url = "https://files.pythonhosted.org/packages/6d/fa/134ce8f8a7ea07f09588c9cc2cea0d69249efab977707cf67669431dcf5c/torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d", size = 95759416 },
{ url = "https://files.pythonhosted.org/packages/11/c5/2370d96b31eb1841c3a0883a492c15278a6718ccad61bb6a649c80d1d9eb/torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7", size = 204164970 },
{ url = "https://files.pythonhosted.org/packages/0b/fa/f33a4148c6fb46ca2a3f8de39c24d473822d5774d652b66ed9b1214da5f7/torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21", size = 66530713 },
{ url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563 },
{ url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867 },
{ url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469 },
{ url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538 },
{ url = "https://files.pythonhosted.org/packages/46/c2/3fb87940fa160d956ee94d644d37b99a24b9c05a4222bf34f94c71880e28/torch-2.7.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c9afea41b11e1a1ab1b258a5c31afbd646d6319042bfe4f231b408034b51128b", size = 99158447 },
{ url = "https://files.pythonhosted.org/packages/cc/2c/91d1de65573fce563f5284e69d9c56b57289625cffbbb6d533d5d56c36a5/torch-2.7.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:0b9960183b6e5b71239a3e6c883d8852c304e691c0b2955f7045e8a6d05b9183", size = 865164221 },
{ url = "https://files.pythonhosted.org/packages/7f/7e/1b1cc4e0e7cc2666cceb3d250eef47a205f0821c330392cf45eb08156ce5/torch-2.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:2ad79d0d8c2a20a37c5df6052ec67c2078a2c4e9a96dd3a8b55daaff6d28ea29", size = 212521189 },
{ url = "https://files.pythonhosted.org/packages/dc/0b/b2b83f30b8e84a51bf4f96aa3f5f65fdf7c31c591cc519310942339977e2/torch-2.7.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:34e0168ed6de99121612d72224e59b2a58a83dae64999990eada7260c5dd582d", size = 68559462 },
{ url = "https://files.pythonhosted.org/packages/40/da/7378d16cc636697f2a94f791cb496939b60fb8580ddbbef22367db2c2274/torch-2.7.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:2b7813e904757b125faf1a9a3154e1d50381d539ced34da1992f52440567c156", size = 99159397 },
{ url = "https://files.pythonhosted.org/packages/0e/6b/87fcddd34df9f53880fa1f0c23af7b6b96c935856473faf3914323588c40/torch-2.7.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:fd5cfbb4c3bbadd57ad1b27d56a28008f8d8753733411a140fcfb84d7f933a25", size = 865183681 },
{ url = "https://files.pythonhosted.org/packages/13/85/6c1092d4b06c3db1ed23d4106488750917156af0b24ab0a2d9951830b0e9/torch-2.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:58df8d5c2eeb81305760282b5069ea4442791a6bbf0c74d9069b7b3304ff8a37", size = 212520100 },
{ url = "https://files.pythonhosted.org/packages/aa/3f/85b56f7e2abcfa558c5fbf7b11eb02d78a4a63e6aeee2bbae3bb552abea5/torch-2.7.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:0a8d43caa342b9986101ec5feb5bbf1d86570b5caa01e9cb426378311258fdde", size = 68569377 },
{ url = "https://files.pythonhosted.org/packages/aa/5e/ac759f4c0ab7c01feffa777bd68b43d2ac61560a9770eeac074b450f81d4/torch-2.7.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:36a6368c7ace41ad1c0f69f18056020b6a5ca47bedaca9a2f3b578f5a104c26c", size = 99013250 },
{ url = "https://files.pythonhosted.org/packages/9c/58/2d245b6f1ef61cf11dfc4aceeaacbb40fea706ccebac3f863890c720ab73/torch-2.7.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:15aab3e31c16feb12ae0a88dba3434a458874636f360c567caa6a91f6bfba481", size = 865042157 },
{ url = "https://files.pythonhosted.org/packages/44/80/b353c024e6b624cd9ce1d66dcb9d24e0294680f95b369f19280e241a0159/torch-2.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:f56d4b2510934e072bab3ab8987e00e60e1262fb238176168f5e0c43a1320c6d", size = 212482262 },
{ url = "https://files.pythonhosted.org/packages/ee/8d/b2939e5254be932db1a34b2bd099070c509e8887e0c5a90c498a917e4032/torch-2.7.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:30b7688a87239a7de83f269333651d8e582afffce6f591fff08c046f7787296e", size = 68574294 },
]
[[package]]
@@ -3198,7 +3214,7 @@ wheels = [
[[package]]
name = "torchvision"
version = "0.21.0"
version = "0.22.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy" },
@@ -3206,21 +3222,18 @@ dependencies = [
{ name = "torch" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/a9/20/72eb0b5b08fa293f20fc41c374e37cf899f0033076f0144d2cdc48f9faee/torchvision-0.21.0-1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5568c5a1ff1b2ec33127b629403adb530fab81378d9018ca4ed6508293f76e2b", size = 2327643 },
{ url = "https://files.pythonhosted.org/packages/4e/3d/b7241abfa3e6651c6e00796f5de2bd1ce4d500bf5159bcbfeea47e711b93/torchvision-0.21.0-1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ff96666b94a55e802ea6796cabe788541719e6f4905fc59c380fed3517b6a64d", size = 2329320 },
{ url = "https://files.pythonhosted.org/packages/52/5b/76ca113a853b19c7b1da761f8a72cb6429b3bd0bf932537d8df4657f47c3/torchvision-0.21.0-1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:ffa2a16499508fe6798323e455f312c7c55f2a88901c9a7c0fb1efa86cf7e327", size = 2329878 },
{ url = "https://files.pythonhosted.org/packages/8e/0d/143bd264876fad17c82096b6c2d433f1ac9b29cdc69ee45023096976ee3d/torchvision-0.21.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:044ea420b8c6c3162a234cada8e2025b9076fa82504758cd11ec5d0f8cd9fa37", size = 1784140 },
{ url = "https://files.pythonhosted.org/packages/5e/44/32e2d2d174391374d5ff3c4691b802e8efda9ae27ab9062eca2255b006af/torchvision-0.21.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:b0c0b264b89ab572888244f2e0bad5b7eaf5b696068fc0b93e96f7c3c198953f", size = 7237187 },
{ url = "https://files.pythonhosted.org/packages/0e/6b/4fca9373eda42c1b04096758306b7bd55f7d8f78ba273446490855a0f25d/torchvision-0.21.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:54815e0a56dde95cc6ec952577f67e0dc151eadd928e8d9f6a7f821d69a4a734", size = 14699067 },
{ url = "https://files.pythonhosted.org/packages/aa/f7/799ddd538b21017cbf80294c92e9efbf6db08dff6efee37c3be114a81845/torchvision-0.21.0-cp310-cp310-win_amd64.whl", hash = "sha256:abbf1d7b9d52c00d2af4afa8dac1fb3e2356f662a4566bd98dfaaa3634f4eb34", size = 1560542 },
{ url = "https://files.pythonhosted.org/packages/29/88/00c69db213ee2443ada8886ec60789b227e06bb869d85ee324578221a7f7/torchvision-0.21.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:110d115333524d60e9e474d53c7d20f096dbd8a080232f88dddb90566f90064c", size = 1784141 },
{ url = "https://files.pythonhosted.org/packages/be/a2/b0cedf0a411f1a5d75cfc0b87cde56dd1ddc1878be46a42c905cd8580220/torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:3891cd086c5071bda6b4ee9d266bb2ac39c998c045c2ebcd1e818b8316fb5d41", size = 7237719 },
{ url = "https://files.pythonhosted.org/packages/8c/a1/ee962ef9d0b2bf7a6f8b14cb95acb70e05cd2101af521032a09e43f8582f/torchvision-0.21.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:54454923a50104c66a9ab6bd8b73a11c2fc218c964b1006d5d1fe5b442c3dcb6", size = 14700617 },
{ url = "https://files.pythonhosted.org/packages/88/53/4ad334b9b1d8dd99836869fec139cb74a27781298360b91b9506c53f1d10/torchvision-0.21.0-cp311-cp311-win_amd64.whl", hash = "sha256:49bcfad8cfe2c27dee116c45d4f866d7974bcf14a5a9fbef893635deae322f2f", size = 1560523 },
{ url = "https://files.pythonhosted.org/packages/6e/1b/28f527b22d5e8800184d0bc847f801ae92c7573a8c15979d92b7091c0751/torchvision-0.21.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:97a5814a93c793aaf0179cfc7f916024f4b63218929aee977b645633d074a49f", size = 1784140 },
{ url = "https://files.pythonhosted.org/packages/36/63/0722e153fd27d64d5b0af45b5c8cb0e80b35a68cf0130303bc9a8bb095c7/torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:b578bcad8a4083b40d34f689b19ca9f7c63e511758d806510ea03c29ac568f7b", size = 7238673 },
{ url = "https://files.pythonhosted.org/packages/bb/ea/03541ed901cdc30b934f897060d09bbf7a98466a08ad1680320f9ce0cbe0/torchvision-0.21.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5083a5b1fec2351bf5ea9900a741d54086db75baec4b1d21e39451e00977f1b1", size = 14701186 },
{ url = "https://files.pythonhosted.org/packages/4c/6a/c7752603060d076dfed95135b78b047dc71792630cbcb022e3693d6f32ef/torchvision-0.21.0-cp312-cp312-win_amd64.whl", hash = "sha256:6eb75d41e3bbfc2f7642d0abba9383cc9ae6c5a4ca8d6b00628c225e1eaa63b3", size = 1560520 },
{ url = "https://files.pythonhosted.org/packages/eb/03/a514766f068b088180f273913e539d08e830be3ae46ef8577ea62584a27c/torchvision-0.22.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:72256f1d7ff510b16c9fb4dd488584d0693f40c792f286a9620674438a81ccca", size = 1947829 },
{ url = "https://files.pythonhosted.org/packages/a3/e5/ec4b52041cd8c440521b75864376605756bd2d112d6351ea6a1ab25008c1/torchvision-0.22.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:810ea4af3bc63cf39e834f91f4218ff5999271caaffe2456247df905002bd6c0", size = 2512604 },
{ url = "https://files.pythonhosted.org/packages/e7/9e/e898a377e674da47e95227f3d7be2c49550ce381eebd8c7831c1f8bb7d39/torchvision-0.22.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6fbca169c690fa2b9b8c39c0ad76d5b8992296d0d03df01e11df97ce12b4e0ac", size = 7446399 },
{ url = "https://files.pythonhosted.org/packages/c7/ec/2cdb90c6d9d61410b3df9ca67c210b60bf9b07aac31f800380b20b90386c/torchvision-0.22.0-cp310-cp310-win_amd64.whl", hash = "sha256:8c869df2e8e00f7b1d80a34439e6d4609b50fe3141032f50b38341ec2b59404e", size = 1716700 },
{ url = "https://files.pythonhosted.org/packages/b1/43/28bc858b022f6337326d75f4027d2073aad5432328f01ee1236d847f1b82/torchvision-0.22.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:191ea28321fc262d8aa1a7fe79c41ff2848864bf382f9f6ea45c41dde8313792", size = 1947828 },
{ url = "https://files.pythonhosted.org/packages/7e/71/ce9a303b94e64fe25d534593522ffc76848c4e64c11e4cbe9f6b8d537210/torchvision-0.22.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6c5620e10ffe388eb6f4744962106ed7cf1508d26e6fdfa0c10522d3249aea24", size = 2514016 },
{ url = "https://files.pythonhosted.org/packages/09/42/6908bff012a1dcc4fc515e52339652d7f488e208986542765c02ea775c2f/torchvision-0.22.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ce292701c77c64dd3935e3e31c722c3b8b176a75f76dc09b804342efc1db5494", size = 7447546 },
{ url = "https://files.pythonhosted.org/packages/e4/cf/8f9305cc0ea26badbbb3558ecae54c04a245429f03168f7fad502f8a5b25/torchvision-0.22.0-cp311-cp311-win_amd64.whl", hash = "sha256:e4017b5685dbab4250df58084f07d95e677b2f3ed6c2e507a1afb8eb23b580ca", size = 1716472 },
{ url = "https://files.pythonhosted.org/packages/cb/ea/887d1d61cf4431a46280972de665f350af1898ce5006cd046326e5d0a2f2/torchvision-0.22.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:31c3165418fe21c3d81fe3459e51077c2f948801b8933ed18169f54652796a0f", size = 1947826 },
{ url = "https://files.pythonhosted.org/packages/72/ef/21f8b6122e13ae045b8e49658029c695fd774cd21083b3fa5c3f9c5d3e35/torchvision-0.22.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:8f116bc82e0c076e70ba7776e611ed392b9666aa443662e687808b08993d26af", size = 2514571 },
{ url = "https://files.pythonhosted.org/packages/7c/48/5f7617f6c60d135f86277c53f9d5682dfa4e66f4697f505f1530e8b69fb1/torchvision-0.22.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:ce4dc334ebd508de2c534817c9388e928bc2500cf981906ae8d6e2ca3bf4727a", size = 7446522 },
{ url = "https://files.pythonhosted.org/packages/99/94/a015e93955f5d3a68689cc7c385a3cfcd2d62b84655d18b61f32fb04eb67/torchvision-0.22.0-cp312-cp312-win_amd64.whl", hash = "sha256:24b8c9255c209ca419cc7174906da2791c8b557b75c23496663ec7d73b55bebf", size = 1716664 },
]
[[package]]
@@ -3284,12 +3297,15 @@ wheels = [
[[package]]
name = "triton"
version = "3.2.0"
version = "3.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "setuptools", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/01/65/3ffa90e158a2c82f0716eee8d26a725d241549b7d7aaf7e4f44ac03ebd89/triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e54983cd51875855da7c68ec05c05cf8bb08df361b1d5b69e05e40b0c9bd62", size = 253090354 },
{ url = "https://files.pythonhosted.org/packages/a7/2e/757d2280d4fefe7d33af7615124e7e298ae7b8e3bc4446cdb8e88b0f9bab/triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220", size = 253157636 },
{ url = "https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c", size = 253159365 },
{ url = "https://files.pythonhosted.org/packages/76/04/d54d3a6d077c646624dc9461b0059e23fd5d30e0dbe67471e3654aec81f9/triton-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fad99beafc860501d7fcc1fb7045d9496cbe2c882b1674640304949165a916e7", size = 156441993 },
{ url = "https://files.pythonhosted.org/packages/3c/c5/4874a81131cc9e934d88377fbc9d24319ae1fb540f3333b4e9c696ebc607/triton-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3161a2bf073d6b22c4e2f33f951f3e5e3001462b2570e6df9cd57565bdec2984", size = 156528461 },
{ url = "https://files.pythonhosted.org/packages/11/53/ce18470914ab6cfbec9384ee565d23c4d1c55f0548160b1c7b33000b11fd/triton-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b68c778f6c4218403a6bd01be7484f6dc9e20fe2083d22dd8aef33e3b87a10a3", size = 156504509 },
]
[[package]]
@@ -3642,20 +3658,20 @@ wheels = [
[[package]]
name = "xformers"
version = "0.0.29.post3"
version = "0.0.30"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "numpy", marker = "sys_platform != 'darwin'" },
{ name = "torch", marker = "sys_platform != 'darwin'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/c1/fd/e9201fbee6a1a6d7a9c67c24a256ad4c2377bc67a634f7dbeaea23bd668a/xformers-0.0.29.post3.tar.gz", hash = "sha256:0b77c67ecc3c9fdd8a0e4399e675adf12e2ff40285e00974cca2d09108157f60", size = 8461348 }
sdist = { url = "https://files.pythonhosted.org/packages/bf/f7/dd2269cce89fd1221947dd7cc3a60707ffe721ef55c1803ac3b1a1f7ae5c/xformers-0.0.30.tar.gz", hash = "sha256:a12bf3eb39e294cdbe8a7253ac9b665f41bac61d6d98df174e34ef7bdb6f2fc4", size = 10214139 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/55/4f/ef63f866ec7d3a23f78629604deaf379cd833fa4fd0cf7e6f8a77906f125/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:59c64379c015f36eb69947f1ee559a0ac8742159ef3d9b8e5a6d1519806d2101", size = 43339848 },
{ url = "https://files.pythonhosted.org/packages/a5/6f/036bfe1fa98f0c48cc38b389dd808f01b34a017f62afcb6cf4bb5177cd5d/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl", hash = "sha256:34f13d69ad9404e44ae11169e99842c3fccdc2c75b1fc4831a70c78392d990db", size = 167739592 },
{ url = "https://files.pythonhosted.org/packages/55/05/9c9faf1c7b3b7b986bbf7a488a185eb67670a8435d0eae94aa59f56181cd/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:bbf2f500dfdbcf4649bf568cc2c9f434399f704dc4064fd1fbdbef2b524a8139", size = 43362399 },
{ url = "https://files.pythonhosted.org/packages/e0/9f/8195d17a5ad1b601bb487f24e54331d102df7f1649e2ced6375eef272e28/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl", hash = "sha256:00f2dfd94c894ff6372e21bee3f09e96bce75b55649df366649c43f049eb7a1e", size = 167742633 },
{ url = "https://files.pythonhosted.org/packages/2d/4a/20b2d9ac50efa0d40fbdb13283fd168cc2db28a2f21a159abbdd17a24213/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:08fa92f3e06372c4ce2a5306c54ae3d4a3a399fc7e24e02aac3761112ec3aeed", size = 43364118 },
{ url = "https://files.pythonhosted.org/packages/d9/ec/7846937d26b2601e40cd6e64583657f753415b94ad318e4ca350270e77d2/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl", hash = "sha256:3706eca371767ff9709595185910d809fc817ec3cf4234ef44d70d2b8844d7e2", size = 167743565 },
{ url = "https://files.pythonhosted.org/packages/45/d0/4ed66b2d46bef4373f106b58361364cbd8ce53c85e60c8ea57ea254887bb/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:f9c9476fb7bd5d60c396ce096e36ae3e7c3461101da7a228ab1d2b7e64fb2318", size = 31503158 },
{ url = "https://files.pythonhosted.org/packages/ee/16/cc10aa84bfd02ceaf16f4341704fd3023790322059b147f546c3c814f8e7/xformers-0.0.30-cp310-cp310-win_amd64.whl", hash = "sha256:9e54eed6080e65455213174ad6b26c5e361715ca2d52759fde26055188802d92", size = 108010789 },
{ url = "https://files.pythonhosted.org/packages/1e/b3/9a850d949093b15ff283acae58c4f5adaf8776c57386b688c7f241f4dfbf/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:60396dff69a04071249809885962b7365afe650a7910f094d67b045b47a60388", size = 31518717 },
{ url = "https://files.pythonhosted.org/packages/c4/37/7df25e7cb29be5620d41f8d8cc71fc160f52c3b02d67de1feac1a5812537/xformers-0.0.30-cp311-cp311-win_amd64.whl", hash = "sha256:7b2e2aa615bce02ac20d58232b0e17304c62ec533ac0db2040a948df0155858d", size = 108011177 },
{ url = "https://files.pythonhosted.org/packages/e6/c6/6f2c364881da54e51a23c17c50db0518d30353bb6da8b1751be9174df538/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:357875986f50f105f445dc9a002c8450623cd4a6a469865c463285d0376fe77b", size = 31521318 },
{ url = "https://files.pythonhosted.org/packages/49/85/28d96d090733ba6859e4195f7c9dcb28196fc2e89197bba5de8d36f1a082/xformers-0.0.30-cp312-cp312-win_amd64.whl", hash = "sha256:8549ca30700d70dae904ec4407c6188cd73fd551e585f862c1d3aca3b7bc371c", size = 108011356 },
]
[[package]]