Compare commits

...

148 Commits

Author SHA1 Message Date
psychedelicious
7596c07a64 chore: prep for v5.9.0rc2 2025-03-25 10:21:23 +11:00
Kevin Turner
98fd1d949b fix: make dev_reload work for files in nodes/ 2025-03-25 10:04:17 +11:00
Linos
6312e6aa8f translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1832 of 1832 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-03-25 08:00:45 +11:00
Riccardo Giovanetti
6435f11bae translationBot(ui): update translation (Italian)
Currently translated at 98.7% (1815 of 1838 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.7% (1809 of 1832 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-03-25 08:00:45 +11:00
psychedelicious
1c69b9b1fa fix(ui): restore display: flex to image viewer and node editor
This was inadventently removed in #7786 and caused some minor layout overflow.
2025-03-25 07:44:07 +11:00
psychedelicious
731970ff88 fix(ui): use expanded mask for paste-back when inpainting 2025-03-25 00:03:13 +11:00
psychedelicious
038bac1614 feat(ui): make it clearer that we are doing scale before processing in graph builders 2025-03-25 00:03:13 +11:00
jazzhaiku
ed9efe7740 Port LLaVA to new API (#7817)
## Summary

- Port LLaVA model config to new classification API
- Add 2 test cases (stripped LLaVA models variants to git-lfs)

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-24 22:50:54 +11:00
jazzhaiku
ffa0beba7a Merge branch 'main' into llava 2025-03-24 15:17:33 +11:00
psychedelicious
75d793f1c4 fix(ui): siglip model translation key 2025-03-24 13:26:38 +11:00
psychedelicious
2b086917e0 chore(ui): lint 2025-03-24 13:24:13 +11:00
psychedelicious
a9f2738086 feat(ui): layout improvements for string field collection input 2025-03-24 13:24:13 +11:00
psychedelicious
3a56799ea5 tidy(ui): remove unused code 2025-03-24 13:24:13 +11:00
psychedelicious
3162ce94dc tidy(ui): use settings for node field settings instead of config
Non-functional naming change to clarify the logic
2025-03-24 13:24:13 +11:00
psychedelicious
c0dc6ac4e1 fix(ui): issue where string drop-down options are not removed when changing component to a different type 2025-03-24 13:24:13 +11:00
psychedelicious
fed1995525 chore(ui): lint 2025-03-24 13:24:13 +11:00
psychedelicious
5006e23456 feat(ui): added reset options button 2025-03-24 13:24:13 +11:00
psychedelicious
2f063bddda fix(ui): restore field-node overlay
Accidentally removed it
2025-03-24 13:24:13 +11:00
psychedelicious
23a26422fd feat(ui): support for custom string field dropdowns in builder 2025-03-24 13:24:13 +11:00
psychedelicious
434f195a96 feat(ui): add empty string placeholder to string fields 2025-03-24 13:24:13 +11:00
psychedelicious
6a4c2d692c chore(ui): typegen 2025-03-24 12:45:46 +11:00
psychedelicious
5127a07cf9 feat(nodes): clean up lora node names
I had named them wonkily and caused some user confusion.
2025-03-24 12:45:46 +11:00
psychedelicious
0b4c6f0ab4 fix(mm): flux model variant probing
In #7780 we added FLUX Fill support, and needed the probe to be able to distinguish between "normal" FLUX models and FLUX Fill models.

Logic was added to the probe to check a particular state dict key (input channels), which should be 384 for FLUX Fill and 64 for other FLUX models.

The new logic was stricter and instead of falling back on the "normal" variant, it raised when an unexpected value for input channels was detected.

This caused failures to probe for BNB-NF4 quantized FLUX Dev/Schnell, which apparently only have 1 input channel.

After checking a variety of FLUX models, I loosened the strictness of the variant probing logic to only special-case the new FLUX Fill model, and otherwise fall back to returning the "normal" variant. This better matches the old behaviour and fixes the import errors.

Closes #7822
2025-03-24 12:36:18 +11:00
Billy
d8450033ea Fix 2025-03-21 17:46:18 +11:00
Billy
3938736bd8 Ruff formatting 2025-03-21 17:35:12 +11:00
Billy
fb2c7b9566 Defaults 2025-03-21 17:35:04 +11:00
Billy
29449ec27d Implement new api for LLaVA 2025-03-21 17:17:56 +11:00
Billy
e38f778d28 Extend ModelOnDisk 2025-03-21 17:17:15 +11:00
Billy
f5e78436a8 Update regression test 2025-03-21 17:14:02 +11:00
Billy
6a15b5d9be Add stripped models for testing llava 2025-03-21 15:34:20 +11:00
psychedelicious
a629102c87 chore(ui): update whatsnew 2025-03-21 13:09:27 +11:00
psychedelicious
848ade8ab8 chore: prep for v5.9.0rc1 2025-03-21 13:09:27 +11:00
Hosted Weblate
2110feb01c translationBot(ui): update translation files
Updated by "Cleanup translation files" hook in Weblate.

Co-authored-by: Hosted Weblate <hosted@weblate.org>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/
Translation: InvokeAI/Web UI
2025-03-21 12:55:07 +11:00
Riku
f3e1821957 translationBot(ui): update translation (German)
Currently translated at 67.0% (1224 of 1826 strings)

Co-authored-by: Riku <riku.block@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/de/
Translation: InvokeAI/Web UI
2025-03-21 12:55:07 +11:00
Linos
bbcf93089a translationBot(ui): update translation (Vietnamese)
Currently translated at 100.0% (1827 of 1827 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1826 of 1826 strings)

translationBot(ui): update translation (Vietnamese)

Currently translated at 100.0% (1825 of 1825 strings)

Co-authored-by: Linos <linos.coding@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/vi/
Translation: InvokeAI/Web UI
2025-03-21 12:55:07 +11:00
Riccardo Giovanetti
66f41aa307 translationBot(ui): update translation (Italian)
Currently translated at 98.7% (1804 of 1827 strings)

translationBot(ui): update translation (Italian)

Currently translated at 98.7% (1803 of 1825 strings)

Co-authored-by: Riccardo Giovanetti <riccardo.giovanetti@gmail.com>
Translate-URL: https://hosted.weblate.org/projects/invokeai/web-ui/it/
Translation: InvokeAI/Web UI
2025-03-21 12:55:07 +11:00
psychedelicious
8a709766b3 feat(ui): better error for unknown fields in builder view mode 2025-03-21 12:51:12 +11:00
psychedelicious
efaa20a7a1 feat(ui): better labels for missing/unexpected fields 2025-03-21 12:51:12 +11:00
psychedelicious
3e4c808b23 refactor(ui): organise useInputFieldTemplate hooks again & add useInputFieldTemplateSafe 2025-03-21 12:51:12 +11:00
psychedelicious
00e3931af4 chore(ui): "useInputFieldLabel" -> "useInputFieldLabelSafe"
Also update docstrings
2025-03-21 12:51:12 +11:00
psychedelicious
08bea07f8b chore(ui): "useInputFieldDescription" -> "useInputFieldDescriptionSafe"
Also update docstrings
2025-03-21 12:51:12 +11:00
psychedelicious
166d2f0e39 chore(ui): "useInputFieldTemplate" -> "useInputFieldTemplateOrThrow" 2025-03-21 12:51:12 +11:00
psychedelicious
21f346717a docs(ui): add docstring to useInputFieldTemplate 2025-03-21 12:51:12 +11:00
psychedelicious
f966fb8b9c docs(ui): add docstring to useInputFieldDescription 2025-03-21 12:51:12 +11:00
psychedelicious
c2b20a5387 feat(ui): hide guidance when FLUX Fill model selected 2025-03-21 10:24:03 +11:00
psychedelicious
bed9089fe6 refactor(ui): just always set guidance to 30 when using FLUX Fill 2025-03-21 10:24:03 +11:00
psychedelicious
d34a4f765c feat(ui): better error for FLUX Fill + t2i/i2i incompatibility 2025-03-21 10:24:03 +11:00
psychedelicious
efe4708b8b feat(ui): better error message/warning for FLUX Fill w/ Control LoRA 2025-03-21 10:24:03 +11:00
psychedelicious
7cb1f61a9e feat(ui): bump FLUX guidance up to 30 if it's too low during graph building 2025-03-21 10:24:03 +11:00
psychedelicious
6e2ef34cba feat(ui): add warning for FLUX Fill + Control LoRA 2025-03-21 10:24:03 +11:00
psychedelicious
d208b99a47 feat(ui): pass the full model config throughout validation logic 2025-03-21 10:24:03 +11:00
psychedelicious
47eeafa5cb feat(ui): add selector to select the main model full config object 2025-03-21 10:24:03 +11:00
psychedelicious
0cb00fbe53 refactor(ui): use new compositing nodes for inpaint/outpaint graphs 2025-03-21 10:24:03 +11:00
psychedelicious
a7e8ed3bc2 feat(ui): add FLUX Fill graph builder util 2025-03-21 10:24:03 +11:00
psychedelicious
22eb25be48 refactor(ui): use more succient syntax to opt-out of RTKQ caching for model fetching utils 2025-03-21 10:24:03 +11:00
psychedelicious
a077f3fefc chore(ui): typegen 2025-03-21 10:24:03 +11:00
psychedelicious
c013a6e38d feat(nodes): deprecate canvas_v2_mask_and_crop 2025-03-21 10:24:03 +11:00
psychedelicious
6cfeb71bed feat(nodes): add expand_mask_with_fade to better handle canvas compositing needs
Previously we used erode/dilate and a Gaussian blur to expand and fade the edges of Canvas masks. The implementation a number of problems:
- Erode/dilate kernel sizes were not calculated correctly, and extra iterations were run to compensate. The result is the blur size, which should have been pixels, was very inaccurate and unreliable.
- What we want is to add a "soft bleed" - like a drop shadow with no offset - starting from the edge of the mask, extending out by however many pixels. But Gaussian blur does not do this. The blurred area starts _inside_ the mask and extends outside it. So it kinda blurs inwards and outwards. We compensated for this by expanding the mask.
- Using a Gaussian blur can cause banding artifacts. Gaussian blur doesn't have a "size" or "radius" parameter in the sense that you think it should. It's a convolution matrix and there are _no non-zero values in the result_. This means that, far away from the mask, once compositing completes, we have some values that are very close to zero but not quite zero. These values are quantized by HTML Canvas, resulting in banding artifacts where you'd expect the blur to have faded to 0% alpha. At least, that is my understanding of why the banding artifacts occur.

The new node uses a better strategy to expand the mask and add the fade out effect:
- Calculate the distance from each white pixel to the nearest black pixel.
- Normalize this distance by dividing by the fade size in px, then clip the values to 0 - 1. The result represents the distance of each white pixel to its nearest black pixel as a percentage of the fade size. At this point, it is a linear distribution.
- Create a polynomial to describe the fade's intensity so that we can have a smooth transition from the masked region (black) to unmasked (white). There are some magic numbers here, deterined experimentally.
- Evaluate the polynomial over the normalized distances, so we now have a matrix representing the fade intensity for every pixel
- Convert this matrix back to uint8 and apply it to the mask

This works soooo much better than the previous method. Not only does it fix the banding issues, but when we enable "output only generated regions", we get a much smaller image. Will add images to the PR to clarify.
2025-03-21 10:24:03 +11:00
psychedelicious
534f993023 feat(nodes): add apply_mask_to_image node
It simply applies the mask to an image.
2025-03-21 10:24:03 +11:00
psychedelicious
67f9b6420c fix(nodes): ensure alpha mask is opened as RGBA 2025-03-21 10:24:03 +11:00
psychedelicious
61bf065237 feat(nodes): rename "FLUX Fill" -> "FLUX Fill Conditioning" 2025-03-21 10:24:03 +11:00
psychedelicious
e78cf889ee fix(ui): clip shift-draw strokes to bbox when clip to bbox enabled
Closes #7809
2025-03-21 08:14:20 +11:00
psychedelicious
5d13f0ba15 tidy(ui): remove recommended flag from workflow (believe was for testing purposes) 2025-03-20 08:50:01 -04:00
psychedelicious
633b9afa46 fix(ui): recommended star stretches tag list layout 2025-03-20 08:50:01 -04:00
psychedelicious
f1889b259d tidy(ui): split browse workflows button into own component 2025-03-20 08:50:01 -04:00
psychedelicious
ed21d0b57e tidy(ui): remove noop useEffect 2025-03-20 08:50:01 -04:00
Mary Hipp
df90da28e1 tsc fix 2025-03-20 15:43:57 +11:00
Mary Hipp
702054aa62 make sure browse is selected 2025-03-20 15:43:57 +11:00
Mary Hipp
636ec1de6e add viewAllWorkflowsRecommended to studio init action to show library with only recomended workflows 2025-03-20 15:43:57 +11:00
Mary Hipp
063d07fd41 switch to using recommended with star insteaed of auto-selecting 2025-03-20 15:43:57 +11:00
Mary Hipp
c78eac624e update workflow tag/categories so that we can pass in 1+ selected tags to start with 2025-03-20 15:43:57 +11:00
Mary Hipp
05de3b7a84 workflow library UI updates: scrollbar to make obvious its overflowing, move deselecet all tags to be next to browse button 2025-03-20 15:43:57 +11:00
Ryan Dick
9cc2232b6f Bump FluxDenoise invocation version and typegen. 2025-03-19 14:45:18 +11:00
Ryan Dick
9fdc06b447 Add FLUX Fill input validation and error/warning reporting. 2025-03-19 14:45:18 +11:00
Ryan Dick
5ea3ec5cc8 Get FLUX Fill working. Note: To use FLUX Fill, set guidance to ~30. 2025-03-19 14:45:18 +11:00
Ryan Dick
f13a07ba6a WIP on updating FluxDenoise to support FLUX Fill. 2025-03-19 14:45:18 +11:00
Ryan Dick
a913f0163d WIP - Add FluxFillInvocation 2025-03-19 14:45:18 +11:00
Ryan Dick
f7cfbd1323 Add FLUX Fill starter model. 2025-03-19 14:45:18 +11:00
Ryan Dick
2806b60701 Add logic to probe FLUX variant (NORMAL vs INPAINT). 2025-03-19 14:45:18 +11:00
psychedelicious
d8c3af624b Use git-lfs for larger assets (#7804)
## Summary

- Integrate Git LFS to our automated Python tests in CI
- Add stripped model files with git-lfs
- `README.md` instructions to install and configure git-lfs
- Unrelated change (skip hashing to make unit test run faster)

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-19 09:53:26 +11:00
psychedelicious
feed44b68d Stripped models (#7797)
## Summary

**Problem**
We want to have automated tests for model classification/probing, but
model files are too large to include in the source.

**Proposed Solution**
Classification/probing only requires metadata (key names, tensor
shapes), not weights.
This PR introduces "stripped" models - lightweight versions that retains
only essential metadata.

- Added script to strip models
- Added stripped models to automated tests


**Model size before and after "stripping":**
```
LLaVA Onevision Qwen2 0.5b-ov-hf before: 1.8 GB, after: 11.6 MB
text_encoder before: 246.1 MB, after: 35.6 kB
llava-onevision-qwen2-7b-si-hf before: 16.1 GB, after: 11.7 MB
RealESRGAN_x2plus.pth before: 67.1 MB, after: 143.0 kB
IP Adapter SD1 before: 2.5 GB, after: 94.9 kB
Hard Edge Detection (canny) before: 722.6 MB, after: 63.6 kB
Lineart before: 722.6 MB, after: 63.6 kB
Segmentation Map before: 722.6 MB, after: 63.6 kB
EasyNegative before: 24.7 kB, after: 151 Bytes
Face Reference (IP Adapter Plus Face) before: 98.2 MB, after: 13.7 kB
Standard Reference (IP Adapter) before: 44.6 MB, after: 6.0 kB
shinkai_makoto_offset before: 151.1 MB, after: 160.0 kB
thickline_fp16 before: 151.1 MB, after: 160.0 kB
Alien Style before: 228.5 MB, after: 582.6 kB
Noodles Style before: 228.5 MB, after: 582.6 kB
Juggernaut XL v9 before: 6.9 GB, after: 3.7 MB
dreamshaper-8 before: 168.9 MB, after: 1.6 MB
```





## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-19 08:13:10 +11:00
Billy
247f3b5d67 Merge branch 'stripped-models' into git-lfs 2025-03-19 07:53:27 +11:00
Billy
8e14f9d971 Merge branch 'main' into stripped-models 2025-03-19 07:52:56 +11:00
Billy
bdb44ee48d Merge branch 'git-lfs' of github.com:invoke-ai/InvokeAI into git-lfs 2025-03-19 07:30:34 +11:00
Billy
b57f5330c5 Pin action to commit 2025-03-19 07:28:28 +11:00
jazzhaiku
ade3c015b4 Update docs/contributing/dev-environment.md
Co-authored-by: Eugene Brodsky <ebr@users.noreply.github.com>
2025-03-19 07:23:23 +11:00
psychedelicious
7fe4d4c21a feat(app): better errors when scanning models with picklescan
Differentiate between malware detection and scan error.
2025-03-19 07:20:25 +11:00
psychedelicious
133a7fde55 Model classification api (#7742)
## Summary
The _goal_ of this PR is to make it easier to add an new config type.
This _scope_ of this PR is to integrate the API and does not include
adding new configs (outside tests) or porting existing ones.


One of the glaring issues of the existing *legacy probe* is that the
logic for each type is spread across multiple classes and intertwined
with the other configs. This means that adding a new config type (or
modifying an existing one) is complex and error prone.

This PR attempts to remedy this by providing a new API for adding
configs that:

- Is backwards compatible with the existing probe.
- Encapsulates fields and logic in a single class, keeping things
self-contained and easy to modify safely.

Below is a minimal toy example illustrating the proposed new structure:

```python
class MinimalConfigExample(ModelConfigBase):
    type: ModelType = ModelType.Main
    format: ModelFormat = ModelFormat.Checkpoint
    fun_quote: str

    @classmethod
    def matches(cls, mod: ModelOnDisk) -> bool:
        return mod.path.suffix == ".json"

    @classmethod
    def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
        with open(mod.path, "r") as f:
            contents = json.load(f)

        return {
            "fun_quote": contents["quote"],
            "base": BaseModelType.Any,
        }
```

To create a new config type, one needs to inherit from `ModelConfigBase`
and implement its interface.

The code falls back to the legacy model probe for existing models using
the old API.
This allows us to incrementally port the configs one by one.



## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2025-03-18 15:25:56 +11:00
Billy
6375214878 Merge branch 'stripped-models' into git-lfs 2025-03-18 14:57:58 +11:00
Billy
b9972be7f1 Merge branch 'model-classification-api' into stripped-models 2025-03-18 14:57:23 +11:00
Billy
e61c5a3f26 Merge 2025-03-18 14:55:11 +11:00
Billy
8c633786f6 Remove accidently included files 2025-03-18 14:16:51 +11:00
Billy
8703eea49b LFS cache 2025-03-18 14:08:21 +11:00
Billy
c8888be4c3 Formatting 2025-03-18 13:10:07 +11:00
Billy
11963a65a4 CI/CD 2025-03-18 12:56:28 +11:00
Billy
ab6422fdf7 Add to README.md 2025-03-18 12:37:32 +11:00
psychedelicious
1f8632029e fix(nodes): add validator to vllm node images field to handle single image field inputs 2025-03-18 11:53:06 +11:00
Ryan Dick
88a762474d typegen 2025-03-18 11:53:06 +11:00
Ryan Dick
e6dd721e33 Add max_length=3 to the LLaVA OneVision image input field. 2025-03-18 11:53:06 +11:00
Billy
2a09604baf Formatting 2025-03-18 11:53:06 +11:00
Billy
f94f00ede0 Ruff formatting 2025-03-18 11:53:06 +11:00
Billy
37af281299 WIP - model selection for LLaVA 2025-03-18 11:53:06 +11:00
Billy
fc82775d7a WIP - model selection for LLaVA 2025-03-18 11:53:06 +11:00
Billy
9ed46f60b7 Add LLaVA OneVision to Config dropdown in UI 2025-03-18 11:53:06 +11:00
Ryan Dick
9a389e6b93 Add a LLaVA OneVision starter model. 2025-03-18 11:53:06 +11:00
Ryan Dick
2ef1ecf381 Fix copy-paste errors. 2025-03-18 11:53:06 +11:00
Ryan Dick
41de112932 Make LLaVA Onevision node work with 0 images, and other minor improvements. 2025-03-18 11:53:06 +11:00
Ryan Dick
e9714fe476 Add LLaVA Onevision model loading and inference support. 2025-03-18 11:53:06 +11:00
Ryan Dick
3f29293e39 Add LlavaOnevision model type and probing logic. 2025-03-18 11:53:06 +11:00
Billy
db1aa38e98 Warning 2025-03-18 09:55:13 +11:00
Billy
12717d4a4d Stripped model data 2025-03-18 09:51:10 +11:00
Billy
1953f3cbcd Skip hashing to make test quicker 2025-03-18 09:50:18 +11:00
Billy
3469fc9843 Ruff 2025-03-18 09:22:16 +11:00
Billy
7cdd4187a9 Update classify script 2025-03-18 09:21:38 +11:00
Billy
ad66c101d2 Remove stripped model files 2025-03-18 09:10:37 +11:00
psychedelicious
28d3356710 chore: prep for v5.8.1 2025-03-18 09:06:47 +11:00
psychedelicious
81e70fb9d2 tidy(app): errant character 2025-03-18 08:00:51 +11:00
psychedelicious
971c425734 fix(app): incorrect values inserted when retrying queue item
In #7688 we optimized queuing preparation logic. This inadvertently broke retrying queue items.

Previously, a `NamedTuple` was used to store the values to insert in the DB when enqueuing. This handy class provides an API similar to a dataclass, where you can instantiate it with kwargs in any order. The resultant tuple re-orders the kwargs to match the order in the class definition.

For example, consider this `NamedTuple`:
```py
class SessionQueueValueToInsert(NamedTuple):
    foo: str
    bar: str
```

When instantiating it, no matter the order of the kwargs, if you make a normal tuple out of it, the tuple values are in the same order as in the class definition:

```
t1 = SessionQueueValueToInsert(foo="foo", bar="bar")
print(tuple(t1)) # -> ('foo', 'bar')

t2 = SessionQueueValueToInsert(bar="bar", foo="foo")
print(tuple(t2)) # -> ('foo', 'bar')
```

So, in the old code, when we used the `NamedTuple`, it implicitly normalized the order of the values we insert into the DB.

In the retry logic, the values of the tuple were not ordered correctly, but the use of `NamedTuple` had secretly fixed the order for us.

In the linked PR, `NamedTuple` was dropped for a normal tuple, after profiling showed `NamedTuple` to be meaningfully slower than a normal tuple.

The implicit order normalization behaviour wasn't understood, and the order wasn't fixed when changin the retry logic to use a normal tuple instead of `NamedTuple`. This results in a bug where we incorrectly create queue items in the DB. For example, we stored the `destination` in the `field_values` column.

When such an incorrectly-created queue item is dequeued, it fails pydantic validation and causes what appears to be an endless loop of errors.

The only user-facing solution is to add this line to `invokeai.yaml` and restart the app:
```yaml
clear_queue_on_startup: true
```

On next startup, the queue is forcibly cleared before the error loop is triggered. Then the user should remove this line so their queue is persisted across app launches per usual.

The solution is simple - fix the ordering of the tuple. I also added a type annotation and comment to the tuple type alias definition.

Note: The endless error loop, as a general problem, will take some thinking to fix. The queue service methods to cancel and fail a queue item still retrieve it and parse it. And the list queue items methods parse the queue items. Bit of a catch 22, maybe the solution is to simply delete totally borked queue items and log an error.
2025-03-18 08:00:51 +11:00
psychedelicious
b09008c530 feat(ui): add cancel and clear all as toggleable app feature 2025-03-18 06:48:10 +11:00
Billy
f9f99f873d More models 2025-03-17 04:18:44 +00:00
Billy
7f93f1b600 Dependencies 2025-03-17 12:57:13 +11:00
Billy
b1d336ce8a Ruff 2025-03-17 12:19:27 +11:00
Billy
40c7be8f5d Warning about missing test cases 2025-03-17 12:19:15 +11:00
Billy
24218b34bf Make ruff happy 2025-03-17 12:04:26 +11:00
Billy
d970c6d6d5 Use override fixture 2025-03-17 11:58:13 +11:00
Billy
e5308be0bb Use override fixture 2025-03-17 11:31:20 +11:00
Billy
7d5687e9ff Disable device meta for spandrel 2025-03-17 11:30:05 +11:00
Billy
654e992630 Accept extra args 2025-03-17 10:25:16 +11:00
Billy
21f247f499 Stripped models script 2025-03-17 09:18:58 +11:00
Billy
8bcd9fe4b7 Extend ModelOnDisk 2025-03-17 09:18:51 +11:00
Billy
637b93d2d8 Ruff 2025-03-14 10:18:25 +11:00
Billy
565b160060 More tests 2025-03-14 10:17:43 +11:00
Billy
bdd0b90769 Merge branch 'model-classification-api' of github.com:invoke-ai/InvokeAI into model-classification-api 2025-03-13 13:37:15 +11:00
Billy
4377158503 Variant 2025-03-13 13:32:57 +11:00
Billy
c8c27079ed Codegen 2025-03-13 13:12:12 +11:00
Billy
d8b9a8d0dd Merge branch 'main' into model-classification-api 2025-03-13 13:03:51 +11:00
Billy
39a4608d15 Fix annotations compatability 3.11 2025-03-13 13:01:19 +11:00
jazzhaiku
cd2d5431db Merge branch 'main' into model-classification-api 2025-03-13 11:21:18 +11:00
Billy
c04cdd9779 Typegen 2025-03-13 11:00:26 +11:00
Billy
b86ac5e049 Explicit union 2025-03-13 10:28:07 +11:00
Billy
665236bb79 Type hints 2025-03-13 09:21:58 +11:00
Billy
f45400a275 Remove hash algo 2025-03-12 18:39:29 +11:00
Billy
be53b89203 Remove redundant hash_algo field 2025-03-11 09:28:57 +11:00
Billy
a215eeaabf Update schema 2025-03-11 09:22:29 +11:00
Billy
d86b392bfd Remove redundant hash_algo field 2025-03-11 09:16:59 +11:00
Billy
3e9e45b177 Update comments 2025-03-11 09:04:19 +11:00
Billy
907d960745 PR suggestions 2025-03-11 08:37:43 +11:00
Billy
bfdace6437 New API for model classification 2025-03-11 08:34:34 +11:00
217 changed files with 4233 additions and 1550 deletions

3
.gitattributes vendored
View File

@@ -2,4 +2,5 @@
# Only affects text files and ignores other file types.
# For more info see: https://www.aleksandrhovhannisyan.com/blog/crlf-vs-lf-normalizing-line-endings-in-git/
* text=auto
docker/** text eol=lf
docker/** text eol=lf
tests/test_model_probe/stripped_models/** filter=lfs diff=lfs merge=lfs -text

View File

@@ -72,7 +72,8 @@ jobs:
PIP_USE_PEP517: '1'
steps:
- name: checkout
uses: actions/checkout@v4
# https://github.com/nschloe/action-cached-lfs-checkout
uses: nschloe/action-cached-lfs-checkout@f46300cd8952454b9f0a21a3d133d4bd5684cfc2
- name: check for changed python files
if: ${{ inputs.always_run != true }}

View File

@@ -18,9 +18,19 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
2. [Fork and clone][forking link] the [InvokeAI repo][repo link].
3. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
3. This repository uses Git LFS to manage large files. To ensure all assets are downloaded:
- Install git-lfs → [Download here](https://git-lfs.com/)
- Enable automatic LFS fetching for this repository:
```shell
git config lfs.fetchinclude "*"
```
- Fetch files from LFS (only needs to be done once; subsequent `git pull` will fetch changes automatically):
```
git lfs pull
```
4. Create an directory for user data (images, models, db, etc). This is typically at `~/invokeai`, but if you already have a non-dev install, you may want to create a separate directory for the dev install.
4. Follow the [manual install][manual install link] guide, with some modifications to the install command:
5. Follow the [manual install][manual install link] guide, with some modifications to the install command:
- Use `.` instead of `invokeai` to install from the current directory. You don't need to specify the version.
@@ -34,19 +44,19 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
uv pip install -e ".[dev,test,docs,xformers]" --python 3.11 --python-preference only-managed --index=https://download.pytorch.org/whl/cu124 --reinstall
```
5. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
6. At this point, you should have Invoke installed, a venv set up and activated, and the server running. But you will see a warning in the terminal that no UI was found. If you go to the URL for the server, you won't get a UI.
This is because the UI build is not distributed with the source code. You need to build it manually. End the running server instance.
If you only want to edit the docs, you can stop here and skip to the **Documentation** section below.
6. Install the frontend dev toolchain:
7. Install the frontend dev toolchain:
- [`nodejs`](https://nodejs.org/) (v20+)
- [`pnpm`](https://pnpm.io/8.x/installation) (must be v8 - not v9!)
7. Do a production build of the frontend:
8. Do a production build of the frontend:
```sh
cd <PATH_TO_INVOKEAI_REPO>/invokeai/frontend/web
@@ -54,7 +64,7 @@ If you just want to use Invoke, you should use the [launcher][launcher link].
pnpm build
```
8. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
9. Restart the server and navigate to the URL. You should get a UI. After making changes to the python code, restart the server to see those changes.
## Updating the UI

View File

@@ -59,6 +59,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlLoRAModel = "ControlLoRAModelField"
SigLipModel = "SigLipModelField"
FluxReduxModel = "FluxReduxModelField"
LlavaOnevisionModel = "LLaVAModelField"
# endregion
# region Misc Field Types
@@ -205,6 +206,8 @@ class FieldDescriptions:
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"
flux_fill_conditioning = "FLUX Fill conditioning tensor"
class ImageField(BaseModel):
@@ -274,6 +277,13 @@ class FluxReduxConditioningField(BaseModel):
)
class FluxFillConditioningField(BaseModel):
"""A FLUX Fill conditioning field."""
image: ImageField = Field(description="The FLUX Fill reference image.")
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -15,6 +15,7 @@ from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
FluxConditioningField,
FluxFillConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
@@ -48,7 +49,7 @@ from invokeai.backend.flux.sampling_utils import (
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.model_manager.config import ModelFormat, ModelVariantType
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -62,7 +63,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.2.3",
version="3.3.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -109,6 +110,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="FLUX Redux conditioning tensor.",
input=Input.Connection,
)
fill_conditioning: FluxFillConditioningField | None = InputField(
default=None,
description="FLUX Fill conditioning.",
input=Input.Connection,
)
cfg_scale: float | list[float] = InputField(default=1.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
cfg_scale_start_step: int = InputField(
default=0,
@@ -261,8 +267,19 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if is_schnell and self.control_lora:
raise ValueError("Control LoRAs cannot be used with FLUX Schnell")
# Prepare the extra image conditioning tensor if a FLUX structural control image is provided.
img_cond = self._prep_structural_control_img_cond(context)
# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
img_cond: torch.Tensor | None = None
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
if is_flux_fill:
img_cond = self._prep_flux_fill_img_cond(
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
)
else:
if self.fill_conditioning is not None:
raise ValueError("fill_conditioning was provided, but the model is not a FLUX Fill model.")
if self.control_lora is not None:
img_cond = self._prep_structural_control_img_cond(context)
inpaint_mask = self._prep_inpaint_mask(context, x)
@@ -271,7 +288,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
img_cond = pack(img_cond) if img_cond is not None else None
noise = pack(noise)
x = pack(x)
@@ -664,7 +680,70 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
img_cond = einops.rearrange(img_cond, "h w c -> 1 c h w")
vae_info = context.models.load(self.controlnet_vae.vae)
return FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
img_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=img_cond)
return pack(img_cond)
def _prep_flux_fill_img_cond(
self, context: InvocationContext, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
"""Prepare the FLUX Fill conditioning. This method should be called iff the model is a FLUX Fill model.
This logic is based on:
https://github.com/black-forest-labs/flux/blob/716724eb276d94397be99710a0a54d352664e23b/src/flux/sampling.py#L107-L157
"""
# Validate inputs.
if self.fill_conditioning is None:
raise ValueError("A FLUX Fill model is being used without fill_conditioning.")
# TODO(ryand): We should probable rename controlnet_vae. It's used for more than just ControlNets.
if self.controlnet_vae is None:
raise ValueError("A FLUX Fill model is being used without controlnet_vae.")
if self.control_lora is not None:
raise ValueError(
"A FLUX Fill model is being used, but a control_lora was provided. Control LoRAs are not compatible with FLUX Fill models."
)
# Log input warnings related to FLUX Fill usage.
if self.denoise_mask is not None:
context.logger.warning(
"Both fill_conditioning and a denoise_mask were provided. You probably meant to use one or the other."
)
if self.guidance < 25.0:
context.logger.warning("A guidance value of ~30.0 is recommended for FLUX Fill models.")
# Load the conditioning image and resize it to the target image size.
cond_img = context.images.get_pil(self.fill_conditioning.image.image_name, mode="RGB")
cond_img = cond_img.resize((self.width, self.height), Image.Resampling.BICUBIC)
cond_img = np.array(cond_img)
cond_img = torch.from_numpy(cond_img).float() / 127.5 - 1.0
cond_img = einops.rearrange(cond_img, "h w c -> 1 c h w")
cond_img = cond_img.to(device=device, dtype=dtype)
# Load the mask and resize it to the target image size.
mask = context.tensors.load(self.fill_conditioning.mask.tensor_name)
# We expect mask to be a bool tensor with shape [1, H, W].
assert mask.dtype == torch.bool
assert mask.dim() == 3
assert mask.shape[0] == 1
mask = tv_resize(mask, size=[self.height, self.width], interpolation=tv_transforms.InterpolationMode.NEAREST)
mask = mask.to(device=device, dtype=dtype)
mask = einops.rearrange(mask, "1 h w -> 1 1 h w")
# Prepare image conditioning.
cond_img = cond_img * (1 - mask)
vae_info = context.models.load(self.controlnet_vae.vae)
cond_img = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=cond_img)
cond_img = pack(cond_img)
# Prepare mask conditioning.
mask = mask[:, 0, :, :]
# Rearrange mask to a 16-channel representation that matches the shape of the VAE-encoded latent space.
mask = einops.rearrange(mask, "b (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
mask = pack(mask)
# Merge image and mask conditioning.
img_cond = torch.cat((cond_img, mask), dim=-1)
return img_cond
def _normalize_ip_adapter_fields(self) -> list[IPAdapterField]:
if self.ip_adapter is None:

View File

@@ -0,0 +1,46 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxFillConditioningField,
InputField,
OutputField,
TensorField,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_fill_output")
class FluxFillOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Fill invocation."""
fill_cond: FluxFillConditioningField = OutputField(
description=FieldDescriptions.flux_redux_conditioning, title="Conditioning"
)
@invocation(
"flux_fill",
title="FLUX Fill Conditioning",
tags=["inpaint"],
category="inpaint",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxFillInvocation(BaseInvocation):
"""Prepare the FLUX Fill conditioning data."""
image: ImageField = InputField(description="The FLUX Fill reference image.")
mask: TensorField = InputField(
description="The bool inpainting mask. Excluded regions should be set to "
"False, included regions should be set to True.",
)
def invoke(self, context: InvocationContext) -> FluxFillOutput:
return FluxFillOutput(fill_cond=FluxFillConditioningField(image=self.image, mask=self.mask))

View File

@@ -28,10 +28,10 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
@invocation(
"flux_lora_loader",
title="FLUX LoRA",
title="Apply LoRA - FLUX",
tags=["lora", "model", "flux"],
category="model",
version="1.2.0",
version="1.2.1",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
@@ -107,10 +107,10 @@ class FluxLoRALoaderInvocation(BaseInvocation):
@invocation(
"flux_lora_collection_loader",
title="FLUX LoRA Collection Loader",
title="Apply LoRA Collection - FLUX",
tags=["lora", "model", "flux"],
category="model",
version="1.3.0",
version="1.3.1",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):

View File

@@ -1051,7 +1051,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard):
tags=["image", "mask", "id"],
category="image",
version="1.0.0",
classification=Classification.Internal,
classification=Classification.Deprecated,
)
class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Handles Canvas V2 image output masking and cropping"""
@@ -1089,6 +1089,112 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(image_dto)
@invocation(
"expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.0"
)
class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
The mask is thresholded to create a binary mask, and then a distance transform is applied to create a fade effect.
The fade size is specified in pixels, and the mask is expanded by that amount. The result is a mask with a smooth transition from black to white.
"""
mask: ImageField = InputField(description="The mask to expand")
threshold: int = InputField(default=0, ge=0, le=255, description="The threshold for the binary mask (0-255)")
fade_size_px: int = InputField(default=32, ge=0, description="The size of the fade in pixels")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_mask = context.images.get_pil(self.mask.image_name, mode="L")
np_mask = numpy.array(pil_mask)
# Threshold the mask to create a binary mask - 0 for black, 255 for white
# If we don't threshold we can get some weird artifacts
np_mask = numpy.where(np_mask > self.threshold, 255, 0).astype(numpy.uint8)
# Create a mask for the black region (1 where black, 0 otherwise)
black_mask = (np_mask == 0).astype(numpy.uint8)
# Invert the black region
bg_mask = 1 - black_mask
# Create a distance transform of the inverted mask
dist = cv2.distanceTransform(bg_mask, cv2.DIST_L2, 5)
# Normalize distances so that pixels <fade_size_px become a linear gradient (0 to 1)
d_norm = numpy.clip(dist / self.fade_size_px, 0, 1)
# Control points: x values (normalized distance) and corresponding fade pct y values.
# There are some magic numbers here that are used to create a smooth transition:
# - The first point is at 0% of fade size from edge of mask (meaning the edge of the mask), and is 0% fade (black)
# - The second point is 1px from the edge of the mask and also has 0% fade, effectively expanding the mask
# by 1px. This fixes an issue where artifacts can occur at the edge of the mask
# - The third point is at 20% of the fade size from the edge of the mask and has 20% fade
# - The fourth point is at 80% of the fade size from the edge of the mask and has 90% fade
# - The last point is at 100% of the fade size from the edge of the mask and has 100% fade (white)
# x values: 0 = mask edge, 1 = fade_size_px from edge
x_control = numpy.array([0.0, 1.0 / self.fade_size_px, 0.2, 0.8, 1.0])
# y values: 0 = black, 1 = white
y_control = numpy.array([0.0, 0.0, 0.2, 0.9, 1.0])
# Fit a cubic polynomial that smoothly passes through the control points
coeffs = numpy.polyfit(x_control, y_control, 3)
poly = numpy.poly1d(coeffs)
# Evaluate and clip the smooth mapping
feather = numpy.clip(poly(d_norm), 0, 1)
# Build final image.
np_result = numpy.where(black_mask == 1, 0, (feather * 255).astype(numpy.uint8))
# Convert back to PIL, grayscale
pil_result = Image.fromarray(np_result.astype(numpy.uint8), mode="L")
image_dto = context.images.save(image=pil_result, image_category=ImageCategory.MASK)
return ImageOutput.build(image_dto)
@invocation(
"apply_mask_to_image",
title="Apply Mask to Image",
tags=["image", "mask", "blend"],
category="image",
version="1.0.0",
)
class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""
Extracts a region from a generated image using a mask and blends it seamlessly onto a source image.
The mask uses black to indicate areas to keep from the generated image and white for areas to discard.
"""
image: ImageField = InputField(description="The image from which to extract the masked region")
mask: ImageField = InputField(description="The mask defining the region (black=keep, white=discard)")
invert_mask: bool = InputField(
default=False,
description="Whether to invert the mask before applying it",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load images
image = context.images.get_pil(self.image.image_name, mode="RGBA")
mask = context.images.get_pil(self.mask.image_name, mode="L")
if self.invert_mask:
# Invert the mask if requested
mask = ImageOps.invert(mask.copy())
# Combine the mask as the alpha channel of the image
r, g, b, _ = image.split() # Split the image into RGB and alpha channels
result_image = Image.merge("RGBA", (r, g, b, mask)) # Use the mask as the new alpha channel
# Save the resulting image
image_dto = context.images.save(image=result_image)
return ImageOutput.build(image_dto)
@invocation(
"img_noise",
title="Add Image Noise",

View File

@@ -0,0 +1,60 @@
from typing import Any
import torch
from PIL.Image import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, UIComponent, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.util.devices import TorchDevice
@invocation("llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], category="vllm", version="1.0.0")
class LlavaOnevisionVllmInvocation(BaseInvocation):
"""Run a LLaVA OneVision VLLM model."""
images: list[ImageField] | ImageField | None = InputField(default=None, max_length=3, description="Input image.")
prompt: str = InputField(
default="",
description="Input text prompt.",
ui_component=UIComponent.Textarea,
)
vllm_model: ModelIdentifierField = InputField(
title="LLaVA Model Type",
description=FieldDescriptions.vllm_model,
ui_type=UIType.LlavaOnevisionModel,
)
@field_validator("images", mode="before")
def listify_images(cls, v: Any) -> list:
if v is None:
return v
if not isinstance(v, list):
return [v]
return v
def _get_images(self, context: InvocationContext) -> list[Image]:
if self.images is None:
return []
image_fields = self.images if isinstance(self.images, list) else [self.images]
return [context.images.get_pil(image_field.image_name, "RGB") for image_field in image_fields]
@torch.no_grad()
def invoke(self, context: InvocationContext) -> StringOutput:
images = self._get_images(context)
with context.models.load(self.vllm_model) as vllm_model:
assert isinstance(vllm_model, LlavaOnevisionModel)
output = vllm_model.run(
prompt=self.prompt,
images=images,
device=TorchDevice.choose_torch_device(),
dtype=TorchDevice.choose_torch_dtype(),
)
return StringOutput(value=output)

View File

@@ -67,7 +67,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation):
invert: bool = InputField(default=False, description="Whether to invert the mask.")
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.images.get_pil(self.image.image_name)
image = context.images.get_pil(self.image.image_name, mode="RGBA")
mask = torch.zeros((1, image.height, image.width), dtype=torch.bool)
if self.invert:
mask[0] = torch.tensor(np.array(image)[:, :, 3] == 0, dtype=torch.bool)

View File

@@ -181,7 +181,7 @@ class LoRALoaderOutput(BaseInvocationOutput):
clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3")
@invocation("lora_loader", title="Apply LoRA - SD1.5", tags=["model"], category="model", version="1.0.4")
class LoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@@ -244,7 +244,7 @@ class LoRASelectorOutput(BaseInvocationOutput):
lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA")
@invocation("lora_selector", title="LoRA Model - SD1.5", tags=["model"], category="model", version="1.0.2")
@invocation("lora_selector", title="Select LoRA", tags=["model"], category="model", version="1.0.3")
class LoRASelectorInvocation(BaseInvocation):
"""Selects a LoRA model and weight."""
@@ -258,7 +258,7 @@ class LoRASelectorInvocation(BaseInvocation):
@invocation(
"lora_collection_loader", title="LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.1"
"lora_collection_loader", title="Apply LoRA Collection - SD1.5", tags=["model"], category="model", version="1.1.2"
)
class LoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to the provided UNet and CLIP models."""
@@ -322,10 +322,10 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput):
@invocation(
"sdxl_lora_loader",
title="LoRA Model - SDXL",
title="Apply LoRA - SDXL",
tags=["lora", "model"],
category="model",
version="1.0.4",
version="1.0.5",
)
class SDXLLoRALoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
@@ -402,10 +402,10 @@ class SDXLLoRALoaderInvocation(BaseInvocation):
@invocation(
"sdxl_lora_collection_loader",
title="LoRA Collection - SDXL",
title="Apply LoRA Collection - SDXL",
tags=["model"],
category="model",
version="1.1.1",
version="1.1.2",
)
class SDXLLoRACollectionLoader(BaseInvocation):
"""Applies a collection of SDXL LoRAs to the provided UNet and CLIP models."""

View File

@@ -49,8 +49,6 @@ def run_app() -> None:
# Miscellaneous startup tasks.
apply_monkeypatches()
register_mime_types()
if app_config.dev_reload:
enable_dev_reload()
check_cudnn(logger)
# Initialize the app and event loop.
@@ -61,6 +59,11 @@ def run_app() -> None:
# core nodes have been imported so that we can catch when a custom node clobbers a core node.
load_custom_nodes(custom_nodes_path=app_config.custom_nodes_path, logger=logger)
if app_config.dev_reload:
# load_custom_nodes seems to bypass jurrigged's import sniffer, so be sure to call it *after* they're already
# imported.
enable_dev_reload(custom_nodes_path=app_config.custom_nodes_path)
# Start the server.
config = uvicorn.Config(
app=app,

View File

@@ -38,9 +38,11 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
InvalidModelConfigException,
ModelConfigBase,
ModelRepoVariant,
ModelSourceType,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
@@ -49,7 +51,6 @@ from invokeai.backend.model_manager.metadata import (
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
@@ -182,9 +183,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> str: # noqa D102
model_path = Path(model_path)
config = config or ModelRecordChanges()
info: AnyModelConfig = ModelProbe.probe(
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
) # type: ignore
info: AnyModelConfig = self._probe(Path(model_path), config) # type: ignore
if preferred_name := config.name:
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
@@ -644,12 +643,22 @@ class ModelInstallService(ModelInstallServiceBase):
move(old_path, new_path)
return new_path
def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
config = config or ModelRecordChanges()
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()
try:
return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
except InvalidModelConfigException:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
) -> str:
config = config or ModelRecordChanges()
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
info = info or self._probe(model_path, config)
model_path = model_path.resolve()

View File

@@ -85,8 +85,11 @@ class ModelLoadService(ModelLoadServiceBase):
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
if scan_result.infected_files != 0:
raise Exception(f"The model at {checkpoint} is potentially infected by malware. Aborting load.")
if scan_result.scan_err:
raise Exception(f"Error scanning model at {checkpoint} for malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result

View File

@@ -570,7 +570,10 @@ ValueToInsertTuple: TypeAlias = tuple[
str | None, # destination (optional)
int | None, # retried_from_item_id (optional, this is always None for new items)
]
"""A type alias for the tuple of values to insert into the session queue table."""
"""A type alias for the tuple of values to insert into the session queue table.
**If you change this, be sure to update the `enqueue_batch` and `retry_items_by_id` methods in the session queue service!**
"""
def prepare_values_to_insert(

View File

@@ -27,6 +27,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
calc_session_count,
prepare_values_to_insert,
)
@@ -689,7 +690,7 @@ class SqliteSessionQueue(SessionQueueBase):
"""Retries the given queue items"""
try:
cursor = self._conn.cursor()
values_to_insert: list[tuple] = []
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
for item_id in item_ids:
@@ -715,16 +716,16 @@ class SqliteSessionQueue(SessionQueueBase):
else queue_item.item_id
)
value_to_insert = (
value_to_insert: ValueToInsertTuple = (
queue_item.queue_id,
queue_item.batch_id,
queue_item.destination,
field_values_json,
queue_item.origin,
queue_item.priority,
workflow_json,
cloned_session_json,
cloned_session.id,
queue_item.batch_id,
field_values_json,
queue_item.priority,
workflow_json,
queue_item.origin,
queue_item.destination,
retried_from_item_id,
)
values_to_insert.append(value_to_insert)

View File

@@ -1,6 +1,7 @@
import logging
import mimetypes
import socket
from pathlib import Path
import torch
@@ -33,8 +34,9 @@ def check_cudnn(logger: logging.Logger) -> None:
)
def enable_dev_reload() -> None:
def enable_dev_reload(custom_nodes_path=None) -> None:
"""Enable hot reloading on python file changes during development."""
import invokeai
from invokeai.backend.util.logging import InvokeAILogger
try:
@@ -44,7 +46,10 @@ def enable_dev_reload() -> None:
'Can\'t start `--dev_reload` because jurigged is not found; `pip install -e ".[dev]"` to include development dependencies.'
) from e
else:
jurigged.watch(logger=InvokeAILogger.get_logger(name="jurigged").info)
paths = [str(Path(invokeai.__file__).with_name("*.py"))]
if custom_nodes_path:
paths.append(str(custom_nodes_path / "*.py"))
jurigged.watch(pattern=paths, logger=InvokeAILogger.get_logger(name="jurigged").info)
def apply_monkeypatches() -> None:

View File

@@ -20,6 +20,7 @@ class ModelSpec:
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-dev-fill": 512,
"flux-schnell": 256,
}
@@ -68,4 +69,19 @@ params = {
qkv_bias=True,
guidance_embed=False,
),
"flux-dev-fill": FluxParams(
in_channels=384,
out_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
}

View File

@@ -0,0 +1,49 @@
from pathlib import Path
from typing import Optional
import torch
from PIL.Image import Image
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, LlavaOnevisionProcessor
from invokeai.backend.raw_model import RawModel
class LlavaOnevisionModel(RawModel):
def __init__(self, vllm_model: LlavaOnevisionForConditionalGeneration, processor: LlavaOnevisionProcessor):
self._vllm_model = vllm_model
self._processor = processor
@classmethod
def load_from_path(cls, path: str | Path):
vllm_model = LlavaOnevisionForConditionalGeneration.from_pretrained(path, local_files_only=True)
assert isinstance(vllm_model, LlavaOnevisionForConditionalGeneration)
processor = AutoProcessor.from_pretrained(path, local_files_only=True)
assert isinstance(processor, LlavaOnevisionProcessor)
return cls(vllm_model, processor)
def run(self, prompt: str, images: list[Image], device: torch.device, dtype: torch.dtype) -> str:
# TODO(ryand): Tune the max number of images that are useful for the model.
if len(images) > 3:
raise ValueError(
f"{len(images)} images were provided as input to the LLaVA OneVision model. "
"Pass <=3 images for good performance."
)
# Define a chat history and use `apply_chat_template` to get correctly formatted prompt.
# "content" is a list of dicts with types "text" or "image".
content = [{"type": "text", "text": prompt}]
# Add the correct number of images.
for _ in images:
content.append({"type": "image"})
conversation = [{"role": "user", "content": content}]
prompt = self._processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = self._processor(images=images or None, text=prompt, return_tensors="pt").to(device=device, dtype=dtype)
output = self._vllm_model.generate(**inputs, max_new_tokens=400, do_sample=False)
output_str: str = self._processor.decode(output[0][2:], skip_special_tokens=True)
# The output_str will include the prompt, so we extract the response.
response = output_str.split("assistant\n", 1)[1].strip()
return response
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
self._vllm_model.to(device=device, dtype=dtype)

View File

@@ -5,6 +5,7 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
@@ -13,8 +14,8 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
__all__ = [
@@ -32,4 +33,5 @@ __all__ = [
"ModelVariantType",
"SchedulerPredictionType",
"SubModelType",
"ModelConfigBase",
]

View File

@@ -20,21 +20,34 @@ Validation errors will raise an InvalidModelConfigException error.
"""
# pyright: reportIncompatibleVariableOverride=false
import json
import logging
import time
from abc import ABC, abstractmethod
from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
from inspect import isabstract
from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
import diffusers
import onnxruntime as ort
import safetensors.torch
import torch
from diffusers.models.modeling_utils import ModelMixin
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.raw_model import RawModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.silence_warnings import SilenceWarnings
logger = logging.getLogger(__name__)
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
@@ -44,7 +57,7 @@ AnyModel = Union[
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
"""Exception for when config parser doesn't recognize this combination of model type and format."""
class BaseModelType(str, Enum):
@@ -78,6 +91,7 @@ class ModelType(str, Enum):
SpandrelImageToImage = "spandrel_image_to_image"
SigLIP = "siglip"
FluxRedux = "flux_redux"
LlavaOnevision = "llava_onevision"
class SubModelType(str, Enum):
@@ -190,12 +204,98 @@ class MainModelDefaultSettings(BaseModel):
class ControlAdapterDefaultSettings(BaseModel):
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor: str | None
model_config = ConfigDict(extra="forbid")
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
class ModelOnDisk:
"""A utility class representing a model stored on disk."""
def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
self.path = path
self.format_type = ModelFormat.Diffusers if path.is_dir() else ModelFormat.Checkpoint
if self.path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
self.name = path.stem
else:
self.name = path.name
self.hash_algo = hash_algo
def hash(self):
return ModelHash(algorithm=self.hash_algo).hash(self.path)
def size(self):
if self.format_type == ModelFormat.Checkpoint:
return self.path.stat().st_size
return sum(file.stat().st_size for file in self.path.rglob("*"))
def component_paths(self):
if self.format_type == ModelFormat.Checkpoint:
return {self.path}
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
return {f for f in self.path.rglob("*") if f.suffix in extensions}
def repo_variant(self):
if self.format_type == ModelFormat.Checkpoint:
return None
weight_files = list(self.path.glob("**/*.safetensors"))
weight_files.extend(list(self.path.glob("**/*.bin")))
for x in weight_files:
if ".fp16" in x.suffixes:
return ModelRepoVariant.FP16
if "openvino_model" in x.name:
return ModelRepoVariant.OpenVINO
if "flax_model" in x.name:
return ModelRepoVariant.Flax
if x.suffix == ".onnx":
return ModelRepoVariant.ONNX
return ModelRepoVariant.Default
@staticmethod
def load_state_dict(path: Path):
with SilenceWarnings():
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.")
checkpoint = torch.load(path, map_location="cpu")
elif path.suffix.endswith(".gguf"):
checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32)
elif path.suffix.endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(path)
else:
raise ValueError(f"Unrecognized model extension: {path.suffix}")
state_dict = checkpoint.get("state_dict", checkpoint)
return state_dict
class MatchSpeed(int, Enum):
"""Represents the estimated runtime speed of a config's 'matches' method."""
FAST = 0
MED = 1
SLOW = 2
class ModelConfigBase(ABC, BaseModel):
"""
Abstract Base class for model configurations.
To create a new config type, inherit from this class and implement its interface:
- (mandatory) override methods 'matches' and 'parse'
- (mandatory) define fields 'type' and 'format' as class attributes
- (optional) override method 'get_tag'
- (optional) override field _MATCH_SPEED
See MinimalConfigExample in test_model_probe.py for an example implementation.
"""
@staticmethod
def json_schema_extra(schema: dict[str, Any]) -> None:
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
hash: str = Field(description="The hash of the model file(s).")
@@ -203,27 +303,134 @@ class ModelConfigBase(BaseModel):
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
)
name: str = Field(description="Name of the model.")
type: ModelType = Field(description="Model type")
format: ModelFormat = Field(description="Model format")
base: BaseModelType = Field(description="The base model.")
description: Optional[str] = Field(description="Model description", default=None)
source: str = Field(description="The original source of the model (path, URL or repo_id).")
source_type: ModelSourceType = Field(description="The type of source")
description: Optional[str] = Field(description="Model description", default=None)
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
_USING_LEGACY_PROBE: ClassVar[set] = set()
_USING_CLASSIFY_API: ClassVar[set] = set()
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
class CheckpointConfigBase(ModelConfigBase):
"""Model config for checkpoint-style models."""
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, LegacyProbeMixin):
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
else:
ModelConfigBase._USING_CLASSIFY_API.add(cls)
@staticmethod
def all_config_classes():
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
concrete = {cls for cls in subclasses if not isabstract(cls)}
return concrete
@staticmethod
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
"""
Returns the best matching ModelConfig instance from a model's file/folder path.
Raises InvalidModelConfigException if no valid configuration is found.
Created to deprecate ModelProbe.probe
"""
candidates = ModelConfigBase._USING_CLASSIFY_API
sorted_by_match_speed = sorted(candidates, key=lambda cls: cls._MATCH_SPEED)
mod = ModelOnDisk(model_path, hash_algo)
for config_cls in sorted_by_match_speed:
try:
return config_cls.from_model_on_disk(mod, **overrides)
except InvalidModelConfigException:
logger.debug(f"ModelConfig '{config_cls.__name__}' failed to parse '{mod.path}', trying next config")
except Exception as e:
logger.error(f"Unexpected exception while parsing '{config_cls.__name__}': {e}, trying next config")
raise InvalidModelConfigException("No valid config found")
@classmethod
def get_tag(cls) -> Tag:
type = cls.model_fields["type"].default.value
format = cls.model_fields["format"].default.value
return Tag(f"{type}.{format}")
@classmethod
@abstractmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
"""Returns a dictionary with the fields needed to construct the model.
Raises InvalidModelConfigException if the model is invalid.
"""
pass
@classmethod
@abstractmethod
def matches(cls, mod: ModelOnDisk) -> bool:
"""Performs a quick check to determine if the config matches the model.
This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing."""
pass
@staticmethod
def cast_overrides(overrides: dict[str, Any]):
"""Casts user overrides from str to Enum"""
if "type" in overrides:
overrides["type"] = ModelType(overrides["type"])
if "format" in overrides:
overrides["format"] = ModelFormat(overrides["format"])
if "base" in overrides:
overrides["base"] = BaseModelType(overrides["base"])
if "source_type" in overrides:
overrides["source_type"] = ModelSourceType(overrides["source_type"])
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, **overrides):
"""Creates an instance of this config or raises InvalidModelConfigException."""
if not cls.matches(mod):
raise InvalidModelConfigException(f"Path {mod.path} does not match {cls.__name__} format")
fields = cls.parse(mod)
cls.cast_overrides(overrides)
fields.update(overrides)
type = fields.get("type") or cls.model_fields["type"].default
base = fields.get("base") or cls.model_fields["base"].default
fields["path"] = mod.path.as_posix()
fields["source"] = fields.get("source") or fields["path"]
fields["source_type"] = fields.get("source_type") or ModelSourceType.Path
fields["name"] = name = fields.get("name") or mod.name
fields["hash"] = fields.get("hash") or mod.hash()
fields["key"] = fields.get("key") or uuid_string()
fields["description"] = fields.get("description") or f"{base.value} {type.value} model {name}"
fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant()
return cls(**fields)
class LegacyProbeMixin:
"""Mixin for classes using the legacy probe for model classification."""
@classmethod
def matches(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}")
@classmethod
def parse(cls, *args, **kwargs):
raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}")
class CheckpointConfigBase(ABC, BaseModel):
"""Base class for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field(
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
@@ -234,153 +441,109 @@ class CheckpointConfigBase(ModelConfigBase):
)
class DiffusersConfigBase(ModelConfigBase):
"""Model config for diffusers-style models."""
class DiffusersConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
class LoRAConfigBase(ModelConfigBase):
class LoRAConfigBase(ABC, BaseModel):
"""Base class for LoRA models."""
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
class T5EncoderConfigBase(ModelConfigBase):
class T5EncoderConfigBase(ABC, BaseModel):
"""Base class for diffusers-style models."""
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
class T5EncoderConfig(T5EncoderConfigBase):
class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase):
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
class LoRALyCORISConfig(LoRAConfigBase):
class LoRALyCORISConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
class ControlAdapterConfigBase(BaseModel):
class ControlAdapterConfigBase(ABC, BaseModel):
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
)
class ControlLoRALyCORISConfig(ModelConfigBase, ControlAdapterConfigBase):
class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.LyCORIS.value}")
class ControlLoRADiffusersConfig(ModelConfigBase, ControlAdapterConfigBase):
class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for Control LoRA models."""
type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlLoRa.value}.{ModelFormat.Diffusers.value}")
class LoRADiffusersConfig(LoRAConfigBase):
class LoRADiffusersConfig(LoRAConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
class VAECheckpointConfig(CheckpointConfigBase):
class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.VAE] = ModelType.VAE
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
class VAEDiffusersConfig(ModelConfigBase):
class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.VAE] = ModelType.VAE
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
class TextualInversionFileConfig(ModelConfigBase):
class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
class TextualInversionFolderConfig(ModelConfigBase):
class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
class MainConfigBase(ModelConfigBase):
class MainConfigBase(ABC, BaseModel):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings] = Field(
@@ -389,167 +552,146 @@ class MainConfigBase(ModelConfigBase):
variant: AnyVariant = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.BnbQuantizednf4b
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase):
class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main checkpoint models."""
format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.format = ModelFormat.GGUFQuantized
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.GGUFQuantized.value}")
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for main diffusers models."""
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
pass
class IPAdapterBaseConfig(ModelConfigBase):
class IPAdapterConfigBase(ABC, BaseModel):
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for IP Adapter diffusers format models."""
# TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long
# time. Need to go through the history to make sure I'm understanding this fully.
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
"""Model config for Clip Embeddings."""
variant: ClipVariantType = Field(description="Clip variant for this model")
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIP-G Embeddings."""
variant: ClipVariantType = ClipVariantType.G
variant: Literal[ClipVariantType.G] = ClipVariantType.G
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}")
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIP-L Embeddings."""
variant: ClipVariantType = ClipVariantType.L
variant: Literal[ClipVariantType.L] = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
@classmethod
def get_tag(cls) -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
class SpandrelImageToImageConfig(ModelConfigBase):
class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
class SigLIPConfig(DiffusersConfigBase):
class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase):
"""Model config for SigLIP."""
type: Literal[ModelType.SigLIP] = ModelType.SigLIP
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SigLIP.value}.{ModelFormat.Diffusers.value}")
class FluxReduxConfig(ModelConfigBase):
class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase):
"""Model config for FLUX Tools Redux model."""
type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.FluxRedux.value}.{ModelFormat.Checkpoint.value}")
class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase):
"""Model config for Llava Onevision models."""
type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.format_type == ModelFormat.Checkpoint:
return False
config_path = mod.path / "config.json"
try:
with open(config_path, "r") as file:
config = json.load(file)
except FileNotFoundError:
return False
architectures = config.get("architectures")
return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration"
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
return {
"base": BaseModelType.Any,
"variant": ModelVariantType.Normal,
}
def get_model_discriminator_value(v: Any) -> str:
@@ -557,22 +699,40 @@ def get_model_discriminator_value(v: Any) -> str:
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
format_ = None
type_ = None
format_ = type_ = variant_ = None
if isinstance(v, dict):
format_ = v.get("format")
if isinstance(format_, Enum):
format_ = format_.value
type_ = v.get("type")
if isinstance(type_, Enum):
type_ = type_.value
variant_ = v.get("variant")
if isinstance(variant_, Enum):
variant_ = variant_.value
else:
format_ = v.format.value
type_ = v.type.value
v = f"{type_}.{format_}"
return v
variant_ = getattr(v, "variant", None)
if variant_:
variant_ = variant_.value
# Ideally, each config would be uniquely identified with a combination of fields
# i.e. (type, format, variant) without any special cases. Alas...
# Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field.
# To maintain compatibility, we default to ClipVariantType.L in this case.
if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value:
variant_ = variant_ or ClipVariantType.L.value
return f"{type_}.{format_}.{variant_}"
return f"{type_}.{format_}"
# The types are listed explicitly because IDEs/LSPs can't identify the correct types
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
@@ -596,11 +756,11 @@ AnyModelConfig = Annotated[
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
Annotated[SigLIPConfig, SigLIPConfig.get_tag()],
Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()],
Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
@@ -609,39 +769,12 @@ AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@classmethod
def make_config(
cls,
model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type[ModelConfigBase]] = None,
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
model: Optional[ModelConfigBase] = None
if isinstance(model_data, ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key:
model.key = key
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
class ModelConfigFactory:
@staticmethod
def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig:
"""Return the appropriate config object from raw dict values."""
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
if isinstance(model, CheckpointConfigBase) and timestamp:
model.converted_at = timestamp
if model:
validate_hash(model.hash)
validate_hash(model.hash)
return model # type: ignore

View File

@@ -3,10 +3,10 @@ import re
from pathlib import Path
from typing import Any, Callable, Dict, Literal, Optional, Union
import picklescan.scanner as pscan
import safetensors.torch
import spandrel
import torch
from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
@@ -141,6 +141,7 @@ class ModelProbe(object):
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
"SiglipModel": ModelType.SigLIP,
"LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision,
}
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
@@ -416,20 +417,22 @@ class ModelProbe(object):
# TODO: Decide between dev/schnell
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
# HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model
# loading. When FLUX support was first added, it was decided that this was the easiest way to support
# the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in
# the future.
if (
"guidance_in.out_layer.weight" in state_dict
or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict
):
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the descriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-dev"
if variant_type == ModelVariantType.Normal:
config_file = "flux-dev"
elif variant_type == ModelVariantType.Inpaint:
config_file = "flux-dev-fill"
else:
raise ValueError(f"Unexpected FLUX variant type: {variant_type}")
else:
# For flux, this is a key in invokeai.backend.flux.util.params
# Due to model type and format being the discriminator for model configs this
# is used rather than attempting to support flux with separate model types and format
# If changed in the future, please fix me
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
@@ -482,9 +485,11 @@ class ModelProbe(object):
and option to exit if an infected file is identified.
"""
# scan model
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
scan_result = pscan.scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
raise Exception(f"Error scanning model {model_name} for malware. Aborting import.")
# Probing utilities
@@ -552,9 +557,27 @@ class CheckpointProbeBase(ProbeBase):
def get_variant_type(self) -> ModelVariantType:
model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint)
base_type = self.get_base_type()
if model_type != ModelType.Main or base_type == BaseModelType.Flux:
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
if base_type == BaseModelType.Flux:
in_channels = state_dict["img_in.weight"].shape[1]
# FLUX Model variant types are distinguished by input channels:
# - Unquantized Dev and Schnell have in_channels=64
# - BNB-NF4 Dev and Schnell have in_channels=1
# - FLUX Fill has in_channels=384
# - Unsure of quantized FLUX Fill models
# - Unsure of GGUF-quantized models
if in_channels == 384:
# This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant
# type is used to determine whether to use the fill model or the base model.
return ModelVariantType.Inpaint
else:
# Fall back on "normal" variant type for all other FLUX models.
return ModelVariantType.Normal
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
@@ -767,6 +790,11 @@ class FluxReduxCheckpointProbe(CheckpointProbeBase):
return BaseModelType.Flux
class LlavaOnevisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
@@ -1047,6 +1075,11 @@ class FluxReduxFolderProbe(FolderProbeBase):
raise NotImplementedError()
class LlaveOnevisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
@@ -1082,6 +1115,7 @@ ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderPro
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@@ -1095,5 +1129,6 @@ ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpoi
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@@ -0,0 +1,32 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.llava_onevision_model import LlavaOnevisionModel
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers)
class LlavaOnevisionModelLoader(ModelLoader):
"""Class for loading LLaVA Onevision VLLM models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for LLaVA OneVision model.")
model_path = Path(config.path)
model = LlavaOnevisionModel.load_from_path(model_path)
model.to(dtype=self._torch_dtype)
return model

View File

@@ -614,6 +614,26 @@ flux_redux = StarterModel(
)
# endregion
# region LlavaOnevisionModel
llava_onevision = StarterModel(
name="LLaVA Onevision Qwen2 0.5B",
base=BaseModelType.Any,
source="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
description="LLaVA Onevision VLLM model",
type=ModelType.LlavaOnevision,
)
# endregion
# region FLUX Fill
flux_fill = StarterModel(
name="FLUX Fill",
base=BaseModelType.Flux,
source="black-forest-labs/FLUX.1-Fill-dev::flux1-fill-dev.safetensors",
description="FLUX Fill model (for inpainting).",
type=ModelType.Main,
)
# endregion
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
@@ -683,6 +703,8 @@ STARTER_MODELS: list[StarterModel] = [
clip_l_encoder,
siglip,
flux_redux,
llava_onevision,
flux_fill,
]
sd1_bundle: list[StarterModel] = [
@@ -731,6 +753,7 @@ flux_bundle: list[StarterModel] = [
flux_canny_control_lora,
flux_depth_control_lora,
flux_redux,
flux_fill,
]
STARTER_BUNDLES: dict[str, list[StarterModel]] = {

View File

@@ -1,12 +1,12 @@
"""Utilities for parsing model files, used mostly by probe.py"""
"""Utilities for parsing model files, used mostly by legacy_probe.py"""
import json
from pathlib import Path
from typing import Dict, Optional, Union
import picklescan.scanner as pscan
import safetensors
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_manager.config import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -57,9 +57,12 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str,
checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32)
else:
if scan:
scan_result = scan_file_path(path)
if scan_result.infected_files != 0 or scan_result.scan_err:
raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.')
scan_result = pscan.scan_file_path(path)
if scan_result.infected_files != 0:
raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.")
if scan_result.scan_err:
raise Exception(f"Error scanning model at {path} for malware. Aborting import.")
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint

View File

@@ -114,7 +114,9 @@
"layout": "Layout",
"board": "Ordner",
"combinatorial": "Kombinatorisch",
"saveChanges": "Änderungen speichern"
"saveChanges": "Änderungen speichern",
"error_withCount_one": "{{count}} Fehler",
"error_withCount_other": "{{count}} Fehler"
},
"gallery": {
"galleryImageSize": "Bildgröße",
@@ -764,10 +766,10 @@
"layerCopiedToClipboard": "Ebene in die Zwischenablage kopiert",
"sentToCanvas": "An Leinwand gesendet",
"problemDeletingWorkflow": "Problem beim Löschen des Arbeitsablaufs",
"uploadFailedInvalidUploadDesc_withCount_one": "Es darf maximal 1 PNG- oder JPEG-Bild sein.",
"uploadFailedInvalidUploadDesc_withCount_other": "Es dürfen maximal {{count}} PNG- oder JPEG-Bilder sein.",
"uploadFailedInvalidUploadDesc_withCount_one": "Darf maximal 1 PNG-, JPEG- oder WEBP-Bild sein.",
"uploadFailedInvalidUploadDesc_withCount_other": "Dürfen maximal {{count}} PNG-, JPEG- oder WEBP-Bild sein.",
"problemRetrievingWorkflow": "Problem beim Abrufen des Arbeitsablaufs",
"uploadFailedInvalidUploadDesc": "Müssen PNG- oder JPEG-Bilder sein.",
"uploadFailedInvalidUploadDesc": "Müssen PNG-, JPEG- oder WEBP-Bilder sein.",
"pasteSuccess": "Eingefügt in {{destination}}",
"pasteFailed": "Einfügen fehlgeschlagen",
"unableToCopy": "Kopieren nicht möglich",
@@ -1259,7 +1261,6 @@
"nodePack": "Knoten-Pack",
"loadWorkflow": "Lade Workflow",
"snapToGrid": "Am Gitternetz einrasten",
"unknownOutput": "Unbekannte Ausgabe: {{name}}",
"updateNode": "Knoten updaten",
"edge": "Rand / Kante",
"sourceNodeDoesNotExist": "Ungültiger Rand: Quell- / Ausgabe-Knoten {{node}} existiert nicht",
@@ -1325,7 +1326,9 @@
"description": "Beschreibung",
"loadWorkflowDesc": "Arbeitsablauf laden?",
"loadWorkflowDesc2": "Ihr aktueller Arbeitsablauf enthält nicht gespeicherte Änderungen.",
"loadingTemplates": "Lade {{name}}"
"loadingTemplates": "Lade {{name}}",
"missingSourceOrTargetHandle": "Fehlender Quell- oder Zielgriff",
"missingSourceOrTargetNode": "Fehlender Quell- oder Zielknoten"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -194,7 +194,9 @@
"combinatorial": "Combinatorial",
"layout": "Layout",
"row": "Row",
"column": "Column"
"column": "Column",
"value": "Value",
"label": "Label"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -846,6 +848,7 @@
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"controlLora": "Control LoRA",
"llavaOnevision": "LLaVA OneVision",
"syncModels": "Sync Models",
"textualInversions": "Textual Inversions",
"triggerPhrases": "Trigger Phrases",
@@ -1013,7 +1016,10 @@
"unknownNodeType": "Unknown node type",
"unknownTemplate": "Unknown Template",
"unknownInput": "Unknown input: {{name}}",
"unknownOutput": "Unknown output: {{name}}",
"missingField_withName": "Missing field \"{{name}}\"",
"unexpectedField_withName": "Unexpected field \"{{name}}\"",
"unknownField_withName": "Unknown field \"{{name}}\"",
"unknownFieldEditWorkflowToFix_withName": "Workflow contains an unknown field \"{{name}}\".\nEdit the workflow to fix the issue.",
"updateNode": "Update Node",
"updateApp": "Update App",
"loadingTemplates": "Loading {{name}}",
@@ -1298,7 +1304,8 @@
"problemDeletingWorkflow": "Problem Deleting Workflow",
"unableToCopy": "Unable to Copy",
"unableToCopyDesc": "Your browser does not support clipboard access. Firefox users may be able to fix this by following ",
"unableToCopyDesc_theseSteps": "these steps"
"unableToCopyDesc_theseSteps": "these steps",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks."
},
"popovers": {
"clipSkip": {
@@ -1700,6 +1707,7 @@
"shared": "Shared",
"browseWorkflows": "Browse Workflows",
"deselectAll": "Deselect All",
"recommended": "Recommended For You",
"opened": "Opened",
"openWorkflow": "Open Workflow",
"updated": "Updated",
@@ -1738,6 +1746,7 @@
"openLibrary": "Open Library",
"workflowThumbnail": "Workflow Thumbnail",
"saveChanges": "Save Changes",
"emptyStringPlaceholder": "<empty string>",
"builder": {
"deleteAllElements": "Delete All Form Elements",
"resetAllNodeFields": "Reset All Node Fields",
@@ -1762,6 +1771,9 @@
"singleLine": "Single Line",
"multiLine": "Multi Line",
"slider": "Slider",
"dropdown": "Dropdown",
"addOption": "Add Option",
"resetOptions": "Reset Options",
"both": "Both",
"emptyRootPlaceholderViewMode": "Click Edit to start building a form for this workflow.",
"emptyRootPlaceholderEditMode": "Drag a form element or node field here to get started.",
@@ -1948,7 +1960,8 @@
"rgNegativePromptNotSupported": "Negative Prompt not supported for selected base model",
"rgReferenceImagesNotSupported": "regional Reference Images not supported for selected base model",
"rgAutoNegativeNotSupported": "Auto-Negative not supported for selected base model",
"rgNoRegion": "no region drawn"
"rgNoRegion": "no region drawn",
"fluxFillIncompatibleWithControlLoRA": "Control LoRA is not compatible with FLUX Fill"
},
"errors": {
"unableToFindImage": "Unable to find image",
@@ -2331,7 +2344,7 @@
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Workflows: New and improved Workflow Library.",
"FLUX: Support for FLUX Redux in Workflows and Canvas."
"FLUX: Support for FLUX Redux & FLUX Fill in Workflows and Canvas."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",

View File

@@ -1653,7 +1653,6 @@
"collectionFieldType": "{{name}} (Collection)",
"newWorkflow": "Nouveau Workflow",
"reorderLinearView": "Réorganiser la vue linéaire",
"unknownOutput": "Sortie inconnue : {{name}}",
"outputFieldTypeParseError": "Impossible d'analyser le type du champ de sortie {{node}}.{{field}} ({{message}})",
"unsupportedMismatchedUnion": "type CollectionOrScalar non concordant avec les types de base {{firstType}} et {{secondType}}",
"unableToParseFieldType": "impossible d'analyser le type de champ",

View File

@@ -113,7 +113,9 @@
"saveChanges": "Salva modifiche",
"error_withCount_one": "{{count}} errore",
"error_withCount_many": "{{count}} errori",
"error_withCount_other": "{{count}} errori"
"error_withCount_other": "{{count}} errori",
"value": "Valore",
"label": "Etichetta"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -779,7 +781,8 @@
"enableModelDescriptions": "Abilita le descrizioni dei modelli nei menu a discesa",
"modelDescriptionsDisabled": "Descrizioni dei modelli nei menu a discesa disabilitate",
"modelDescriptionsDisabledDesc": "Le descrizioni dei modelli nei menu a discesa sono state disabilitate. Abilitale nelle Impostazioni.",
"showDetailedInvocationProgress": "Mostra dettagli avanzamento"
"showDetailedInvocationProgress": "Mostra dettagli avanzamento",
"enableHighlightFocusedRegions": "Evidenzia le regioni interessate"
},
"toast": {
"uploadFailed": "Caricamento fallito",
@@ -847,7 +850,8 @@
"pasteSuccess": "Incollato su {{destination}}",
"unableToCopy": "Impossibile copiare",
"unableToCopyDesc": "Il tuo browser non supporta l'accesso agli appunti. Gli utenti di Firefox potrebbero risolvere il problema seguendo ",
"unableToCopyDesc_theseSteps": "questi passaggi"
"unableToCopyDesc_theseSteps": "questi passaggi",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill non è compatibile con Testo a Immagine o Immagine a Immagine. Per queste attività, utilizzare altri modelli FLUX."
},
"accessibility": {
"invokeProgressBar": "Barra di avanzamento generazione",
@@ -968,7 +972,6 @@
"unableToGetWorkflowVersion": "Impossibile ottenere la versione dello schema del flusso di lavoro",
"nodePack": "Pacchetto di nodi",
"unableToExtractSchemaNameFromRef": "Impossibile estrarre il nome dello schema dal riferimento",
"unknownOutput": "Output sconosciuto: {{name}}",
"unknownNodeType": "Tipo di nodo sconosciuto",
"targetNodeDoesNotExist": "Connessione non valida: il nodo di destinazione/input {{node}} non esiste",
"unknownFieldType": "$t(nodes.unknownField) tipo: {{type}}",
@@ -1038,7 +1041,11 @@
"generatorImages_many": "{{count}} immagini",
"generatorImages_other": "{{count}} immagini",
"generatorImagesFromBoard": "Immagini dalla Bacheca",
"missingSourceOrTargetNode": "Nodo sorgente o di destinazione mancante"
"missingSourceOrTargetNode": "Nodo sorgente o di destinazione mancante",
"unknownField_withName": "Campo \"{{name}}\" sconosciuto",
"missingField_withName": "Campo \"{{name}}\" mancante",
"unknownFieldEditWorkflowToFix_withName": "Il flusso di lavoro contiene un campo \"{{name}}\" sconosciuto .\nModifica il flusso di lavoro per risolvere il problema.",
"unexpectedField_withName": "Campo \"{{name}}\" inaspettato"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
@@ -1776,7 +1783,12 @@
"text": "Testo",
"numberInput": "Ingresso numerico",
"containerRowLayout": "Contenitore (disposizione riga)",
"containerColumnLayout": "Contenitore (disposizione colonna)"
"containerColumnLayout": "Contenitore (disposizione colonna)",
"minimum": "Minimo",
"maximum": "Massimo",
"dropdown": "Elenco a discesa",
"addOption": "Aggiungi opzione",
"resetOptions": "Reimposta opzioni"
},
"loadMore": "Carica altro",
"searchPlaceholder": "Cerca per nome, descrizione o etichetta",
@@ -1791,7 +1803,9 @@
"private": "Privato",
"deselectAll": "Deseleziona tutto",
"noRecentWorkflows": "Nessun flusso di lavoro recente",
"view": "Visualizza"
"view": "Visualizza",
"recommended": "Consigliato per te",
"emptyStringPlaceholder": "<stringa vuota>"
},
"accordions": {
"compositing": {
@@ -2235,7 +2249,8 @@
"rgNegativePromptNotSupported": "Prompt negativo non supportato per il modello base selezionato",
"ipAdapterIncompatibleBaseModel": "modello base dell'immagine di riferimento incompatibile",
"ipAdapterNoImageSelected": "nessuna immagine di riferimento selezionata",
"rgAutoNegativeNotSupported": "Auto-Negativo non supportato per il modello base selezionato"
"rgAutoNegativeNotSupported": "Auto-Negativo non supportato per il modello base selezionato",
"fluxFillIncompatibleWithControlLoRA": "Il controllo LoRA non è compatibile con FLUX Fill"
},
"pasteTo": "Incolla su",
"pasteToBboxDesc": "Nuovo livello (nel riquadro di delimitazione)",
@@ -2350,8 +2365,8 @@
"watchRecentReleaseVideos": "Guarda i video su questa versione",
"watchUiUpdatesOverview": "Guarda le novità dell'interfaccia",
"items": [
"Gestione della memoria: nuova impostazione per gli utenti con GPU Nvidia per ridurre l'utilizzo della VRAM.",
"Prestazioni: continui miglioramenti alle prestazioni e alla reattività complessive dell'applicazione."
"Flussi di lavoro: nuova e migliorata libreria dei flussi di lavoro.",
"FLUX: supporto per FLUX Redux e FLUX Fill in Flussi di lavoro e Tela."
]
},
"system": {

View File

@@ -425,7 +425,6 @@
"newWorkflow": "Nieuwe werkstroom",
"unknownErrorValidatingWorkflow": "Onbekende fout bij valideren werkstroom",
"unsupportedAnyOfLength": "te veel union-leden ({{count}})",
"unknownOutput": "Onbekende uitvoer: {{name}}",
"viewMode": "Gebruik in lineaire weergave",
"unableToExtractSchemaNameFromRef": "fout bij het extraheren van de schemanaam via de ref",
"unsupportedMismatchedUnion": "niet-overeenkomende soort CollectionOrScalar met basissoorten {{firstType}} en {{secondType}}",

View File

@@ -879,7 +879,6 @@
"unableToExtractSchemaNameFromRef": "невозможно извлечь имя схемы из ссылки",
"executionStateError": "Ошибка",
"prototypeDesc": "Этот вызов является прототипом. Он может претерпевать изменения при обновлении приложения и может быть удален в любой момент.",
"unknownOutput": "Неизвестный вывод: {{name}}",
"executionStateCompleted": "Выполнено",
"node": "Узел",
"workflowAuthor": "Автор",

View File

@@ -236,7 +236,8 @@
"layout": "Bố Cục",
"row": "Hàng",
"board": "Bảng",
"saveChanges": "Lưu Thay Đổi"
"saveChanges": "Lưu Thay Đổi",
"error_withCount_other": "{{count}} lỗi"
},
"prompt": {
"addPromptTrigger": "Thêm Prompt Trigger",
@@ -769,7 +770,8 @@
"urlForbiddenErrorMessage": "Bạn có thể cần yêu cầu quyền truy cập từ trang web đang cung cấp model.",
"urlUnauthorizedErrorMessage": "Bạn có thể cần thiếp lập một token API để dùng được model này.",
"fluxRedux": "FLUX Redux",
"sigLip": "SigLIP"
"sigLip": "SigLIP",
"llavaOnevision": "LLaVA OneVision"
},
"metadata": {
"guidance": "Hướng Dẫn",
@@ -892,7 +894,6 @@
"targetNodeFieldDoesNotExist": "Kết nối không phù hợp: đích đến/đầu vào của vùng {{node}}.{{field}} không tồn tại",
"missingTemplate": "Node không hợp lệ: node {{node}} thuộc loại {{type}} bị thiếu mẫu trình bày (chưa tải?)",
"unsupportedMismatchedUnion": "Dạng số lượng dữ liệu không khớp với {{firstType}} và {{secondType}}",
"unknownOutput": "Đầu Ra Không Rõ: {{name}}",
"betaDesc": "Trình kích hoạt này vẫn trong giai đoạn beta. Cho đến khi ổn định, nó có thể phá hỏng thay đổi trong khi cập nhật ứng dụng. Chúng tôi dự định hỗ trợ trình kích hoạt này về lâu dài.",
"cannotConnectInputToInput": "Không thế kết nối đầu vào với đầu vào",
"showEdgeLabelsHelp": "Hiển thị tên trên kết nối, chỉ ra những node được kết nối",
@@ -1019,7 +1020,11 @@
"downloadWorkflowError": "Lỗi tải xuống workflow",
"generatorImagesFromBoard": "Ảnh Từ Bảng",
"generatorImagesCategory": "Phân Loại",
"generatorImages_other": "{{count}} ảnh"
"generatorImages_other": "{{count}} ảnh",
"unknownField_withName": "Vùng Dữ Liệu Không Rõ \"{{name}}\"",
"unexpectedField_withName": "Sai Vùng Dữ Liệu \"{{name}}\"",
"unknownFieldEditWorkflowToFix_withName": "Workflow chứa vùng dữ liệu không rõ \"{{name}}\".\nHãy biên tập workflow để sửa lỗi.",
"missingField_withName": "Thiếu Vùng Dữ Liệu \"{{name}}\""
},
"popovers": {
"paramCFGRescaleMultiplier": {
@@ -1618,7 +1623,8 @@
"displayInProgress": "Hiển Thị Hình Ảnh Đang Xử Lý",
"intermediatesClearedFailed": "Có Vấn Đề Khi Dọn Sạch Sản Phẩm Trung Gian",
"enableInvisibleWatermark": "Bật Chế Độ Ẩn Watermark",
"showDetailedInvocationProgress": "Hiện Dữ Liệu Xử Lý"
"showDetailedInvocationProgress": "Hiện Dữ Liệu Xử Lý",
"enableHighlightFocusedRegions": "Nhấn Mạnh Khu Vực Chỉ Định"
},
"sdxl": {
"loading": "Đang Tải...",
@@ -2048,7 +2054,8 @@
"rgNegativePromptNotSupported": "Lệnh Tiêu Cực không được hỗ trợ cho model cơ sở được chọn",
"rgReferenceImagesNotSupported": "Ảnh Mẫu Khu Vực không được hỗ trợ cho model cơ sở được chọn",
"rgAutoNegativeNotSupported": "Tự Động Đảo Chiều không được hỗ trợ cho model cơ sở được chọn",
"rgNoRegion": "không có khu vực được vẽ"
"rgNoRegion": "không có khu vực được vẽ",
"fluxFillIncompatibleWithControlLoRA": "LoRA Điều Khiển Được không tương tích với FLUX Fill"
},
"pasteTo": "Dán Vào",
"pasteToAssets": "Tài Nguyên",
@@ -2199,7 +2206,8 @@
"unableToCopyDesc_theseSteps": "các bước sau",
"unableToCopyDesc": "Trình duyệt của bạn không hỗ trợ tính năng clipboard. Người dùng Firefox có thể khắc phục theo ",
"pasteSuccess": "Dán Vào {{destination}}",
"pasteFailed": "Dán Thất Bại"
"pasteFailed": "Dán Thất Bại",
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill không tương tích với Từ Ngữ Sang Hình Ảnh và Hình Ảnh Sang Hình Ảnh. Dùng model FLUX khác cho các tính năng này."
},
"ui": {
"tabs": {
@@ -2288,7 +2296,11 @@
"container": "Hộp Chứa",
"heading": "Đầu Dòng",
"text": "Văn Bản",
"divider": "Gạch Chia"
"divider": "Gạch Chia",
"minimum": "Tối Thiểu",
"maximum": "Tối Đa",
"containerRowLayout": "Hộp Chứa (bố cục hàng)",
"containerColumnLayout": "Hộp Chứa (bố cục cột)"
},
"yourWorkflows": "Workflow Của Bạn",
"browseWorkflows": "Khám Phá Workflow",
@@ -2300,7 +2312,11 @@
"filterByTags": "Lọc Theo Nhãn",
"recentlyOpened": "Mở Gần Đây",
"private": "Cá Nhân",
"loadMore": "Tải Thêm"
"loadMore": "Tải Thêm",
"view": "Xem",
"deselectAll": "Huỷ Chọn Tất Cả",
"noRecentWorkflows": "Không Có Workflows Gần Đây",
"recommended": "Có Thể Bạn Sẽ Cần"
},
"upscaling": {
"missingUpscaleInitialImage": "Thiếu ảnh dùng để upscale",
@@ -2327,7 +2343,8 @@
"gettingStartedSeries": "Cần thêm hướng dẫn? Xem thử <LinkComponent>Bắt Đầu Làm Quen</LinkComponent> để biết thêm mẹo khai thác toàn bộ tiềm năng của Invoke Studio.",
"toGetStarted": "Để bắt đầu, hãy nhập lệnh vào hộp và nhấp chuột vào <StrongComponent>Kích Hoạt</StrongComponent> để tạo ra bức ảnh đầu tiên. Chọn một mẫu trình bày cho lệnh để cải thiện kết quả. Bạn có thể chọn để lưu ảnh trực tiếp vào <StrongComponent>Thư Viện Ảnh</StrongComponent> hoặc chỉnh sửa chúng ở <StrongComponent>Canvas</StrongComponent>.",
"noModelsInstalled": "Dường như bạn chưa tải model nào cả! Bạn có thể <DownloadStarterModelsButton>tải xuống các model khởi đầu</DownloadStarterModelsButton> hoặc <ImportModelsButton>nhập vào thêm model</ImportModelsButton>.",
"lowVRAMMode": "Cho hiệu suất tốt nhất, hãy làm theo <LinkComponent>hướng dẫn VRAM Thấp</LinkComponent> của chúng tôi."
"lowVRAMMode": "Cho hiệu suất tốt nhất, hãy làm theo <LinkComponent>hướng dẫn VRAM Thấp</LinkComponent> của chúng tôi.",
"toGetStartedWorkflow": "Để bắt đầu, hãy điền vào khu vực bên trái và bấm <StrongComponent>Kích Hoạt</StrongComponent> nhằm tạo sinh ảnh. Muốn khám phá thêm workflow? Nhấp vào <StrongComponent>icon thư mục</StrongComponent> nằm cạnh tiêu đề workflow để xem một dãy các mẫu trình bày khác."
},
"whatsNew": {
"whatsNewInInvoke": "Có Gì Mới Ở Invoke",
@@ -2335,8 +2352,8 @@
"watchRecentReleaseVideos": "Xem Video Phát Hành Mới Nhất",
"watchUiUpdatesOverview": "Xem Tổng Quan Về Những Cập Nhật Cho Giao Diện Người Dùng",
"items": [
"Trình Quản Lý Bộ Nhớ: Thiết lập mới cho người dùng với GPU Nvidia để giảm lượng VRAM sử dụng.",
"Hiệu suất: Các cải thiện tiếp theo nhằm gói gọn hiệu suất và khả năng phản hồi của ứng dụng."
"Workflow: Thư Viện Workflow mới và đã được cải tiến.",
"FLUX: Hỗ trợ FLUX Redux & FLUX Fill trong Workflow và Canvas."
]
},
"upsell": {

View File

@@ -908,7 +908,6 @@
"unableToGetWorkflowVersion": "无法获取工作流架构版本",
"nodePack": "节点包",
"unableToExtractSchemaNameFromRef": "无法从参考中提取架构名",
"unknownOutput": "未知输出:{{name}}",
"unknownErrorValidatingWorkflow": "验证工作流时出现未知错误",
"collectionFieldType": "{{name}}(合集)",
"unknownNodeType": "未知节点类型",

View File

@@ -12,6 +12,12 @@ import { sentImageToCanvas } from 'features/gallery/store/actions';
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { $isWorkflowLibraryModalOpen } from 'features/nodes/store/workflowLibraryModal';
import {
$workflowLibraryTagOptions,
workflowLibraryTagsReset,
workflowLibraryTagToggled,
workflowLibraryViewChanged,
} from 'features/nodes/store/workflowLibrarySlice';
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
@@ -30,9 +36,17 @@ type SendToCanvasAction = _StudioInitAction<'sendToCanvas', { imageName: string
type UseAllParametersAction = _StudioInitAction<'useAllParameters', { imageName: string }>;
type StudioDestinationAction = _StudioInitAction<
'goToDestination',
{ destination: 'generation' | 'canvas' | 'workflows' | 'upscaling' | 'viewAllWorkflows' | 'viewAllStylePresets' }
{
destination:
| 'generation'
| 'canvas'
| 'workflows'
| 'upscaling'
| 'viewAllWorkflows'
| 'viewAllWorkflowsRecommended'
| 'viewAllStylePresets';
}
>;
// Use global state to show loader until we are ready to render the studio.
export const $didStudioInit = atom(false);
@@ -58,6 +72,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const didParseOpenAPISchema = useStore($hasTemplates);
const store = useAppStore();
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const workflowLibraryTagOptions = useStore($workflowLibraryTagOptions);
const handleSendToCanvas = useCallback(
async (imageName: string) => {
@@ -173,6 +188,18 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
store.dispatch(setActiveTab('workflows'));
$isWorkflowLibraryModalOpen.set(true);
break;
case 'viewAllWorkflowsRecommended':
// Go to the workflows tab and open the workflow library modal with the recommended workflows view
store.dispatch(setActiveTab('workflows'));
$isWorkflowLibraryModalOpen.set(true);
store.dispatch(workflowLibraryViewChanged('defaults'));
store.dispatch(workflowLibraryTagsReset());
for (const tag of workflowLibraryTagOptions) {
if (tag.recommended) {
store.dispatch(workflowLibraryTagToggled(tag.label));
}
}
break;
case 'viewAllStylePresets':
// Go to the canvas tab and open the style presets menu
store.dispatch(setActiveTab('canvas'));
@@ -180,7 +207,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
break;
}
},
[store]
[store, workflowLibraryTagOptions]
);
const handleStudioInitAction = useCallback(

View File

@@ -27,7 +27,8 @@ export type AppFeature =
| 'bulkDownload'
| 'starterModels'
| 'hfToken'
| 'retryQueueItem';
| 'retryQueueItem'
| 'cancelAndClearAll';
/**
* A disable-able Stable Diffusion feature
*/

View File

@@ -19,7 +19,7 @@ const styles: CSSProperties = { position: 'absolute', top: 0, left: 0, right: 0,
const ScrollableContent = ({ children, maxHeight, overflowX = 'hidden', overflowY = 'scroll' }: Props) => {
const overlayscrollbarsOptions = useMemo(
() => getOverlayScrollbarsParams(overflowX, overflowY).options,
() => getOverlayScrollbarsParams({ overflowX, overflowY }).options,
[overflowX, overflowY]
);
const [os, osRef] = useState<OverlayScrollbarsComponentRef | null>(null);

View File

@@ -20,12 +20,22 @@ export const overlayScrollbarsParams: UseOverlayScrollbarsParams = {
},
};
export const getOverlayScrollbarsParams = (
overflowX: 'hidden' | 'scroll' = 'hidden',
overflowY: 'hidden' | 'scroll' = 'scroll'
) => {
export const getOverlayScrollbarsParams = ({
overflowX = 'hidden',
overflowY = 'scroll',
visibility = 'auto',
}: {
overflowX?: 'hidden' | 'scroll';
overflowY?: 'hidden' | 'scroll';
visibility?: 'auto' | 'hidden' | 'visible';
}) => {
const params = deepClone(overlayScrollbarsParams);
merge(params, { options: { overflow: { y: overflowY, x: overflowX } } });
merge(params, {
options: {
overflow: { y: overflowY, x: overflowX },
scrollbars: { visibility, autoHide: visibility === 'visible' ? 'never' : 'scroll' },
},
});
return params;
};

View File

@@ -4,7 +4,6 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled';
import { selectModel } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import {
@@ -19,11 +18,12 @@ import { upperFirst } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiWarningBold } from 'react-icons/pi';
import { selectMainModelConfig } from 'services/api/endpoints/models';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => {
return createSelector(selectCanvasSlice, selectModel, (canvas, model) => {
return createSelector(selectCanvasSlice, selectMainModelConfig, (canvas, model) => {
// This component is used within a <CanvasEntityStateGate /> so we can safely assume that the entity exists.
// Should never throw.
const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings');

View File

@@ -301,8 +301,9 @@ export class CanvasBrushToolModule extends CanvasModuleBase {
points,
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
// When shift is held, the line may extend beyond the clip region. No clip for these lines.
clip: isShiftDraw ? null : this.parent.getClip(selectedEntity.state),
// When shift is held, the line may extend beyond the clip region. Clip only if we are clipping to bbox. If we
// are clipping to stage, we don't need to clip at all.
clip: isShiftDraw && !settings.clipToBbox ? null : this.parent.getClip(selectedEntity.state),
});
} else {
const lastLinePoint = getLastPointOfLastLine(selectedEntity.state.objects, 'brush_line');
@@ -325,8 +326,9 @@ export class CanvasBrushToolModule extends CanvasModuleBase {
points,
strokeWidth: settings.brushWidth,
color: this.manager.stateApi.getCurrentColor(),
// When shift is held, the line may extend beyond the clip region. No clip for these lines.
clip: isShiftDraw ? null : this.parent.getClip(selectedEntity.state),
// When shift is held, the line may extend beyond the clip region. Clip only if we are clipping to bbox. If we
// are clipping to stage, we don't need to clip at all.
clip: isShiftDraw && !settings.clipToBbox ? null : this.parent.getClip(selectedEntity.state),
});
}
};

View File

@@ -5,7 +5,7 @@ import type {
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { MainModelConfig } from 'services/api/types';
const WARNINGS = {
UNSUPPORTED_MODEL: 'controlLayers.warnings.unsupportedModel',
@@ -20,13 +20,14 @@ const WARNINGS = {
CONTROL_ADAPTER_NO_MODEL_SELECTED: 'controlLayers.warnings.controlAdapterNoModelSelected',
CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL: 'controlLayers.warnings.controlAdapterIncompatibleBaseModel',
CONTROL_ADAPTER_NO_CONTROL: 'controlLayers.warnings.controlAdapterNoControl',
FLUX_FILL_NO_WORKY_WITH_CONTROL_LORA: 'controlLayers.warnings.fluxFillIncompatibleWithControlLoRA',
} as const;
type WarningTKey = (typeof WARNINGS)[keyof typeof WARNINGS];
export const getRegionalGuidanceWarnings = (
entity: CanvasRegionalGuidanceState,
model: ParameterModel | null
model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
@@ -78,7 +79,7 @@ export const getRegionalGuidanceWarnings = (
export const getGlobalReferenceImageWarnings = (
entity: CanvasReferenceImageState,
model: ParameterModel | null
model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
@@ -110,7 +111,7 @@ export const getGlobalReferenceImageWarnings = (
export const getControlLayerWarnings = (
entity: CanvasControlLayerState,
model: ParameterModel | null
model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
@@ -129,6 +130,13 @@ export const getControlLayerWarnings = (
} else if (entity.controlAdapter.model.base !== model.base) {
// Supported model architecture but doesn't match
warnings.push(WARNINGS.CONTROL_ADAPTER_INCOMPATIBLE_BASE_MODEL);
} else if (
model.base === 'flux' &&
model.variant === 'inpaint' &&
entity.controlAdapter.model.type === 'control_lora'
) {
// FLUX inpaint variants are FLUX Fill models - not compatible w/ Control LoRA
warnings.push(WARNINGS.FLUX_FILL_NO_WORKY_WITH_CONTROL_LORA);
}
}
@@ -137,7 +145,7 @@ export const getControlLayerWarnings = (
export const getRasterLayerWarnings = (
_entity: CanvasRasterLayerState,
_model: ParameterModel | null
_model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];
@@ -148,7 +156,7 @@ export const getRasterLayerWarnings = (
export const getInpaintMaskWarnings = (
_entity: CanvasInpaintMaskState,
_model: ParameterModel | null
_model: MainModelConfig | null | undefined
): WarningTKey[] => {
const warnings: WarningTKey[] = [];

View File

@@ -21,7 +21,10 @@ type Props = {
extraCopyActions?: { label: string; getData: (data: unknown) => unknown }[];
} & FlexProps;
const overlayscrollbarsOptions = getOverlayScrollbarsParams('scroll', 'scroll').options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({
overflowX: 'scroll',
overflowY: 'scroll',
}).options;
const DataViewer = (props: Props) => {
const { label, data, fileName, withDownload = true, withCopy = true, extraCopyActions, ...rest } = props;

View File

@@ -26,6 +26,7 @@ const useFocusRegionOptions = {
};
const FOCUS_REGION_STYLES: SystemStyleObject = {
display: 'flex',
width: 'full',
height: 'full',
position: 'absolute',
@@ -45,7 +46,7 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
<FocusRegionWrapper region="viewer" sx={FOCUS_REGION_STYLES} layerStyle="first" {...useFocusRegionOptions}>
{hasImageToCompare && <CompareToolbar />}
{!hasImageToCompare && <ViewerToolbar closeButton={closeButton} />}
<Box ref={containerRef} w="full" h="full" p={2}>
<Box ref={containerRef} w="full" h="full" p={2} overflow="hidden">
{!hasImageToCompare && <CurrentImagePreview />}
{hasImageToCompare && <ImageComparison containerDims={containerDims} />}
</Box>

View File

@@ -42,8 +42,7 @@ export class InvalidModelConfigError extends Error {
export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key));
req.unsubscribe();
const req = dispatch(modelsApi.endpoints.getModelConfig.initiate(key, { subscribe: false }));
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for key ${key}`);
@@ -62,8 +61,9 @@ export const fetchModelConfig = async (key: string): Promise<AnyModelConfig> =>
const fetchModelConfigByAttrs = async (name: string, base: BaseModelType, type: ModelType): Promise<AnyModelConfig> => {
const { dispatch } = getStore();
try {
const req = dispatch(modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }));
req.unsubscribe();
const req = dispatch(
modelsApi.endpoints.getModelConfigByAttrs.initiate({ name, base, type }, { subscribe: false })
);
return await req.unwrap();
} catch {
throw new ModelConfigNotFoundError(`Unable to retrieve model config for name/base/type ${name}/${base}/${type}`);

View File

@@ -16,6 +16,7 @@ import {
useEmbeddingModels,
useFluxReduxModels,
useIPAdapterModels,
useLLaVAModels,
useLoRAModels,
useMainModels,
useRefinerModels,
@@ -126,6 +127,12 @@ const ModelList = () => {
[fluxReduxModels, searchTerm, filteredModelType]
);
const [llavaOneVisionModels, { isLoading: isLoadingLlavaOneVisionModels }] = useLLaVAModels();
const filteredLlavaOneVisionModels = useMemo(
() => modelsFilter(llavaOneVisionModels, searchTerm, filteredModelType),
[llavaOneVisionModels, searchTerm, filteredModelType]
);
const totalFilteredModels = useMemo(() => {
return (
filteredMainModels.length +
@@ -236,6 +243,17 @@ const ModelList = () => {
{!isLoadingClipEmbedModels && filteredClipEmbedModels.length > 0 && (
<ModelListWrapper title={t('modelManager.clipEmbed')} modelList={filteredClipEmbedModels} key="clip-embed" />
)}
{/* LLaVA OneVision List */}
{isLoadingLlavaOneVisionModels && <FetchingModelsLoader loadingMessage="Loading LLaVA OneVision Models..." />}
{!isLoadingLlavaOneVisionModels && filteredLlavaOneVisionModels.length > 0 && (
<ModelListWrapper
title={t('modelManager.llavaOnevision')}
modelList={filteredLlavaOneVisionModels}
key="llava-onevision"
/>
)}
{/* Spandrel Image to Image List */}
{isLoadingSpandrelImageToImageModels && (
<FetchingModelsLoader loadingMessage="Loading Image-to-Image Models..." />

View File

@@ -25,8 +25,9 @@ export const ModelTypeFilter = memo(() => {
clip_vision: 'CLIP Vision',
spandrel_image_to_image: t('modelManager.spandrelImageToImage'),
control_lora: t('modelManager.controlLora'),
siglip: t('modelManager.siglip'),
siglip: t('modelManager.sigLip'),
flux_redux: t('modelManager.fluxRedux'),
llava_onevision: t('modelManager.llavaOnevision'),
}),
[t]
);

View File

@@ -14,6 +14,7 @@ import BottomLeftPanel from './flow/panels/BottomLeftPanel/BottomLeftPanel';
import MinimapPanel from './flow/panels/MinimapPanel/MinimapPanel';
const FOCUS_REGION_STYLES: SystemStyleObject = {
display: 'flex',
position: 'relative',
width: 'full',
height: 'full',

View File

@@ -8,7 +8,7 @@ import {
Textarea,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useInputFieldDescription } from 'features/nodes/hooks/useInputFieldDescription';
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { ChangeEvent } from 'react';
@@ -48,7 +48,7 @@ InputFieldDescriptionPopover.displayName = 'InputFieldDescriptionPopover';
const Content = memo(({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const description = useInputFieldDescription(nodeId, fieldName);
const description = useInputFieldDescriptionSafe(nodeId, fieldName);
const onChange = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(fieldDescriptionChanged({ nodeId, fieldName, val: e.target.value }));

View File

@@ -7,7 +7,7 @@ import { InputFieldResetToDefaultValueIconButton } from 'features/nodes/componen
import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd-hooks';
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { memo, useRef } from 'react';
@@ -22,7 +22,7 @@ interface Props {
}
export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
const isConnected = useInputFieldIsConnected(nodeId, fieldName);

View File

@@ -1,23 +1,83 @@
import { InputFieldUnknownPlaceholder } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldUnknownPlaceholder';
import { Flex, Text } from '@invoke-ai/ui-library';
import { InputFieldWrapper } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper';
import { useInputFieldInstanceExists } from 'features/nodes/hooks/useInputFieldInstanceExists';
import { useInputFieldNameSafe } from 'features/nodes/hooks/useInputFieldNameSafe';
import { useInputFieldTemplateExists } from 'features/nodes/hooks/useInputFieldTemplateExists';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
import type { PropsWithChildren, ReactNode } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = PropsWithChildren<{
nodeId: string;
fieldName: string;
fallback?: ReactNode;
formatLabel?: (name: string) => string;
}>;
export const InputFieldGate = memo(({ nodeId, fieldName, children }: Props) => {
export const InputFieldGate = memo(({ nodeId, fieldName, children, fallback, formatLabel }: Props) => {
const hasInstance = useInputFieldInstanceExists(nodeId, fieldName);
const hasTemplate = useInputFieldTemplateExists(nodeId, fieldName);
if (!hasTemplate || !hasInstance) {
return <InputFieldUnknownPlaceholder nodeId={nodeId} fieldName={fieldName} />;
// fallback may be null, indicating we should render nothing at all - must check for undefined explicitly
if (fallback !== undefined) {
return fallback;
}
return (
<Fallback
nodeId={nodeId}
fieldName={fieldName}
formatLabel={formatLabel}
hasInstance={hasInstance}
hasTemplate={hasTemplate}
/>
);
}
return children;
});
InputFieldGate.displayName = 'InputFieldGate';
const Fallback = memo(
({
nodeId,
fieldName,
formatLabel,
hasTemplate,
hasInstance,
}: {
nodeId: string;
fieldName: string;
formatLabel?: (name: string) => string;
hasTemplate: boolean;
hasInstance: boolean;
}) => {
const { t } = useTranslation();
const name = useInputFieldNameSafe(nodeId, fieldName);
const label = useMemo(() => {
if (formatLabel) {
return formatLabel(name);
}
if (hasTemplate && !hasInstance) {
return t('nodes.missingField_withName', { name });
}
if (!hasTemplate && hasInstance) {
return t('nodes.unexpectedField_withName', { name });
}
return t('nodes.unknownField_withName', { name });
}, [formatLabel, hasInstance, hasTemplate, name, t]);
return (
<InputFieldWrapper>
<Flex w="full" px={1} py={1} justifyContent="center">
<Text fontWeight="semibold" color="error.300" whiteSpace="pre" textAlign="center">
{label}
</Text>
</Flex>
</InputFieldWrapper>
);
}
);
Fallback.displayName = 'Fallback';

View File

@@ -7,7 +7,7 @@ import {
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants';
import type { FieldInputTemplate } from 'features/nodes/types/field';
@@ -62,7 +62,7 @@ const handleStyles = {
} satisfies CSSProperties;
export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => isModelFieldType(fieldTemplate.type), [fieldTemplate.type]);

View File

@@ -13,10 +13,11 @@ import { StringGeneratorFieldInputComponent } from 'features/nodes/components/fl
import { IntegerFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/IntegerField/IntegerFieldInput';
import { IntegerFieldInputAndSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/IntegerField/IntegerFieldInputAndSlider';
import { IntegerFieldSlider } from 'features/nodes/components/flow/nodes/Invocation/fields/IntegerField/IntegerFieldSlider';
import { StringFieldDropdown } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldDropdown';
import { StringFieldInput } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldInput';
import { StringFieldTextarea } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/StringFieldTextarea';
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import {
isBoardFieldInputInstance,
isBoardFieldInputTemplate,
@@ -62,6 +63,8 @@ import {
isIntegerGeneratorFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isLLaVAModelFieldInputInstance,
isLLaVAModelFieldInputTemplate,
isLoRAModelFieldInputInstance,
isLoRAModelFieldInputTemplate,
isMainModelFieldInputInstance,
@@ -112,6 +115,7 @@ import FluxReduxModelFieldInputComponent from './inputs/FluxReduxModelFieldInput
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
import LLaVAModelFieldInputComponent from './inputs/LLaVAModelFieldInputComponent';
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
@@ -132,7 +136,7 @@ type Props = {
export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props) => {
const field = useInputFieldInstance(nodeId, fieldName);
const template = useInputFieldTemplate(nodeId, fieldName);
const template = useInputFieldTemplateOrThrow(nodeId, fieldName);
// When deciding which component to render, first we check the type of the template, which is more efficient than the
// instance type check. The instance type check uses zod and is slower.
@@ -148,7 +152,7 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
if (!isStringFieldInputInstance(field)) {
return null;
}
if (settings?.type !== 'string-field-config') {
if (!settings || settings.type !== 'string-field-config') {
if (template.ui_component === 'textarea') {
return <StringFieldTextarea nodeId={nodeId} field={field} fieldTemplate={template} />;
} else {
@@ -159,8 +163,10 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <StringFieldInput nodeId={nodeId} field={field} fieldTemplate={template} />;
} else if (settings.component === 'textarea') {
return <StringFieldTextarea nodeId={nodeId} field={field} fieldTemplate={template} />;
} else if (settings.component === 'dropdown') {
return <StringFieldDropdown nodeId={nodeId} field={field} fieldTemplate={template} settings={settings} />;
} else {
assert<Equals<never, typeof settings.component>>(false, 'Unexpected settings.component');
assert<Equals<never, typeof settings>>(false, 'Unexpected settings');
}
}
@@ -322,6 +328,13 @@ export const InputFieldRenderer = memo(({ nodeId, fieldName, settings }: Props)
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isLLaVAModelFieldInputTemplate(template)) {
if (!isLLaVAModelFieldInputInstance(field)) {
return null;
}
return <LLaVAModelFieldInputComponent nodeId={nodeId} field={field} fieldTemplate={template} />;
}
if (isFluxVAEModelFieldInputTemplate(template)) {
if (!isFluxVAEModelFieldInputInstance(field)) {
return null;

View File

@@ -9,7 +9,7 @@ import {
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
import { useInputFieldLabel } from 'features/nodes/hooks/useInputFieldLabel';
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
import { useInputFieldTemplateTitle } from 'features/nodes/hooks/useInputFieldTemplateTitle';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import { HANDLE_TOOLTIP_OPEN_DELAY, NO_FIT_ON_DOUBLE_CLICK_CLASS } from 'features/nodes/types/constants';
@@ -43,7 +43,7 @@ interface Props {
export const InputFieldTitle = memo((props: Props) => {
const { nodeId, fieldName, isInvalid, isDragging } = props;
const inputRef = useRef<HTMLInputElement>(null);
const label = useInputFieldLabel(nodeId, fieldName);
const label = useInputFieldLabelSafe(nodeId, fieldName);
const fieldTemplateTitle = useInputFieldTemplateTitle(nodeId, fieldName);
const { t } = useTranslation();
const isConnected = useInputFieldIsConnected(nodeId, fieldName);

View File

@@ -1,7 +1,7 @@
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
import { useInputFieldErrors } from 'features/nodes/hooks/useInputFieldErrors';
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { startCase } from 'lodash-es';
import { memo, useMemo } from 'react';
@@ -16,7 +16,7 @@ export const InputFieldTooltipContent = memo(({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const fieldInstance = useInputFieldInstance(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldErrors = useInputFieldErrors(nodeId, fieldName);

View File

@@ -1,27 +0,0 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { InputFieldWrapper } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldWrapper';
import { useInputFieldName } from 'features/nodes/hooks/useInputFieldName';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
nodeId: string;
fieldName: string;
};
export const InputFieldUnknownPlaceholder = memo(({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const name = useInputFieldName(nodeId, fieldName);
return (
<InputFieldWrapper>
<FormControl isInvalid={true} alignItems="stretch" justifyContent="center" gap={2} h="full" w="full">
<FormLabel display="flex" mb={0} px={1} py={2} gap={2}>
{t('nodes.unknownInput', { name })}
</FormLabel>
</FormControl>
</InputFieldWrapper>
);
});
InputFieldUnknownPlaceholder.displayName = 'InputFieldUnknownPlaceholder';

View File

@@ -1,7 +1,10 @@
import { OutputFieldUnknownPlaceholder } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldUnknownPlaceholder';
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { OutputFieldWrapper } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldWrapper';
import { useOutputFieldName } from 'features/nodes/hooks/useOutputFieldName';
import { useOutputFieldTemplateExists } from 'features/nodes/hooks/useOutputFieldTemplateExists';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = PropsWithChildren<{
nodeId: string;
@@ -12,10 +15,27 @@ export const OutputFieldGate = memo(({ nodeId, fieldName, children }: Props) =>
const hasTemplate = useOutputFieldTemplateExists(nodeId, fieldName);
if (!hasTemplate) {
return <OutputFieldUnknownPlaceholder nodeId={nodeId} fieldName={fieldName} />;
return <Fallback nodeId={nodeId} fieldName={fieldName} />;
}
return children;
});
OutputFieldGate.displayName = 'OutputFieldGate';
const Fallback = memo(({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const name = useOutputFieldName(nodeId, fieldName);
return (
<OutputFieldWrapper>
<FormControl isInvalid={true} alignItems="stretch" justifyContent="space-between" gap={2} h="full" w="full">
<FormLabel display="flex" alignItems="center" h="full" color="error.300" mb={0} px={1} gap={2}>
{t('nodes.unexpectedField_withName', { name })}
</FormLabel>
</FormControl>
</OutputFieldWrapper>
);
});
Fallback.displayName = 'Fallback';

View File

@@ -1,27 +0,0 @@
import { FormControl, FormLabel } from '@invoke-ai/ui-library';
import { OutputFieldWrapper } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldWrapper';
import { useOutputFieldName } from 'features/nodes/hooks/useOutputFieldName';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
nodeId: string;
fieldName: string;
};
export const OutputFieldUnknownPlaceholder = memo(({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const name = useOutputFieldName(nodeId, fieldName);
return (
<OutputFieldWrapper>
<FormControl isInvalid={true} alignItems="stretch" justifyContent="space-between" gap={2} h="full" w="full">
<FormLabel display="flex" alignItems="center" h="full" color="error.300" mb={0} px={1} gap={2}>
{t('nodes.unknownOutput', { name })}
</FormLabel>
</FormControl>
</OutputFieldWrapper>
);
});
OutputFieldUnknownPlaceholder.displayName = 'OutputFieldUnknownPlaceholder';

View File

@@ -0,0 +1,36 @@
import { Select } from '@invoke-ai/ui-library';
import type { FieldComponentProps } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/types';
import { useStringField } from 'features/nodes/components/flow/nodes/Invocation/fields/StringField/useStringField';
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { StringFieldInputInstance, StringFieldInputTemplate } from 'features/nodes/types/field';
import type { NodeFieldStringSettings } from 'features/nodes/types/workflow';
import { memo } from 'react';
export const StringFieldDropdown = memo(
(
props: FieldComponentProps<
StringFieldInputInstance,
StringFieldInputTemplate,
{ settings: Extract<NodeFieldStringSettings, { component: 'dropdown' }> }
>
) => {
const { value, onChange } = useStringField(props);
return (
<Select
onChange={onChange}
className={`${NO_DRAG_CLASS} ${NO_PAN_CLASS} ${NO_WHEEL_CLASS}`}
isDisabled={props.settings.options.length === 0}
value={value}
>
{props.settings.options.map((choice, i) => (
<option key={`${i}_${choice.value}`} value={choice.value}>
{choice.label || choice.value || `Option ${i + 1}`}
</option>
))}
</Select>
);
}
);
StringFieldDropdown.displayName = 'StringFieldDropdown';

View File

@@ -4,12 +4,21 @@ import { useStringField } from 'features/nodes/components/flow/nodes/Invocation/
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { StringFieldInputInstance, StringFieldInputTemplate } from 'features/nodes/types/field';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const StringFieldInput = memo(
(props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>) => {
const { value, onChange } = useStringField(props);
const { t } = useTranslation();
return <Input className={`${NO_DRAG_CLASS} ${NO_PAN_CLASS} ${NO_WHEEL_CLASS}`} value={value} onChange={onChange} />;
return (
<Input
className={`${NO_DRAG_CLASS} ${NO_PAN_CLASS} ${NO_WHEEL_CLASS}`}
placeholder={t('workflows.emptyStringPlaceholder')}
value={value}
onChange={onChange}
/>
);
}
);

View File

@@ -4,14 +4,17 @@ import { useStringField } from 'features/nodes/components/flow/nodes/Invocation/
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { StringFieldInputInstance, StringFieldInputTemplate } from 'features/nodes/types/field';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const StringFieldTextarea = memo(
(props: FieldComponentProps<StringFieldInputInstance, StringFieldInputTemplate>) => {
const { t } = useTranslation();
const { value, onChange } = useStringField(props);
return (
<Textarea
className={`${NO_DRAG_CLASS} ${NO_PAN_CLASS} ${NO_WHEEL_CLASS}`}
placeholder={t('workflows.emptyStringPlaceholder')}
value={value}
onChange={onChange}
h="full"

View File

@@ -10,7 +10,7 @@ export const useStringField = (props: FieldComponentProps<StringFieldInputInstan
const dispatch = useAppDispatch();
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
(e: ChangeEvent<HTMLInputElement | HTMLTextAreaElement | HTMLSelectElement>) => {
dispatch(
fieldStringValueChanged({
nodeId,

View File

@@ -24,7 +24,7 @@ import { PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
const sx = {
borderWidth: 1,

View File

@@ -24,7 +24,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebounce } from 'use-debounce';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
export const FloatGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<FloatGeneratorFieldInputInstance, FloatGeneratorFieldInputTemplate>) => {

View File

@@ -24,7 +24,7 @@ import type { ImageDTO } from 'services/api/types';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
const sx = {
borderWidth: 1,

View File

@@ -17,7 +17,7 @@ import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
export const ImageGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<ImageGeneratorFieldInputInstance, ImageGeneratorFieldInputTemplate>) => {

View File

@@ -27,7 +27,7 @@ import { PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
const sx = {
borderWidth: 1,

View File

@@ -27,7 +27,7 @@ import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebounce } from 'use-debounce';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
export const IntegerGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<IntegerGeneratorFieldInputInstance, IntegerGeneratorFieldInputTemplate>) => {

View File

@@ -0,0 +1,55 @@
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldLLaVAModelValueChanged } from 'features/nodes/store/nodesSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useLLaVAModels } from 'services/api/hooks/modelsByType';
import type { LlavaOnevisionConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<LLaVAModelFieldInputInstance, LLaVAModelFieldInputTemplate>;
const LLaVAModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useLLaVAModels();
const _onChange = useCallback(
(value: LlavaOnevisionConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldLLaVAModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<FormControl className={`${NO_WHEEL_CLASS} ${NO_DRAG_CLASS}`} isInvalid={!value} isDisabled={!options.length}>
<Combobox
value={value}
placeholder={placeholder}
noOptionsMessage={noOptionsMessage}
options={options}
onChange={onChange}
/>
</FormControl>
);
};
export default memo(LLaVAModelFieldInputComponent);

View File

@@ -1,5 +1,5 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, Divider, Flex, FormLabel, Grid, GridItem, IconButton, Input } from '@invoke-ai/ui-library';
import { Button, Divider, Flex, Grid, GridItem, IconButton, Input, Text } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
@@ -17,7 +17,7 @@ import { PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
const sx = {
borderWidth: 1,
@@ -107,42 +107,6 @@ export const StringFieldCollectionInputComponent = memo(
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
type StringListItemContentProps = {
value: string;
index: number;
onRemoveString: (index: number) => void;
onChangeString: (index: number, value: string) => void;
};
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveString(index);
}, [index, onRemoveString]);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChangeString(index, e.target.value);
},
[index, onChangeString]
);
return (
<Flex alignItems="center" gap={1}>
<Input size="xs" resize="none" value={value} onChange={onChange} />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.remove')}
tooltip={t('common.remove')}
/>
</Flex>
);
});
StringListItemContent.displayName = 'StringListItemContent';
type ListItemContentProps = {
value: string;
index: number;
@@ -166,19 +130,26 @@ const ListItemContent = memo(({ value, index, onRemoveString, onChangeString }:
return (
<>
<GridItem>
<FormLabel ps={1} m={0}>
<Text variant="subtext" textAlign="center" minW={8}>
{index + 1}.
</FormLabel>
</Text>
</GridItem>
<GridItem>
<Input size="sm" resize="none" value={value} onChange={onChange} />
<Input
placeholder={t('workflows.emptyStringPlaceholder')}
size="sm"
resize="none"
value={value}
onChange={onChange}
/>
</GridItem>
<GridItem>
<IconButton
tabIndex={-1}
size="sm"
variant="link"
alignSelf="stretch"
minW={8}
minH={8}
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.delete')}

View File

@@ -21,7 +21,7 @@ import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
export const StringGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<StringGeneratorFieldInputInstance, StringGeneratorFieldInputTemplate>) => {

View File

@@ -1,6 +1,7 @@
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { ContainerElementSettings } from 'features/nodes/components/sidePanel/builder/ContainerElementSettings';
import { useDepthContext } from 'features/nodes/components/sidePanel/builder/contexts';
import { NodeFieldElementSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementSettings';
@@ -47,8 +48,16 @@ export const FormElementEditModeHeader = memo(({ element, dragHandleRef, ...rest
<Label element={element} />
<Spacer />
{isContainerElement(element) && <ContainerElementSettings element={element} />}
{isNodeFieldElement(element) && <ZoomToNodeButton element={element} />}
{isNodeFieldElement(element) && <NodeFieldElementSettings element={element} />}
{isNodeFieldElement(element) && (
<InputFieldGate
nodeId={element.data.fieldIdentifier.nodeId}
fieldName={element.data.fieldIdentifier.fieldName}
fallback={null} // Do not render these buttons if the field is not found
>
<ZoomToNodeButton element={element} />
<NodeFieldElementSettings element={element} />
</InputFieldGate>
)}
<RemoveElementButton element={element} />
</Flex>
);

View File

@@ -1,5 +1,4 @@
import { useAppSelector } from 'app/store/storeHooks';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { NodeFieldElementEditMode } from 'features/nodes/components/sidePanel/builder/NodeFieldElementEditMode';
import { NodeFieldElementViewMode } from 'features/nodes/components/sidePanel/builder/NodeFieldElementViewMode';
import { selectWorkflowMode, useElement } from 'features/nodes/store/workflowSlice';
@@ -15,19 +14,11 @@ export const NodeFieldElement = memo(({ id }: { id: string }) => {
}
if (mode === 'view') {
return (
<InputFieldGate nodeId={el.data.fieldIdentifier.nodeId} fieldName={el.data.fieldIdentifier.fieldName}>
<NodeFieldElementViewMode el={el} />
</InputFieldGate>
);
return <NodeFieldElementViewMode el={el} />;
}
// mode === 'edit'
return (
<InputFieldGate nodeId={el.data.fieldIdentifier.nodeId} fieldName={el.data.fieldIdentifier.fieldName}>
<NodeFieldElementEditMode el={el} />
</InputFieldGate>
);
return <NodeFieldElementEditMode el={el} />;
});
NodeFieldElement.displayName = 'NodeFieldElement';

View File

@@ -2,8 +2,8 @@ import { FormHelperText, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { linkifyOptions, linkifySx } from 'common/components/linkify';
import { useEditable } from 'common/hooks/useEditable';
import { useInputFieldDescription } from 'features/nodes/hooks/useInputFieldDescription';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { fieldDescriptionChanged } from 'features/nodes/store/nodesSlice';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import Linkify from 'linkify-react';
@@ -13,8 +13,8 @@ export const NodeFieldElementDescriptionEditable = memo(({ el }: { el: NodeField
const { data } = el;
const { fieldIdentifier } = data;
const dispatch = useAppDispatch();
const description = useInputFieldDescription(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplate(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const inputRef = useRef<HTMLTextAreaElement>(null);
const onChange = useCallback(

View File

@@ -1,5 +1,6 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, FormControl } from '@invoke-ai/ui-library';
import { Box, Divider, Flex, FormControl } from '@invoke-ai/ui-library';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
import { useFormElementDnd } from 'features/nodes/components/sidePanel/builder/dnd-hooks';
@@ -8,9 +9,11 @@ import { FormElementEditModeContent } from 'features/nodes/components/sidePanel/
import { FormElementEditModeHeader } from 'features/nodes/components/sidePanel/builder/FormElementEditModeHeader';
import { NodeFieldElementDescriptionEditable } from 'features/nodes/components/sidePanel/builder/NodeFieldElementDescriptionEditable';
import { NodeFieldElementLabelEditable } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabelEditable';
import { NodeFieldElementStringDropdownSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementStringDropdownSettings';
import { useMouseOverFormField, useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
import type { RefObject } from 'react';
import { memo, useRef } from 'react';
const sx: SystemStyleObject = {
@@ -31,25 +34,11 @@ export const NodeFieldElementEditMode = memo(({ el }: { el: NodeFieldElement })
const dragHandleRef = useRef<HTMLDivElement>(null);
const [activeDropRegion, isDragging] = useFormElementDnd(el.id, draggableRef, dragHandleRef);
const containerCtx = useContainerContext();
const { id, data } = el;
const { fieldIdentifier, showDescription } = data;
const { id } = el;
return (
<Flex ref={draggableRef} id={id} className={NODE_FIELD_CLASS_NAME} sx={sx} data-parent-layout={containerCtx.layout}>
<FormElementEditModeHeader dragHandleRef={dragHandleRef} element={el} data-is-dragging={isDragging} />
<FormElementEditModeContent data-is-dragging={isDragging} p={4}>
<FormControl flex="1 1 0" orientation="vertical">
<NodeFieldElementLabelEditable el={el} />
<Flex w="full" gap={4}>
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
settings={data.settings}
/>
</Flex>
{showDescription && <NodeFieldElementDescriptionEditable el={el} />}
</FormControl>
</FormElementEditModeContent>
<NodeFieldElementEditModeContent dragHandleRef={dragHandleRef} el={el} isDragging={isDragging} />
<NodeFieldElementOverlay element={el} />
<DndListDropIndicator activeDropRegion={activeDropRegion} gap="var(--invoke-space-4)" />
</Flex>
@@ -58,6 +47,48 @@ export const NodeFieldElementEditMode = memo(({ el }: { el: NodeFieldElement })
NodeFieldElementEditMode.displayName = 'NodeFieldElementEditMode';
const NodeFieldElementEditModeContent = memo(
({
el,
dragHandleRef,
isDragging,
}: {
el: NodeFieldElement;
dragHandleRef: RefObject<HTMLDivElement>;
isDragging: boolean;
}) => {
const { id, data } = el;
const { fieldIdentifier, showDescription } = data;
return (
<>
<FormElementEditModeHeader dragHandleRef={dragHandleRef} element={el} data-is-dragging={isDragging} />
<FormElementEditModeContent data-is-dragging={isDragging} p={4}>
<InputFieldGate nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName}>
<FormControl flex="1 1 0" orientation="vertical">
<NodeFieldElementLabelEditable el={el} />
<Flex w="full" gap={4}>
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
settings={data.settings}
/>
</Flex>
{showDescription && <NodeFieldElementDescriptionEditable el={el} />}
{data.settings?.type === 'string-field-config' && data.settings.component === 'dropdown' && (
<>
<Divider />
<NodeFieldElementStringDropdownSettings id={id} settings={data.settings} />
</>
)}
</FormControl>
</InputFieldGate>
</FormElementEditModeContent>
</>
);
}
);
NodeFieldElementEditModeContent.displayName = 'NodeFieldElementEditModeContent';
const nodeFieldOverlaySx: SystemStyleObject = {
position: 'absolute',
top: 0,

View File

@@ -14,42 +14,48 @@ import { useTranslation } from 'react-i18next';
type Props = {
id: string;
config: NodeFieldFloatSettings;
settings: NodeFieldFloatSettings;
nodeId: string;
fieldName: string;
fieldTemplate: FloatFieldInputTemplate;
};
export const NodeFieldElementFloatSettings = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
export const NodeFieldElementFloatSettings = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
return (
<>
<SettingComponent id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
<SettingMin id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
<SettingMax id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
<SettingComponent
id={id}
settings={settings}
nodeId={nodeId}
fieldName={fieldName}
fieldTemplate={fieldTemplate}
/>
<SettingMin id={id} settings={settings} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
<SettingMax id={id} settings={settings} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
</>
);
});
NodeFieldElementFloatSettings.displayName = 'NodeFieldElementFloatSettings';
const SettingComponent = memo(({ id, config }: Props) => {
const SettingComponent = memo(({ id, settings }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChangeComponent = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const newConfig: NodeFieldFloatSettings = {
...config,
const newSettings: NodeFieldFloatSettings = {
...settings,
component: zNumberComponent.parse(e.target.value),
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
},
[config, dispatch, id]
[settings, dispatch, id]
);
return (
<FormControl orientation="vertical">
<FormLabel flex={1}>{t('workflows.builder.component')}</FormLabel>
<Select value={config.component} onChange={onChangeComponent} size="sm">
<Select value={settings.component} onChange={onChangeComponent} size="sm">
<option value="number-input">{t('workflows.builder.numberInput')}</option>
<option value="slider">{t('workflows.builder.slider')}</option>
<option value="number-input-and-slider">{t('workflows.builder.both')}</option>
@@ -59,36 +65,36 @@ const SettingComponent = memo(({ id, config }: Props) => {
});
SettingComponent.displayName = 'SettingComponent';
const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
const SettingMin = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const field = useInputFieldInstance<FloatFieldInputInstance>(nodeId, fieldName);
const floatField = useFloatField(nodeId, fieldName, fieldTemplate);
const onToggleSetting = useCallback(() => {
const newConfig: NodeFieldFloatSettings = {
...config,
min: config.min !== undefined ? undefined : floatField.min,
const onToggleOverride = useCallback(() => {
const newSettings: NodeFieldFloatSettings = {
...settings,
min: settings.min !== undefined ? undefined : floatField.min,
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
}, [config, dispatch, floatField, id]);
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
}, [settings, dispatch, floatField, id]);
const onChange = useCallback(
(min: number) => {
const newConfig: NodeFieldFloatSettings = {
...config,
const newSettings: NodeFieldFloatSettings = {
...settings,
min,
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
// We may need to update the value if it is outside the new min/max range
const constrained = constrainNumber(field.value, floatField, newConfig);
const constrained = constrainNumber(field.value, floatField, newSettings);
if (field.value !== constrained) {
dispatch(fieldFloatValueChanged({ nodeId, fieldName, value: constrained }));
}
},
[config, dispatch, field.value, fieldName, floatField, id, nodeId]
[settings, dispatch, field.value, fieldName, floatField, id, nodeId]
);
const constraintMin = useMemo(
@@ -97,20 +103,20 @@ const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
);
const constraintMax = useMemo(
() => (config.max ?? floatField.max) - floatField.step,
[config.max, floatField.max, floatField.step]
() => (settings.max ?? floatField.max) - floatField.step,
[settings.max, floatField.max, floatField.step]
);
return (
<FormControl orientation="vertical">
<Flex justifyContent="space-between" w="full" alignItems="center">
<FormLabel m={0}>{t('workflows.builder.minimum')}</FormLabel>
<Switch isChecked={config.min !== undefined} onChange={onToggleSetting} size="sm" />
<Switch isChecked={settings.min !== undefined} onChange={onToggleOverride} size="sm" />
</Flex>
<CompositeNumberInput
w="full"
isDisabled={config.min === undefined}
value={config.min === undefined ? (`${floatField.min} (inherited)` as unknown as number) : config.min}
isDisabled={settings.min === undefined}
value={settings.min === undefined ? (`${floatField.min} (inherited)` as unknown as number) : settings.min}
onChange={onChange}
min={constraintMin}
max={constraintMax}
@@ -121,41 +127,41 @@ const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
});
SettingMin.displayName = 'SettingMin';
const SettingMax = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
const SettingMax = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const field = useInputFieldInstance<FloatFieldInputInstance>(nodeId, fieldName);
const floatField = useFloatField(nodeId, fieldName, fieldTemplate);
const onToggleSetting = useCallback(() => {
const newConfig: NodeFieldFloatSettings = {
...config,
max: config.max !== undefined ? undefined : floatField.max,
const onToggleOverride = useCallback(() => {
const newSettings: NodeFieldFloatSettings = {
...settings,
max: settings.max !== undefined ? undefined : floatField.max,
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
}, [config, dispatch, floatField, id]);
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
}, [settings, dispatch, floatField, id]);
const onChange = useCallback(
(max: number) => {
const newConfig: NodeFieldFloatSettings = {
...config,
const newSettings: NodeFieldFloatSettings = {
...settings,
max,
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
// We may need to update the value if it is outside the new min/max range
const constrained = constrainNumber(field.value, floatField, newConfig);
const constrained = constrainNumber(field.value, floatField, newSettings);
if (field.value !== constrained) {
dispatch(fieldFloatValueChanged({ nodeId, fieldName, value: constrained }));
}
},
[config, dispatch, field.value, fieldName, floatField, id, nodeId]
[settings, dispatch, field.value, fieldName, floatField, id, nodeId]
);
const constraintMin = useMemo(
() => (config.min ?? floatField.min) + floatField.step,
[config.min, floatField.min, floatField.step]
() => (settings.min ?? floatField.min) + floatField.step,
[settings.min, floatField.min, floatField.step]
);
const constraintMax = useMemo(
@@ -167,12 +173,12 @@ const SettingMax = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
<FormControl orientation="vertical">
<Flex justifyContent="space-between" w="full" alignItems="center">
<FormLabel m={0}>{t('workflows.builder.maximum')}</FormLabel>
<Switch isChecked={config.max !== undefined} onChange={onToggleSetting} size="sm" />
<Switch isChecked={settings.max !== undefined} onChange={onToggleOverride} size="sm" />
</Flex>
<CompositeNumberInput
w="full"
isDisabled={config.max === undefined}
value={config.max === undefined ? (`${floatField.max} (inherited)` as unknown as number) : config.max}
isDisabled={settings.max === undefined}
value={settings.max === undefined ? (`${floatField.max} (inherited)` as unknown as number) : settings.max}
onChange={onChange}
min={constraintMin}
max={constraintMax}

View File

@@ -15,42 +15,48 @@ import { useTranslation } from 'react-i18next';
type Props = {
id: string;
config: NodeFieldIntegerSettings;
settings: NodeFieldIntegerSettings;
nodeId: string;
fieldName: string;
fieldTemplate: IntegerFieldInputTemplate;
};
export const NodeFieldElementIntegerSettings = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
export const NodeFieldElementIntegerSettings = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
return (
<>
<SettingComponent id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
<SettingMin id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
<SettingMax id={id} config={config} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
<SettingComponent
id={id}
settings={settings}
nodeId={nodeId}
fieldName={fieldName}
fieldTemplate={fieldTemplate}
/>
<SettingMin id={id} settings={settings} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
<SettingMax id={id} settings={settings} nodeId={nodeId} fieldName={fieldName} fieldTemplate={fieldTemplate} />
</>
);
});
NodeFieldElementIntegerSettings.displayName = 'NodeFieldElementIntegerSettings';
const SettingComponent = memo(({ id, config }: Props) => {
const SettingComponent = memo(({ id, settings }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChangeComponent = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const newConfig: NodeFieldIntegerSettings = {
...config,
const newSettings: NodeFieldIntegerSettings = {
...settings,
component: zNumberComponent.parse(e.target.value),
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
},
[config, dispatch, id]
[settings, dispatch, id]
);
return (
<FormControl orientation="vertical">
<FormLabel flex={1}>{t('workflows.builder.component')}</FormLabel>
<Select value={config.component} onChange={onChangeComponent} size="sm">
<Select value={settings.component} onChange={onChangeComponent} size="sm">
<option value="number-input">{t('workflows.builder.numberInput')}</option>
<option value="slider">{t('workflows.builder.slider')}</option>
<option value="number-input-and-slider">{t('workflows.builder.both')}</option>
@@ -60,37 +66,37 @@ const SettingComponent = memo(({ id, config }: Props) => {
});
SettingComponent.displayName = 'SettingComponent';
const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
const SettingMin = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const field = useInputFieldInstance<IntegerFieldInputInstance>(nodeId, fieldName);
const integerField = useIntegerField(nodeId, fieldName, fieldTemplate);
const onToggleSetting = useCallback(() => {
const newConfig: NodeFieldIntegerSettings = {
...config,
min: config.min !== undefined ? undefined : integerField.min,
const onToggleOverride = useCallback(() => {
const newSettings: NodeFieldIntegerSettings = {
...settings,
min: settings.min !== undefined ? undefined : integerField.min,
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
}, [config, dispatch, integerField.min, id]);
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
}, [settings, dispatch, integerField.min, id]);
const onChange = useCallback(
(min: number) => {
const newConfig: NodeFieldIntegerSettings = {
...config,
const newSettings: NodeFieldIntegerSettings = {
...settings,
min,
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
// We may need to update the value if it is outside the new min/max range
const constrained = constrainNumber(field.value, integerField, newConfig);
const constrained = constrainNumber(field.value, integerField, newSettings);
if (field.value !== constrained) {
dispatch(fieldIntegerValueChanged({ nodeId, fieldName, value: constrained }));
}
},
[config, dispatch, id, field, integerField, nodeId, fieldName]
[settings, dispatch, id, field, integerField, nodeId, fieldName]
);
const constraintMin = useMemo(
@@ -99,20 +105,20 @@ const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
);
const constraintMax = useMemo(
() => (config.max ?? integerField.max) - integerField.step,
[config.max, integerField.max, integerField.step]
() => (settings.max ?? integerField.max) - integerField.step,
[settings.max, integerField.max, integerField.step]
);
return (
<FormControl orientation="vertical">
<Flex justifyContent="space-between" w="full" alignItems="center">
<FormLabel m={0}>{t('workflows.builder.minimum')}</FormLabel>
<Switch isChecked={config.min !== undefined} onChange={onToggleSetting} size="sm" />
<Switch isChecked={settings.min !== undefined} onChange={onToggleOverride} size="sm" />
</Flex>
<CompositeNumberInput
w="full"
isDisabled={config.min === undefined}
value={config.min ?? (`${integerField.min} (inherited)` as unknown as number)}
isDisabled={settings.min === undefined}
value={settings.min ?? (`${integerField.min} (inherited)` as unknown as number)}
onChange={onChange}
min={constraintMin}
max={constraintMax}
@@ -123,42 +129,42 @@ const SettingMin = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
});
SettingMin.displayName = 'SettingMin';
const SettingMax = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props) => {
const SettingMax = memo(({ id, settings, nodeId, fieldName, fieldTemplate }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const field = useInputFieldInstance<IntegerFieldInputInstance>(nodeId, fieldName);
const integerField = useIntegerField(nodeId, fieldName, fieldTemplate);
const onToggleSetting = useCallback(() => {
const newConfig: NodeFieldIntegerSettings = {
...config,
max: config.max !== undefined ? undefined : integerField.max,
const onToggleOverride = useCallback(() => {
const newSettings: NodeFieldIntegerSettings = {
...settings,
max: settings.max !== undefined ? undefined : integerField.max,
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
}, [config, dispatch, integerField.max, id]);
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
}, [settings, dispatch, integerField.max, id]);
const onChange = useCallback(
(max: number) => {
const newConfig: NodeFieldIntegerSettings = {
...config,
const newSettings: NodeFieldIntegerSettings = {
...settings,
max,
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
// We may need to update the value if it is outside the new min/max range
const constrained = constrainNumber(field.value, integerField, newConfig);
const constrained = constrainNumber(field.value, integerField, newSettings);
if (field.value !== constrained) {
dispatch(fieldIntegerValueChanged({ nodeId, fieldName, value: constrained }));
}
},
[config, dispatch, field.value, fieldName, integerField, id, nodeId]
[settings, dispatch, field.value, fieldName, integerField, id, nodeId]
);
const constraintMin = useMemo(
() => (config.min ?? integerField.min) + integerField.step,
[config.min, integerField.min, integerField.step]
() => (settings.min ?? integerField.min) + integerField.step,
[settings.min, integerField.min, integerField.step]
);
const constraintMax = useMemo(
@@ -170,12 +176,12 @@ const SettingMax = memo(({ id, config, nodeId, fieldName, fieldTemplate }: Props
<FormControl orientation="vertical">
<Flex justifyContent="space-between" w="full" alignItems="center">
<FormLabel m={0}>{t('workflows.builder.maximum')}</FormLabel>
<Switch isChecked={config.max !== undefined} onChange={onToggleSetting} size="sm" />
<Switch isChecked={settings.max !== undefined} onChange={onToggleOverride} size="sm" />
</Flex>
<CompositeNumberInput
w="full"
isDisabled={config.max === undefined}
value={config.max ?? (`${integerField.max} (inherited)` as unknown as number)}
isDisabled={settings.max === undefined}
value={settings.max ?? (`${integerField.max} (inherited)` as unknown as number)}
onChange={onChange}
min={constraintMin}
max={constraintMax}

View File

@@ -1,15 +1,15 @@
import { Flex, FormLabel, Spacer } from '@invoke-ai/ui-library';
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
import { useInputFieldLabel } from 'features/nodes/hooks/useInputFieldLabel';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { memo, useMemo } from 'react';
export const NodeFieldElementLabel = memo(({ el }: { el: NodeFieldElement }) => {
const { data } = el;
const { fieldIdentifier } = data;
const label = useInputFieldLabel(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplate(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const _label = useMemo(() => label || fieldTemplate.title, [label, fieldTemplate.title]);

View File

@@ -2,8 +2,8 @@ import { Flex, FormLabel, Input, Spacer } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEditable } from 'common/hooks/useEditable';
import { NodeFieldElementResetToInitialValueIconButton } from 'features/nodes/components/flow/nodes/Invocation/fields/NodeFieldElementResetToInitialValueIconButton';
import { useInputFieldLabel } from 'features/nodes/hooks/useInputFieldLabel';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldLabelSafe } from 'features/nodes/hooks/useInputFieldLabelSafe';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { fieldLabelChanged } from 'features/nodes/store/nodesSlice';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { memo, useCallback, useRef } from 'react';
@@ -12,8 +12,8 @@ export const NodeFieldElementLabelEditable = memo(({ el }: { el: NodeFieldElemen
const { data } = el;
const { fieldIdentifier } = data;
const dispatch = useAppDispatch();
const label = useInputFieldLabel(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplate(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const label = useInputFieldLabelSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const inputRef = useRef<HTMLInputElement>(null);
const onChange = useCallback(

View File

@@ -15,7 +15,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { NodeFieldElementFloatSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementFloatSettings';
import { NodeFieldElementIntegerSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementIntegerSettings';
import { NodeFieldElementStringSettings } from 'features/nodes/components/sidePanel/builder/NodeFieldElementStringSettings';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
import {
isFloatFieldInputTemplate,
@@ -36,7 +36,7 @@ export const NodeFieldElementSettings = memo(({ element }: { element: NodeFieldE
const { id, data } = element;
const { showDescription, fieldIdentifier } = data;
const { nodeId, fieldName } = fieldIdentifier;
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const { t } = useTranslation();
const dispatch = useAppDispatch();
@@ -79,7 +79,7 @@ export const NodeFieldElementSettings = memo(({ element }: { element: NodeFieldE
{settings?.type === 'integer-field-config' && isIntegerFieldInputTemplate(fieldTemplate) && (
<NodeFieldElementIntegerSettings
id={id}
config={settings}
settings={settings}
nodeId={nodeId}
fieldName={fieldName}
fieldTemplate={fieldTemplate}
@@ -88,13 +88,15 @@ export const NodeFieldElementSettings = memo(({ element }: { element: NodeFieldE
{settings?.type === 'float-field-config' && isFloatFieldInputTemplate(fieldTemplate) && (
<NodeFieldElementFloatSettings
id={id}
config={settings}
settings={settings}
nodeId={nodeId}
fieldName={fieldName}
fieldTemplate={fieldTemplate}
/>
)}
{settings?.type === 'string-field-config' && <NodeFieldElementStringSettings id={id} config={settings} />}
{settings?.type === 'string-field-config' && (
<NodeFieldElementStringSettings id={id} settings={settings} />
)}
</Flex>
</PopoverBody>
</PopoverContent>

View File

@@ -0,0 +1,205 @@
import { Button, ButtonGroup, Divider, Flex, Grid, GridItem, IconButton, Input, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
import { NO_DRAG_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import { getDefaultStringOption, type NodeFieldStringDropdownSettings } from 'features/nodes/types/workflow';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiPlusBold, PiXBold } from 'react-icons/pi';
const overlayscrollbarsOptions = getOverlayScrollbarsParams({}).options;
export const NodeFieldElementStringDropdownSettings = memo(
({ id, settings }: { id: string; settings: NodeFieldStringDropdownSettings }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onRemoveOption = useCallback(
(index: number) => {
const options = [...settings.options];
options.splice(index, 1);
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
},
[settings, dispatch, id]
);
const onChangeOptionValue = useCallback(
(index: number, value: string) => {
if (!settings.options[index]) {
return;
}
const option = { ...settings.options[index], value };
const options = [...settings.options];
options[index] = option;
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
},
[dispatch, id, settings]
);
const onChangeOptionLabel = useCallback(
(index: number, label: string) => {
if (!settings.options[index]) {
return;
}
const option = { ...settings.options[index], label };
const options = [...settings.options];
options[index] = option;
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
},
[dispatch, id, settings]
);
const onAddOption = useCallback(() => {
const options = [...settings.options, getDefaultStringOption()];
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
}, [dispatch, id, settings]);
const onResetOptions = useCallback(() => {
const options = [getDefaultStringOption()];
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: { ...settings, options } } }));
}, [dispatch, id, settings]);
return (
<Flex
className={NO_DRAG_CLASS}
position="relative"
w="full"
h="auto"
maxH={64}
alignItems="stretch"
justifyContent="center"
borderRadius="base"
flexDir="column"
gap={2}
>
<ButtonGroup isAttached={false} w="full">
<Button onClick={onAddOption} variant="ghost" flex={1} leftIcon={<PiPlusBold />}>
{t('workflows.builder.addOption')}
</Button>
<Button onClick={onResetOptions} variant="ghost" flex={1} leftIcon={<PiArrowCounterClockwiseBold />}>
{t('workflows.builder.resetOptions')}
</Button>
</ButtonGroup>
{settings.options.length > 0 && (
<>
<Divider />
<OverlayScrollbarsComponent
className={NO_WHEEL_CLASS}
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid gap={1} gridTemplateColumns="auto 1fr 1fr auto" gridTemplateRows="auto 1fr" alignItems="center">
<GridItem minW={8}>
<Text textAlign="center" variant="subtext">
#
</Text>
</GridItem>
<GridItem>
<Text textAlign="center" variant="subtext">
{t('common.label')}
</Text>
</GridItem>
<GridItem>
<Text textAlign="center" variant="subtext">
{t('common.value')}
</Text>
</GridItem>
<GridItem />
{settings.options.map(({ value, label }, index) => (
<ListItemContent
key={index}
value={value}
label={label}
index={index}
onRemoveOption={onRemoveOption}
onChangeOptionValue={onChangeOptionValue}
onChangeOptionLabel={onChangeOptionLabel}
isRemoveDisabled={settings.options.length <= 1}
/>
))}
</Grid>
</OverlayScrollbarsComponent>
</>
)}
</Flex>
);
}
);
NodeFieldElementStringDropdownSettings.displayName = 'NodeFieldElementStringDropdownSettings';
type ListItemContentProps = {
value: string;
label: string;
index: number;
onRemoveOption: (index: number) => void;
onChangeOptionValue: (index: number, value: string) => void;
onChangeOptionLabel: (index: number, label: string) => void;
isRemoveDisabled: boolean;
};
const ListItemContent = memo(
({
value,
label,
index,
onRemoveOption,
onChangeOptionValue,
onChangeOptionLabel,
isRemoveDisabled,
}: ListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveOption(index);
}, [index, onRemoveOption]);
const onChangeValue = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChangeOptionValue(index, e.target.value);
},
[index, onChangeOptionValue]
);
const onChangeLabel = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChangeOptionLabel(index, e.target.value);
},
[index, onChangeOptionLabel]
);
return (
<>
<GridItem>
<Text variant="subtext" textAlign="center">
{index + 1}.
</Text>
</GridItem>
<GridItem>
<Input size="sm" resize="none" placeholder="label" value={label} onChange={onChangeLabel} />
</GridItem>
<GridItem>
<Input size="sm" resize="none" placeholder="value" value={value} onChange={onChangeValue} />
</GridItem>
<GridItem>
<IconButton
tabIndex={-1}
size="sm"
variant="link"
minW={8}
minH={8}
onClick={onClickRemove}
isDisabled={isRemoveDisabled}
icon={<PiXBold />}
aria-label={t('common.delete')}
/>
</GridItem>
</>
);
}
);
ListItemContent.displayName = 'ListItemContent';

View File

@@ -1,36 +1,55 @@
import { FormControl, FormLabel, Select } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { formElementNodeFieldDataChanged } from 'features/nodes/store/workflowSlice';
import { type NodeFieldStringSettings, zStringComponent } from 'features/nodes/types/workflow';
import { getDefaultStringOption, type NodeFieldStringSettings, zStringComponent } from 'features/nodes/types/workflow';
import { omit } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const NodeFieldElementStringSettings = memo(
({ id, config }: { id: string; config: NodeFieldStringSettings }) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
type Props = {
id: string;
settings: NodeFieldStringSettings;
};
const onChangeComponent = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const newConfig: NodeFieldStringSettings = {
...config,
component: zStringComponent.parse(e.target.value),
export const NodeFieldElementStringSettings = memo(({ id, settings }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChangeComponent = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const component = zStringComponent.parse(e.target.value);
if (component === settings.component) {
// no change
return;
}
if (component === 'dropdown') {
// if the component is changing to dropdown, we need to set the options
const newSettings: NodeFieldStringSettings = {
...settings,
component,
options: [getDefaultStringOption()],
};
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newConfig } }));
},
[config, dispatch, id]
);
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
return;
} else {
// if the component is changing from dropdown, we need to remove the options
const newSettings: NodeFieldStringSettings = omit({ ...settings, component }, 'options');
dispatch(formElementNodeFieldDataChanged({ id, changes: { settings: newSettings } }));
}
},
[settings, dispatch, id]
);
return (
<FormControl>
<FormLabel flex={1}>{t('workflows.builder.component')}</FormLabel>
<Select value={config.component} onChange={onChangeComponent} size="sm">
<option value="input">{t('workflows.builder.singleLine')}</option>
<option value="textarea">{t('workflows.builder.multiLine')}</option>
</Select>
</FormControl>
);
}
);
return (
<FormControl>
<FormLabel flex={1}>{t('workflows.builder.component')}</FormLabel>
<Select value={settings.component} onChange={onChangeComponent} size="sm">
<option value="input">{t('workflows.builder.singleLine')}</option>
<option value="textarea">{t('workflows.builder.multiLine')}</option>
<option value="dropdown">{t('workflows.builder.dropdown')}</option>
</Select>
</FormControl>
);
});
NodeFieldElementStringSettings.displayName = 'NodeFieldElementStringSettings';

View File

@@ -1,15 +1,17 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, FormControl, FormHelperText } from '@invoke-ai/ui-library';
import { linkifyOptions, linkifySx } from 'common/components/linkify';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { InputFieldRenderer } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer';
import { useContainerContext } from 'features/nodes/components/sidePanel/builder/contexts';
import { NodeFieldElementLabel } from 'features/nodes/components/sidePanel/builder/NodeFieldElementLabel';
import { useInputFieldDescription } from 'features/nodes/hooks/useInputFieldDescription';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldDescriptionSafe } from 'features/nodes/hooks/useInputFieldDescriptionSafe';
import { useInputFieldTemplateOrThrow, useInputFieldTemplateSafe } from 'features/nodes/hooks/useInputFieldTemplate';
import type { NodeFieldElement } from 'features/nodes/types/workflow';
import { NODE_FIELD_CLASS_NAME } from 'features/nodes/types/workflow';
import Linkify from 'linkify-react';
import { memo, useMemo } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const sx: SystemStyleObject = {
'&[data-parent-layout="column"]': {
@@ -25,16 +27,23 @@ const sx: SystemStyleObject = {
},
};
const useFormatFallbackLabel = () => {
const { t } = useTranslation();
const formatLabel = useCallback((name: string) => t('nodes.unknownFieldEditWorkflowToFix_withName', { name }), [t]);
return formatLabel;
};
export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement }) => {
const { id, data } = el;
const { fieldIdentifier, showDescription } = data;
const description = useInputFieldDescription(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplate(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const containerCtx = useContainerContext();
const formatFallbackLabel = useFormatFallbackLabel();
const _description = useMemo(
() => description || fieldTemplate.description,
[description, fieldTemplate.description]
() => description || fieldTemplate?.description || '',
[description, fieldTemplate?.description]
);
return (
@@ -45,22 +54,45 @@ export const NodeFieldElementViewMode = memo(({ el }: { el: NodeFieldElement })
data-parent-layout={containerCtx.layout}
data-with-description={showDescription && !!_description}
>
<FormControl flex="1 1 0" orientation="vertical">
<NodeFieldElementLabel el={el} />
<Flex w="full" gap={4}>
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
settings={data.settings}
/>
</Flex>
{showDescription && _description && (
<FormHelperText sx={linkifySx}>
<Linkify options={linkifyOptions}>{_description}</Linkify>
</FormHelperText>
)}
</FormControl>
<InputFieldGate
nodeId={el.data.fieldIdentifier.nodeId}
fieldName={el.data.fieldIdentifier.fieldName}
formatLabel={formatFallbackLabel}
>
<NodeFieldElementViewModeContent el={el} />
</InputFieldGate>
</Flex>
);
});
NodeFieldElementViewMode.displayName = 'NodeFieldElementViewMode';
const NodeFieldElementViewModeContent = memo(({ el }: { el: NodeFieldElement }) => {
const { data } = el;
const { fieldIdentifier, showDescription } = data;
const description = useInputFieldDescriptionSafe(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(fieldIdentifier.nodeId, fieldIdentifier.fieldName);
const _description = useMemo(
() => description || fieldTemplate.description,
[description, fieldTemplate.description]
);
return (
<FormControl flex="1 1 0" orientation="vertical">
<NodeFieldElementLabel el={el} />
<Flex w="full" gap={4}>
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
settings={data.settings}
/>
</Flex>
{showDescription && _description && (
<FormHelperText sx={linkifySx}>
<Linkify options={linkifyOptions}>{_description}</Linkify>
</FormHelperText>
)}
</FormControl>
);
});
NodeFieldElementViewModeContent.displayName = 'NodeFieldElementViewModeContent';

View File

@@ -1,6 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useInputFieldInstance } from 'features/nodes/hooks/useInputFieldInstance';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { formElementAdded, selectFormRootElementId } from 'features/nodes/store/workflowSlice';
import { buildNodeFieldElement } from 'features/nodes/types/workflow';
import { useCallback } from 'react';
@@ -8,7 +8,7 @@ import { useCallback } from 'react';
export const useAddNodeFieldToRoot = (nodeId: string, fieldName: string) => {
const dispatch = useAppDispatch();
const rootElementId = useAppSelector(selectFormRootElementId);
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const field = useInputFieldInstance(nodeId, fieldName);
const addNodeFieldToRoot = useCallback(() => {

View File

@@ -1,7 +1,20 @@
import type { ButtonProps, CheckboxProps } from '@invoke-ai/ui-library';
import { Button, Checkbox, Collapse, Flex, Spacer, Text } from '@invoke-ai/ui-library';
import {
Box,
Button,
ButtonGroup,
Checkbox,
Collapse,
Flex,
Icon,
IconButton,
Spacer,
Text,
Tooltip,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import type { WorkflowLibraryView, WorkflowTagCategory } from 'features/nodes/store/workflowLibrarySlice';
import {
$workflowLibraryCategoriesOptions,
@@ -15,9 +28,10 @@ import {
} from 'features/nodes/store/workflowLibrarySlice';
import { NewWorkflowButton } from 'features/workflowLibrary/components/NewWorkflowButton';
import { UploadWorkflowButton } from 'features/workflowLibrary/components/UploadWorkflowButton';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold, PiUsersBold } from 'react-icons/pi';
import { PiArrowCounterClockwiseBold, PiStarFill, PiUsersBold } from 'react-icons/pi';
import { useDispatch } from 'react-redux';
import { useGetCountsByTagQuery } from 'services/api/endpoints/workflows';
@@ -27,7 +41,7 @@ export const WorkflowLibrarySideNav = () => {
const view = useAppSelector(selectWorkflowLibraryView);
return (
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={1}>
<Flex h="full" minH={0} overflow="hidden" flexDir="column" w={64} gap={0}>
<Flex flexDir="column" w="full" pb={2}>
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
</Flex>
@@ -48,7 +62,7 @@ export const WorkflowLibrarySideNav = () => {
)}
</Flex>
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
<WorkflowLibraryViewButton view="defaults">{t('workflows.browseWorkflows')}</WorkflowLibraryViewButton>
<BrowseWorkflowsButton />
<DefaultsViewCheckboxesCollapsible />
</Flex>
<Spacer />
@@ -58,39 +72,56 @@ export const WorkflowLibrarySideNav = () => {
);
};
const DefaultsViewCheckboxesCollapsible = memo(() => {
const BrowseWorkflowsButton = memo(() => {
const { t } = useTranslation();
const dispatch = useDispatch();
const tags = useAppSelector(selectWorkflowLibrarySelectedTags);
const tagCategoryOptions = useStore($workflowLibraryTagCategoriesOptions);
const view = useAppSelector(selectWorkflowLibraryView);
const dispatch = useAppDispatch();
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
const resetTags = useCallback(() => {
dispatch(workflowLibraryTagsReset());
}, [dispatch]);
if (view === 'defaults' && selectedTags.length > 0) {
return (
<ButtonGroup>
<WorkflowLibraryViewButton view="defaults" w="auto">
{t('workflows.browseWorkflows')}
</WorkflowLibraryViewButton>
<Tooltip label={t('workflows.deselectAll')}>
<IconButton
onClick={resetTags}
size="md"
aria-label={t('workflows.deselectAll')}
icon={<PiArrowCounterClockwiseBold size={12} />}
variant="ghost"
bg="base.700"
color="base.50"
/>
</Tooltip>
</ButtonGroup>
);
}
return <WorkflowLibraryViewButton view="defaults">{t('workflows.browseWorkflows')}</WorkflowLibraryViewButton>;
});
BrowseWorkflowsButton.displayName = 'BrowseWorkflowsButton';
const overlayscrollbarsOptions = getOverlayScrollbarsParams({ visibility: 'visible' }).options;
const DefaultsViewCheckboxesCollapsible = memo(() => {
const tagCategoryOptions = useStore($workflowLibraryTagCategoriesOptions);
const view = useAppSelector(selectWorkflowLibraryView);
return (
<Collapse in={view === 'defaults'}>
<Flex flexDir="column" gap={2} pl={4} py={2} overflow="hidden" h="100%" minH={0}>
<Button
isDisabled={tags.length === 0}
onClick={resetTags}
size="sm"
variant="link"
fontWeight="bold"
justifyContent="flex-start"
flexGrow={0}
flexShrink={0}
leftIcon={<PiArrowCounterClockwiseBold />}
h={8}
>
{t('workflows.deselectAll')}
</Button>
<Flex flexDir="column" gap={2} overflow="auto">
{tagCategoryOptions.map((tagCategory) => (
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
))}
</Flex>
<OverlayScrollbarsComponent style={overlayScrollbarsStyles} options={overlayscrollbarsOptions}>
<Flex flexDir="column" gap={2} overflow="auto">
{tagCategoryOptions.map((tagCategory) => (
<TagCategory key={tagCategory.categoryTKey} tagCategory={tagCategory} />
))}
</Flex>
</OverlayScrollbarsComponent>
</Flex>
</Collapse>
);
@@ -102,7 +133,7 @@ const useCountForIndividualTag = (tag: string) => {
const queryArg = useMemo(
() =>
({
tags: allTags,
tags: allTags.map((tag) => tag.label),
categories: ['default'],
}) satisfies Parameters<typeof useGetCountsByTagQuery>[0],
[allTags]
@@ -127,7 +158,7 @@ const useCountForTagCategory = (tagCategory: WorkflowTagCategory) => {
const queryArg = useMemo(
() =>
({
tags: allTags,
tags: allTags.map((tag) => tag.label),
categories: ['default'], // We only allow filtering by tag for default workflows
}) satisfies Parameters<typeof useGetCountsByTagQuery>[0],
[allTags]
@@ -140,7 +171,7 @@ const useCountForTagCategory = (tagCategory: WorkflowTagCategory) => {
return { count: 0 };
}
return {
count: tagCategory.tags.reduce((acc, tag) => acc + (data[tag] ?? 0), 0),
count: tagCategory.tags.reduce((acc, tag) => acc + (data[tag.label] ?? 0), 0),
};
},
}) satisfies Parameters<typeof useGetCountsByTagQuery>[1],
@@ -190,7 +221,7 @@ const TagCategory = memo(({ tagCategory }: { tagCategory: WorkflowTagCategory })
</Text>
<Flex flexDir="column" gap={2} pl={4}>
{tagCategory.tags.map((tag) => (
<TagCheckbox key={tag} tag={tag} />
<TagCheckbox key={tag.label} tag={tag} />
))}
</Flex>
</Flex>
@@ -198,14 +229,15 @@ const TagCategory = memo(({ tagCategory }: { tagCategory: WorkflowTagCategory })
});
TagCategory.displayName = 'TagCategory';
const TagCheckbox = memo(({ tag, ...rest }: CheckboxProps & { tag: string }) => {
const TagCheckbox = memo(({ tag, ...rest }: CheckboxProps & { tag: { label: string; recommended?: boolean } }) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const selectedTags = useAppSelector(selectWorkflowLibrarySelectedTags);
const isChecked = selectedTags.includes(tag);
const count = useCountForIndividualTag(tag);
const isChecked = selectedTags.includes(tag.label);
const count = useCountForIndividualTag(tag.label);
const onChange = useCallback(() => {
dispatch(workflowLibraryTagToggled(tag));
dispatch(workflowLibraryTagToggled(tag.label));
}, [dispatch, tag]);
if (count === 0) {
@@ -213,9 +245,17 @@ const TagCheckbox = memo(({ tag, ...rest }: CheckboxProps & { tag: string }) =>
}
return (
<Checkbox isChecked={isChecked} onChange={onChange} {...rest} flexShrink={0}>
<Text>{`${tag} (${count})`}</Text>
</Checkbox>
<Flex alignItems="center" gap={2}>
<Checkbox isChecked={isChecked} onChange={onChange} {...rest} flexShrink={0} />
<Text>{`${tag.label} (${count})`}</Text>
{tag.recommended && (
<Tooltip label={t('workflows.recommended')}>
<Box as="span" lineHeight={0}>
<Icon as={PiStarFill} boxSize={4} fill="invokeYellow.500" />
</Box>
</Tooltip>
)}
</Flex>
);
});
TagCheckbox.displayName = 'TagCheckbox';

View File

@@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useInputFieldTemplateOrThrow } from 'features/nodes/hooks/useInputFieldTemplate';
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
@@ -10,7 +10,7 @@ import { useCallback, useMemo } from 'react';
export const useInputFieldDefaultValue = (nodeId: string, fieldName: string) => {
const dispatch = useAppDispatch();
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTemplate = useInputFieldTemplateOrThrow(nodeId, fieldName);
const selectIsChanged = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {

View File

@@ -1,22 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
export const useInputFieldDescription = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return '';
}
return node?.data.inputs[fieldName]?.description ?? '';
}),
[fieldName, nodeId]
);
const notes = useAppSelector(selector);
return notes;
};

View File

@@ -0,0 +1,26 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
/**
* Gets the user-defined description of an input field for a given node.
*
* If the node doesn't exist or is not an invocation node, an empty string is returned.
*
* @param nodeId The ID of the node
* @param fieldName The name of the field
*/
export const useInputFieldDescriptionSafe = (nodeId: string, fieldName: string) => {
const selector = useMemo(
() =>
createSelector(
selectNodesSlice,
(nodes) => selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.description ?? ''
),
[fieldName, nodeId]
);
const description = useAppSelector(selector);
return description;
};

View File

@@ -1,18 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
export const useInputFieldLabel = (nodeId: string, fieldName: string): string => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => {
return selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.label ?? '';
}),
[fieldName, nodeId]
);
const label = useAppSelector(selector);
return label;
};

View File

@@ -0,0 +1,24 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectFieldInputInstanceSafe, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
/**
* Gets the user-defined label of an input field for a given node.
*
* If the node doesn't exist or is not an invocation node, an empty string is returned.
*
* @param nodeId The ID of the node
* @param fieldName The name of the field
*/
export const useInputFieldLabelSafe = (nodeId: string, fieldName: string): string => {
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) => selectFieldInputInstanceSafe(nodes, nodeId, fieldName)?.label ?? ''),
[fieldName, nodeId]
);
const label = useAppSelector(selector);
return label;
};

View File

@@ -5,7 +5,7 @@ import { $templates } from 'features/nodes/store/nodesSlice';
import { selectInvocationNodeSafe, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
export const useInputFieldName = (nodeId: string, fieldName: string) => {
export const useInputFieldNameSafe = (nodeId: string, fieldName: string) => {
const templates = useStore($templates);
const selector = useMemo(

View File

@@ -3,7 +3,16 @@ import type { FieldInputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useInputFieldTemplate = (nodeId: string, fieldName: string): FieldInputTemplate => {
/**
* Returns the template for a specific input field of a node.
*
* **Note:** This hook will throw an error if the template for the input field is not found.
*
* @param nodeId - The ID of the node.
* @param fieldName - The name of the input field.
* @throws Will throw an error if the template for the input field is not found.
*/
export const useInputFieldTemplateOrThrow = (nodeId: string, fieldName: string): FieldInputTemplate => {
const template = useNodeTemplate(nodeId);
const fieldTemplate = useMemo(() => {
const _fieldTemplate = template.inputs[fieldName];
@@ -12,3 +21,17 @@ export const useInputFieldTemplate = (nodeId: string, fieldName: string): FieldI
}, [fieldName, template.inputs]);
return fieldTemplate;
};
/**
* Returns the template for a specific input field of a node.
*
* **Note:** This function is a safe version of `useInputFieldTemplate` and will not throw an error if the template is not found.
*
* @param nodeId - The ID of the node.
* @param fieldName - The name of the input field.
*/
export const useInputFieldTemplateSafe = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
const template = useNodeTemplate(nodeId);
const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
return fieldTemplate;
};

View File

@@ -27,6 +27,7 @@ import type {
IntegerFieldValue,
IntegerGeneratorFieldValue,
IPAdapterModelFieldValue,
LLaVAModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
ModelIdentifierFieldValue,
@@ -64,6 +65,7 @@ import {
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
zIPAdapterModelFieldValue,
zLLaVAModelFieldValue,
zLoRAModelFieldValue,
zMainModelFieldValue,
zModelIdentifierFieldValue,
@@ -380,6 +382,9 @@ export const nodesSlice = createSlice({
fieldLoRAModelValueChanged: (state, action: FieldValueAction<LoRAModelFieldValue>) => {
fieldValueReducer(state, action, zLoRAModelFieldValue);
},
fieldLLaVAModelValueChanged: (state, action: FieldValueAction<LLaVAModelFieldValue>) => {
fieldValueReducer(state, action, zLLaVAModelFieldValue);
},
fieldControlNetModelValueChanged: (state, action: FieldValueAction<ControlNetModelFieldValue>) => {
fieldValueReducer(state, action, zControlNetModelFieldValue);
},
@@ -509,6 +514,7 @@ export const {
fieldSpandrelImageToImageModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldLLaVAModelValueChanged,
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldIntegerValueChanged,
@@ -633,6 +639,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldT2IAdapterModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldLLaVAModelValueChanged,
fieldMainModelValueChanged,
fieldIntegerValueChanged,
fieldIntegerCollectionValueChanged,

View File

@@ -92,12 +92,21 @@ export const selectWorkflowLibraryView = createWorkflowLibrarySelector(({ view }
export const DEFAULT_WORKFLOW_LIBRARY_CATEGORIES = ['user', 'default'] satisfies WorkflowCategory[];
export const $workflowLibraryCategoriesOptions = atom<WorkflowCategory[]>(DEFAULT_WORKFLOW_LIBRARY_CATEGORIES);
export type WorkflowTagCategory = { categoryTKey: string; tags: string[] };
export type WorkflowTagCategory = { categoryTKey: string; tags: Array<{ label: string; recommended?: boolean }> };
export const DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES: WorkflowTagCategory[] = [
{ categoryTKey: 'Industry', tags: ['Architecture', 'Fashion', 'Game Dev', 'Food'] },
{ categoryTKey: 'Common Tasks', tags: ['Upscaling', 'Text to Image', 'Image to Image'] },
{ categoryTKey: 'Model Architecture', tags: ['SD1.5', 'SDXL', 'SD3.5', 'FLUX'] },
{ categoryTKey: 'Tech Showcase', tags: ['Control', 'Reference Image'] },
{
categoryTKey: 'Industry',
tags: [{ label: 'Architecture' }, { label: 'Fashion' }, { label: 'Game Dev' }, { label: 'Food' }],
},
{
categoryTKey: 'Common Tasks',
tags: [{ label: 'Upscaling' }, { label: 'Text to Image' }, { label: 'Image to Image' }],
},
{
categoryTKey: 'Model Architecture',
tags: [{ label: 'SD1.5' }, { label: 'SDXL' }, { label: 'SD3.5' }, { label: 'FLUX' }],
},
{ categoryTKey: 'Tech Showcase', tags: [{ label: 'Control' }, { label: 'Reference Image' }] },
];
export const $workflowLibraryTagCategoriesOptions = atom<WorkflowTagCategory[]>(
DEFAULT_WORKFLOW_LIBRARY_TAG_CATEGORIES

View File

@@ -69,6 +69,7 @@ const zModelType = z.enum([
'main',
'vae',
'lora',
'llava_onevision',
'control_lora',
'controlnet',
't2i_adapter',

View File

@@ -189,6 +189,10 @@ const zLoRAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LoRAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zLLaVAModelFieldType = zFieldTypeBase.extend({
name: z.literal('LLaVAModelField'),
originalType: zStatelessFieldType.optional(),
});
const zControlNetModelFieldType = zFieldTypeBase.extend({
name: z.literal('ControlNetModelField'),
originalType: zStatelessFieldType.optional(),
@@ -273,6 +277,7 @@ const zStatefulFieldType = z.union([
zSDXLRefinerModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zLLaVAModelFieldType,
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
@@ -309,6 +314,7 @@ const modelFieldTypeNames = [
zSDXLRefinerModelFieldType.shape.name.value,
zVAEModelFieldType.shape.name.value,
zLoRAModelFieldType.shape.name.value,
zLLaVAModelFieldType.shape.name.value,
zControlNetModelFieldType.shape.name.value,
zIPAdapterModelFieldType.shape.name.value,
zT2IAdapterModelFieldType.shape.name.value,
@@ -891,6 +897,26 @@ export const isLoRAModelFieldInputInstance = buildInstanceTypeGuard(zLoRAModelFi
export const isLoRAModelFieldInputTemplate = buildTemplateTypeGuard<LoRAModelFieldInputTemplate>('LoRAModelField');
// #endregion
// #region LLaVAModelField
export const zLLaVAModelFieldValue = zModelIdentifierField.optional();
const zLLaVAModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zLLaVAModelFieldValue,
});
const zLLaVAModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zLLaVAModelFieldType,
originalType: zFieldType.optional(),
default: zLLaVAModelFieldValue,
});
const zLLaVAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zLLaVAModelFieldType,
});
export type LLaVAModelFieldValue = z.infer<typeof zLLaVAModelFieldValue>;
export type LLaVAModelFieldInputInstance = z.infer<typeof zLLaVAModelFieldInputInstance>;
export type LLaVAModelFieldInputTemplate = z.infer<typeof zLLaVAModelFieldInputTemplate>;
export const isLLaVAModelFieldInputInstance = buildInstanceTypeGuard(zLLaVAModelFieldInputInstance);
export const isLLaVAModelFieldInputTemplate = buildTemplateTypeGuard<LLaVAModelFieldInputTemplate>('LLaVAModelField');
// #endregion
// #region ControlNetModelField
export const zControlNetModelFieldValue = zModelIdentifierField.optional();
const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({
@@ -1739,6 +1765,7 @@ export const zStatefulFieldValue = z.union([
zSDXLRefinerModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
zLLaVAModelFieldValue,
zControlNetModelFieldValue,
zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue,
@@ -1785,6 +1812,7 @@ const zStatefulFieldInputInstance = z.union([
zSDXLRefinerModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zLLaVAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance,
@@ -1825,6 +1853,7 @@ const zStatefulFieldInputTemplate = z.union([
zSDXLRefinerModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
zLoRAModelFieldInputTemplate,
zLLaVAModelFieldInputTemplate,
zControlNetModelFieldInputTemplate,
zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate,
@@ -1871,6 +1900,7 @@ const zStatefulFieldOutputTemplate = z.union([
zSDXLRefinerModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,
zLoRAModelFieldOutputTemplate,
zLLaVAModelFieldOutputTemplate,
zControlNetModelFieldOutputTemplate,
zIPAdapterModelFieldOutputTemplate,
zT2IAdapterModelFieldOutputTemplate,

View File

@@ -91,14 +91,37 @@ const zNodeFieldIntegerSettings = z.object({
export type NodeFieldIntegerSettings = z.infer<typeof zNodeFieldIntegerSettings>;
export const getIntegerFieldSettingsDefaults = (): NodeFieldIntegerSettings => zNodeFieldIntegerSettings.parse({});
export const zStringComponent = z.enum(['input', 'textarea']);
const zStringOption = z
.object({
label: z.string(),
value: z.string(),
})
.default({ label: '', value: '' });
type StringOption = z.infer<typeof zStringOption>;
export const getDefaultStringOption = (): StringOption => ({ label: '', value: '' });
export const zStringComponent = z.enum(['input', 'textarea', 'dropdown']);
const STRING_FIELD_CONFIG_TYPE = 'string-field-config';
const zNodeFieldStringSettings = z.object({
const zNodeFieldStringInputSettings = z.object({
type: z.literal(STRING_FIELD_CONFIG_TYPE).default(STRING_FIELD_CONFIG_TYPE),
component: zStringComponent.default('input'),
component: z.literal('input').default('input'),
});
const zNodeFieldStringTextareaSettings = z.object({
type: z.literal(STRING_FIELD_CONFIG_TYPE).default(STRING_FIELD_CONFIG_TYPE),
component: z.literal('textarea').default('textarea'),
});
const zNodeFieldStringDropdownSettings = z.object({
type: z.literal(STRING_FIELD_CONFIG_TYPE).default(STRING_FIELD_CONFIG_TYPE),
component: z.literal('dropdown').default('dropdown'),
options: z.array(zStringOption),
});
export type NodeFieldStringDropdownSettings = z.infer<typeof zNodeFieldStringDropdownSettings>;
const zNodeFieldStringSettings = z.union([
zNodeFieldStringInputSettings,
zNodeFieldStringTextareaSettings,
zNodeFieldStringDropdownSettings,
]);
export type NodeFieldStringSettings = z.infer<typeof zNodeFieldStringSettings>;
export const getStringFieldSettingsDefaults = (): NodeFieldStringSettings => zNodeFieldStringSettings.parse({});
export const getStringFieldSettingsDefaults = (): NodeFieldStringSettings => zNodeFieldStringInputSettings.parse({});
const zNodeFieldData = z.object({
fieldIdentifier: zFieldIdentifier,

View File

@@ -4,9 +4,8 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types';
import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
import type { ImageDTO, Invocation } from 'services/api/types';
import type { ImageDTO, Invocation, MainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
@@ -17,7 +16,7 @@ type AddControlNetsArg = {
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
model: MainModelConfig;
};
type AddControlNetsResult = {
@@ -66,7 +65,7 @@ type AddT2IAdaptersArg = {
g: Graph;
rect: Rect;
collector: Invocation<'collect'>;
model: ParameterModel;
model: MainModelConfig;
};
type AddT2IAdaptersResult = {
@@ -114,7 +113,7 @@ type AddControlLoRAArg = {
entities: CanvasControlLayerState[];
g: Graph;
rect: Rect;
model: ParameterModel;
model: MainModelConfig;
denoise: Invocation<'flux_denoise'>;
};
@@ -129,9 +128,9 @@ export const addControlLoRA = async ({ manager, entities, g, rect, model, denois
// No valid control LoRA found
return;
}
if (validControlLayers.length > 1) {
throw new Error('Cannot add more than one FLUX control LoRA.');
}
assert(model.variant !== 'inpaint', 'FLUX Control LoRA is not compatible with FLUX Fill.');
assert(validControlLayers.length <= 1, 'Cannot add more than one FLUX control LoRA.');
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(validControlLayer.id);

View File

@@ -0,0 +1,204 @@
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types';
type AddFLUXFillArg = {
state: RootState;
g: Graph;
manager: CanvasManager;
l2i: Invocation<'flux_vae_decode'>;
denoise: Invocation<'flux_denoise'>;
originalSize: Dimensions;
scaledSize: Dimensions;
};
export const addFLUXFill = async ({
state,
g,
manager,
l2i,
denoise,
originalSize,
scaledSize,
}: AddFLUXFillArg): Promise<Invocation<'invokeai_img_blend' | 'apply_mask_to_image'>> => {
// FLUX Fill always fully denoises
denoise.denoising_start = 0;
denoise.denoising_end = 1;
const params = selectParamsSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const canvas = selectCanvasSlice(state);
const { bbox } = canvas;
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const initialImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const inpaintMaskAdapters = manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
const maskImage = await manager.compositor.getCompositeImageDTO(inpaintMaskAdapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const fluxFill = g.addNode({ type: 'flux_fill', id: getPrefixedId('flux_fill') });
const needsScaleBeforeProcessing = !isEqual(scaledSize, originalSize);
if (needsScaleBeforeProcessing) {
// Scale before processing requires some resizing
// Combine the inpaint mask and the initial image's alpha channel into a single mask
const maskAlphaToMask = g.addNode({
id: getPrefixedId('alpha_to_mask'),
type: 'tomask',
image: { image_name: maskImage.image_name },
invert: !canvasSettings.preserveMask,
});
const initialImageAlphaToMask = g.addNode({
id: getPrefixedId('image_alpha_to_mask'),
type: 'tomask',
image: { image_name: initialImage.image_name },
});
const maskCombine = g.addNode({
id: getPrefixedId('mask_combine'),
type: 'mask_combine',
});
g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1');
g.addEdge(initialImageAlphaToMask, 'image', maskCombine, 'mask2');
// Resize the combined and initial image to the scaled size
const resizeInputMaskToScaledSize = g.addNode({
id: getPrefixedId('resize_mask_to_scaled_size'),
type: 'img_resize',
...scaledSize,
});
g.addEdge(maskCombine, 'image', resizeInputMaskToScaledSize, 'image');
const alphaMaskToTensorMask = g.addNode({
type: 'image_mask_to_tensor',
id: getPrefixedId('image_mask_to_tensor'),
});
g.addEdge(resizeInputMaskToScaledSize, 'image', alphaMaskToTensorMask, 'image');
g.addEdge(alphaMaskToTensorMask, 'mask', fluxFill, 'mask');
// Resize the initial image to the scaled size and add to the FLUX Fill node
const resizeInputImageToScaledSize = g.addNode({
id: getPrefixedId('resize_image_to_scaled_size'),
type: 'img_resize',
image: { image_name: initialImage.image_name },
...scaledSize,
});
g.addEdge(resizeInputImageToScaledSize, 'image', fluxFill, 'image');
// Provide the FLUX Fill conditioning w/ image and mask to the denoise node
g.addEdge(fluxFill, 'fill_cond', denoise, 'fill_conditioning');
// Resize the output image back to the original size
const resizeOutputImageToOriginalSize = g.addNode({
id: getPrefixedId('resize_image_to_original_size'),
type: 'img_resize',
...originalSize,
});
// After denoising, resize the image and mask back to original size
g.addEdge(l2i, 'image', resizeOutputImageToOriginalSize, 'image');
const expandMask = g.addNode({
type: 'expand_mask_with_fade',
id: getPrefixedId('expand_mask_with_fade'),
fade_size_px: params.maskBlur,
});
g.addEdge(maskCombine, 'image', expandMask, 'mask');
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
// to canvas but not outputting only masked regions
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
const imageLayerBlend = g.addNode({
type: 'invokeai_img_blend',
id: getPrefixedId('image_layer_blend'),
layer_base: { image_name: initialImage.image_name },
});
g.addEdge(resizeOutputImageToOriginalSize, 'image', imageLayerBlend, 'layer_upper');
g.addEdge(expandMask, 'image', imageLayerBlend, 'mask');
return imageLayerBlend;
} else {
// Otherwise, just apply the mask
const applyMaskToImage = g.addNode({
type: 'apply_mask_to_image',
id: getPrefixedId('apply_mask_to_image'),
invert_mask: true,
});
g.addEdge(expandMask, 'image', applyMaskToImage, 'mask');
g.addEdge(resizeOutputImageToOriginalSize, 'image', applyMaskToImage, 'image');
return applyMaskToImage;
}
} else {
// Combine the inpaint mask and the initial image's alpha channel into a single mask
const maskAlphaToMask = g.addNode({
id: getPrefixedId('alpha_to_mask'),
type: 'tomask',
image: { image_name: maskImage.image_name },
invert: !canvasSettings.preserveMask,
});
const initialImageAlphaToMask = g.addNode({
id: getPrefixedId('image_alpha_to_mask'),
type: 'tomask',
image: { image_name: initialImage.image_name },
});
const maskCombine = g.addNode({
id: getPrefixedId('mask_combine'),
type: 'mask_combine',
});
g.addEdge(maskAlphaToMask, 'image', maskCombine, 'mask1');
g.addEdge(initialImageAlphaToMask, 'image', maskCombine, 'mask2');
const alphaMaskToTensorMask = g.addNode({
type: 'image_mask_to_tensor',
id: getPrefixedId('image_mask_to_tensor'),
});
g.addEdge(maskCombine, 'image', alphaMaskToTensorMask, 'image');
g.addEdge(alphaMaskToTensorMask, 'mask', fluxFill, 'mask');
fluxFill.image = { image_name: initialImage.image_name };
g.addEdge(fluxFill, 'fill_cond', denoise, 'fill_conditioning');
const expandMask = g.addNode({
type: 'expand_mask_with_fade',
id: getPrefixedId('expand_mask_with_fade'),
fade_size_px: params.maskBlur,
});
g.addEdge(maskCombine, 'image', expandMask, 'mask');
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
// to canvas but not outputting only masked regions
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
const imageLayerBlend = g.addNode({
type: 'invokeai_img_blend',
id: getPrefixedId('image_layer_blend'),
layer_base: { image_name: initialImage.image_name },
});
g.addEdge(l2i, 'image', imageLayerBlend, 'layer_upper');
g.addEdge(expandMask, 'image', imageLayerBlend, 'mask');
return imageLayerBlend;
} else {
// Otherwise, just apply the mask
const applyMaskToImage = g.addNode({
type: 'apply_mask_to_image',
id: getPrefixedId('apply_mask_to_image'),
invert_mask: true,
});
g.addEdge(expandMask, 'image', applyMaskToImage, 'mask');
g.addEdge(l2i, 'image', applyMaskToImage, 'image');
return applyMaskToImage;
}
}
};

View File

@@ -2,8 +2,7 @@ import type { CanvasReferenceImageState, FLUXReduxConfig } from 'features/contro
import { isFLUXReduxConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { Invocation } from 'services/api/types';
import type { Invocation, MainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
type AddFLUXReduxResult = {
@@ -14,7 +13,7 @@ type AddFLUXReduxArg = {
entities: CanvasReferenceImageState[];
g: Graph;
collector: Invocation<'collect'>;
model: ParameterModel;
model: MainModelConfig;
};
export const addFLUXReduxes = ({ entities, g, collector, model }: AddFLUXReduxArg): AddFLUXReduxResult => {

View File

@@ -5,8 +5,7 @@ import {
} from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import type { Invocation } from 'services/api/types';
import type { Invocation, MainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
type AddIPAdaptersResult = {
@@ -17,7 +16,7 @@ type AddIPAdaptersArg = {
entities: CanvasReferenceImageState[];
g: Graph;
collector: Invocation<'collect'>;
model: ParameterModel;
model: MainModelConfig;
};
export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {

View File

@@ -39,7 +39,7 @@ export const addInpaint = async ({
scaledSize,
denoising_start,
fp32,
}: AddInpaintArg): Promise<Invocation<'canvas_v2_mask_and_crop' | 'img_resize'>> => {
}: AddInpaintArg): Promise<Invocation<'invokeai_img_blend' | 'apply_mask_to_image'>> => {
denoise.denoising_start = denoising_start;
const params = selectParamsSlice(state);
@@ -60,7 +60,9 @@ export const addInpaint = async ({
silent: true,
});
if (!isEqual(scaledSize, originalSize)) {
const needsScaleBeforeProcessing = !isEqual(scaledSize, originalSize);
if (needsScaleBeforeProcessing) {
// Scale before processing requires some resizing
const i2l = g.addNode({
id: i2lNodeType,
@@ -104,10 +106,10 @@ export const addInpaint = async ({
edge_radius: params.canvasCoherenceEdgeSize,
fp32,
});
const canvasPasteBack = g.addNode({
id: getPrefixedId('canvas_v2_mask_and_crop'),
type: 'canvas_v2_mask_and_crop',
mask_blur: params.maskBlur,
const expandMask = g.addNode({
type: 'expand_mask_with_fade',
id: getPrefixedId('expand_mask_with_fade'),
fade_size_px: params.maskBlur,
});
// Resize initial image and mask to scaled size, feed into to gradient mask
@@ -127,19 +129,32 @@ export const addInpaint = async ({
// After denoising, resize the image and mask back to original size
g.addEdge(l2i, 'image', resizeImageToOriginalSize, 'image');
g.addEdge(createGradientMask, 'expanded_mask_area', resizeMaskToOriginalSize, 'image');
// Finally, paste the generated masked image back onto the original image
g.addEdge(resizeImageToOriginalSize, 'image', canvasPasteBack, 'generated_image');
g.addEdge(resizeMaskToOriginalSize, 'image', canvasPasteBack, 'mask');
g.addEdge(createGradientMask, 'expanded_mask_area', expandMask, 'mask');
g.addEdge(expandMask, 'image', resizeMaskToOriginalSize, 'image');
// After denoising, resize the image and mask back to original size
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
// to canvas but not outputting only masked regions
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
canvasPasteBack.source_image = { image_name: initialImage.image_name };
const imageLayerBlend = g.addNode({
type: 'invokeai_img_blend',
id: getPrefixedId('image_layer_blend'),
layer_base: { image_name: initialImage.image_name },
});
g.addEdge(resizeImageToOriginalSize, 'image', imageLayerBlend, 'layer_upper');
g.addEdge(resizeMaskToOriginalSize, 'image', imageLayerBlend, 'mask');
return imageLayerBlend;
} else {
// Otherwise, just apply the mask
const applyMaskToImage = g.addNode({
type: 'apply_mask_to_image',
id: getPrefixedId('apply_mask_to_image'),
invert_mask: true,
});
g.addEdge(resizeMaskToOriginalSize, 'image', applyMaskToImage, 'mask');
g.addEdge(resizeImageToOriginalSize, 'image', applyMaskToImage, 'image');
return applyMaskToImage;
}
return canvasPasteBack;
} else {
// No scale before processing, much simpler
const i2l = g.addNode({
@@ -164,11 +179,6 @@ export const addInpaint = async ({
fp32,
image: { image_name: initialImage.image_name },
});
const canvasPasteBack = g.addNode({
id: getPrefixedId('canvas_v2_mask_and_crop'),
type: 'canvas_v2_mask_and_crop',
mask_blur: params.maskBlur,
});
g.addEdge(alphaToMask, 'image', createGradientMask, 'mask');
g.addEdge(i2l, 'latents', denoise, 'latents');
@@ -178,16 +188,35 @@ export const addInpaint = async ({
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
g.addEdge(createGradientMask, 'expanded_mask_area', canvasPasteBack, 'mask');
g.addEdge(l2i, 'image', canvasPasteBack, 'generated_image');
const expandMask = g.addNode({
type: 'expand_mask_with_fade',
id: getPrefixedId('expand_mask_with_fade'),
fade_size_px: params.maskBlur,
});
g.addEdge(createGradientMask, 'expanded_mask_area', expandMask, 'mask');
// Do the paste back if we are sending to gallery (in which case we want to see the full image), or if we are sending
// to canvas but not outputting only masked regions
if (!canvasSettings.sendToCanvas || !canvasSettings.outputOnlyMaskedRegions) {
canvasPasteBack.source_image = { image_name: initialImage.image_name };
const imageLayerBlend = g.addNode({
type: 'invokeai_img_blend',
id: getPrefixedId('image_layer_blend'),
layer_base: { image_name: initialImage.image_name },
});
g.addEdge(l2i, 'image', imageLayerBlend, 'layer_upper');
g.addEdge(expandMask, 'image', imageLayerBlend, 'mask');
return imageLayerBlend;
} else {
// Otherwise, just apply the mask
const applyMaskToImage = g.addNode({
type: 'apply_mask_to_image',
id: getPrefixedId('apply_mask_to_image'),
invert_mask: true,
});
g.addEdge(expandMask, 'image', applyMaskToImage, 'mask');
g.addEdge(l2i, 'image', applyMaskToImage, 'image');
return applyMaskToImage;
}
return canvasPasteBack;
}
};

Some files were not shown because too many files have changed in this diff Show More