Compare commits

...

201 Commits

Author SHA1 Message Date
psychedelicious
b897ca18ce feat(nodes): wip outputs_service 2023-05-17 00:42:27 +10:00
psychedelicious
cc21fb216c chore(ui): clean up GalleryPanel 2023-05-16 10:43:26 +10:00
psychedelicious
6fe62a2705 feat(ui): sampler --> scheduler 2023-05-16 10:40:26 +10:00
psychedelicious
da87378713 chore(ui): regen api client 2023-05-16 10:39:40 +10:00
psychedelicious
b6f5267385 chore(ui): clean up generationSlice 2023-05-16 10:21:18 +10:00
psychedelicious
f9e78d3c64 chore(ui): clean up gallerySlice 2023-05-16 10:16:36 +10:00
psychedelicious
b7b5bd1b46 chore(ui): clean up uiSlice 2023-05-16 09:57:19 +10:00
psychedelicious
9a3727d3ad chore(ui): clean up systemSlice 2023-05-16 09:48:58 +10:00
psychedelicious
d68c14516c chore(ui): clean up persist denylists 2023-05-16 09:46:03 +10:00
psychedelicious
9f4d39aa42 chore(ui): clean up modelSlice 2023-05-16 09:45:49 +10:00
blessedcoolant
84b801d88f ui: restore canvas and upload functionality (#3414)
- refactor image uploading, fix init image upload button 
- refactor toast and hotkey hooks into logical components
- restore canvas save/download/copy/merge functionality
- clean up unused files and packages
- fix canvas rendering issue resulting from fractional stage coords
2023-05-16 02:23:39 +12:00
blessedcoolant
2fc70c509b Merge branch 'main' into feat/ui/fix-uploading 2023-05-16 02:20:59 +12:00
Lincoln Stein
34fb1c4b19 make conditioning.py work with compel 1.1.5 (#3383)
This PR fixes the ValueError issue that was preventing all prompts from
working.
2023-05-15 09:46:04 -04:00
Lincoln Stein
80bdd550cf Merge branch 'main' into lstein/bugfix/compel 2023-05-15 09:25:21 -04:00
psychedelicious
2359b92b46 chore(ui): tidy unused component ref 2023-05-15 22:58:15 +10:00
psychedelicious
a404fb2d32 docs(ui): update PACKAGE_SCRIPTS.md 2023-05-15 22:49:28 +10:00
psychedelicious
513eb11616 chore(ui): clean up unused files/packages 2023-05-15 22:48:06 +10:00
psychedelicious
d2c9140e69 feat(ui): restore save/copy/download/merge functionality 2023-05-15 22:21:03 +10:00
psychedelicious
d95fe5925a feat(ui): restore image post-upload actions
eg set init image if on img2img when uploading
2023-05-15 18:52:48 +10:00
psychedelicious
835922ea8f fix(ui): floor canvas coords to prevent partial pixel offset rendering issues 2023-05-15 18:50:34 +10:00
psychedelicious
e1e5266fc3 feat(ui): refactor base image uploading logic 2023-05-15 17:45:05 +10:00
psychedelicious
5e4457445f feat(ui): make toast/hotkey into logical components 2023-05-15 15:25:27 +10:00
psychedelicious
0221ca8f49 fix(ui): use cloned canvas for retrieving dataURL/Blobs 2023-05-15 13:54:30 +10:00
Eugene Brodsky
cf36e4029e fix(ui): fix syntax error in the logo component flexbox 2023-05-15 08:24:33 +10:00
Eugene Brodsky
c8a98a9a22 Merge branch 'main' into lstein/bugfix/compel 2023-05-14 14:43:18 -04:00
blessedcoolant
38ecca9362 Logging Improvements (#3401)
This PR improves the logging module a tad bit along with the
documentation.

**New Look:**


![WindowsTerminal_XaijwCqFpo](https://github.com/invoke-ai/InvokeAI/assets/54517381/49a97411-1927-4a49-80ff-f4d9665be55f)

## Usage

**General Logger**

InvokeAI has a module level logger. You can call it this way.

In this below example, you will use the default logger `InvokeAI` and
all your messages will be logged under that name.

```python

from invokeai.backend.util.logging import logger

logger.critical("CriticalMessage") // In Bold Red
logger.error("Info Message") // In Red
logger.warning("Info Message") // In Yellow
logger.info("Info Message") // In Grey 
logger.debug("Debug Message") // In Grey
```

Results:

```
[12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red]
[12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red]
[12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow]
[12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey]
[12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey]
```

**Custom Logger**

If you want to use a custom logger for your module, you can import it
the following way.

```python

from invokeai.backend.util.logging import logging
logger = logging.getLogger(name='Model Manager')

logger.critical("CriticalMessage") // In Bold Red
logger.error("Info Message") // In Red
logger.warning("Info Message") // In Yellow
logger.info("Info Message") // In Grey 
logger.debug("Debug Message") // In Grey
```

Results:

```
[12-05-2023 20]::[Model Manager]::CRITICAL --> This is an info message [In Bold Red]
[12-05-2023 20]::[Model Manager]::ERROR --> This is an info message [In Red]
[12-05-2023 20]::[Model Manager]::WARNING --> This is an info message [In Yellow]
[12-05-2023 20]::[Model Manager]::INFO --> This is an info message [In Grey]
[12-05-2023 20]::[Model Manager]::DEBUG --> This is an info message [In Grey]
```

**When to use custom logger?**

It is recommended to use a custom logger if your module is not a part of
base InvokeAI. For example: custom extensions / nodes.
2023-05-15 02:18:20 +12:00
blessedcoolant
c4681774a5 Merge branch 'main' into logging-facelift 2023-05-15 02:08:29 +12:00
Damian Stewart
050add58d2 fix getting conditionings 2023-05-14 12:20:54 +02:00
blessedcoolant
3d60c958c7 ui: commercial fixes (#3409)
minor commercial fixes
2023-05-14 20:44:06 +12:00
psychedelicious
f5df150097 feat(ui): add callback to signal app is ready
needed for commercial
2023-05-14 18:42:15 +10:00
psychedelicious
dac82adb5b fix(ui): make logo component non-selectable 2023-05-14 18:41:11 +10:00
Eugene
b72c9787a9 Revert "comment out customer_attention_context"
This reverts commit 8f8cd90787.

Due to NameError: name 'options' is not defined
2023-05-14 00:37:55 -04:00
Eugene Brodsky
2623941d91 Merge branch 'main' into lstein/bugfix/compel 2023-05-13 22:23:59 -04:00
psychedelicious
d3a7fea939 Revert "fix: Rework the layout of the parameters scrollbar"
This reverts commit 6f1fc397f7.
2023-05-14 11:45:08 +10:00
psychedelicious
5a7b687c84 fix(ui): add missing packages 2023-05-14 11:45:08 +10:00
psychedelicious
0020457fc7 fix(ui): tweak settings scheduler styling 2023-05-14 11:45:08 +10:00
psychedelicious
658b556544 feat(ui): IAICustomSelect v2, implement for scheduler & model 2023-05-14 11:45:08 +10:00
psychedelicious
37da0fc075 feat(ui): IAICustomSelect v1 2023-05-14 11:45:08 +10:00
psychedelicious
6d3e8507cc fix(ui): fix "no image" fallbacks 2023-05-14 11:45:08 +10:00
blessedcoolant
0e9470503f fix: Rework the layout of the parameters scrollbar 2023-05-14 11:45:08 +10:00
blessedcoolant
d2ebc6741b feat: Add setting to hide / display schedulers 2023-05-14 11:45:08 +10:00
blessedcoolant
026d3260b4 Add Heun Karras Scheduler 2023-05-14 11:45:08 +10:00
blessedcoolant
78533714e3 Merge branch 'main' into logging-facelift 2023-05-14 09:07:51 +12:00
blessedcoolant
691e1bf829 Make debug messages cyan/blue 2023-05-14 09:06:57 +12:00
Mary Hipp
47a088d685 rehydrate selectedImage URL when results and uploads are fetched 2023-05-13 09:48:38 +10:00
Eugene Brodsky
63db3fc22f reduce queue check interval to 0.5s 2023-05-12 17:54:26 -04:00
Eugene
ad0bb3f61a fix: queue error should not crash InvocationProcessor
1. if retrieving an item from the queue raises an exception, the
   InvocationProcessor thread crashes, but the API continues running in
   a non-functional state. This fixes the issue
2. when there are no items in the queue, sleep 1 second before checking
   again.
3. Also ensures the thread isn't crashed if an exception is raised from
   invoker, and emits the error event

Intentionally using base Exceptions because for now we don't know which
specific exception to expect.

Fixes (sort of)? #3222
2023-05-12 17:54:26 -04:00
Kent Keirsey
8f8cd90787 comment out customer_attention_context 2023-05-12 13:59:00 -04:00
blessedcoolant
d796ea7bec feat: Logging Improvements 2023-05-13 02:13:49 +12:00
psychedelicious
e5b7dd63e9 fix(nodes): temporarily disable librarygraphs
- Do not retrieve graph from DB until we resolve the issue of changing node schemas causing application to fail to start up due to invalid graphs
2023-05-12 22:33:49 +10:00
Eugene Brodsky
af060188bd Merge branch 'main' into lstein/bugfix/compel 2023-05-12 08:22:18 -04:00
blessedcoolant
4270e7ae25 Feat/ui/improve-language (#3399) 2023-05-12 23:32:50 +12:00
psychedelicious
60a565d7de feat(ui): use chakra menu for theme changer 2023-05-12 20:04:29 +10:00
psychedelicious
78cf70eaad fix(ui): tweak lang picker style 2023-05-12 20:04:10 +10:00
psychedelicious
eebaa50710 fix(ui): fix language picker tooltip 2023-05-12 19:52:21 +10:00
psychedelicious
7d582553f2 feat(ui): use chakra menu for language picker 2023-05-12 19:50:34 +10:00
psychedelicious
4d6eea7e81 feat(ui): store language in redux 2023-05-12 19:35:03 +10:00
blessedcoolant
f44593331d ui: misc fixes (#3398)
- do not show canvas intermediates in gallery
- do not show progress image in uploads gallery category
- use custom dark mode `localStorage` key (prevents collision with
commercial)
- use variable font (reduce bundle size by factor of 10)
- change how custom headers are used
- use style injection for building package
- fix tab icon sizes
2023-05-12 21:00:47 +12:00
psychedelicious
3d9ecbf3c7 fix(ui): add missing package 2023-05-12 18:55:59 +10:00
psychedelicious
032aa1d59c fix(ui): excise most zIndexs
our stacking contexts are accurate, `zIndex` isn't needed
2023-05-12 18:50:54 +10:00
psychedelicious
35e0863bdb fix(ui): fix tab icon sizes 2023-05-12 17:56:18 +10:00
psychedelicious
14070d674e build(ui): add style injection plugin
when building for package, CSS is all in JS files. when used as a package, it is then injected into the page. bit of a hack to missing CSS in commercial product
2023-05-12 17:56:18 +10:00
psychedelicious
108ce06c62 feat(ui): change custom header to be a prop instead of children 2023-05-12 17:56:18 +10:00
psychedelicious
da364f3444 feat(ui): use variable font
reduces package build's CSS by an order of magnitude
2023-05-12 17:56:18 +10:00
psychedelicious
df5ba75c14 feat(ui): use custom dark mode localStorage key 2023-05-12 17:56:18 +10:00
psychedelicious
e4fb9cb33f chore(ui): regen api client 2023-05-12 17:56:18 +10:00
psychedelicious
65b527eb20 fix(ui): do not show progress images in uploads gallery category 2023-05-12 17:56:18 +10:00
psychedelicious
7dc9d18052 fix(ui): do not show intermediates uploads in gallery 2023-05-12 17:56:18 +10:00
blessedcoolant
5013a4b9f3 feat(ui): expand config options (#3393)
now may disable individual SD features eg Noise, Variation, etc - stuff
which is not ready for consumption in commercial.
2023-05-12 16:10:17 +12:00
blessedcoolant
f929359322 Merge branch 'main' into feat/ui/expand-config 2023-05-12 16:06:31 +12:00
blessedcoolant
6522c71971 feat(nodes): add RandomIntInvocation (#3390)
just outputs a single random int
2023-05-12 16:06:06 +12:00
blessedcoolant
9c1e65f3a3 Merge branch 'main' into feat/nodes/add-randomintinvocation 2023-05-12 15:56:41 +12:00
psychedelicious
ebec200ba6 Remove unused import 2023-05-12 13:56:02 +10:00
blessedcoolant
e559730b6e feat(nodes): add w/h to latents outputs (#3389)
This reduces the number of nodes needed when working with latents (ie
fewer plain integer value nodes)

Also correct a few mistakes in the fields
2023-05-12 15:40:46 +12:00
blessedcoolant
0acb8ed85d Merge branch 'main' into feat/nodes/add-w-h-latentsoutput 2023-05-12 15:23:29 +12:00
blessedcoolant
8c1c9cd702 Merge branch 'main' into feat/nodes/add-randomintinvocation 2023-05-12 15:21:49 +12:00
blessedcoolant
0ece4686aa fix(nodes): remove Optionals on ImageOutputs (#3392) 2023-05-12 15:21:42 +12:00
blessedcoolant
af95cef7f9 Merge branch 'main' into fix/nodes/fix-imageoutput-optionals 2023-05-12 15:08:19 +12:00
blessedcoolant
1eca7a918a feat(ui): make core parameters layout consistent (#3394) 2023-05-12 15:08:07 +12:00
blessedcoolant
9e6b958023 Merge branch 'main' into feat/ui/consistent-param-layout 2023-05-12 15:06:16 +12:00
blessedcoolant
f7b99d93ae docs(ui): update ui readme (#3396) 2023-05-12 15:05:55 +12:00
blessedcoolant
85d03dcd90 Merge branch 'main' into docs/ui/update-ui-readme 2023-05-12 15:04:12 +12:00
blessedcoolant
032555bcfe fix(model manager): fix string formatting error on model checksum timer (#3397)
The error occurs when loading a model for the first time. (or after
removing its checksum file, probably.)
2023-05-12 15:04:01 +12:00
Kevin Turner
4caa1f19b2 fix(model manager): fix string formatting error on model checksum timer 2023-05-11 19:06:02 -07:00
Lincoln Stein
95d4bd3012 Merge branch 'lstein/bugfix/compel' of github.com:invoke-ai/InvokeAI into lstein/bugfix/compel 2023-05-11 21:13:29 -04:00
Lincoln Stein
037078c8ad make InvokeAIDiffuserComponent.custom_attention_control a classmethod 2023-05-11 21:13:18 -04:00
psychedelicious
6de2f66b50 docs(ui): update ui readme 2023-05-12 11:11:59 +10:00
blessedcoolant
cd7b248eda Add UniPC / Euler Karras / DPMPP_2 Karras / DEIS / DDPM Schedulers (#3388)
**Features:**

- Add UniPC Scheduler
- Add Euler Karras Scheduler
- Add DPMPP_2 Karras Scheduler
- Add DEIS Scheduler
- Add DDPM Scheduler

**Other:**

- Renamed schedulers to their accurate names: _a = Ancestral, _k =
Karras
- Fix scheduler not defaulting correctly to DDIM.
- Code split SCHEDULER_MAP so its consistently loaded from the same
place.

**Known Bugs:**

- dpmpp_2s not working in img2img for denoising values < 0.8 ==> // This
seems to be an upstream bug. I've disabled it in img2img and canvas
until the upstream bug is fixed.
https://github.com/huggingface/diffusers/issues/1866
2023-05-12 09:06:22 +12:00
blessedcoolant
6d8c077f4e Merge branch 'main' into unipc-sched 2023-05-12 05:59:13 +12:00
blessedcoolant
97127e560e Disable dpmpp_2s in img2img & unifiedCanvas
... until upstream bug is fixed.
2023-05-12 04:51:58 +12:00
Sergey Borisov
27dc07d95a Set zero eta by default(fix ddim scheduler error) 2023-05-11 18:49:27 +03:00
blessedcoolant
f7dc171c4f Rename default schedulers across the app 2023-05-12 03:44:20 +12:00
blessedcoolant
4b957edfec Add DDPM Scheduler 2023-05-12 03:18:34 +12:00
blessedcoolant
46ca7718d9 Add DEIS Scheduler 2023-05-12 03:10:30 +12:00
blessedcoolant
b928d7a6e6 Change scheduler names to be accurate
_a = Ancestral
_k = Karras
2023-05-12 02:59:43 +12:00
blessedcoolant
8a836247c8 Add DPMPP Single, Euler Karras and DPMPP2 Multi Karras Schedulers 2023-05-12 02:23:33 +12:00
Mary Hipp
95c3644564 fix it again 2023-05-12 00:10:39 +10:00
psychedelicious
799cd07174 feat(ui): make core parameters layout consistent 2023-05-11 22:45:53 +10:00
psychedelicious
9af385468d feat(ui): expand config options
now may disable individual SD features eg Noise, Variation, etc - stuff which is not ready for consumption in commercial.
2023-05-11 22:42:13 +10:00
blessedcoolant
3487388788 Merge branch 'unipc-sched' of https://github.com/blessedcoolant/InvokeAI into unipc-sched 2023-05-12 00:40:24 +12:00
blessedcoolant
9a383e456d Codesplit SCHEDULER_MAP for reusage 2023-05-12 00:40:03 +12:00
blessedcoolant
805f9f8f4a Merge branch 'main' into unipc-sched 2023-05-12 00:24:55 +12:00
blessedcoolant
52aa0c9bbd ui: miscellaneous fixes (#3386) 2023-05-12 00:21:29 +12:00
psychedelicious
7f5f4689cc fix(ui): clear progress image on cancel 2023-05-11 22:20:37 +10:00
psychedelicious
a3f81f4b98 fix(ui): fix results not displaying
- fix for commercial product
2023-05-11 22:20:37 +10:00
psychedelicious
15c59e606f feat(ui): add spinner to gallery progress images
- otherwise you may think you can click it but you cannot
2023-05-11 22:20:37 +10:00
psychedelicious
40d4cabecd feat(ui): improve image overlay 2023-05-11 22:20:37 +10:00
psychedelicious
3493c8119b feat(ui): improve image preview css and fallback 2023-05-11 22:20:30 +10:00
blessedcoolant
c1e7460d39 Merge branch 'main' into unipc-sched 2023-05-12 00:11:09 +12:00
blessedcoolant
3ffff023b2 Add missing key to scheduler_map
It was breaking coz the sampler was not being reset. So needs a key on each. Will simplify this later.
2023-05-12 00:08:50 +12:00
psychedelicious
f9384be59b fix(ui): fix init image causing overflow 2023-05-11 20:55:30 +10:00
psychedelicious
6cf308004a fix(nodes): remove Optionals on ImageOutputs 2023-05-11 20:54:57 +10:00
blessedcoolant
d1029138d2 Default to DDIM if scheduler is missing 2023-05-11 22:54:35 +12:00
blessedcoolant
06b5800d28 Add UniPC Scheduler 2023-05-11 22:43:18 +12:00
psychedelicious
483f2ccb56 feat(nodes): add RandomIntInvocation
just outputs a single random int
2023-05-11 20:33:32 +10:00
psychedelicious
93ced0bec6 feat(nodes): add w/h to latents outputs
This reduces the number of nodes needed when working with latents (ie fewer plain integer value nodes)

Also correct a few mistakes in the fields
2023-05-11 20:32:55 +10:00
psychedelicious
4333852c37 fix(nodes): fix missing context arg in LatentsToLatents 2023-05-11 19:28:42 +10:00
Eugene Brodsky
3baa230077 Merge branch 'main' into lstein/bugfix/compel 2023-05-11 00:50:45 -04:00
Eugene
9e594f9018 pad conditioning tensors to same length
fixes crash when prompt length is greater than 75 tokens
2023-05-11 00:34:15 -04:00
Mary Hipp Rogers
b0c41b4828 filter our websocket errors (#3382)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2023-05-11 01:58:40 +00:00
psychedelicious
e0d6946b6b fix(nodes): fix metadata test
- `progress_images` is no longer a parameter
- `seamless` needs to be reworked as a model config, removed as a param
2023-05-11 11:55:51 +10:00
psychedelicious
bf7ea8309f fix(ui): change tab to img2img when selected initial image 2023-05-11 11:55:51 +10:00
psychedelicious
54b65f725f fix(ui): rescale canvas on gallery resize 2023-05-11 11:55:51 +10:00
psychedelicious
8ef49c2640 fix(ui): fix canvas img2img if no init image selected 2023-05-11 11:55:51 +10:00
psychedelicious
f488b1a7f2 fix(nodes): fix usage of Optional 2023-05-11 11:55:51 +10:00
psychedelicious
d2edb7c402 build(ui): add yalc to gitignore 2023-05-11 11:55:51 +10:00
psychedelicious
f0a3f07b45 feat(ui): antialias progress images 2023-05-11 11:55:51 +10:00
psychedelicious
b42b630583 fix(ui): h/w disabled bug 2023-05-11 11:55:51 +10:00
psychedelicious
31a78d571b feat(ui): canvas antialiasing 2023-05-11 11:55:51 +10:00
psychedelicious
fdc2232ea0 feat(ui): progress images in gallery and viewer 2023-05-11 11:55:51 +10:00
psychedelicious
e94d0b2d40 fix(ui): fix janky gallery image delete 2023-05-11 11:55:51 +10:00
psychedelicious
75ccbaee9c fix(ui): disable invoke button as soon as pressed 2023-05-11 11:55:51 +10:00
psychedelicious
2848c8397c fix(ui): fix missing images on reload issue
- Mainly an issue for commercial due to incomplete metadata handling
2023-05-11 11:55:51 +10:00
psychedelicious
fe8b5193de feat(ui): half-baked use all parameters
until we have a better system for metadata, this will remain half-baked
2023-05-11 11:55:51 +10:00
psychedelicious
3d1470399c fix(ui): fix metadataviewer styling 2023-05-11 11:55:51 +10:00
psychedelicious
fcf9c63049 fix(ui): fix copying image link 2023-05-11 11:55:51 +10:00
blessedcoolant
7bfb5640ad cleanup(ui): Remove unused vars + minor bug fixes 2023-05-11 11:55:51 +10:00
psychedelicious
15e57e3a3d fix(ui): duplicate gallery in nodes editor 2023-05-11 11:55:51 +10:00
psychedelicious
279468c0e8 feat(ui): restore tab names 2023-05-11 11:55:51 +10:00
psychedelicious
c565812723 feat(ui): organize parameters panels 2023-05-11 11:55:51 +10:00
psychedelicious
ec6c8e2a38 feat(ui): wip layout 2023-05-11 11:55:51 +10:00
psychedelicious
77f2690711 fix(ui): remove duplicate gallery 2023-05-11 11:55:51 +10:00
psychedelicious
c4b3a24ed7 feat(ui): revert tabs to txt2img/img2img 2023-05-11 11:55:51 +10:00
psychedelicious
33c69359c2 feat(ui): add IAICollapse for parameters 2023-05-11 11:55:51 +10:00
psychedelicious
864f4bb4af feat(ui): wip img2img layouting 2023-05-11 11:55:51 +10:00
psychedelicious
5365f42a04 feat(ui): wip layouting 2023-05-11 11:55:51 +10:00
psychedelicious
3dc60254b9 feat(ui): support collect nodes 2023-05-11 11:55:51 +10:00
psychedelicious
027a8562d7 fix(ui): default node model selection 2023-05-11 11:55:51 +10:00
psychedelicious
34f3a0f0e3 feat(nodes): improve default model choosing output 2023-05-11 11:55:51 +10:00
psychedelicious
d0bac1675e fix(nodes): fix ImageOutput Config 2023-05-11 11:55:51 +10:00
psychedelicious
4e56c962f4 fix(nodes): fix infill docstrings 2023-05-11 11:55:51 +10:00
psychedelicious
4ef0e43759 fix(nodes): remove dataURL invocation 2023-05-11 11:55:51 +10:00
psychedelicious
6945d10297 chore(ui): regen api client 2023-05-11 11:55:51 +10:00
psychedelicious
4d6cef7ac8 fix(ui): fix types bug 2023-05-11 11:55:51 +10:00
psychedelicious
a7786d5ff2 fix(nodes): restore seamless to TextToLatents 2023-05-11 11:55:51 +10:00
psychedelicious
6c1de975d9 feat(nodes): add infill nodes 2023-05-11 11:55:51 +10:00
psychedelicious
a1079e455a feat(nodes): cleanup unused params, seed generation 2023-05-11 11:55:51 +10:00
psychedelicious
5457c7f069 fix(ui): use lodash-es instead of lodash 2023-05-11 11:55:51 +10:00
psychedelicious
b8c1a3f96c chore(ui): remove unused babelrc & npm script 2023-05-11 11:55:51 +10:00
psychedelicious
cee8e85f76 chore(ui): bump redux-remember 2023-05-11 11:55:51 +10:00
psychedelicious
09f166577e feat(ui): migrate to redux-remember 2023-05-11 11:55:51 +10:00
psychedelicious
bcc21531fb feat(ui): update for InfillInvocation 2023-05-11 11:55:51 +10:00
psychedelicious
da4eacdffe feat(nodes): add InfillInvocation 2023-05-11 11:55:51 +10:00
psychedelicious
6102e560ba feat(nodes): add LatentsToImage node (VAE encode) 2023-05-11 11:55:51 +10:00
psychedelicious
ff3aa57117 feat(ui): fix endless gallery scroll for single col layout 2023-05-11 11:55:51 +10:00
psychedelicious
49db6f4fac fix(nodes): fix trivial typing issues 2023-05-11 11:55:51 +10:00
psychedelicious
20f6a597ab fix(nodes): add MetadataColorField 2023-05-11 11:55:51 +10:00
psychedelicious
04c453721c feat(ui): tweak gallery loading indicator 2023-05-11 11:55:51 +10:00
psychedelicious
350ffecc1f feat(ui): endless gallery scroll 2023-05-11 11:55:51 +10:00
psychedelicious
b0557aa16b fix(ui): fix currentimagepreview not working for uploads 2023-05-11 11:55:51 +10:00
psychedelicious
1c9429a6ea feat(ui): wip canvas 2023-05-11 11:55:51 +10:00
psychedelicious
206e6b1730 feat(nodes): wip inpaint node 2023-05-11 11:55:51 +10:00
psychedelicious
357cee2849 fix(nodes): fix cfg scale min value 2023-05-11 11:55:51 +10:00
psychedelicious
0b49997bb6 feat(nodes): allow uploaded images to be any ImageType (eg intermediates) 2023-05-11 11:55:51 +10:00
psychedelicious
5e09dd380d Revert "feat(nodes): free gpu mem after invocation"
This reverts commit 99cb33f477306d5dcc455efe04053ce41b8d85bd.
2023-05-11 11:55:51 +10:00
psychedelicious
c7303adb0d feat(ui): fix generation mode logic 2023-05-11 11:55:51 +10:00
psychedelicious
ed1f096a6f feat(ui): wip canvas migration 4 2023-05-11 11:55:51 +10:00
psychedelicious
6ab5d28cf3 feat(ui): wip canvas migration, createListenerMiddleware 2023-05-11 11:55:51 +10:00
psychedelicious
a75148cb16 feat(nodes): free gpu mem after invocation 2023-05-11 11:55:51 +10:00
psychedelicious
f7bbc4004a feat(ui): wip canvas nodes migration 3 2023-05-11 11:55:51 +10:00
psychedelicious
cee21ca082 feat(ui): wip canvas nodes migration 2 2023-05-11 11:55:51 +10:00
psychedelicious
08ec12b391 feat(ui): wip canvas nodes migration 2023-05-11 11:55:51 +10:00
psychedelicious
ff5e2a9a8c chore(ui): regen api client 2023-05-11 11:55:51 +10:00
psychedelicious
e0b9b5cc6c feat(nodes): add dataURL to image node 2023-05-11 11:55:51 +10:00
Lincoln Stein
aca4770481 fixed compel.py as requested 2023-05-10 21:40:44 -04:00
Lincoln Stein
5d5157fc65 make conditioning.py work with compel 1.1.5 2023-05-10 18:08:33 -04:00
Mary Hipp Rogers
fb6ef61a4d change path for locale (#3381)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2023-05-10 10:30:17 -04:00
psychedelicious
ee24ad7b13 fix(nodes): fix broken docs routes 2023-05-10 08:28:17 -04:00
psychedelicious
f8e90ba3f0 feat(nodes): add ui build static route 2023-05-10 08:28:17 -04:00
blessedcoolant
ad0b70ca23 fix(nodes): fix #3306 (#3377)
Check if the cache has the object before deleting it.
2023-05-10 17:39:45 +12:00
psychedelicious
7dfa135b2c fix(nodes): fix #3306
Check if the cache has the object before deleting it.
2023-05-10 15:29:10 +10:00
Lincoln Stein
beeaa05658 Update dependencies to get deterministic image generation behavior (main branch) (#3354)
This PR updates to `xformers ~= 0.0.19` and `torch ~= 2.0.0`, which
together seem to solve the non-deterministic image generation issue that
was previously seen with earlier versions of `xformers`.
2023-05-10 00:10:51 -04:00
Lincoln Stein
6b6d654f60 Merge branch 'main' into enhance/update-dependencies 2023-05-09 23:56:46 -04:00
Mary Hipp
853c83d0c2 surface detail field for 403 errors 2023-05-09 12:40:19 +10:00
Mary Hipp
1809990ed4 if backend returns an error, show it in toast 2023-05-09 11:09:36 +10:00
Eugene
79d49853d2 use websocket transport first for socket.io 2023-05-09 11:01:02 +10:00
Lincoln Stein
bd0ad59c27 bump compel version 2023-05-07 15:22:46 -04:00
Lincoln Stein
cce40acba5 Merge branch 'enhance/update-dependencies' of github.com:invoke-ai/InvokeAI into enhance/update-dependencies 2023-05-07 15:22:31 -04:00
Lincoln Stein
bc9491ab69 bump compel version 2023-05-07 15:21:24 -04:00
blessedcoolant
b909bac0dc Merge branch 'main' into enhance/update-dependencies 2023-05-07 21:44:43 +12:00
Lincoln Stein
8f80ba9520 update dependencies to get deterministic image generation 2023-05-06 23:09:24 -04:00
425 changed files with 8063 additions and 7018 deletions

View File

@@ -247,8 +247,8 @@ class InvokeAiInstance:
pip[
"install",
"--require-virtualenv",
"torch",
"torchvision",
"torch~=2.0.0",
"torchvision>=0.14.1",
"--force-reinstall",
"--find-links" if find_links is not None else None,
find_links,

View File

@@ -1,6 +1,8 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from invokeai.app.models.image import ImageField
from invokeai.app.services.outputs_sqlite import OutputsSqliteItemStorage
import invokeai.backend.util.logging as logger
from typing import types
@@ -76,6 +78,7 @@ class ApiDependencies:
latents=latents,
images=images,
metadata=metadata,
outputs=OutputsSqliteItemStorage(filename=db_location),
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@@ -83,7 +83,7 @@ async def get_thumbnail(
status_code=201,
)
async def upload_image(
file: UploadFile, request: Request, response: Response
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -99,21 +99,21 @@ async def upload_image(
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
saved_image = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img
image_type, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
ImageType.UPLOAD, saved_image.image_name, True
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,

View File

@@ -126,7 +126,6 @@ app.openapi = custom_openapi
# Override API doc favicons
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
@@ -144,6 +143,8 @@ def overridden_redoc():
redoc_favicon_url="/static/favicon.ico",
)
# Must mount *after* the other routes else it borks em
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
def invoke_api():
# Start our own event loop for eventing usage

View File

@@ -3,12 +3,12 @@
from typing import Literal, Optional
import numpy as np
import numpy.random
from pydantic import Field
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import (
BaseInvocation,
InvocationConfig,
InvocationContext,
BaseInvocationOutput,
)
@@ -50,11 +50,11 @@ class RandomRangeInvocation(BaseInvocation):
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate")
seed: Optional[int] = Field(
seed: int = Field(
ge=0,
le=np.iinfo(np.int32).max,
description="The seed for the RNG",
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
le=SEED_MAX,
description="The seed for the RNG (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> IntCollectionOutput:

View File

@@ -100,7 +100,8 @@ class CompelInvocation(BaseInvocation):
# TODO: support legacy blend?
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
conjunction = Compel.parse_prompt_string(prompt_str)
prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0]
if getattr(Globals, "log_tokenization", False):
log_tokenization_for_prompt_object(prompt, tokenizer)

View File

@@ -1,15 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from functools import partial
from typing import Literal, Optional, Union
from typing import Literal, Optional, Union, get_args
import numpy as np
from torch import Tensor
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
@@ -17,7 +19,8 @@ from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
@@ -44,15 +47,13 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# TODO: pass this an emitter method or something? or a session for dispatching?
@@ -105,6 +106,11 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
context.services.outputs.set(image_name, context.graph_execution_state_id)
s = context.services.outputs.get(image_name)
print(s)
return build_image_output(
image_type=image_type,
image_name=image_name,
@@ -148,7 +154,6 @@ class ImageToImageInvocation(TextToImageInvocation):
self.image.image_type, self.image.image_name
)
)
mask = None
if self.fit:
image = image.resize((self.width, self.height))
@@ -165,7 +170,6 @@ class ImageToImageInvocation(TextToImageInvocation):
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
@@ -197,7 +201,6 @@ class ImageToImageInvocation(TextToImageInvocation):
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@@ -205,6 +208,17 @@ class InpaintInvocation(ImageToImageInvocation):
# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
inpaint_replace: float = Field(
default=0.0,
ge=0.0,

View File

@@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Literal, Optional
import numpy
@@ -32,14 +33,12 @@ class ImageOutput(BaseInvocationOutput):
# fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {
"required": ["type", "image", "width", "height", "mode"]
}
schema_extra = {"required": ["type", "image", "width", "height"]}
def build_image_output(
@@ -54,7 +53,6 @@ def build_image_output(
image=image_field,
width=image.width,
height=image.height,
mode=image.mode,
)
@@ -151,7 +149,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
@@ -209,7 +207,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,

View File

@@ -0,0 +1,233 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal, Optional, Union, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput, build_image_output
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
InvocationContext,
)
def infill_methods() -> list[str]:
methods = [
"tile",
"solid",
]
if PatchMatch.patchmatch_available():
methods.insert(0, "patchmatch")
return methods
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
def infill_patchmatch(im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
# Skip patchmatch if patchmatch isn't available
if not PatchMatch.patchmatch_available():
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
def get_tile_images(image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False,
)
def tile_fill_missing(
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = tiles_mask > 0
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1, 2)
st = tiles_all.reshape(
(
math.prod(tiles_all.shape[0:2]),
math.prod(tiles_all.shape[2:4]),
tiles_all.shape[4],
)
)
si = Image.fromarray(st, mode="RGBA")
return si
class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: Optional[ColorField] = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)
class InfillPatchMatchInvocation(BaseInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(image.copy())
else:
raise ValueError("PatchMatch is not available on this system")
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
)

View File

@@ -1,11 +1,13 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import random
from typing import Literal, Optional
from typing import Literal, Optional, Union
import einops
from pydantic import BaseModel, Field
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
@@ -13,7 +15,9 @@ from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
@@ -37,41 +41,55 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
#fmt: off
type: Literal["latent_output"] = "latent_output"
latents: LatentsField = Field(default=None, description="The output latents")
type: Literal["latents_output"] = "latents_output"
# Inputs
latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels")
#fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
class NoiseOutput(BaseInvocationOutput):
"""Invocation noise output"""
#fmt: off
type: Literal["noise_output"] = "noise_output"
type: Literal["noise_output"] = "noise_output"
# Inputs
noise: LatentsField = Field(default=None, description="The output noise")
width: int = Field(description="The width of the noise in pixels")
height: int = Field(description="The height of the noise in pixels")
#fmt: on
# TODO: this seems like a hack
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
def build_noise_output(latents_name: str, latents: torch.Tensor):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
SAMPLER_NAME_VALUES = Literal[
tuple(list(scheduler_map.keys()))
tuple(list(SCHEDULER_MAP.keys()))
]
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@@ -102,17 +120,13 @@ def get_noise(width:int, height:int, device:torch.device, seed:int = 0, latent_c
return x
def random_seed():
return random.randint(0, np.iinfo(np.uint32).max)
class NoiseInvocation(BaseInvocation):
"""Generates latent noise."""
type: Literal["noise"] = "noise"
# Inputs
seed: int = Field(ge=0, le=np.iinfo(np.uint32).max, description="The seed to use", default_factory=random_seed)
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use", default_factory=get_random_seed)
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting noise", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting noise", )
@@ -131,9 +145,7 @@ class NoiseInvocation(BaseInvocation):
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
return NoiseOutput(
noise=LatentsField(latents_name=name)
)
return build_noise_output(latents_name=name, latents=noise)
# Text to image
@@ -149,11 +161,10 @@ class TextToLatentsInvocation(BaseInvocation):
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="k_lms", description="The scheduler to use" )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
# fmt: on
# Schema customisation
@@ -218,7 +229,7 @@ class TextToLatentsInvocation(BaseInvocation):
h_symmetry_time_pct=None,#h_symmetry_time_pct,
v_symmetry_time_pct=None#v_symmetry_time_pct,
),
).add_scheduler_args_if_applicable(model.scheduler, eta=None)#ddim_eta)
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
return conditioning_data
@@ -250,9 +261,7 @@ class TextToLatentsInvocation(BaseInvocation):
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
@@ -260,6 +269,10 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
@@ -271,10 +284,6 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
},
}
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.5, description="The strength of the latents to use")
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
@@ -287,7 +296,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size
@@ -295,11 +304,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latent, device=model.device, dtype=latent.dtype
)
timesteps, _ = model.get_img2img_timesteps(
self.steps,
self.strength,
device=model.device,
)
timesteps, _ = model.get_img2img_timesteps(self.steps, self.strength)
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents,
@@ -315,9 +320,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
return LatentsOutput(
latents=LatentsField(latents_name=name)
)
return build_latents_output(latents_name=name, latents=result_latents)
# Latent to image
@@ -381,11 +384,11 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
@@ -402,7 +405,7 @@ class ResizeLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
return build_latents_output(latents_name=name, latents=resized_latents)
class ScaleLatentsInvocation(BaseInvocation):
@@ -411,10 +414,10 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
@@ -432,4 +435,48 @@ class ScaleLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
return LatentsOutput(latents=LatentsField(latents_name=name))
return build_latents_output(latents_name=name, latents=resized_latents)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
type: Literal["i2l"] = "i2l"
# Inputs
image: Union[ImageField, None] = Field(description="The image to encode")
model: str = Field(default="", description="The model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
"type_hints": {"model": "model"},
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
)
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info["model"]
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = model.non_noised_latents_from_image(
image_tensor,
device=model._model_group.device_for(model.unet),
dtype=model.unet.dtype,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@@ -3,6 +3,7 @@
from typing import Literal
from pydantic import BaseModel, Field
import numpy as np
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
@@ -73,3 +74,12 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=int(self.a / self.b))
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""
#fmt: off
type: Literal["rand_int"] = "rand_int"
#fmt: on
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max))

View File

@@ -4,10 +4,11 @@ from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
logger = model_manager.logger
if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name)
else:
if model_name and not model_manager.valid_model(model_name):
default_model_name = model_manager.default_model()
logger.warning(f"\'{model_name}\' is not a valid model name. Using default model \'{default_model_name}\' instead.")
model = model_manager.get_model()
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
else:
model = model_manager.get_model(model_name)
return model

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional
from typing import Optional, Tuple
from pydantic import BaseModel, Field
@@ -27,3 +27,13 @@ class ImageField(BaseModel):
class Config:
schema_extra = {"required": ["image_type", "image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)

View File

@@ -48,13 +48,14 @@ def create_text_to_image() -> LibraryGraph:
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = list()
text_to_image = graph_library.get(default_text_to_image_graph_id)
# text_to_image = graph_library.get(default_text_to_image_graph_id)
# TODO: Check if the graph is the same as the default one, and if not, update it
#if text_to_image is None:
# # TODO: Check if the graph is the same as the default one, and if not, update it
# #if text_to_image is None:
text_to_image = create_text_to_image()
graph_library.set(text_to_image)

View File

@@ -270,4 +270,5 @@ class DiskImageStorage(ImageStorageBase):
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
del self.__cache[cache_id]
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@@ -1,6 +1,8 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import types
from invokeai.app.services.outputs_session_storage import OutputsSessionStorageABC
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
@@ -17,6 +19,7 @@ class InvocationServices:
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
outputs: OutputsSessionStorageABC
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
@@ -34,6 +37,7 @@ class InvocationServices:
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
outputs: OutputsSessionStorageABC,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
@@ -46,6 +50,7 @@ class InvocationServices:
self.logger = logger
self.latents = latents
self.images = images
self.outputs = outputs
self.metadata = metadata
self.queue = queue
self.graph_library = graph_library

View File

@@ -20,9 +20,18 @@ class MetadataLatentsField(TypedDict):
latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
]

View File

@@ -0,0 +1,59 @@
from abc import ABC, abstractmethod
from typing import Callable, Generic, TypeVar
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
class PaginatedStringResults(GenericModel):
"""Paginated results"""
#fmt: off
items: list[str] = Field(description="Session IDs")
page: int = Field(description="Current Page")
pages: int = Field(description="Total number of pages")
per_page: int = Field(description="Number of items per page")
total: int = Field(description="Total number of items in result")
#fmt: on
class OutputsSessionStorageABC(ABC):
_on_changed_callbacks: list[Callable[[str], None]]
_on_deleted_callbacks: list[Callable[[str], None]]
def __init__(self) -> None:
self._on_changed_callbacks = list()
self._on_deleted_callbacks = list()
"""Base item storage class"""
@abstractmethod
def get(self, output_id: str) -> str:
pass
@abstractmethod
def set(self, output_id: str, session_id: str) -> None:
pass
@abstractmethod
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
pass
@abstractmethod
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedStringResults:
pass
def on_changed(self, on_changed: Callable[[str], None]) -> None:
"""Register a callback for when an item is changed"""
self._on_changed_callbacks.append(on_changed)
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
"""Register a callback for when an item is deleted"""
self._on_deleted_callbacks.append(on_deleted)
def _on_changed(self, foreign_key_value: str) -> None:
for callback in self._on_changed_callbacks:
callback(foreign_key_value)
def _on_deleted(self, item_id: str) -> None:
for callback in self._on_deleted_callbacks:
callback(item_id)

View File

@@ -0,0 +1,162 @@
import json
import sqlite3
from threading import Lock
from typing import Union
from invokeai.app.services.outputs_session_storage import (
OutputsSessionStorageABC,
PaginatedStringResults,
)
sqlite_memory = ":memory:"
class OutputsSqliteItemStorage(OutputsSessionStorageABC):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: Lock
def __init__(
self,
filename: str,
):
super().__init__()
self._filename = filename
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS outputs (
id TEXT NOT NULL PRIMARY KEY,
session_id TEXT NOT NULL
);"""
)
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS outputs_id ON outputs(id);"""
)
finally:
self._lock.release()
def set(self, output_id: str, session_id: str):
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO outputs (id, session_id) VALUES (?, ?);""",
(output_id, session_id),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(output_id)
def get(self, output_id: str) -> Union[str, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id = ?;
""",
(output_id,),
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
return result[0]
def delete(self, output_id: str):
try:
self._lock.acquire()
self._cursor.execute(
f"""DELETE FROM outputs WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(output_id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedStringResults:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
LIMIT ? OFFSET ?;
""",
(per_page, page * per_page),
)
result = self._cursor.fetchall()
items = list(map(lambda r: r[0], result))
self._cursor.execute(
f"""
SELECT count(*)
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id;
""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedStringResults(
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedStringResults:
try:
self._lock.acquire()
self._cursor.execute(
f"""
SELECT graph_executions.item session
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id LIKE ? LIMIT ? OFFSET ?;
""",
(f"%{query}%", per_page, page * per_page),
)
result = self._cursor.fetchall()
items = list(map(lambda r: r[0], result))
self._cursor.execute(
f"""
SELECT count(*)
FROM graph_executions
INNER JOIN outputs ON outputs.session_id = graph_executions.id
WHERE outputs.id LIKE ?;
""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedStringResults(
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)

View File

@@ -1,3 +1,4 @@
import time
import traceback
from threading import Event, Thread, BoundedSemaphore
@@ -6,6 +7,7 @@ from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
import invokeai.backend.util.logging as logger
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
@@ -34,8 +36,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
self.__threadLimit.acquire()
while not stop_event.is_set():
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e:
logger.debug("Exception while getting from queue: %s" % e)
if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue
graph_execution_state = (
@@ -124,7 +132,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Queue any further commands if invoking all
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e:
logger.error("Error while invoking: %s" % e)
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error=traceback.format_exc()
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(
graph_execution_state.id

View File

@@ -1,5 +1,13 @@
import datetime
import numpy as np
def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
SEED_MAX = np.iinfo(np.int32).max
def get_random_seed():
return np.random.randint(0, SEED_MAX)

View File

@@ -108,17 +108,21 @@ APP_VERSION = invokeai.version.__version__
SAMPLER_CHOICES = [
"ddim",
"k_dpm_2_a",
"k_dpm_2",
"k_dpmpp_2_a",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
# diffusers:
"ddpm",
"deis",
"lms",
"pndm",
"heun",
"heun_k",
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2m",
"dpmpp_2m_k",
"unipc",
]
PRECISION_CHOICES = [
@@ -631,7 +635,7 @@ class Args(object):
choices=SAMPLER_CHOICES,
metavar="SAMPLER_NAME",
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default="k_lms",
default="lms",
)
render_group.add_argument(
"--log_tokenization",

View File

@@ -37,6 +37,7 @@ from .safety_checker import SafetyChecker
from .prompting import get_uc_and_c_and_ec
from .prompting.conditioning import log_tokenization
from .stable_diffusion import HuggingFaceConceptsLibrary
from .stable_diffusion.schedulers import SCHEDULER_MAP
from .util import choose_precision, choose_torch_device
def fix_func(orig):
@@ -141,7 +142,7 @@ class Generate:
model=None,
conf="configs/models.yaml",
embedding_path=None,
sampler_name="k_lms",
sampler_name="lms",
ddim_eta=0.0, # deterministic
full_precision=False,
precision="auto",
@@ -1047,29 +1048,12 @@ class Generate:
def _set_scheduler(self):
default = self.model.scheduler
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
# provide an alias for compatibility.
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
if self.sampler_name in scheduler_map:
sampler_class = scheduler_map[self.sampler_name]
if self.sampler_name in SCHEDULER_MAP:
sampler_class, sampler_extra_config = SCHEDULER_MAP[self.sampler_name]
msg = (
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
)
self.sampler = sampler_class.from_config(self.model.scheduler.config)
self.sampler = sampler_class.from_config({**self.model.scheduler.config, **sampler_extra_config})
else:
msg = (
f" Unsupported Sampler: {self.sampler_name} "+

View File

@@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP
downsampling = 8
@@ -71,19 +72,6 @@ class InvokeAIGeneratorOutput:
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
class InvokeAIGenerator(metaclass=ABCMeta):
scheduler_map = dict(
ddim=diffusers.DDIMScheduler,
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
k_euler=diffusers.EulerDiscreteScheduler,
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
k_heun=diffusers.HeunDiscreteScheduler,
k_lms=diffusers.LMSDiscreteScheduler,
plms=diffusers.PNDMScheduler,
)
def __init__(self,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
@@ -175,14 +163,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
'''
Return list of all the schedulers that we currently handle.
'''
return list(self.scheduler_map.keys())
return list(SCHEDULER_MAP.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
scheduler = scheduler_class.from_config(model.scheduler.config)
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
@@ -226,10 +220,10 @@ class Inpaint(Img2Img):
def generate(self,
mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
seam_steps: int = 10,
seam_steps: int = 30,
tile_size: int = 32,
inpaint_replace=False,
infill_method=None,

View File

@@ -4,6 +4,7 @@ invokeai.backend.generator.inpaint descends from .generator
from __future__ import annotations
import math
from typing import Tuple, Union
import cv2
import numpy as np
@@ -59,7 +60,7 @@ class Inpaint(Img2Img):
writeable=False,
)
def infill_patchmatch(self, im: Image.Image) -> Image:
def infill_patchmatch(self, im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
@@ -75,18 +76,18 @@ class Inpaint(Img2Img):
return im_patched
def tile_fill_missing(
self, im: Image.Image, tile_size: int = 16, seed: int = None
) -> Image:
self, im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size = (tile_size, tile_size)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = self.get_tile_images(a, *tile_size).copy()
tiles = self.get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
@@ -127,7 +128,9 @@ class Inpaint(Img2Img):
return si
def mask_edge(self, mask: Image, edge_size: int, edge_blur: int) -> Image:
def mask_edge(
self, mask: Image.Image, edge_size: int, edge_blur: int
) -> Image.Image:
npimg = np.asarray(mask, dtype=np.uint8)
# Detect any partially transparent regions
@@ -206,15 +209,15 @@ class Inpaint(Img2Img):
cfg_scale,
ddim_eta,
conditioning,
init_image: PIL.Image.Image | torch.FloatTensor,
mask_image: PIL.Image.Image | torch.FloatTensor,
init_image: Image.Image | torch.FloatTensor,
mask_image: Image.Image | torch.FloatTensor,
strength: float,
mask_blur_radius: int = 8,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
seam_size: int = 96,
seam_blur: int = 16,
seam_strength: float = 0.7,
seam_steps: int = 10,
seam_steps: int = 30,
tile_size: int = 32,
step_callback=None,
inpaint_replace=False,
@@ -222,7 +225,7 @@ class Inpaint(Img2Img):
infill_method=None,
inpaint_width=None,
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
inpaint_fill: Tuple[int, int, int, int] = (0x7F, 0x7F, 0x7F, 0xFF),
attention_maps_callback=None,
**kwargs,
):
@@ -239,7 +242,7 @@ class Inpaint(Img2Img):
self.inpaint_width = inpaint_width
self.inpaint_height = inpaint_height
if isinstance(init_image, PIL.Image.Image):
if isinstance(init_image, Image.Image):
self.pil_image = init_image.copy()
# Do infill
@@ -250,8 +253,8 @@ class Inpaint(Img2Img):
self.pil_image.copy(), seed=self.seed, tile_size=tile_size
)
elif infill_method == "solid":
solid_bg = PIL.Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = PIL.Image.alpha_composite(solid_bg, init_image)
solid_bg = Image.new("RGBA", init_image.size, inpaint_fill)
init_filled = Image.alpha_composite(solid_bg, init_image)
else:
raise ValueError(
f"Non-supported infill type {infill_method}", infill_method
@@ -269,7 +272,7 @@ class Inpaint(Img2Img):
# Create init tensor
init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB"))
if isinstance(mask_image, PIL.Image.Image):
if isinstance(mask_image, Image.Image):
self.pil_mask = mask_image.copy()
debug_image(
mask_image,

View File

@@ -47,6 +47,7 @@ from diffusers import (
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMScheduler,
UniPCMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
@@ -1209,6 +1210,8 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == 'unipc':
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:

View File

@@ -30,7 +30,7 @@ from diffusers import (
UNet2DConditionModel,
SchedulerMixin,
logging as dlogging,
)
)
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@@ -68,7 +68,7 @@ class SDModelComponent(Enum):
scheduler="scheduler"
safety_checker="safety_checker"
feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2
class ModelManager(object):
@@ -182,7 +182,7 @@ class ModelManager(object):
vae from the model currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.vae)
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned CLIPTokenizer. If no
@@ -190,12 +190,12 @@ class ModelManager(object):
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
"""Given a model name identified in models.yaml, load the model into
GPU if necessary and return its assigned UNet2DConditionModel. If no model
name is provided, return the UNet from the model
currently in the GPU.
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.unet)
@@ -222,7 +222,7 @@ class ModelManager(object):
currently in the GPU.
"""
return self._get_sub_model(model_name, SDModelComponent.scheduler)
def _get_sub_model(
self,
model_name: str=None,
@@ -1228,7 +1228,7 @@ class ModelManager(object):
sha.update(chunk)
hash = sha.hexdigest()
toc = time.time()
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
self.logger.debug(f"sha256 = {hash} ({count} files hashed in {toc - tic:4.2f}s)")
with open(hashpath, "w") as f:
f.write(hash)
return hash

View File

@@ -16,6 +16,7 @@ from compel.prompt_parser import (
FlattenedPrompt,
Fragment,
PromptParser,
Conjunction,
)
import invokeai.backend.util.logging as logger
@@ -25,58 +26,48 @@ from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
):
def get_uc_and_c_and_ec(prompt_string,
model: InvokeAIDiffuserComponent,
log_tokens=False, skip_normalize_legacy_blend=False):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_string
)
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
tokenizer = model.tokenizer
compel = Compel(
tokenizer=tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False
)
compel = Compel(tokenizer=model.tokenizer,
text_encoder=model.text_encoder,
textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False,
)
# get rid of any newline characters
prompt_string = prompt_string.replace("\n", " ")
(
positive_prompt_string,
negative_prompt_string,
) = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
if legacy_blend is not None:
positive_prompt = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
positive_conjunction: Conjunction
if legacy_blend is not None:
positive_conjunction = legacy_blend
else:
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
if log_tokens or getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
tokens_count = get_max_token_count(tokenizer, positive_prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get("cross_attention_control", None),
)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
cross_attention_control_args=options.get(
'cross_attention_control', None))
return uc, c, ec
def get_prompt_structure(
prompt_string, skip_normalize_legacy_blend: bool = False
) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt):
@@ -87,18 +78,17 @@ def get_prompt_structure(
legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend
)
positive_prompt: Union[FlattenedPrompt, Blend]
positive_prompt: Conjunction
if legacy_blend is not None:
positive_prompt = legacy_blend
positive_conjunction = legacy_blend
else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string
)
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
positive_prompt = positive_conjunction.prompts[0]
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
return positive_prompt, negative_prompt
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
@@ -245,22 +235,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t
logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
logger.debug(f"{discarded}\x1b[0m")
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]:
def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
if len(weighted_subprompts) <= 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
pp = PromptParser()
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(
prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize
)
flattened_prompts = []
weights = []
for i, x in enumerate(parsed_conjunctions):
if len(x.prompts)>0:
flattened_prompts.append(x.prompts[0])
weights.append(weighted_subprompts[i][1])
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)])
def split_weighted_subprompts(text, skip_normalize=False) -> list:
"""

View File

@@ -509,10 +509,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
run_id=None,
callback: Callable[[PipelineIntermediateState], None] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
if timesteps is None:
self.scheduler.set_timesteps(
num_inference_steps, device=self._model_group.device_for(self.unet)
)
self.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps = self.scheduler.timesteps
infer_latents_from_embeddings = GeneratorToCallbackinator(
self.generate_latents_from_embeddings, PipelineIntermediateState
@@ -545,8 +548,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance = []
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
yield PipelineIntermediateState(
run_id=run_id,
@@ -725,12 +729,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
noise: torch.Tensor,
run_id=None,
callback=None,
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(
num_inference_steps,
strength,
device=self._model_group.device_for(self.unet),
)
) -> InvokeAIStableDiffusionPipelineOutput:
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
result_latents, result_attention_maps = self.latents_from_embeddings(
latents=initial_latents if strength < 1.0 else torch.zeros_like(
initial_latents, device=initial_latents.device, dtype=initial_latents.dtype
@@ -756,13 +756,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return self.check_for_safety(output, dtype=conditioning_data.dtype)
def get_img2img_timesteps(
self, num_inference_steps: int, strength: float, device
self, num_inference_steps: int, strength: float, device=None
) -> (torch.Tensor, int):
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
assert img2img_pipeline.scheduler is self.scheduler
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
if self.scheduler.config.get("cpu_only", False):
scheduler_device = torch.device('cpu')
else:
scheduler_device = self._model_group.device_for(self.unet)
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=scheduler_device)
timesteps, adjusted_steps = img2img_pipeline.get_timesteps(
num_inference_steps, strength, device=device
num_inference_steps, strength, device=scheduler_device
)
# Workaround for low strength resulting in zero timesteps.
# TODO: submit upstream fix for zero-step img2img
@@ -796,9 +802,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if init_image.dim() == 3:
init_image = init_image.unsqueeze(0)
timesteps, _ = self.get_img2img_timesteps(
num_inference_steps, strength, device=device
)
timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength)
# 6. Prepare latent variables
# can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents

View File

@@ -10,6 +10,7 @@ import diffusers
import psutil
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from torch import nn
@@ -352,8 +353,7 @@ def restore_default_cross_attention(
else:
remove_attention_function(model)
def override_cross_attention(model, context: Context, is_running_diffusers=False):
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
@@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False
indices = torch.arange(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
if b0 < max_length:
if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0):
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
context.cross_attention_mask = mask.to(device)
context.cross_attention_index_map = indices.to(device)
if is_running_diffusers:
unet = model
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next(
(
p.slice_size
for p in old_attn_processors.values()
if type(p) is SlicedAttnProcessor
),
default_slice_size,
)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
old_attn_processors = unet.attn_processors
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
context.register_cross_attention_modules(model)
inject_attention_function(model, context)
return None
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
def get_cross_attention_modules(
model, which: CrossAttentionType

View File

@@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union
import numpy as np
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
@@ -17,8 +18,8 @@ from .cross_attention_control import (
CrossAttentionType,
SwapCrossAttnContext,
get_cross_attention_modules,
override_cross_attention,
restore_default_cross_attention,
setup_cross_attention_control_attention_processors,
)
from .cross_attention_map_saving import AttentionMapSaver
@@ -79,24 +80,35 @@ class InvokeAIDiffuserComponent:
self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance
@classmethod
@contextmanager
def custom_attention_context(
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
cls,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int
):
do_swap = (
extra_conditioning_info is not None
and extra_conditioning_info.wants_cross_attention_control
)
old_attn_processor = None
if do_swap:
old_attn_processor = self.override_cross_attention(
extra_conditioning_info, step_count=step_count
)
old_attn_processors = None
if extra_conditioning_info and (
extra_conditioning_info.wants_cross_attention_control
):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
cross_attention_control_context,
)
try:
yield None
finally:
if old_attn_processor is not None:
self.restore_default_cross_attention(old_attn_processor)
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()

View File

@@ -0,0 +1 @@
from .schedulers import SCHEDULER_MAP

View File

@@ -0,0 +1,23 @@
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
SCHEDULER_MAP = dict(
ddim=(DDIMScheduler, dict()),
ddpm=(DDPMScheduler, dict()),
deis=(DEISMultistepScheduler, dict()),
lms=(LMSDiscreteScheduler, dict()),
pndm=(PNDMScheduler, dict()),
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
euler=(EulerDiscreteScheduler, dict(use_karras_sigmas=False)),
euler_k=(EulerDiscreteScheduler, dict(use_karras_sigmas=True)),
euler_a=(EulerAncestralDiscreteScheduler, dict()),
kdpm_2=(KDPM2DiscreteScheduler, dict()),
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
)

View File

@@ -2,34 +2,37 @@
"""invokeai.util.logging
Logging class for InvokeAI that produces console messages that follow
the conventions established in InvokeAI 1.X through 2.X.
Logging class for InvokeAI that produces console messages
One way to use it:
Usage:
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(__name__)
logger.critical('this is critical')
logger.error('this is an error')
logger.warning('this is a warning')
logger.info('this is info')
logger.debug('this is debugging')
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
(or)
logger = InvokeAILogger.getLogger(__name__) // To use the filename
logger.critical('this is critical') // Critical Message
logger.error('this is an error') // Error Message
logger.warning('this is a warning') // Warning Message
logger.info('this is info') // Info Message
logger.debug('this is debugging') // Debug Message
Console messages:
### this is critical
*** this is an error ***
** this is a warning
>> this is info
| this is debugging
[12-05-2023 20]::[InvokeAI]::CRITICAL --> This is an info message [In Bold Red]
[12-05-2023 20]::[InvokeAI]::ERROR --> This is an info message [In Red]
[12-05-2023 20]::[InvokeAI]::WARNING --> This is an info message [In Yellow]
[12-05-2023 20]::[InvokeAI]::INFO --> This is an info message [In Grey]
[12-05-2023 20]::[InvokeAI]::DEBUG --> This is an info message [In Grey]
Another way:
import invokeai.backend.util.logging as ialog
ialogger.debug('this is a debugging message')
Alternate Method (in this case the logger name will be set to InvokeAI):
import invokeai.backend.util.logging as IAILogger
IAILogger.debug('this is a debugging message')
"""
import logging
# module level functions
def debug(msg, *args, **kwargs):
InvokeAILogger.getLogger().debug(msg, *args, **kwargs)
@@ -42,7 +45,7 @@ def warning(msg, *args, **kwargs):
def error(msg, *args, **kwargs):
InvokeAILogger.getLogger().error(msg, *args, **kwargs)
def critical(msg, *args, **kwargs):
InvokeAILogger.getLogger().critical(msg, *args, **kwargs)
@@ -55,49 +58,47 @@ def disable(level=logging.CRITICAL):
def basicConfig(**kwargs):
InvokeAILogger.getLogger().basicConfig(**kwargs)
def getLogger(name: str=None)->logging.Logger:
def getLogger(name: str = None) -> logging.Logger:
return InvokeAILogger.getLogger(name)
class InvokeAILogFormatter(logging.Formatter):
'''
Repurposed from:
https://stackoverflow.com/questions/14844970/modifying-logging-message-format-based-on-message-logging-level-in-python3
Custom Formatting for the InvokeAI Logger
'''
crit_fmt = "### %(msg)s"
err_fmt = "*** %(msg)s"
warn_fmt = "** %(msg)s"
info_fmt = ">> %(msg)s"
dbg_fmt = " | %(msg)s"
def __init__(self):
super().__init__(fmt="%(levelno)d: %(msg)s", datefmt=None, style='%')
# Color Codes
grey = "\x1b[38;20m"
yellow = "\x1b[33;20m"
red = "\x1b[31;20m"
cyan = "\x1b[36;20m"
bold_red = "\x1b[31;1m"
reset = "\x1b[0m"
# Log Format
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
# Format Map
FORMATS = {
logging.DEBUG: cyan + format + reset,
logging.INFO: grey + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
}
def format(self, record):
# Remember the format used when the logging module
# was installed (in the event that this formatter is
# used with the vanilla logging module.
format_orig = self._style._fmt
if record.levelno == logging.DEBUG:
self._style._fmt = InvokeAILogFormatter.dbg_fmt
if record.levelno == logging.INFO:
self._style._fmt = InvokeAILogFormatter.info_fmt
if record.levelno == logging.WARNING:
self._style._fmt = InvokeAILogFormatter.warn_fmt
if record.levelno == logging.ERROR:
self._style._fmt = InvokeAILogFormatter.err_fmt
if record.levelno == logging.CRITICAL:
self._style._fmt = InvokeAILogFormatter.crit_fmt
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt="%d-%m-%Y %H:%M:%S")
return formatter.format(record)
# parent class does the work
result = super().format(record)
self._style._fmt = format_orig
return result
class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(self, name:str='invokeai')->logging.Logger:
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
if name not in self.loggers:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)

View File

@@ -4,17 +4,21 @@ from .parse_seed_weights import parse_seed_weights
SAMPLER_CHOICES = [
"ddim",
"k_dpm_2_a",
"k_dpm_2",
"k_dpmpp_2_a",
"k_dpmpp_2",
"k_euler_a",
"k_euler",
"k_heun",
"k_lms",
"plms",
# diffusers:
"ddpm",
"deis",
"lms",
"pndm",
"heun",
'heun_k',
"euler",
"euler_k",
"euler_a",
"kdpm_2",
"kdpm_2_a",
"dpmpp_2s",
"dpmpp_2m",
"dpmpp_2m_k",
"unipc",
]

View File

@@ -1,13 +0,0 @@
{
"plugins": [
[
"transform-imports",
{
"lodash": {
"transform": "lodash/${member}",
"preventFullImport": true
}
}
]
]
}

View File

@@ -34,4 +34,8 @@ stats.html
!.yarn/plugins
!.yarn/releases
!.yarn/sdks
!.yarn/versions
!.yarn/versions
# Yalc
.yalc
yalc.lock

View File

@@ -5,6 +5,7 @@ import { PluginOption, UserConfig } from 'vite';
import dts from 'vite-plugin-dts';
import eslint from 'vite-plugin-eslint';
import tsconfigPaths from 'vite-tsconfig-paths';
import cssInjectedByJsPlugin from 'vite-plugin-css-injected-by-js';
export const packageConfig: UserConfig = {
base: './',
@@ -16,9 +17,10 @@ export const packageConfig: UserConfig = {
dts({
insertTypesEntry: true,
}),
cssInjectedByJsPlugin(),
],
build: {
chunkSizeWarningLimit: 1500,
cssCodeSplit: true,
lib: {
entry: path.resolve(__dirname, '../src/index.ts'),
name: 'InvokeAIUI',
@@ -30,6 +32,7 @@ export const packageConfig: UserConfig = {
globals: {
react: 'React',
'react-dom': 'ReactDOM',
'@emotion/react': 'EmotionReact',
},
},
},

View File

@@ -15,15 +15,3 @@ The `postinstall` script patches a few packages and runs the Chakra CLI to gener
### Patch `@chakra-ui/cli`
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
### Patch `redux-persist`
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
### Patch `redux-deep-persist`
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.

View File

@@ -37,7 +37,7 @@ From `invokeai/frontend/web/` run `yarn install` to get everything set up.
Start everything in dev mode:
1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web`
2. Start the InvokeAI Nodes backend: `python scripts/invokeai-new.py --web # run from the repo root`
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
### Production builds

View File

@@ -21,7 +21,6 @@
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
@@ -63,11 +62,13 @@
"@dagrejs/graphlib": "^2.1.12",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@floating-ui/react-dom": "^2.0.0",
"@fontsource/inter": "^4.5.15",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"downshift": "^7.6.0",
"formik": "^2.2.9",
"framer-motion": "^10.12.4",
"fuse.js": "^6.6.2",
@@ -88,17 +89,14 @@
"react-i18next": "^12.2.2",
"react-icons": "^4.7.1",
"react-konva": "^18.2.7",
"react-konva-utils": "^1.0.4",
"react-redux": "^8.0.5",
"react-rnd": "^10.4.1",
"react-transition-group": "^4.4.5",
"react-resizable-panels": "^0.0.42",
"react-use": "^17.4.0",
"react-virtuoso": "^4.3.5",
"react-zoom-pan-pinch": "^3.0.7",
"reactflow": "^11.7.0",
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"redux-remember": "^3.3.1",
"roarr": "^7.15.0",
"serialize-error": "^11.0.0",
"socket.io-client": "^4.6.0",
@@ -118,6 +116,7 @@
"@types/node": "^18.16.2",
"@types/react": "^18.2.0",
"@types/react-dom": "^18.2.1",
"@types/react-redux": "^7.1.25",
"@types/react-transition-group": "^4.4.5",
"@types/uuid": "^9.0.0",
"@typescript-eslint/eslint-plugin": "^5.59.1",
@@ -143,6 +142,7 @@
"terser": "^5.17.1",
"ts-toolbelt": "^9.6.0",
"vite": "^4.3.3",
"vite-plugin-css-injected-by-js": "^3.1.1",
"vite-plugin-dts": "^2.3.0",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.2.0",

View File

@@ -1,24 +0,0 @@
diff --git a/node_modules/redux-deep-persist/lib/types.d.ts b/node_modules/redux-deep-persist/lib/types.d.ts
index b67b8c2..7fc0fa1 100644
--- a/node_modules/redux-deep-persist/lib/types.d.ts
+++ b/node_modules/redux-deep-persist/lib/types.d.ts
@@ -35,6 +35,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
whitelist?: Array<string>;
transforms?: Array<Transform<HSS, ESS, S, RS>>;
throttle?: number;
+ debounce?: number;
migrate?: PersistMigrate;
stateReconciler?: false | StateReconciler<S>;
getStoredState?: (config: PersistConfig<S, RS, HSS, ESS>) => Promise<PersistedState>;
diff --git a/node_modules/redux-deep-persist/src/types.ts b/node_modules/redux-deep-persist/src/types.ts
index 398ac19..cbc5663 100644
--- a/node_modules/redux-deep-persist/src/types.ts
+++ b/node_modules/redux-deep-persist/src/types.ts
@@ -91,6 +91,7 @@ export interface PersistConfig<S, RS = any, HSS = any, ESS = any> {
whitelist?: Array<string>;
transforms?: Array<Transform<HSS, ESS, S, RS>>;
throttle?: number;
+ debounce?: number;
migrate?: PersistMigrate;
stateReconciler?: false | StateReconciler<S>;
/**

View File

@@ -1,116 +0,0 @@
diff --git a/node_modules/redux-persist/es/createPersistoid.js b/node_modules/redux-persist/es/createPersistoid.js
index 8b43b9a..184faab 100644
--- a/node_modules/redux-persist/es/createPersistoid.js
+++ b/node_modules/redux-persist/es/createPersistoid.js
@@ -6,6 +6,7 @@ export default function createPersistoid(config) {
var whitelist = config.whitelist || null;
var transforms = config.transforms || [];
var throttle = config.throttle || 0;
+ var debounce = config.debounce || 0;
var storageKey = "".concat(config.keyPrefix !== undefined ? config.keyPrefix : KEY_PREFIX).concat(config.key);
var storage = config.storage;
var serialize;
@@ -28,30 +29,37 @@ export default function createPersistoid(config) {
var timeIterator = null;
var writePromise = null;
- var update = function update(state) {
- // add any changed keys to the queue
- Object.keys(state).forEach(function (key) {
- if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
+ // Timer for debounced `update()`
+ let timer = 0;
- if (lastState[key] === state[key]) return; // value unchanged? noop
+ function update(state) {
+ // Debounce the update
+ clearTimeout(timer);
+ timer = setTimeout(() => {
+ // add any changed keys to the queue
+ Object.keys(state).forEach(function (key) {
+ if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop
- if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
+ if (lastState[key] === state[key]) return; // value unchanged? noop
- keysToProcess.push(key); // add key to queue
- }); //if any key is missing in the new state which was present in the lastState,
- //add it for processing too
+ if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop
- Object.keys(lastState).forEach(function (key) {
- if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
- keysToProcess.push(key);
- }
- }); // start the time iterator if not running (read: throttle)
+ keysToProcess.push(key); // add key to queue
+ }); //if any key is missing in the new state which was present in the lastState,
+ //add it for processing too
- if (timeIterator === null) {
- timeIterator = setInterval(processNextKey, throttle);
- }
+ Object.keys(lastState).forEach(function (key) {
+ if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) {
+ keysToProcess.push(key);
+ }
+ }); // start the time iterator if not running (read: throttle)
+
+ if (timeIterator === null) {
+ timeIterator = setInterval(processNextKey, throttle);
+ }
- lastState = state;
+ lastState = state;
+ }, debounce)
};
function processNextKey() {
diff --git a/node_modules/redux-persist/es/types.js.flow b/node_modules/redux-persist/es/types.js.flow
index c50d3cd..39d8be2 100644
--- a/node_modules/redux-persist/es/types.js.flow
+++ b/node_modules/redux-persist/es/types.js.flow
@@ -19,6 +19,7 @@ export type PersistConfig = {
whitelist?: Array<string>,
transforms?: Array<Transform>,
throttle?: number,
+ debounce?: number,
migrate?: (PersistedState, number) => Promise<PersistedState>,
stateReconciler?: false | Function,
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
diff --git a/node_modules/redux-persist/lib/types.js.flow b/node_modules/redux-persist/lib/types.js.flow
index c50d3cd..39d8be2 100644
--- a/node_modules/redux-persist/lib/types.js.flow
+++ b/node_modules/redux-persist/lib/types.js.flow
@@ -19,6 +19,7 @@ export type PersistConfig = {
whitelist?: Array<string>,
transforms?: Array<Transform>,
throttle?: number,
+ debounce?: number,
migrate?: (PersistedState, number) => Promise<PersistedState>,
stateReconciler?: false | Function,
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
diff --git a/node_modules/redux-persist/src/types.js b/node_modules/redux-persist/src/types.js
index c50d3cd..39d8be2 100644
--- a/node_modules/redux-persist/src/types.js
+++ b/node_modules/redux-persist/src/types.js
@@ -19,6 +19,7 @@ export type PersistConfig = {
whitelist?: Array<string>,
transforms?: Array<Transform>,
throttle?: number,
+ debounce?: number,
migrate?: (PersistedState, number) => Promise<PersistedState>,
stateReconciler?: false | Function,
getStoredState?: PersistConfig => Promise<PersistedState>, // used for migrations
diff --git a/node_modules/redux-persist/types/types.d.ts b/node_modules/redux-persist/types/types.d.ts
index b3733bc..2a1696c 100644
--- a/node_modules/redux-persist/types/types.d.ts
+++ b/node_modules/redux-persist/types/types.d.ts
@@ -35,6 +35,7 @@ declare module "redux-persist/es/types" {
whitelist?: Array<string>;
transforms?: Array<Transform<HSS, ESS, S, RS>>;
throttle?: number;
+ debounce?: number;
migrate?: PersistMigrate;
stateReconciler?: false | StateReconciler<S>;
/**

View File

@@ -25,7 +25,7 @@
"common": {
"hotkeysLabel": "Hotkeys",
"themeLabel": "Theme",
"languagePickerLabel": "Language Picker",
"languagePickerLabel": "Language",
"reportBugLabel": "Report Bug",
"githubLabel": "Github",
"discordLabel": "Discord",
@@ -54,7 +54,7 @@
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Nodes",
"nodes": "Node Editor",
"postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
"postProcessing": "Post Processing",
@@ -102,7 +102,8 @@
"generate": "Generate",
"openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?"
"areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt"
},
"gallery": {
"generations": "Generations",
@@ -449,13 +450,14 @@
"cfgScale": "CFG Scale",
"width": "Width",
"height": "Height",
"sampler": "Sampler",
"scheduler": "Scheduler",
"seed": "Seed",
"imageToImage": "Image to Image",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle",
"shuffle": "Shuffle Seed",
"noiseThreshold": "Noise Threshold",
"perlinNoise": "Perlin Noise",
"noiseSettings": "Noise",
"variations": "Variations",
"variationAmount": "Variation Amount",
"seedWeights": "Seed Weights",
@@ -470,6 +472,8 @@
"scale": "Scale",
"otherOptions": "Other Options",
"seamlessTiling": "Seamless Tiling",
"seamlessXAxis": "X Axis",
"seamlessYAxis": "Y Axis",
"hiresOptim": "High Res Optimization",
"hiresStrength": "High Res Strength",
"imageFit": "Fit Initial Image To Output Size",
@@ -527,7 +531,8 @@
"useCanvasBeta": "Use Canvas Beta Layout",
"enableImageDebugging": "Enable Image Debugging",
"useSlidersForAll": "Use Sliders For All Options",
"autoShowProgress": "Auto Show Progress Images",
"showProgressInViewer": "Show Progress Images in Viewer",
"antialiasProgressImages": "Antialias Progress Images",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
@@ -535,7 +540,10 @@
"consoleLogLevel": "Log Level",
"shouldLogToConsole": "Console Logging",
"developer": "Developer",
"general": "General"
"general": "General",
"generation": "Generation",
"ui": "User Interface",
"availableSchedulers": "Available Schedulers"
},
"toast": {
"serverError": "Server Error",
@@ -544,13 +552,14 @@
"canceled": "Processing Canceled",
"tempFoldersEmptied": "Temp Folder Emptied",
"uploadFailed": "Upload failed",
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
"uploadFailedUnableToLoadDesc": "Unable to load file",
"uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image",
"downloadImageStarted": "Image Download Started",
"imageCopied": "Image Copied",
"imageLinkCopied": "Image Link Copied",
"problemCopyingImageLink": "Unable to Copy Image Link",
"imageNotLoaded": "No Image Loaded",
"imageNotLoadedDesc": "No image found to send to image to image module",
"imageNotLoadedDesc": "Could not find image",
"imageSavedToGallery": "Image Saved to Gallery",
"canvasMerged": "Canvas Merged",
"sentToImageToImage": "Sent To Image To Image",
@@ -645,7 +654,8 @@
"betaClear": "Clear",
"betaDarkenOutside": "Darken Outside",
"betaLimitToBox": "Limit To Box",
"betaPreserveMasked": "Preserve Masked"
"betaPreserveMasked": "Preserve Masked",
"antialiasing": "Antialiasing"
},
"ui": {
"showProgressImages": "Show Progress Images",

View File

@@ -1,46 +1,44 @@
import ImageUploader from 'common/components/ImageUploader';
import ProgressBar from 'features/system/components/ProgressBar';
import SiteHeader from 'features/system/components/SiteHeader';
import ProgressBar from 'features/system/components/ProgressBar';
import InvokeTabs from 'features/ui/components/InvokeTabs';
import useToastWatcher from 'features/system/hooks/useToastWatcher';
import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton';
import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons';
import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { Box, Flex, Grid, Portal } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import GalleryDrawer from 'features/gallery/components/GalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
memo,
PropsWithChildren,
useCallback,
useEffect,
useState,
} from 'react';
import { memo, ReactNode, useCallback, useEffect, useState } from 'react';
import { motion, AnimatePresence } from 'framer-motion';
import Loading from 'common/components/Loading/Loading';
import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady';
import { PartialAppConfig } from 'app/types/invokeai';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { configChanged } from 'features/system/store/configSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useLogger } from 'app/logging/useLogger';
import ProgressImagePreview from 'features/parameters/components/ProgressImagePreview';
import ParametersDrawer from 'features/ui/components/ParametersDrawer';
import { languageSelector } from 'features/system/store/systemSelectors';
import i18n from 'i18n';
import Toaster from './Toaster';
import GlobalHotkeys from './GlobalHotkeys';
const DEFAULT_CONFIG = {};
interface Props extends PropsWithChildren {
interface Props {
config?: PartialAppConfig;
headerComponent?: ReactNode;
setIsReady?: (isReady: boolean) => void;
}
const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
useToastWatcher();
useGlobalHotkeys();
const log = useLogger();
const App = ({
config = DEFAULT_CONFIG,
headerComponent,
setIsReady,
}: Props) => {
const language = useAppSelector(languageSelector);
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
const log = useLogger();
const isLightboxEnabled = useFeatureStatus('lightbox').isFeatureEnabled;
@@ -48,81 +46,95 @@ const App = ({ config = DEFAULT_CONFIG, children }: Props) => {
const [loadingOverridden, setLoadingOverridden] = useState(false);
const { setColorMode } = useColorMode();
const dispatch = useAppDispatch();
useEffect(() => {
i18n.changeLanguage(language);
}, [language]);
useEffect(() => {
log.info({ namespace: 'App', data: config }, 'Received config');
dispatch(configChanged(config));
}, [dispatch, config, log]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
}, [setColorMode, currentTheme]);
const handleOverrideClicked = useCallback(() => {
setLoadingOverridden(true);
}, []);
useEffect(() => {
if (isApplicationReady && setIsReady) {
setIsReady(true);
}
return () => {
setIsReady && setIsReady(false);
};
}, [isApplicationReady, setIsReady]);
return (
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
{isLightboxEnabled && <Lightbox />}
<ImageUploader>
<ProgressBar />
<Grid
gap={4}
p={4}
gridAutoRows="min-content auto"
w={APP_WIDTH}
h={APP_HEIGHT}
>
{children || <SiteHeader />}
<Flex
<>
<Grid w="100vw" h="100vh" position="relative" overflow="hidden">
{isLightboxEnabled && <Lightbox />}
<ImageUploader>
<ProgressBar />
<Grid
gap={4}
w={{ base: '100vw', xl: 'full' }}
h="full"
flexDir={{ base: 'column', xl: 'row' }}
p={4}
gridAutoRows="min-content auto"
w={APP_WIDTH}
h={APP_HEIGHT}
>
<InvokeTabs />
<ImageGalleryPanel />
</Flex>
</Grid>
</ImageUploader>
{headerComponent || <SiteHeader />}
<Flex
gap={4}
w={{ base: '100vw', xl: 'full' }}
h="full"
flexDir={{ base: 'column', xl: 'row' }}
>
<InvokeTabs />
</Flex>
</Grid>
</ImageUploader>
<AnimatePresence>
{!isApplicationReady && !loadingOverridden && (
<motion.div
key="loading"
initial={{ opacity: 1 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.3 }}
style={{ zIndex: 3 }}
>
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
<Loading />
</Box>
<Box
onClick={handleOverrideClicked}
position="absolute"
top={0}
right={0}
cursor="pointer"
w="2rem"
h="2rem"
/>
</motion.div>
)}
</AnimatePresence>
<GalleryDrawer />
<ParametersDrawer />
<Portal>
<FloatingParametersPanelButtons />
</Portal>
<Portal>
<FloatingGalleryButton />
</Portal>
<ProgressImagePreview />
</Grid>
<AnimatePresence>
{!isApplicationReady && !loadingOverridden && (
<motion.div
key="loading"
initial={{ opacity: 1 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
transition={{ duration: 0.3 }}
style={{ zIndex: 3 }}
>
<Box position="absolute" top={0} left={0} w="100vw" h="100vh">
<Loading />
</Box>
<Box
onClick={handleOverrideClicked}
position="absolute"
top={0}
right={0}
cursor="pointer"
w="2rem"
h="2rem"
/>
</motion.div>
)}
</AnimatePresence>
<Portal>
<FloatingParametersPanelButtons />
</Portal>
<Portal>
<FloatingGalleryButton />
</Portal>
</Grid>
<Toaster />
<GlobalHotkeys />
</>
);
};

View File

@@ -0,0 +1,44 @@
import { Flex, Spinner, Tooltip } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { systemSelector } from 'features/system/store/systemSelectors';
import { memo } from 'react';
const selector = createSelector(systemSelector, (system) => {
const { isUploading } = system;
let tooltip = '';
if (isUploading) {
tooltip = 'Uploading...';
}
return {
tooltip,
shouldShow: isUploading,
};
});
export const AuxiliaryProgressIndicator = () => {
const { shouldShow, tooltip } = useAppSelector(selector);
if (!shouldShow) {
return null;
}
return (
<Flex
sx={{
alignItems: 'center',
justifyContent: 'center',
color: 'base.600',
}}
>
<Tooltip label={tooltip} placement="right" hasArrow>
<Spinner />
</Tooltip>
</Flex>
);
};
export default memo(AuxiliaryProgressIndicator);

View File

@@ -2,7 +2,15 @@ import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice';
import {
setActiveTab,
toggleGalleryPanel,
toggleParametersPanel,
togglePinGalleryPanel,
togglePinParametersPanel,
} from 'features/ui/store/uiSlice';
import { isEqual } from 'lodash-es';
import React, { memo } from 'react';
import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook';
const globalHotkeysSelector = createSelector(
@@ -20,7 +28,11 @@ const globalHotkeysSelector = createSelector(
// TODO: Does not catch keypresses while focused in an input. Maybe there is a way?
export const useGlobalHotkeys = () => {
/**
* Logical component. Handles app-level global hotkeys.
* @returns null
*/
const GlobalHotkeys: React.FC = () => {
const dispatch = useAppDispatch();
const { shift } = useAppSelector(globalHotkeysSelector);
@@ -36,4 +48,40 @@ export const useGlobalHotkeys = () => {
{ keyup: true, keydown: true },
[shift]
);
useHotkeys('o', () => {
dispatch(toggleParametersPanel());
});
useHotkeys(['shift+o'], () => {
dispatch(togglePinParametersPanel());
});
useHotkeys('g', () => {
dispatch(toggleGalleryPanel());
});
useHotkeys(['shift+g'], () => {
dispatch(togglePinGalleryPanel());
});
useHotkeys('1', () => {
dispatch(setActiveTab('txt2img'));
});
useHotkeys('2', () => {
dispatch(setActiveTab('img2img'));
});
useHotkeys('3', () => {
dispatch(setActiveTab('unifiedCanvas'));
});
useHotkeys('4', () => {
dispatch(setActiveTab('nodes'));
});
return null;
};
export default memo(GlobalHotkeys);

View File

@@ -1,18 +1,13 @@
import React, { lazy, memo, PropsWithChildren, useEffect } from 'react';
import React, {
lazy,
memo,
PropsWithChildren,
ReactNode,
useEffect,
} from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from 'app/store/store';
import { persistor } from '../store/persistor';
import { OpenAPI } from 'services/api';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
import '@fontsource/inter/400.css';
import '@fontsource/inter/500.css';
import '@fontsource/inter/600.css';
import '@fontsource/inter/700.css';
import '@fontsource/inter/800.css';
import '@fontsource/inter/900.css';
import Loading from '../../common/components/Loading/Loading';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
@@ -28,9 +23,17 @@ interface Props extends PropsWithChildren {
apiUrl?: string;
token?: string;
config?: PartialAppConfig;
headerComponent?: ReactNode;
setIsReady?: (isReady: boolean) => void;
}
const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
const InvokeAIUI = ({
apiUrl,
token,
config,
headerComponent,
setIsReady,
}: Props) => {
useEffect(() => {
// configure API client token
if (token) {
@@ -57,13 +60,15 @@ const InvokeAIUI = ({ apiUrl, token, config, children }: Props) => {
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App config={config}>{children}</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>
<React.Suspense fallback={<Loading />}>
<ThemeLocaleProvider>
<App
config={config}
headerComponent={headerComponent}
setIsReady={setIsReady}
/>
</ThemeLocaleProvider>
</React.Suspense>
</Provider>
</React.StrictMode>
);

View File

@@ -1,4 +1,8 @@
import { ChakraProvider, extendTheme } from '@chakra-ui/react';
import {
ChakraProvider,
createLocalStorageManager,
extendTheme,
} from '@chakra-ui/react';
import { ReactNode, useEffect } from 'react';
import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme';
@@ -9,15 +13,8 @@ import { greenTeaThemeColors } from 'theme/colors/greenTea';
import { invokeAIThemeColors } from 'theme/colors/invokeAI';
import { lightThemeColors } from 'theme/colors/lightTheme';
import { oceanBlueColors } from 'theme/colors/oceanBlue';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
import '@fontsource/inter/400.css';
import '@fontsource/inter/500.css';
import '@fontsource/inter/600.css';
import '@fontsource/inter/700.css';
import '@fontsource/inter/800.css';
import '@fontsource/inter/900.css';
import '@fontsource/inter/variable.css';
import 'overlayscrollbars/overlayscrollbars.css';
import 'theme/css/overlayscrollbars.css';
@@ -32,6 +29,8 @@ const THEMES = {
ocean: oceanBlueColors,
};
const manager = createLocalStorageManager('@@invokeai-color-mode');
function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
const { i18n } = useTranslation();
@@ -51,7 +50,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
document.body.dir = direction;
}, [direction]);
return <ChakraProvider theme={theme}>{children}</ChakraProvider>;
return (
<ChakraProvider theme={theme} colorModeManager={manager}>
{children}
</ChakraProvider>
);
}
export default ThemeLocaleProvider;

View File

@@ -0,0 +1,65 @@
import { useToast, UseToastOptions } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { toastQueueSelector } from 'features/system/store/systemSelectors';
import { addToast, clearToastQueue } from 'features/system/store/systemSlice';
import { useCallback, useEffect } from 'react';
export type MakeToastArg = string | UseToastOptions;
/**
* Makes a toast from a string or a UseToastOptions object.
* If a string is passed, the toast will have the status 'info' and will be closable with a duration of 2500ms.
*/
export const makeToast = (arg: MakeToastArg): UseToastOptions => {
if (typeof arg === 'string') {
return {
title: arg,
status: 'info',
isClosable: true,
duration: 2500,
};
}
return { status: 'info', isClosable: true, duration: 2500, ...arg };
};
/**
* Logical component. Watches the toast queue and makes toasts when the queue is not empty.
* @returns null
*/
const Toaster = () => {
const dispatch = useAppDispatch();
const toastQueue = useAppSelector(toastQueueSelector);
const toast = useToast();
useEffect(() => {
toastQueue.forEach((t) => {
toast(t);
});
toastQueue.length > 0 && dispatch(clearToastQueue());
}, [dispatch, toast, toastQueue]);
return null;
};
/**
* Returns a function that can be used to make a toast.
* @example
* const toaster = useAppToaster();
* toaster('Hello world!');
* toaster({ title: 'Hello world!', status: 'success' });
* @returns A function that can be used to make a toast.
* @see makeToast
* @see MakeToastArg
* @see UseToastOptions
*/
export const useAppToaster = () => {
const dispatch = useAppDispatch();
const toaster = useCallback(
(arg: MakeToastArg) => dispatch(addToast(makeToast(arg))),
[dispatch]
);
return toaster;
};
export default Toaster;

View File

@@ -1,17 +1,28 @@
// TODO: use Enums?
export const DIFFUSERS_SCHEDULERS: Array<string> = [
export const SCHEDULERS = [
'ddim',
'plms',
'k_lms',
'dpmpp_2',
'k_dpm_2',
'k_dpm_2_a',
'k_dpmpp_2',
'k_euler',
'k_euler_a',
'k_heun',
];
'lms',
'euler',
'euler_k',
'euler_a',
'dpmpp_2s',
'dpmpp_2m',
'dpmpp_2m_k',
'kdpm_2',
'kdpm_2_a',
'deis',
'ddpm',
'pndm',
'heun',
'heun_k',
'unipc',
] as const;
export type Scheduler = (typeof SCHEDULERS)[number];
export const isScheduler = (x: string): x is Scheduler =>
SCHEDULERS.includes(x as Scheduler);
// Valid image widths
export const WIDTHS: Array<number> = Array.from(Array(64)).map(

View File

@@ -1,26 +1,20 @@
import { createSelector } from '@reduxjs/toolkit';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { validateSeedWeights } from 'common/util/seedWeightPairs';
import { initialCanvasImageSelector } from 'features/canvas/store/canvasSelectors';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash-es';
export const readinessSelector = createSelector(
[
generationSelector,
systemSelector,
initialCanvasImageSelector,
activeTabNameSelector,
],
(generation, system, initialCanvasImage, activeTabName) => {
[generationSelector, systemSelector, activeTabNameSelector],
(generation, system, activeTabName) => {
const {
prompt,
shouldGenerateVariations,
seedWeights,
initialImage,
seed,
isImageToImageEnabled,
} = generation;
const { isProcessing, isConnected } = system;
@@ -34,7 +28,7 @@ export const readinessSelector = createSelector(
reasonsWhyNotReady.push('Missing prompt');
}
if (isImageToImageEnabled && !initialImage) {
if (activeTabName === 'img2img' && !initialImage) {
isReady = false;
reasonsWhyNotReady.push('No initial image selected');
}
@@ -64,10 +58,5 @@ export const readinessSelector = createSelector(
// All good
return { isReady, reasonsWhyNotReady };
},
{
memoizeOptions: {
equalityCheck: isEqual,
resultEqualityCheck: isEqual,
},
}
defaultSelectorOptions
);

View File

@@ -1,209 +1,209 @@
// import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
// import * as InvokeAI from 'app/types/invokeai';
// import type { RootState } from 'app/store/store';
// import {
// frontendToBackendParameters,
// FrontendToBackendParametersConfig,
// } from 'common/util/parameterTranslation';
// import dateFormat from 'dateformat';
// import {
// GalleryCategory,
// GalleryState,
// removeImage,
// } from 'features/gallery/store/gallerySlice';
// import {
// generationRequested,
// modelChangeRequested,
// modelConvertRequested,
// modelMergingRequested,
// setIsProcessing,
// } from 'features/system/store/systemSlice';
// import { InvokeTabName } from 'features/ui/store/tabMap';
// import { Socket } from 'socket.io-client';
import { AnyAction, Dispatch, MiddlewareAPI } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import type { RootState } from 'app/store/store';
import {
frontendToBackendParameters,
FrontendToBackendParametersConfig,
} from 'common/util/parameterTranslation';
import dateFormat from 'dateformat';
import {
GalleryCategory,
GalleryState,
removeImage,
} from 'features/gallery/store/gallerySlice';
import {
generationRequested,
modelChangeRequested,
modelConvertRequested,
modelMergingRequested,
setIsProcessing,
} from 'features/system/store/systemSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { Socket } from 'socket.io-client';
// /**
// * Returns an object containing all functions which use `socketio.emit()`.
// * i.e. those which make server requests.
// */
// const makeSocketIOEmitters = (
// store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
// socketio: Socket
// ) => {
// // We need to dispatch actions to redux and get pieces of state from the store.
// const { dispatch, getState } = store;
/**
* Returns an object containing all functions which use `socketio.emit()`.
* i.e. those which make server requests.
*/
const makeSocketIOEmitters = (
store: MiddlewareAPI<Dispatch<AnyAction>, RootState>,
socketio: Socket
) => {
// We need to dispatch actions to redux and get pieces of state from the store.
const { dispatch, getState } = store;
// return {
// emitGenerateImage: (generationMode: InvokeTabName) => {
// dispatch(setIsProcessing(true));
return {
emitGenerateImage: (generationMode: InvokeTabName) => {
dispatch(setIsProcessing(true));
// const state: RootState = getState();
const state: RootState = getState();
// const {
// generation: generationState,
// postprocessing: postprocessingState,
// system: systemState,
// canvas: canvasState,
// } = state;
const {
generation: generationState,
postprocessing: postprocessingState,
system: systemState,
canvas: canvasState,
} = state;
// const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
// {
// generationMode,
// generationState,
// postprocessingState,
// canvasState,
// systemState,
// };
const frontendToBackendParametersConfig: FrontendToBackendParametersConfig =
{
generationMode,
generationState,
postprocessingState,
canvasState,
systemState,
};
// dispatch(generationRequested());
dispatch(generationRequested());
// const { generationParameters, esrganParameters, facetoolParameters } =
// frontendToBackendParameters(frontendToBackendParametersConfig);
const { generationParameters, esrganParameters, facetoolParameters } =
frontendToBackendParameters(frontendToBackendParametersConfig);
// socketio.emit(
// 'generateImage',
// generationParameters,
// esrganParameters,
// facetoolParameters
// );
socketio.emit(
'generateImage',
generationParameters,
esrganParameters,
facetoolParameters
);
// // we need to truncate the init_mask base64 else it takes up the whole log
// // TODO: handle maintaining masks for reproducibility in future
// if (generationParameters.init_mask) {
// generationParameters.init_mask = generationParameters.init_mask
// .substr(0, 64)
// .concat('...');
// }
// if (generationParameters.init_img) {
// generationParameters.init_img = generationParameters.init_img
// .substr(0, 64)
// .concat('...');
// }
// we need to truncate the init_mask base64 else it takes up the whole log
// TODO: handle maintaining masks for reproducibility in future
if (generationParameters.init_mask) {
generationParameters.init_mask = generationParameters.init_mask
.substr(0, 64)
.concat('...');
}
if (generationParameters.init_img) {
generationParameters.init_img = generationParameters.init_img
.substr(0, 64)
.concat('...');
}
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Image generation requested: ${JSON.stringify({
// ...generationParameters,
// ...esrganParameters,
// ...facetoolParameters,
// })}`,
// })
// );
// },
// emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Image generation requested: ${JSON.stringify({
...generationParameters,
...esrganParameters,
...facetoolParameters,
})}`,
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const {
// postprocessing: {
// upscalingLevel,
// upscalingDenoising,
// upscalingStrength,
// },
// } = getState();
const {
postprocessing: {
upscalingLevel,
upscalingDenoising,
upscalingStrength,
},
} = getState();
// const esrganParameters = {
// upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
// };
// socketio.emit('runPostprocessing', imageToProcess, {
// type: 'esrgan',
// ...esrganParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `ESRGAN upscale requested: ${JSON.stringify({
// file: imageToProcess.url,
// ...esrganParameters,
// })}`,
// })
// );
// },
// emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
// dispatch(setIsProcessing(true));
const esrganParameters = {
upscale: [upscalingLevel, upscalingDenoising, upscalingStrength],
};
socketio.emit('runPostprocessing', imageToProcess, {
type: 'esrgan',
...esrganParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `ESRGAN upscale requested: ${JSON.stringify({
file: imageToProcess.url,
...esrganParameters,
})}`,
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
// const {
// postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
// } = getState();
const {
postprocessing: { facetoolType, facetoolStrength, codeformerFidelity },
} = getState();
// const facetoolParameters: Record<string, unknown> = {
// facetool_strength: facetoolStrength,
// };
const facetoolParameters: Record<string, unknown> = {
facetool_strength: facetoolStrength,
};
// if (facetoolType === 'codeformer') {
// facetoolParameters.codeformer_fidelity = codeformerFidelity;
// }
if (facetoolType === 'codeformer') {
facetoolParameters.codeformer_fidelity = codeformerFidelity;
}
// socketio.emit('runPostprocessing', imageToProcess, {
// type: facetoolType,
// ...facetoolParameters,
// });
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
// {
// file: imageToProcess.url,
// ...facetoolParameters,
// }
// )}`,
// })
// );
// },
// emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
// const { url, uuid, category, thumbnail } = imageToDelete;
// dispatch(removeImage(imageToDelete));
// socketio.emit('deleteImage', url, thumbnail, uuid, category);
// },
// emitRequestImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { earliest_mtime } = gallery.categories[category];
// socketio.emit('requestImages', category, earliest_mtime);
// },
// emitRequestNewImages: (category: GalleryCategory) => {
// const gallery: GalleryState = getState().gallery;
// const { latest_mtime } = gallery.categories[category];
// socketio.emit('requestLatestImages', category, latest_mtime);
// },
// emitCancelProcessing: () => {
// socketio.emit('cancel');
// },
// emitRequestSystemConfig: () => {
// socketio.emit('requestSystemConfig');
// },
// emitSearchForModels: (modelFolder: string) => {
// socketio.emit('searchForModels', modelFolder);
// },
// emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
// socketio.emit('addNewModel', modelConfig);
// },
// emitDeleteModel: (modelName: string) => {
// socketio.emit('deleteModel', modelName);
// },
// emitConvertToDiffusers: (
// modelToConvert: InvokeAI.InvokeModelConversionProps
// ) => {
// dispatch(modelConvertRequested());
// socketio.emit('convertToDiffusers', modelToConvert);
// },
// emitMergeDiffusersModels: (
// modelMergeInfo: InvokeAI.InvokeModelMergingProps
// ) => {
// dispatch(modelMergingRequested());
// socketio.emit('mergeDiffusersModels', modelMergeInfo);
// },
// emitRequestModelChange: (modelName: string) => {
// dispatch(modelChangeRequested());
// socketio.emit('requestModelChange', modelName);
// },
// emitSaveStagingAreaImageToGallery: (url: string) => {
// socketio.emit('requestSaveStagingAreaImageToGallery', url);
// },
// emitRequestEmptyTempFolder: () => {
// socketio.emit('requestEmptyTempFolder');
// },
// };
// };
socketio.emit('runPostprocessing', imageToProcess, {
type: facetoolType,
...facetoolParameters,
});
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Face restoration (${facetoolType}) requested: ${JSON.stringify(
{
file: imageToProcess.url,
...facetoolParameters,
}
)}`,
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);
},
emitRequestImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { earliest_mtime } = gallery.categories[category];
socketio.emit('requestImages', category, earliest_mtime);
},
emitRequestNewImages: (category: GalleryCategory) => {
const gallery: GalleryState = getState().gallery;
const { latest_mtime } = gallery.categories[category];
socketio.emit('requestLatestImages', category, latest_mtime);
},
emitCancelProcessing: () => {
socketio.emit('cancel');
},
emitRequestSystemConfig: () => {
socketio.emit('requestSystemConfig');
},
emitSearchForModels: (modelFolder: string) => {
socketio.emit('searchForModels', modelFolder);
},
emitAddNewModel: (modelConfig: InvokeAI.InvokeModelConfigProps) => {
socketio.emit('addNewModel', modelConfig);
},
emitDeleteModel: (modelName: string) => {
socketio.emit('deleteModel', modelName);
},
emitConvertToDiffusers: (
modelToConvert: InvokeAI.InvokeModelConversionProps
) => {
dispatch(modelConvertRequested());
socketio.emit('convertToDiffusers', modelToConvert);
},
emitMergeDiffusersModels: (
modelMergeInfo: InvokeAI.InvokeModelMergingProps
) => {
dispatch(modelMergingRequested());
socketio.emit('mergeDiffusersModels', modelMergeInfo);
},
emitRequestModelChange: (modelName: string) => {
dispatch(modelChangeRequested());
socketio.emit('requestModelChange', modelName);
},
emitSaveStagingAreaImageToGallery: (url: string) => {
socketio.emit('requestSaveStagingAreaImageToGallery', url);
},
emitRequestEmptyTempFolder: () => {
socketio.emit('requestEmptyTempFolder');
},
};
};
// export default makeSocketIOEmitters;
export default makeSocketIOEmitters;
export default {};

View File

@@ -0,0 +1,4 @@
import { createAction } from '@reduxjs/toolkit';
import { InvokeTabName } from 'features/ui/store/tabMap';
export const userInvoked = createAction<InvokeTabName>('app/userInvoked');

View File

@@ -0,0 +1,8 @@
export const LOCALSTORAGE_KEYS = [
'chakra-ui-color-mode',
'i18nextLng',
'ROARR_FILTER',
'ROARR_LOG',
];
export const LOCALSTORAGE_PREFIX = '@@invokeai-';

View File

@@ -0,0 +1,36 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { resultsPersistDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsPersistDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { lightboxPersistDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { modelsPersistDenylist } from 'features/system/store/modelsPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es';
import { SerializeFunction } from 'redux-remember';
const serializationDenylist: {
[key: string]: string[];
} = {
canvas: canvasPersistDenylist,
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
lightbox: lightboxPersistDenylist,
models: modelsPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
results: resultsPersistDenylist,
system: systemPersistDenylist,
// config: configPersistDenyList,
ui: uiPersistDenylist,
uploads: uploadsPersistDenylist,
// hotkeys: hotkeysPersistDenylist,
};
export const serialize: SerializeFunction = (data, key) => {
const result = omit(data, serializationDenylist[key]);
return JSON.stringify(result);
};

View File

@@ -0,0 +1,38 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialResultsState } from 'features/gallery/store/resultsSlice';
import { initialUploadsState } from 'features/gallery/store/uploadsSlice';
import { initialLightboxState } from 'features/lightbox/store/lightboxSlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialConfigState } from 'features/system/store/configSlice';
import { initialModelsState } from 'features/system/store/modelSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialHotkeysState } from 'features/ui/store/hotkeysSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
import { defaultsDeep } from 'lodash-es';
import { UnserializeFunction } from 'redux-remember';
const initialStates: {
[key: string]: any;
} = {
canvas: initialCanvasState,
gallery: initialGalleryState,
generation: initialGenerationState,
lightbox: initialLightboxState,
models: initialModelsState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
results: initialResultsState,
system: initialSystemState,
config: initialConfigState,
ui: initialUIState,
uploads: initialUploadsState,
hotkeys: initialHotkeysState,
};
export const unserialize: UnserializeFunction = (data, key) => {
const result = defaultsDeep(JSON.parse(data), initialStates[key]);
return result;
};

View File

@@ -0,0 +1,30 @@
import { AnyAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { forEach } from 'lodash-es';
import { Graph } from 'services/api';
export const actionSanitizer = <A extends AnyAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
if (action.payload.nodes) {
const sanitizedNodes: Graph['nodes'] = {};
// Sanitize nodes as needed
forEach(action.payload.nodes, (node, key) => {
// Don't log the whole freaking dataURL
if (node.type === 'dataURL_image') {
const { dataURL, ...rest } = node;
sanitizedNodes[key] = { ...rest, dataURL: '<dataURL>' };
} else {
sanitizedNodes[key] = { ...node };
}
});
return {
...action,
payload: { ...action.payload, nodes: sanitizedNodes },
};
}
}
return action;
};

View File

@@ -0,0 +1,11 @@
export const actionsDenylist = [
'canvas/setCursorPosition',
'canvas/setStageCoordinates',
'canvas/setStageScale',
'canvas/setIsDrawing',
'canvas/setBoundingBoxCoordinates',
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
];

View File

@@ -0,0 +1,3 @@
export const stateSanitizer = <S>(state: S): S => {
return state;
};

View File

@@ -0,0 +1,54 @@
import {
createListenerMiddleware,
addListener,
ListenerEffect,
AnyAction,
} from '@reduxjs/toolkit';
import type { TypedStartListening, TypedAddListener } from '@reduxjs/toolkit';
import type { RootState, AppDispatch } from '../../store';
import { addInitialImageSelectedListener } from './listeners/initialImageSelected';
import { addImageResultReceivedListener } from './listeners/invocationComplete';
import { addImageUploadedListener } from './listeners/imageUploaded';
import { addRequestedImageDeletionListener } from './listeners/imageDeleted';
import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery';
import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage';
import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard';
import { addCanvasMergedListener } from './listeners/canvasMerged';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
export const startAppListening =
listenerMiddleware.startListening as AppStartListening;
export const addAppListener = addListener as TypedAddListener<
RootState,
AppDispatch
>;
export type AppListenerEffect = ListenerEffect<
AnyAction,
RootState,
AppDispatch
>;
addImageUploadedListener();
addInitialImageSelectedListener();
addImageResultReceivedListener();
addRequestedImageDeletionListener();
addUserInvokedCanvasListener();
addUserInvokedNodesListener();
addUserInvokedTextToImageListener();
addUserInvokedImageToImageListener();
addCanvasSavedToGalleryListener();
addCanvasDownloadedAsImageListener();
addCanvasCopiedToClipboardListener();
addCanvasMergedListener();

View File

@@ -0,0 +1,33 @@
import { canvasCopiedToClipboard } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard';
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
export const addCanvasCopiedToClipboardListener = () => {
startAppListening({
actionCreator: canvasCopiedToClipboard,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Copying Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
copyBlobToClipboard(blob);
},
});
};

View File

@@ -0,0 +1,33 @@
import { canvasDownloadedAsImage } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { downloadBlob } from 'features/canvas/util/downloadBlob';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasDownloadedAsImageListener = () => {
startAppListening({
actionCreator: canvasDownloadedAsImage,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Downloading Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
downloadBlob(blob, 'mergedCanvas.png');
},
});
};

View File

@@ -0,0 +1,88 @@
import { canvasMerged } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' });
export const addCanvasMergedListener = () => {
startAppListening({
actionCreator: canvasMerged,
effect: async (action, { dispatch, getState, take }) => {
const state = getState();
const blob = await getBaseLayerBlob(state, true);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Merging Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
const canvasBaseLayer = getCanvasBaseLayer();
if (!canvasBaseLayer) {
moduleLog.error('Problem getting canvas base layer');
dispatch(
addToast({
title: 'Problem Merging Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
const baseLayerRect = canvasBaseLayer.getClientRect({
relativeTo: canvasBaseLayer.getParent(),
});
const filename = `mergedCanvas_${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([blob], filename, { type: 'image/png' }),
},
})
);
const [{ payload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === filename
);
const mergedCanvasImage = deserializeImageResponse(payload.response);
dispatch(
setMergedCanvas({
kind: 'image',
layer: 'base',
image: mergedCanvasImage,
...baseLayerRect,
})
);
dispatch(
addToast({
title: 'Canvas Merged',
status: 'success',
})
);
},
});
};

View File

@@ -0,0 +1,40 @@
import { canvasSavedToGallery } from 'features/canvas/store/actions';
import { startAppListening } from '..';
import { log } from 'app/logging/useLogger';
import { imageUploaded } from 'services/thunks/image';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' });
export const addCanvasSavedToGalleryListener = () => {
startAppListening({
actionCreator: canvasSavedToGallery,
effect: async (action, { dispatch, getState }) => {
const state = getState();
const blob = await getBaseLayerBlob(state);
if (!blob) {
moduleLog.error('Problem getting base layer blob');
dispatch(
addToast({
title: 'Problem Saving Canvas',
description: 'Unable to export base layer',
status: 'error',
})
);
return;
}
dispatch(
imageUploaded({
imageType: 'results',
formData: {
file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }),
},
})
);
},
});
};

View File

@@ -0,0 +1,59 @@
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { startAppListening } from '..';
import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
export const addRequestedImageDeletionListener = () => {
startAppListening({
actionCreator: requestedImageDeletion,
effect: (action, { dispatch, getState }) => {
const image = action.payload;
if (!image) {
moduleLog.warn('No image provided');
return;
}
const { name, type } = image;
if (type !== 'uploads' && type !== 'results') {
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
return;
}
const selectedImageName = getState().gallery.selectedImage?.name;
if (selectedImageName === name) {
const allIds = getState()[type].ids;
const allEntities = getState()[type].entities;
const deletedImageIndex = allIds.findIndex(
(result) => result.toString() === name
);
const filteredIds = allIds.filter((id) => id.toString() !== name);
const newSelectedImageIndex = clamp(
deletedImageIndex,
0,
filteredIds.length - 1
);
const newSelectedImageId = filteredIds[newSelectedImageIndex];
const newSelectedImage = allEntities[newSelectedImageId];
if (newSelectedImageId) {
dispatch(imageSelected(newSelectedImage));
} else {
dispatch(imageSelected());
}
}
dispatch(imageDeleted({ imageName: name, imageType: type }));
},
});
};

View File

@@ -0,0 +1,46 @@
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { startAppListening } from '..';
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice';
import { initialImageSelected } from 'features/parameters/store/actions';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { resultAdded } from 'features/gallery/store/resultsSlice';
export const addImageUploadedListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.payload.response.image_type !== 'intermediates',
effect: (action, { dispatch, getState }) => {
const { response } = action.payload;
const { imageType } = action.meta.arg;
const state = getState();
const image = deserializeImageResponse(response);
if (imageType === 'uploads') {
dispatch(uploadAdded(image));
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (action.meta.arg.activeTabName === 'img2img') {
dispatch(initialImageSelected(image));
}
if (action.meta.arg.activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(image));
}
}
if (imageType === 'results') {
dispatch(resultAdded(image));
}
},
});
};

View File

@@ -0,0 +1,54 @@
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { Image, isInvokeAIImage } from 'app/types/invokeai';
import { selectResultsById } from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { initialImageSelected } from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster';
export const addInitialImageSelectedListener = () => {
startAppListening({
actionCreator: initialImageSelected,
effect: (action, { getState, dispatch }) => {
if (!action.payload) {
dispatch(
addToast(
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
)
);
return;
}
if (isInvokeAIImage(action.payload)) {
dispatch(initialImageChanged(action.payload));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
return;
}
const { name, type } = action.payload;
let image: Image | undefined;
const state = getState();
if (type === 'results') {
image = selectResultsById(state, name);
} else if (type === 'uploads') {
image = selectUploadsById(state, name);
}
if (!image) {
dispatch(
addToast(
makeToast({ title: t('toast.imageNotLoadedDesc'), status: 'error' })
)
);
return;
}
dispatch(initialImageChanged(image));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
},
});
};

View File

@@ -0,0 +1,88 @@
import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { Image } from 'app/types/invokeai';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
import { startAppListening } from '..';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
const nodeDenylist = ['dataURL_image'];
export const addImageResultReceivedListener = () => {
startAppListening({
predicate: (action) => {
if (
invocationComplete.match(action) &&
isImageOutput(action.payload.data.result)
) {
return true;
}
return false;
},
effect: (action, { getState, dispatch }) => {
if (!invocationComplete.match(action)) {
return;
}
const { data, shouldFetchImages } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const name = result.image.image_name;
const type = result.image.image_type;
const state = getState();
// if we need to refetch, set URLs to placeholder for now
const { url, thumbnail } = shouldFetchImages
? { url: '', thumbnail: '' }
: buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width,
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
dispatch(resultAdded(image));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (state.config.shouldFetchImages) {
dispatch(imageReceived({ imageName: name, imageType: type }));
dispatch(
thumbnailReceived({
thumbnailName: name,
thumbnailType: type,
})
);
}
if (
graph_execution_state_id ===
state.canvas.layerState.stagingArea.sessionId
) {
dispatch(addImageToStagingArea(image));
}
}
},
});
};

View File

@@ -0,0 +1,164 @@
import { startAppListening } from '..';
import { sessionCreated, sessionInvoked } from 'services/thunks/session';
import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { log } from 'app/logging/useLogger';
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { Graph } from 'services/api';
import {
canvasSessionIdChanged,
stagingAreaInitialized,
} from 'features/canvas/store/canvasSlice';
import { userInvoked } from 'app/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode';
import { blobToDataURL } from 'features/canvas/util/blobToDataURL';
import openBase64ImageInTab from 'common/util/openBase64ImageInTab';
const moduleLog = log.child({ namespace: 'invoke' });
/**
* This listener is responsible for building the canvas graph and blobs when the user invokes the canvas.
* It is also responsible for uploading the base and mask layers to the server.
*/
export const addUserInvokedCanvasListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'unifiedCanvas',
effect: async (action, { getState, dispatch, take }) => {
const state = getState();
// Build canvas blobs
const canvasBlobsAndImageData = await getCanvasData(state);
if (!canvasBlobsAndImageData) {
moduleLog.error('Unable to create canvas data');
return;
}
const { baseBlob, baseImageData, maskBlob, maskImageData } =
canvasBlobsAndImageData;
// Determine the generation mode
const generationMode = getCanvasGenerationMode(
baseImageData,
maskImageData
);
if (state.system.enableImageDebugging) {
const baseDataURL = await blobToDataURL(baseBlob);
const maskDataURL = await blobToDataURL(maskBlob);
openBase64ImageInTab([
{ base64: maskDataURL, caption: 'mask b64' },
{ base64: baseDataURL, caption: 'image b64' },
]);
}
moduleLog.debug(`Generation mode: ${generationMode}`);
// Build the canvas graph
const graphComponents = await buildCanvasGraphComponents(
state,
generationMode
);
if (!graphComponents) {
moduleLog.error('Problem building graph');
return;
}
const { rangeNode, iterateNode, baseNode, edges } = graphComponents;
// Upload the base layer, to be used as init image
const baseFilename = `${uuidv4()}.png`;
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([baseBlob], baseFilename, { type: 'image/png' }),
},
})
);
if (baseNode.type === 'img2img' || baseNode.type === 'inpaint') {
const [{ payload: basePayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === baseFilename
);
const { image_name: baseName, image_type: baseType } =
basePayload.response;
baseNode.image = {
image_name: baseName,
image_type: baseType,
};
}
// Upload the mask layer image
const maskFilename = `${uuidv4()}.png`;
if (baseNode.type === 'inpaint') {
dispatch(
imageUploaded({
imageType: 'intermediates',
formData: {
file: new File([maskBlob], maskFilename, { type: 'image/png' }),
},
})
);
const [{ payload: maskPayload }] = await take(
(action): action is ReturnType<typeof imageUploaded.fulfilled> =>
imageUploaded.fulfilled.match(action) &&
action.meta.arg.formData.file.name === maskFilename
);
const { image_name: maskName, image_type: maskType } =
maskPayload.response;
baseNode.mask = {
image_name: maskName,
image_type: maskType,
};
}
// Assemble!
const nodes: Graph['nodes'] = {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
};
const graph = { nodes, edges };
dispatch(canvasGraphBuilt(graph));
moduleLog({ data: graph }, 'Canvas graph built');
// Actually create the session
dispatch(sessionCreated({ graph }));
// Wait for the session to be invoked (this is just the HTTP request to start processing)
const [{ meta }] = await take(sessionInvoked.fulfilled.match);
const { sessionId } = meta.arg;
if (!state.canvas.layerState.stagingArea.boundingBox) {
dispatch(
stagingAreaInitialized({
sessionId,
boundingBox: {
...state.canvas.boundingBoxCoordinates,
...state.canvas.boundingBoxDimensions,
},
})
);
}
dispatch(canvasSessionIdChanged(sessionId));
},
});
};

View File

@@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { buildImageToImageGraph } from 'features/nodes/util/graphBuilders/buildImageToImageGraph';
import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { imageToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedImageToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'img2img',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildImageToImageGraph(state);
dispatch(imageToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Image to Image graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { sessionCreated } from 'services/thunks/session';
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
import { log } from 'app/logging/useLogger';
import { nodesGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedNodesListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'nodes',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildNodesGraph(state);
dispatch(nodesGraphBuilt(graph));
moduleLog({ data: graph }, 'Nodes graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@@ -0,0 +1,24 @@
import { startAppListening } from '..';
import { buildTextToImageGraph } from 'features/nodes/util/graphBuilders/buildTextToImageGraph';
import { sessionCreated } from 'services/thunks/session';
import { log } from 'app/logging/useLogger';
import { textToImageGraphBuilt } from 'features/nodes/store/actions';
import { userInvoked } from 'app/store/actions';
const moduleLog = log.child({ namespace: 'invoke' });
export const addUserInvokedTextToImageListener = () => {
startAppListening({
predicate: (action): action is ReturnType<typeof userInvoked> =>
userInvoked.match(action) && action.payload === 'txt2img',
effect: (action, { getState, dispatch }) => {
const state = getState();
const graph = buildTextToImageGraph(state);
dispatch(textToImageGraphBuilt(graph));
moduleLog({ data: graph }, 'Text to Image graph built');
dispatch(sessionCreated({ graph }));
},
});
};

View File

@@ -1,4 +0,0 @@
import { store } from 'app/store/store';
import { persistStore } from 'redux-persist';
export const persistor = persistStore(store);

View File

@@ -1,9 +1,12 @@
import { combineReducers, configureStore } from '@reduxjs/toolkit';
import {
AnyAction,
ThunkDispatch,
combineReducers,
configureStore,
} from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import { rememberReducer, rememberEnhancer } from 'redux-remember';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
@@ -19,33 +22,17 @@ import hotkeysReducer from 'features/ui/store/hotkeysSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import { canvasDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { galleryDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { generationDenylist } from 'features/parameters/store/generationPersistDenylist';
import { lightboxDenylist } from 'features/lightbox/store/lightboxPersistDenylist';
import { modelsDenylist } from 'features/system/store/modelsPersistDenylist';
import { nodesDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { postprocessingDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { systemDenylist } from 'features/system/store/systemPersistDenylist';
import { uiDenylist } from 'features/ui/store/uiPersistDenylist';
import { resultsDenylist } from 'features/gallery/store/resultsPersistDenylist';
import { uploadsDenylist } from 'features/gallery/store/uploadsPersistDenylist';
import { listenerMiddleware } from './middleware/listenerMiddleware';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
*
* While we definitely want generation parameters to be persisted, there are a number
* of things we do *not* want to be persisted across reloads:
* - Gallery/selected image (user may add/delete images from disk between page loads)
* - Connection/processing status
* - Availability of external libraries like ESRGAN/GFPGAN
*
* These can be denylisted in redux-persist.
*
* The necesssary nested persistors with denylists are configured below.
*/
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
const rootReducer = combineReducers({
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
import { LOCALSTORAGE_PREFIX } from './constants';
const allReducers = {
canvas: canvasReducer,
gallery: galleryReducer,
generation: generationReducer,
@@ -59,65 +46,54 @@ const rootReducer = combineReducers({
ui: uiReducer,
uploads: uploadsReducer,
hotkeys: hotkeysReducer,
});
};
const rootPersistConfig = getPersistConfig({
key: 'root',
storage,
rootReducer,
blacklist: [
...canvasDenylist,
...galleryDenylist,
...generationDenylist,
...lightboxDenylist,
...modelsDenylist,
...nodesDenylist,
...postprocessingDenylist,
// ...resultsDenylist,
'results',
...systemDenylist,
...uiDenylist,
// ...uploadsDenylist,
'uploads',
'hotkeys',
'config',
],
});
const rootReducer = combineReducers(allReducers);
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
const rememberedRootReducer = rememberReducer(rootReducer);
// TODO: rip the old middleware out when nodes is complete
// export function buildMiddleware() {
// if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
// return socketMiddleware();
// } else {
// return socketioMiddleware();
// }
// }
const rememberedKeys: (keyof typeof allReducers)[] = [
'canvas',
'gallery',
'generation',
'lightbox',
// 'models',
'nodes',
'postprocessing',
'system',
'ui',
// 'hotkeys',
// 'results',
// 'uploads',
// 'config',
];
export const store = configureStore({
reducer: persistedReducer,
reducer: rememberedRootReducer,
enhancers: [
rememberEnhancer(window.localStorage, rememberedKeys, {
persistDebounce: 300,
serialize,
unserialize,
prefix: LOCALSTORAGE_PREFIX,
}),
],
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
immutableCheck: false,
serializableCheck: false,
}).concat(dynamicMiddlewares),
})
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
actionsDenylist: [
'canvas/setCursorPosition',
'canvas/setStageCoordinates',
'canvas/setStageScale',
'canvas/setIsDrawing',
'canvas/setBoundingBoxCoordinates',
'canvas/setBoundingBoxDimensions',
'canvas/setIsDrawing',
'canvas/addPointToCurrentLine',
'socket/generatorProgress',
],
actionsDenylist,
actionSanitizer,
stateSanitizer,
trace: true,
},
});
export type AppGetState = typeof store.getState;
export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch;

View File

@@ -1,6 +1,6 @@
import { TypedUseSelectorHook, useDispatch, useSelector } from 'react-redux';
import { AppDispatch, RootState } from 'app/store/store';
import { AppThunkDispatch, RootState } from 'app/store/store';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch: () => AppDispatch = useDispatch;
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;

View File

@@ -0,0 +1,7 @@
import { isEqual } from 'lodash-es';
export const defaultSelectorOptions = {
memoizeOptions: {
resultEqualityCheck: isEqual,
},
};

View File

@@ -12,12 +12,10 @@
* 'gfpgan'.
*/
import { GalleryCategory } from 'features/gallery/store/gallerySlice';
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
import { SelectedImage } from 'features/parameters/store/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageResponseMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
import { O } from 'ts-toolbelt';
/**
@@ -49,15 +47,21 @@ export type CommonGeneratedImageMetadata = {
postprocessing: null | Array<ESRGANMetadata | FacetoolMetadata>;
sampler:
| 'ddim'
| 'k_dpm_2_a'
| 'k_dpm_2'
| 'k_dpmpp_2_a'
| 'k_dpmpp_2'
| 'k_euler_a'
| 'k_euler'
| 'k_heun'
| 'k_lms'
| 'plms';
| 'ddpm'
| 'deis'
| 'lms'
| 'pndm'
| 'heun'
| 'heun_k'
| 'euler'
| 'euler_k'
| 'euler_a'
| 'kdpm_2'
| 'kdpm_2_a'
| 'dpmpp_2s'
| 'dpmpp_2m'
| 'dpmpp_2m_k'
| 'unipc';
prompt: Prompt;
seed: number;
variations: SeedWeights;
@@ -126,6 +130,14 @@ export type Image = {
metadata: ImageResponseMetadata;
};
export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
if ('url' in obj && 'thumbnail' in obj) {
return true;
}
return false;
};
/**
* Types related to the system status.
*/
@@ -270,7 +282,7 @@ export type FoundModelResponse = {
// export type SystemConfigResponse = SystemConfig;
export type ImageResultResponse = Omit<_Image, 'uuid'> & {
export type ImageResultResponse = Omit<Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
@@ -315,11 +327,11 @@ export type AppFeature =
/**
* A disable-able Stable Diffusion feature
*/
export type StableDiffusionFeature =
| 'noiseConfig'
| 'variations'
export type SDFeature =
| 'noise'
| 'variation'
| 'symmetry'
| 'tiling'
| 'seamless'
| 'hires';
/**
@@ -337,6 +349,7 @@ export type AppConfig = {
shouldFetchImages: boolean;
disabledTabs: InvokeTabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];
canRestoreDeletedImagesFromBin: boolean;
sd: {
iterations: {

View File

@@ -0,0 +1,61 @@
import { ChevronUpIcon } from '@chakra-ui/icons';
import { Box, Collapse, Flex, Spacer, Switch } from '@chakra-ui/react';
import { PropsWithChildren, memo } from 'react';
export type IAIToggleCollapseProps = PropsWithChildren & {
label: string;
isOpen: boolean;
onToggle: () => void;
withSwitch?: boolean;
};
const IAICollapse = (props: IAIToggleCollapseProps) => {
const { label, isOpen, onToggle, children, withSwitch = false } = props;
return (
<Box>
<Flex
onClick={onToggle}
sx={{
alignItems: 'center',
p: 2,
px: 4,
borderTopRadius: 'base',
borderBottomRadius: isOpen ? 0 : 'base',
bg: isOpen ? 'base.750' : 'base.800',
color: 'base.100',
_hover: {
bg: isOpen ? 'base.700' : 'base.750',
},
fontSize: 'sm',
fontWeight: 600,
cursor: 'pointer',
transitionProperty: 'common',
transitionDuration: 'normal',
userSelect: 'none',
}}
>
{label}
<Spacer />
{withSwitch && <Switch isChecked={isOpen} pointerEvents="none" />}
{!withSwitch && (
<ChevronUpIcon
sx={{
w: '1rem',
h: '1rem',
transform: isOpen ? 'rotate(0deg)' : 'rotate(180deg)',
transitionProperty: 'common',
transitionDuration: 'normal',
}}
/>
)}
</Flex>
<Collapse in={isOpen} animateOpacity>
<Box sx={{ p: 4, borderBottomRadius: 'base', bg: 'base.800' }}>
{children}
</Box>
</Collapse>
</Box>
);
};
export default memo(IAICollapse);

View File

@@ -0,0 +1,172 @@
import { CheckIcon } from '@chakra-ui/icons';
import {
Box,
Flex,
FlexProps,
FormControl,
FormControlProps,
FormLabel,
Grid,
GridItem,
List,
ListItem,
Select,
Text,
Tooltip,
TooltipProps,
} from '@chakra-ui/react';
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
import { useSelect } from 'downshift';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo } from 'react';
type IAICustomSelectProps = {
label?: string;
items: string[];
selectedItem: string;
setSelectedItem: (v: string | null | undefined) => void;
withCheckIcon?: boolean;
formControlProps?: FormControlProps;
buttonProps?: FlexProps;
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
};
const IAICustomSelect = (props: IAICustomSelectProps) => {
const {
label,
items,
setSelectedItem,
selectedItem,
withCheckIcon,
formControlProps,
tooltip,
buttonProps,
tooltipProps,
} = props;
const {
isOpen,
getToggleButtonProps,
getLabelProps,
getMenuProps,
highlightedIndex,
getItemProps,
} = useSelect({
items,
selectedItem,
onSelectedItemChange: ({ selectedItem: newSelectedItem }) =>
setSelectedItem(newSelectedItem),
});
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
whileElementsMounted: autoUpdate,
middleware: [offset(4), shift({ crossAxis: true, padding: 8 })],
});
return (
<FormControl sx={{ w: 'full' }} {...formControlProps}>
{label && (
<FormLabel
{...getLabelProps()}
onClick={() => {
refs.floating.current && refs.floating.current.focus();
}}
>
{label}
</FormLabel>
)}
<Tooltip label={tooltip} {...tooltipProps}>
<Select
{...getToggleButtonProps({ ref: refs.setReference })}
{...buttonProps}
as={Flex}
sx={{
alignItems: 'center',
userSelect: 'none',
cursor: 'pointer',
}}
>
<Text sx={{ fontSize: 'sm', fontWeight: 500, color: 'base.100' }}>
{selectedItem}
</Text>
</Select>
</Tooltip>
<Box {...getMenuProps()}>
{isOpen && (
<List
as={Flex}
ref={refs.setFloating}
sx={{
...floatingStyles,
width: 'max-content',
top: 0,
left: 0,
flexDirection: 'column',
zIndex: 1,
bg: 'base.800',
borderRadius: 'base',
border: '1px',
borderColor: 'base.700',
shadow: 'dark-lg',
py: 2,
px: 0,
h: 'fit-content',
maxH: 64,
}}
>
<OverlayScrollbarsComponent>
{items.map((item, index) => (
<ListItem
sx={{
bg: highlightedIndex === index ? 'base.700' : undefined,
py: 1,
paddingInlineStart: 3,
paddingInlineEnd: 6,
cursor: 'pointer',
transitionProperty: 'common',
transitionDuration: '0.15s',
}}
key={`${item}${index}`}
{...getItemProps({ item, index })}
>
{withCheckIcon ? (
<Grid gridTemplateColumns="1.25rem auto">
<GridItem>
{selectedItem === item && <CheckIcon boxSize={2} />}
</GridItem>
<GridItem>
<Text
sx={{
fontSize: 'sm',
color: 'base.100',
fontWeight: 500,
}}
>
{item}
</Text>
</GridItem>
</Grid>
) : (
<Text
sx={{
fontSize: 'sm',
color: 'base.100',
fontWeight: 500,
}}
>
{item}
</Text>
)}
</ListItem>
))}
</OverlayScrollbarsComponent>
</List>
)}
</Box>
</FormControl>
);
};
export default memo(IAICustomSelect);

View File

@@ -5,6 +5,7 @@ import {
Input,
InputProps,
} from '@chakra-ui/react';
import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { ChangeEvent, memo } from 'react';
interface IAIInputProps extends InputProps {
@@ -31,7 +32,7 @@ const IAIInput = (props: IAIInputProps) => {
{...formControlProps}
>
{label !== '' && <FormLabel>{label}</FormLabel>}
<Input {...rest} />
<Input {...rest} onPaste={stopPastePropagation} />
</FormControl>
);
};

View File

@@ -14,6 +14,7 @@ import {
Tooltip,
TooltipProps,
} from '@chakra-ui/react';
import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { clamp } from 'lodash-es';
import { FocusEvent, memo, useEffect, useState } from 'react';
@@ -125,6 +126,7 @@ const IAINumberInput = (props: Props) => {
onChange={handleOnChange}
onBlur={handleBlur}
{...rest}
onPaste={stopPastePropagation}
>
<NumberInputField {...numberInputFieldProps} />
{showStepper && (

View File

@@ -27,7 +27,7 @@ const IAIPopover = (props: IAIPopoverProps) => {
return (
<Popover isLazy={isLazy} {...rest}>
<PopoverTrigger>{triggerComponent}</PopoverTrigger>
<PopoverContent>
<PopoverContent shadow="dark-lg">
{hasArrow && <PopoverArrow />}
{children}
</PopoverContent>

View File

@@ -0,0 +1,9 @@
import { Textarea, TextareaProps, forwardRef } from '@chakra-ui/react';
import { stopPastePropagation } from 'common/util/stopPastePropagation';
import { memo } from 'react';
const IAITextarea = forwardRef((props: TextareaProps, ref) => {
return <Textarea ref={ref} onPaste={stopPastePropagation} {...props} />;
});
export default memo(IAITextarea);

View File

@@ -0,0 +1,54 @@
import { Badge, Flex } from '@chakra-ui/react';
import { Image } from 'app/types/invokeai';
import { isNumber, isString } from 'lodash-es';
import { useMemo } from 'react';
type ImageMetadataOverlayProps = {
image: Image;
};
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
const dimensions = useMemo(() => {
if (!isNumber(image.metadata?.width) || isNumber(!image.metadata?.height)) {
return;
}
return `${image.metadata?.width} × ${image.metadata?.height}`;
}, [image.metadata]);
const model = useMemo(() => {
if (!isString(image.metadata?.invokeai?.node?.model)) {
return;
}
return image.metadata?.invokeai?.node?.model;
}, [image.metadata]);
return (
<Flex
sx={{
pointerEvents: 'none',
flexDirection: 'column',
position: 'absolute',
top: 0,
right: 0,
p: 2,
alignItems: 'flex-end',
gap: 2,
}}
>
{dimensions && (
<Badge variant="solid" colorScheme="base">
{dimensions}
</Badge>
)}
{model && (
<Badge variant="solid" colorScheme="base">
{model}
</Badge>
)}
</Flex>
);
};
export default ImageMetadataOverlay;

View File

@@ -1,36 +0,0 @@
import { Badge, Box, Flex } from '@chakra-ui/react';
import { Image } from 'app/types/invokeai';
type ImageToImageOverlayProps = {
image: Image;
};
const ImageToImageOverlay = ({ image }: ImageToImageOverlayProps) => {
return (
<Box
sx={{
top: 0,
left: 0,
w: 'full',
h: 'full',
position: 'absolute',
}}
>
<Flex
sx={{
position: 'absolute',
top: 0,
right: 0,
p: 2,
alignItems: 'flex-start',
}}
>
<Badge variant="solid" colorScheme="base">
{image.metadata?.width} × {image.metadata?.height}
</Badge>
</Flex>
</Box>
);
};
export default ImageToImageOverlay;

View File

@@ -1,4 +1,4 @@
import { Box, useToast } from '@chakra-ui/react';
import { Box } from '@chakra-ui/react';
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import useImageUploader from 'common/hooks/useImageUploader';
@@ -10,12 +10,33 @@ import {
ReactNode,
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { imageUploaded } from 'services/thunks/image';
import ImageUploadOverlay from './ImageUploadOverlay';
import { useAppToaster } from 'app/components/Toaster';
import { filter, map, some } from 'lodash-es';
import { createSelector } from '@reduxjs/toolkit';
import { systemSelector } from 'features/system/store/systemSelectors';
import { ErrorCode } from 'react-dropzone';
const selector = createSelector(
[systemSelector, activeTabNameSelector],
(system, activeTabName) => {
const { isConnected, isUploading } = system;
const isUploaderDisabled = !isConnected || isUploading;
return {
isUploaderDisabled,
activeTabName,
};
}
);
type ImageUploaderProps = {
children: ReactNode;
@@ -24,38 +45,49 @@ type ImageUploaderProps = {
const ImageUploader = (props: ImageUploaderProps) => {
const { children } = props;
const dispatch = useAppDispatch();
const activeTabName = useAppSelector(activeTabNameSelector);
const toast = useToast({});
const { isUploaderDisabled, activeTabName } = useAppSelector(selector);
const toaster = useAppToaster();
const { t } = useTranslation();
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
const { setOpenUploader } = useImageUploader();
const { setOpenUploaderFunction } = useImageUploader();
const fileRejectionCallback = useCallback(
(rejection: FileRejection) => {
setIsHandlingUpload(true);
const msg = rejection.errors.reduce(
(acc: string, cur: { message: string }) => `${acc}\n${cur.message}`,
''
);
toast({
toaster({
title: t('toast.uploadFailed'),
description: msg,
description: rejection.errors.map((error) => error.message).join('\n'),
status: 'error',
isClosable: true,
});
},
[t, toast]
[t, toaster]
);
const fileAcceptedCallback = useCallback(
async (file: File) => {
dispatch(imageUploaded({ formData: { file } }));
dispatch(
imageUploaded({
imageType: 'uploads',
formData: { file },
activeTabName,
})
);
},
[dispatch]
[dispatch, activeTabName]
);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
if (fileRejections.length > 1) {
toaster({
title: t('toast.uploadFailed'),
description: t('toast.uploadFailedInvalidUploadDesc'),
status: 'error',
});
return;
}
fileRejections.forEach((rejection: FileRejection) => {
fileRejectionCallback(rejection);
});
@@ -64,7 +96,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
fileAcceptedCallback(file);
});
},
[fileAcceptedCallback, fileRejectionCallback]
[t, toaster, fileAcceptedCallback, fileRejectionCallback]
);
const {
@@ -73,92 +105,73 @@ const ImageUploader = (props: ImageUploaderProps) => {
isDragAccept,
isDragReject,
isDragActive,
inputRef,
open,
} = useDropzone({
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
noClick: true,
onDrop,
onDragOver: () => setIsHandlingUpload(true),
maxFiles: 1,
disabled: isUploaderDisabled,
multiple: false,
});
setOpenUploader(open);
useEffect(() => {
const pasteImageListener = (e: ClipboardEvent) => {
const dataTransferItemList = e.clipboardData?.items;
if (!dataTransferItemList) return;
const imageItems: Array<DataTransferItem> = [];
for (const item of dataTransferItemList) {
if (
item.kind === 'file' &&
['image/png', 'image/jpg'].includes(item.type)
) {
imageItems.push(item);
}
}
if (!imageItems.length) return;
e.stopImmediatePropagation();
if (imageItems.length > 1) {
toast({
description: t('toast.uploadFailedMultipleImagesDesc'),
status: 'error',
isClosable: true,
});
// This is a hack to allow pasting images into the uploader
const handlePaste = async (e: ClipboardEvent) => {
if (!inputRef.current) {
return;
}
const file = imageItems[0].getAsFile();
if (!file) {
toast({
description: t('toast.uploadFailedUnableToLoadDesc'),
status: 'error',
isClosable: true,
});
return;
if (e.clipboardData?.files) {
// Set the files on the inputRef
inputRef.current.files = e.clipboardData.files;
// Dispatch the change event, dropzone catches this and we get to use its own validation
inputRef.current?.dispatchEvent(new Event('change', { bubbles: true }));
}
dispatch(imageUploaded({ formData: { file } }));
};
document.addEventListener('paste', pasteImageListener);
// Set the open function so we can open the uploader from anywhere
setOpenUploaderFunction(open);
// Add the paste event listener
document.addEventListener('paste', handlePaste);
return () => {
document.removeEventListener('paste', pasteImageListener);
document.removeEventListener('paste', handlePaste);
setOpenUploaderFunction(() => {
return;
});
};
}, [t, dispatch, toast, activeTabName]);
}, [inputRef, open, setOpenUploaderFunction]);
const overlaySecondaryText = ['img2img', 'unifiedCanvas'].includes(
activeTabName
)
? ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`
: ``;
const overlaySecondaryText = useMemo(() => {
if (['img2img', 'unifiedCanvas'].includes(activeTabName)) {
return ` to ${String(t(`common.${activeTabName}` as ResourceKey))}`;
}
return '';
}, [t, activeTabName]);
return (
<ImageUploaderTriggerContext.Provider value={open}>
<Box
{...getRootProps({ style: {} })}
onKeyDown={(e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') return;
}}
>
<input {...getInputProps()} />
{children}
{isDragActive && isHandlingUpload && (
<ImageUploadOverlay
isDragAccept={isDragAccept}
isDragReject={isDragReject}
overlaySecondaryText={overlaySecondaryText}
setIsHandlingUpload={setIsHandlingUpload}
/>
)}
</Box>
</ImageUploaderTriggerContext.Provider>
<Box
{...getRootProps({ style: {} })}
onKeyDown={(e: KeyboardEvent) => {
// Bail out if user hits spacebar - do not open the uploader
if (e.key === ' ') return;
}}
>
<input {...getInputProps()} />
{children}
{isDragActive && isHandlingUpload && (
<ImageUploadOverlay
isDragAccept={isDragAccept}
isDragReject={isDragReject}
overlaySecondaryText={overlaySecondaryText}
setIsHandlingUpload={setIsHandlingUpload}
/>
)}
</Box>
);
};

View File

@@ -1,6 +1,5 @@
import { Flex, Heading, Icon } from '@chakra-ui/react';
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useContext } from 'react';
import useImageUploader from 'common/hooks/useImageUploader';
import { FaUpload } from 'react-icons/fa';
type ImageUploaderButtonProps = {
@@ -9,11 +8,7 @@ type ImageUploaderButtonProps = {
const ImageUploaderButton = (props: ImageUploaderButtonProps) => {
const { styleClass } = props;
const open = useContext(ImageUploaderTriggerContext);
const handleClickUpload = () => {
open && open();
};
const { openUploader } = useImageUploader();
return (
<Flex
@@ -26,7 +21,7 @@ const ImageUploaderButton = (props: ImageUploaderButtonProps) => {
className={styleClass}
>
<Flex
onClick={handleClickUpload}
onClick={openUploader}
sx={{
display: 'flex',
flexDirection: 'column',

View File

@@ -1,19 +1,18 @@
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUpload } from 'react-icons/fa';
import IAIIconButton from './IAIIconButton';
import useImageUploader from 'common/hooks/useImageUploader';
const ImageUploaderIconButton = () => {
const { t } = useTranslation();
const openImageUploader = useContext(ImageUploaderTriggerContext);
const { openUploader } = useImageUploader();
return (
<IAIIconButton
aria-label={t('accessibility.uploadImage')}
tooltip="Upload Image"
icon={<FaUpload />}
onClick={openImageUploader || undefined}
onClick={openUploader}
/>
);
};

View File

@@ -6,10 +6,12 @@ import { FaUndo, FaUpload } from 'react-icons/fa';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCallback } from 'react';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import useImageUploader from 'common/hooks/useImageUploader';
const ImageToImageSettingsHeader = () => {
const InitialImageButtons = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { openUploader } = useImageUploader();
const handleResetInitialImage = useCallback(() => {
dispatch(clearInitialImage());
@@ -18,19 +20,18 @@ const ImageToImageSettingsHeader = () => {
return (
<Flex w="full" alignItems="center">
<Text size="sm" fontWeight={500} color="base.300">
Image to Image
{t('parameters.initialImage')}
</Text>
<Spacer />
<ButtonGroup>
<IAIIconButton
size="sm"
icon={<FaUndo />}
aria-label={t('accessibility.reset')}
onClick={handleResetInitialImage}
/>
<IAIIconButton
size="sm"
icon={<FaUpload />}
onClick={openUploader}
aria-label={t('common.upload')}
/>
</ButtonGroup>
@@ -38,4 +39,4 @@ const ImageToImageSettingsHeader = () => {
);
};
export default ImageToImageSettingsHeader;
export default InitialImageButtons;

View File

@@ -24,7 +24,6 @@ const Loading = () => {
height="24px !important"
right="1.5rem"
bottom="1.5rem"
speed="1.2s"
/>
</Flex>
);

View File

@@ -7,7 +7,7 @@ const SelectImagePlaceholder = () => {
sx={{
w: 'full',
h: 'full',
bg: 'base.800',
// bg: 'base.800',
borderRadius: 'base',
alignItems: 'center',
justifyContent: 'center',

View File

@@ -1,29 +0,0 @@
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import WorkInProgress from './WorkInProgress';
export const PostProcessingWIP = () => {
const { t } = useTranslation();
return (
<WorkInProgress>
<Flex
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
w: '100%',
h: '100%',
gap: 4,
textAlign: 'center',
}}
>
<Heading>{t('common.postProcessing')}</Heading>
<VStack maxW="50rem" gap={4}>
<Text>{t('common.postProcessDesc1')}</Text>
<Text>{t('common.postProcessDesc2')}</Text>
<Text>{t('common.postProcessDesc3')}</Text>
</VStack>
</Flex>
</WorkInProgress>
);
};

View File

@@ -1,28 +0,0 @@
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import WorkInProgress from './WorkInProgress';
export default function TrainingWIP() {
const { t } = useTranslation();
return (
<WorkInProgress>
<Flex
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
w: '100%',
h: '100%',
gap: 4,
textAlign: 'center',
}}
>
<Heading>{t('common.training')}</Heading>
<VStack maxW="50rem" gap={4}>
<Text>{t('common.trainingDesc1')}</Text>
<Text>{t('common.trainingDesc2')}</Text>
</VStack>
</Flex>
</WorkInProgress>
);
}

View File

@@ -1,26 +0,0 @@
import { Flex } from '@chakra-ui/react';
import { ReactNode } from 'react';
type WorkInProgressProps = {
children: ReactNode;
};
const WorkInProgress = (props: WorkInProgressProps) => {
const { children } = props;
return (
<Flex
sx={{
width: '100%',
height: '100%',
bg: 'base.850',
borderRadius: 'base',
position: 'relative',
}}
>
{children}
</Flex>
);
};
export default WorkInProgress;

View File

@@ -1,35 +0,0 @@
import { RefObject, useEffect } from 'react';
const watchers: {
ref: RefObject<HTMLElement>;
enable: boolean;
callback: () => void;
}[] = [];
const useClickOutsideWatcher = () => {
useEffect(() => {
function handleClickOutside(e: MouseEvent) {
watchers.forEach(({ ref, enable, callback }) => {
if (enable && ref.current && !ref.current.contains(e.target as Node)) {
callback();
}
});
}
document.addEventListener('mousedown', handleClickOutside);
return () => {
document.removeEventListener('mousedown', handleClickOutside);
};
}, []);
return {
addWatcher: (watcher: {
ref: RefObject<HTMLElement>;
callback: () => void;
enable: boolean;
}) => {
watchers.push(watcher);
},
};
};
export default useClickOutsideWatcher;

View File

@@ -1,13 +1,22 @@
let openFunction: () => void;
import { useCallback } from 'react';
let openUploader = () => {
return;
};
const useImageUploader = () => {
return {
setOpenUploader: (open?: () => void) => {
if (open) {
openFunction = open;
const setOpenUploaderFunction = useCallback(
(openUploaderFunction?: () => void) => {
if (openUploaderFunction) {
openUploader = openUploaderFunction;
}
},
openUploader: openFunction,
[]
);
return {
setOpenUploaderFunction,
openUploader,
};
};

View File

@@ -1,17 +0,0 @@
import React from 'react';
import { useTranslation } from 'react-i18next';
export default function useUpdateTranslations(fn: () => void) {
const { i18n } = useTranslation();
const currentLang = localStorage.getItem('i18nextLng');
React.useEffect(() => {
fn();
}, [fn]);
React.useEffect(() => {
i18n.on('languageChanged', () => {
fn();
});
}, [fn, i18n, currentLang]);
}

View File

@@ -1,20 +0,0 @@
import { createIcon } from '@chakra-ui/react';
const ImageToImageIcon = createIcon({
displayName: 'ImageToImageIcon',
viewBox: '0 0 3543 3543',
path: (
<g transform="matrix(1.10943,0,0,1.10943,-206.981,-213.533)">
<path
fill="currentColor"
fillRule="evenodd"
clipRule="evenodd"
d="M688.533,2405.95L542.987,2405.95C349.532,2405.95 192.47,2248.89 192.47,2055.44L192.47,542.987C192.47,349.532 349.532,192.47 542.987,192.47L2527.88,192.47C2721.33,192.47 2878.4,349.532 2878.4,542.987L2878.4,1172.79L3023.94,1172.79C3217.4,1172.79 3374.46,1329.85 3374.46,1523.3C3374.46,1523.3 3374.46,3035.75 3374.46,3035.75C3374.46,3229.21 3217.4,3386.27 3023.94,3386.27L1039.05,3386.27C845.595,3386.27 688.533,3229.21 688.533,3035.75L688.533,2405.95ZM3286.96,2634.37L3286.96,1523.3C3286.96,1378.14 3169.11,1260.29 3023.94,1260.29C3023.94,1260.29 1039.05,1260.29 1039.05,1260.29C893.887,1260.29 776.033,1378.14 776.033,1523.3L776.033,2489.79L1440.94,1736.22L2385.83,2775.59L2880.71,2200.41L3286.96,2634.37ZM2622.05,1405.51C2778.5,1405.51 2905.51,1532.53 2905.51,1688.98C2905.51,1845.42 2778.5,1972.44 2622.05,1972.44C2465.6,1972.44 2338.58,1845.42 2338.58,1688.98C2338.58,1532.53 2465.6,1405.51 2622.05,1405.51ZM2790.9,1172.79L1323.86,1172.79L944.882,755.906L279.97,1509.47L279.97,542.987C279.97,397.824 397.824,279.97 542.987,279.97C542.987,279.97 2527.88,279.97 2527.88,279.97C2673.04,279.97 2790.9,397.824 2790.9,542.987L2790.9,1172.79ZM2125.98,425.197C2282.43,425.197 2409.45,552.213 2409.45,708.661C2409.45,865.11 2282.43,992.126 2125.98,992.126C1969.54,992.126 1842.52,865.11 1842.52,708.661C1842.52,552.213 1969.54,425.197 2125.98,425.197Z"
/>
</g>
),
defaultProps: {
boxSize: '24px',
},
});
export default ImageToImageIcon;

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show More