diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 5a8c4dbf99..f0db3e6d9e 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -98,7 +98,8 @@ class CompelInvocation(BaseInvocation): # TODO: support legacy blend? - prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str) + conjunction = Compel.parse_prompt_string(prompt_str) + prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] if context.services.configuration.log_tokenization: log_tokenization_for_prompt_object(prompt, tokenizer) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 98f87d2dd4..2ce58c016b 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -5,7 +5,12 @@ from typing import Literal from pydantic import BaseModel, Field import numpy as np -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InvocationContext, + InvocationConfig, +) class MathInvocationConfig(BaseModel): @@ -22,19 +27,21 @@ class MathInvocationConfig(BaseModel): class IntOutput(BaseInvocationOutput): """An integer output""" - #fmt: off + + # fmt: off type: Literal["int_output"] = "int_output" a: int = Field(default=None, description="The output integer") - #fmt: on + # fmt: on class AddInvocation(BaseInvocation, MathInvocationConfig): """Adds two numbers""" - #fmt: off + + # fmt: off type: Literal["add"] = "add" a: int = Field(default=0, description="The first number") b: int = Field(default=0, description="The second number") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a + self.b) @@ -42,11 +49,12 @@ class AddInvocation(BaseInvocation, MathInvocationConfig): class SubtractInvocation(BaseInvocation, MathInvocationConfig): """Subtracts two numbers""" - #fmt: off + + # fmt: off type: Literal["sub"] = "sub" a: int = Field(default=0, description="The first number") b: int = Field(default=0, description="The second number") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a - self.b) @@ -54,11 +62,12 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig): class MultiplyInvocation(BaseInvocation, MathInvocationConfig): """Multiplies two numbers""" - #fmt: off + + # fmt: off type: Literal["mul"] = "mul" a: int = Field(default=0, description="The first number") b: int = Field(default=0, description="The second number") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a * self.b) @@ -66,11 +75,12 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig): class DivideInvocation(BaseInvocation, MathInvocationConfig): """Divides two numbers""" - #fmt: off + + # fmt: off type: Literal["div"] = "div" a: int = Field(default=0, description="The first number") b: int = Field(default=0, description="The second number") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=int(self.a / self.b)) @@ -78,8 +88,13 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig): class RandomIntInvocation(BaseInvocation): """Outputs a single random integer.""" - #fmt: off + + # fmt: off type: Literal["rand_int"] = "rand_int" - #fmt: on + low: int = Field(default=0, description="The inclusive low value") + high: int = Field( + default=np.iinfo(np.int32).max, description="The exclusive high value" + ) + # fmt: on def invoke(self, context: InvocationContext) -> IntOutput: - return IntOutput(a=np.random.randint(0, np.iinfo(np.int32).max)) + return IntOutput(a=np.random.randint(self.low, self.high)) diff --git a/invokeai/backend/prompting/conditioning.py b/invokeai/backend/prompting/conditioning.py index 1d2b16e9e1..fe2a553015 100644 --- a/invokeai/backend/prompting/conditioning.py +++ b/invokeai/backend/prompting/conditioning.py @@ -16,6 +16,7 @@ from compel.prompt_parser import ( FlattenedPrompt, Fragment, PromptParser, + Conjunction, ) import invokeai.backend.util.logging as logger @@ -26,58 +27,48 @@ from ..util import torch_dtype config = get_invokeai_config() -def get_uc_and_c_and_ec( - prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False -): +def get_uc_and_c_and_ec(prompt_string, + model: InvokeAIDiffuserComponent, + log_tokens=False, skip_normalize_legacy_blend=False): # lazy-load any deferred textual inversions. # this might take a couple of seconds the first time a textual inversion is used. - model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms( - prompt_string - ) + model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string) - tokenizer = model.tokenizer - compel = Compel( - tokenizer=tokenizer, - text_encoder=model.text_encoder, - textual_inversion_manager=model.textual_inversion_manager, - dtype_for_device_getter=torch_dtype, - truncate_long_prompts=False - ) + compel = Compel(tokenizer=model.tokenizer, + text_encoder=model.text_encoder, + textual_inversion_manager=model.textual_inversion_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=False, + ) # get rid of any newline characters prompt_string = prompt_string.replace("\n", " ") - ( - positive_prompt_string, - negative_prompt_string, - ) = split_prompt_to_positive_and_negative(prompt_string) - legacy_blend = try_parse_legacy_blend( - positive_prompt_string, skip_normalize_legacy_blend - ) - positive_prompt: Union[FlattenedPrompt, Blend] - if legacy_blend is not None: - positive_prompt = legacy_blend - else: - positive_prompt = Compel.parse_prompt_string(positive_prompt_string) - negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( - negative_prompt_string - ) + positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) + legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend) + positive_conjunction: Conjunction + if legacy_blend is not None: + positive_conjunction = legacy_blend + else: + positive_conjunction = Compel.parse_prompt_string(positive_prompt_string) + positive_prompt = positive_conjunction.prompts[0] + + negative_conjunction = Compel.parse_prompt_string(negative_prompt_string) + negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0] + + tokens_count = get_max_token_count(model.tokenizer, positive_prompt) if log_tokens or config.log_tokenization: - log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer) + log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer) c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) - tokens_count = get_max_token_count(tokenizer, positive_prompt) - - ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=tokens_count, - cross_attention_control_args=options.get("cross_attention_control", None), - ) + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count, + cross_attention_control_args=options.get( + 'cross_attention_control', None)) return uc, c, ec - def get_prompt_structure( prompt_string, skip_normalize_legacy_blend: bool = False ) -> (Union[FlattenedPrompt, Blend], FlattenedPrompt): @@ -88,18 +79,17 @@ def get_prompt_structure( legacy_blend = try_parse_legacy_blend( positive_prompt_string, skip_normalize_legacy_blend ) - positive_prompt: Union[FlattenedPrompt, Blend] + positive_prompt: Conjunction if legacy_blend is not None: - positive_prompt = legacy_blend + positive_conjunction = legacy_blend else: - positive_prompt = Compel.parse_prompt_string(positive_prompt_string) - negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string( - negative_prompt_string - ) + positive_conjunction = Compel.parse_prompt_string(positive_prompt_string) + positive_prompt = positive_conjunction.prompts[0] + negative_conjunction = Compel.parse_prompt_string(negative_prompt_string) + negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0] return positive_prompt, negative_prompt - def get_max_token_count( tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False ) -> int: @@ -246,22 +236,21 @@ def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_t logger.info(f"[TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):") logger.debug(f"{discarded}\x1b[0m") - -def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Blend]: +def try_parse_legacy_blend(text: str, skip_normalize: bool = False) -> Optional[Conjunction]: weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize) if len(weighted_subprompts) <= 1: return None strings = [x[0] for x in weighted_subprompts] - weights = [x[1] for x in weighted_subprompts] pp = PromptParser() parsed_conjunctions = [pp.parse_conjunction(x) for x in strings] - flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] - - return Blend( - prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize - ) - + flattened_prompts = [] + weights = [] + for i, x in enumerate(parsed_conjunctions): + if len(x.prompts)>0: + flattened_prompts.append(x.prompts[0]) + weights.append(weighted_subprompts[i][1]) + return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)]) def split_weighted_subprompts(text, skip_normalize=False) -> list: """ diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 58cb244ac3..b9fc5946c0 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -548,8 +548,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance = [] extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context( - extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps), + self.invokeai_diffuser.model, + extra_conditioning_info=extra_conditioning_info, + step_count=len(self.scheduler.timesteps), ): yield PipelineIntermediateState( run_id=run_id, diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index dfd19ea964..79a0982cfe 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -10,6 +10,7 @@ import diffusers import psutil import torch from compel.cross_attention_control import Arguments +from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.attention_processor import AttentionProcessor from torch import nn @@ -352,8 +353,7 @@ def restore_default_cross_attention( else: remove_attention_function(model) - -def override_cross_attention(model, context: Context, is_running_diffusers=False): +def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -372,37 +372,22 @@ def override_cross_attention(model, context: Context, is_running_diffusers=False indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: if b0 < max_length: - if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0): + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) - if is_running_diffusers: - unet = model - old_attn_processors = unet.attn_processors - if torch.backends.mps.is_available(): - # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS - unet.set_attn_processor(SwapCrossAttnProcessor()) - else: - # try to re-use an existing slice size - default_slice_size = 4 - slice_size = next( - ( - p.slice_size - for p in old_attn_processors.values() - if type(p) is SlicedAttnProcessor - ), - default_slice_size, - ) - unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) - return old_attn_processors + old_attn_processors = unet.attn_processors + if torch.backends.mps.is_available(): + # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS + unet.set_attn_processor(SwapCrossAttnProcessor()) else: - context.register_cross_attention_modules(model) - inject_attention_function(model, context) - return None - + # try to re-use an existing slice size + default_slice_size = 4 + slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) def get_cross_attention_modules( model, which: CrossAttentionType diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 254a81f03b..7970bc8691 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch +from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import AttentionProcessor from typing_extensions import TypeAlias @@ -17,8 +18,8 @@ from .cross_attention_control import ( CrossAttentionType, SwapCrossAttnContext, get_cross_attention_modules, - override_cross_attention, restore_default_cross_attention, + setup_cross_attention_control_attention_processors, ) from .cross_attention_map_saving import AttentionMapSaver @@ -80,24 +81,35 @@ class InvokeAIDiffuserComponent: self.cross_attention_control_context = None self.sequential_guidance = config.sequential_guidance + @classmethod @contextmanager def custom_attention_context( - self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int + cls, + unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs + extra_conditioning_info: Optional[ExtraConditioningInfo], + step_count: int ): - do_swap = ( - extra_conditioning_info is not None - and extra_conditioning_info.wants_cross_attention_control - ) - old_attn_processor = None - if do_swap: - old_attn_processor = self.override_cross_attention( - extra_conditioning_info, step_count=step_count - ) + old_attn_processors = None + if extra_conditioning_info and ( + extra_conditioning_info.wants_cross_attention_control + ): + old_attn_processors = unet.attn_processors + # Load lora conditions into the model + if extra_conditioning_info.wants_cross_attention_control: + cross_attention_control_context = Context( + arguments=extra_conditioning_info.cross_attention_control_args, + step_count=step_count, + ) + setup_cross_attention_control_attention_processors( + unet, + cross_attention_control_context, + ) + try: yield None finally: - if old_attn_processor is not None: - self.restore_default_cross_attention(old_attn_processor) + if old_attn_processors is not None: + unet.set_attn_processor(old_attn_processors) # TODO resuscitate attention map saving # self.remove_attention_map_saving() diff --git a/invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md b/invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md index 90d85bb540..5f882717b1 100644 --- a/invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md +++ b/invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md @@ -15,15 +15,3 @@ The `postinstall` script patches a few packages and runs the Chakra CLI to gener ### Patch `@chakra-ui/cli` See: - -### Patch `redux-persist` - -We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`. - -`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it. - -So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that. - -### Patch `redux-deep-persist` - -This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work. diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 404d20d937..317929c6a4 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -89,18 +89,13 @@ "react-i18next": "^12.2.2", "react-icons": "^4.7.1", "react-konva": "^18.2.7", - "react-konva-utils": "^1.0.4", "react-redux": "^8.0.5", "react-resizable-panels": "^0.0.42", - "react-rnd": "^10.4.1", - "react-transition-group": "^4.4.5", "react-use": "^17.4.0", "react-virtuoso": "^4.3.5", "react-zoom-pan-pinch": "^3.0.7", "reactflow": "^11.7.0", - "redux-deep-persist": "^1.0.7", "redux-dynamic-middlewares": "^2.2.0", - "redux-persist": "^6.0.0", "redux-remember": "^3.3.1", "roarr": "^7.15.0", "serialize-error": "^11.0.0", diff --git a/invokeai/frontend/web/patches/redux-deep-persist+1.0.7.patch b/invokeai/frontend/web/patches/redux-deep-persist+1.0.7.patch deleted file mode 100644 index 47a62e6aac..0000000000 --- a/invokeai/frontend/web/patches/redux-deep-persist+1.0.7.patch +++ /dev/null @@ -1,24 +0,0 @@ -diff --git a/node_modules/redux-deep-persist/lib/types.d.ts b/node_modules/redux-deep-persist/lib/types.d.ts -index b67b8c2..7fc0fa1 100644 ---- a/node_modules/redux-deep-persist/lib/types.d.ts -+++ b/node_modules/redux-deep-persist/lib/types.d.ts -@@ -35,6 +35,7 @@ export interface PersistConfig { - whitelist?: Array; - transforms?: Array>; - throttle?: number; -+ debounce?: number; - migrate?: PersistMigrate; - stateReconciler?: false | StateReconciler; - getStoredState?: (config: PersistConfig) => Promise; -diff --git a/node_modules/redux-deep-persist/src/types.ts b/node_modules/redux-deep-persist/src/types.ts -index 398ac19..cbc5663 100644 ---- a/node_modules/redux-deep-persist/src/types.ts -+++ b/node_modules/redux-deep-persist/src/types.ts -@@ -91,6 +91,7 @@ export interface PersistConfig { - whitelist?: Array; - transforms?: Array>; - throttle?: number; -+ debounce?: number; - migrate?: PersistMigrate; - stateReconciler?: false | StateReconciler; - /** diff --git a/invokeai/frontend/web/patches/redux-persist+6.0.0.patch b/invokeai/frontend/web/patches/redux-persist+6.0.0.patch deleted file mode 100644 index 9e0a8492db..0000000000 --- a/invokeai/frontend/web/patches/redux-persist+6.0.0.patch +++ /dev/null @@ -1,116 +0,0 @@ -diff --git a/node_modules/redux-persist/es/createPersistoid.js b/node_modules/redux-persist/es/createPersistoid.js -index 8b43b9a..184faab 100644 ---- a/node_modules/redux-persist/es/createPersistoid.js -+++ b/node_modules/redux-persist/es/createPersistoid.js -@@ -6,6 +6,7 @@ export default function createPersistoid(config) { - var whitelist = config.whitelist || null; - var transforms = config.transforms || []; - var throttle = config.throttle || 0; -+ var debounce = config.debounce || 0; - var storageKey = "".concat(config.keyPrefix !== undefined ? config.keyPrefix : KEY_PREFIX).concat(config.key); - var storage = config.storage; - var serialize; -@@ -28,30 +29,37 @@ export default function createPersistoid(config) { - var timeIterator = null; - var writePromise = null; - -- var update = function update(state) { -- // add any changed keys to the queue -- Object.keys(state).forEach(function (key) { -- if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop -+ // Timer for debounced `update()` -+ let timer = 0; - -- if (lastState[key] === state[key]) return; // value unchanged? noop -+ function update(state) { -+ // Debounce the update -+ clearTimeout(timer); -+ timer = setTimeout(() => { -+ // add any changed keys to the queue -+ Object.keys(state).forEach(function (key) { -+ if (!passWhitelistBlacklist(key)) return; // is keyspace ignored? noop - -- if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop -+ if (lastState[key] === state[key]) return; // value unchanged? noop - -- keysToProcess.push(key); // add key to queue -- }); //if any key is missing in the new state which was present in the lastState, -- //add it for processing too -+ if (keysToProcess.indexOf(key) !== -1) return; // is key already queued? noop - -- Object.keys(lastState).forEach(function (key) { -- if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) { -- keysToProcess.push(key); -- } -- }); // start the time iterator if not running (read: throttle) -+ keysToProcess.push(key); // add key to queue -+ }); //if any key is missing in the new state which was present in the lastState, -+ //add it for processing too - -- if (timeIterator === null) { -- timeIterator = setInterval(processNextKey, throttle); -- } -+ Object.keys(lastState).forEach(function (key) { -+ if (state[key] === undefined && passWhitelistBlacklist(key) && keysToProcess.indexOf(key) === -1 && lastState[key] !== undefined) { -+ keysToProcess.push(key); -+ } -+ }); // start the time iterator if not running (read: throttle) -+ -+ if (timeIterator === null) { -+ timeIterator = setInterval(processNextKey, throttle); -+ } - -- lastState = state; -+ lastState = state; -+ }, debounce) - }; - - function processNextKey() { -diff --git a/node_modules/redux-persist/es/types.js.flow b/node_modules/redux-persist/es/types.js.flow -index c50d3cd..39d8be2 100644 ---- a/node_modules/redux-persist/es/types.js.flow -+++ b/node_modules/redux-persist/es/types.js.flow -@@ -19,6 +19,7 @@ export type PersistConfig = { - whitelist?: Array, - transforms?: Array, - throttle?: number, -+ debounce?: number, - migrate?: (PersistedState, number) => Promise, - stateReconciler?: false | Function, - getStoredState?: PersistConfig => Promise, // used for migrations -diff --git a/node_modules/redux-persist/lib/types.js.flow b/node_modules/redux-persist/lib/types.js.flow -index c50d3cd..39d8be2 100644 ---- a/node_modules/redux-persist/lib/types.js.flow -+++ b/node_modules/redux-persist/lib/types.js.flow -@@ -19,6 +19,7 @@ export type PersistConfig = { - whitelist?: Array, - transforms?: Array, - throttle?: number, -+ debounce?: number, - migrate?: (PersistedState, number) => Promise, - stateReconciler?: false | Function, - getStoredState?: PersistConfig => Promise, // used for migrations -diff --git a/node_modules/redux-persist/src/types.js b/node_modules/redux-persist/src/types.js -index c50d3cd..39d8be2 100644 ---- a/node_modules/redux-persist/src/types.js -+++ b/node_modules/redux-persist/src/types.js -@@ -19,6 +19,7 @@ export type PersistConfig = { - whitelist?: Array, - transforms?: Array, - throttle?: number, -+ debounce?: number, - migrate?: (PersistedState, number) => Promise, - stateReconciler?: false | Function, - getStoredState?: PersistConfig => Promise, // used for migrations -diff --git a/node_modules/redux-persist/types/types.d.ts b/node_modules/redux-persist/types/types.d.ts -index b3733bc..2a1696c 100644 ---- a/node_modules/redux-persist/types/types.d.ts -+++ b/node_modules/redux-persist/types/types.d.ts -@@ -35,6 +35,7 @@ declare module "redux-persist/es/types" { - whitelist?: Array; - transforms?: Array>; - throttle?: number; -+ debounce?: number; - migrate?: PersistMigrate; - stateReconciler?: false | StateReconciler; - /** diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index f82b3af677..94dff3934a 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -450,7 +450,7 @@ "cfgScale": "CFG Scale", "width": "Width", "height": "Height", - "sampler": "Sampler", + "scheduler": "Scheduler", "seed": "Seed", "imageToImage": "Image to Image", "randomizeSeed": "Randomize Seed", @@ -552,8 +552,8 @@ "canceled": "Processing Canceled", "tempFoldersEmptied": "Temp Folder Emptied", "uploadFailed": "Upload failed", - "uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time", "uploadFailedUnableToLoadDesc": "Unable to load file", + "uploadFailedInvalidUploadDesc": "Must be single PNG or JPEG image", "downloadImageStarted": "Image Download Started", "imageCopied": "Image Copied", "imageLinkCopied": "Image Link Copied", diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index eb6496f43e..40554356b1 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -2,14 +2,11 @@ import ImageUploader from 'common/components/ImageUploader'; import SiteHeader from 'features/system/components/SiteHeader'; import ProgressBar from 'features/system/components/ProgressBar'; import InvokeTabs from 'features/ui/components/InvokeTabs'; - -import useToastWatcher from 'features/system/hooks/useToastWatcher'; - import FloatingGalleryButton from 'features/ui/components/FloatingGalleryButton'; import FloatingParametersPanelButtons from 'features/ui/components/FloatingParametersPanelButtons'; import { Box, Flex, Grid, Portal } from '@chakra-ui/react'; import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants'; -import GalleryDrawer from 'features/gallery/components/ImageGalleryPanel'; +import GalleryDrawer from 'features/gallery/components/GalleryPanel'; import Lightbox from 'features/lightbox/components/Lightbox'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { memo, ReactNode, useCallback, useEffect, useState } from 'react'; @@ -17,13 +14,14 @@ import { motion, AnimatePresence } from 'framer-motion'; import Loading from 'common/components/Loading/Loading'; import { useIsApplicationReady } from 'features/system/hooks/useIsApplicationReady'; import { PartialAppConfig } from 'app/types/invokeai'; -import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys'; import { configChanged } from 'features/system/store/configSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { useLogger } from 'app/logging/useLogger'; import ParametersDrawer from 'features/ui/components/ParametersDrawer'; import { languageSelector } from 'features/system/store/systemSelectors'; import i18n from 'i18n'; +import Toaster from './Toaster'; +import GlobalHotkeys from './GlobalHotkeys'; const DEFAULT_CONFIG = {}; @@ -38,9 +36,6 @@ const App = ({ headerComponent, setIsReady, }: Props) => { - useToastWatcher(); - useGlobalHotkeys(); - const language = useAppSelector(languageSelector); const log = useLogger(); @@ -77,65 +72,69 @@ const App = ({ }, [isApplicationReady, setIsReady]); return ( - - {isLightboxEnabled && } - - - - {headerComponent || } - + + {isLightboxEnabled && } + + + - - - - + {headerComponent || } + + + + + - - + + - - {!isApplicationReady && !loadingOverridden && ( - - - - - - - )} - + + {!isApplicationReady && !loadingOverridden && ( + + + + + + + )} + - - - - - - - + + + + + + + + + + ); }; diff --git a/invokeai/frontend/web/src/app/components/AuxiliaryProgressIndicator.tsx b/invokeai/frontend/web/src/app/components/AuxiliaryProgressIndicator.tsx new file mode 100644 index 0000000000..a0c5d22266 --- /dev/null +++ b/invokeai/frontend/web/src/app/components/AuxiliaryProgressIndicator.tsx @@ -0,0 +1,44 @@ +import { Flex, Spinner, Tooltip } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { systemSelector } from 'features/system/store/systemSelectors'; +import { memo } from 'react'; + +const selector = createSelector(systemSelector, (system) => { + const { isUploading } = system; + + let tooltip = ''; + + if (isUploading) { + tooltip = 'Uploading...'; + } + + return { + tooltip, + shouldShow: isUploading, + }; +}); + +export const AuxiliaryProgressIndicator = () => { + const { shouldShow, tooltip } = useAppSelector(selector); + + if (!shouldShow) { + return null; + } + + return ( + + + + + + ); +}; + +export default memo(AuxiliaryProgressIndicator); diff --git a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts similarity index 89% rename from invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts rename to invokeai/frontend/web/src/app/components/GlobalHotkeys.ts index 3935a390fb..c4660416bf 100644 --- a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts +++ b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts @@ -10,6 +10,7 @@ import { togglePinParametersPanel, } from 'features/ui/store/uiSlice'; import { isEqual } from 'lodash-es'; +import React, { memo } from 'react'; import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook'; const globalHotkeysSelector = createSelector( @@ -27,7 +28,11 @@ const globalHotkeysSelector = createSelector( // TODO: Does not catch keypresses while focused in an input. Maybe there is a way? -export const useGlobalHotkeys = () => { +/** + * Logical component. Handles app-level global hotkeys. + * @returns null + */ +const GlobalHotkeys: React.FC = () => { const dispatch = useAppDispatch(); const { shift } = useAppSelector(globalHotkeysSelector); @@ -75,4 +80,8 @@ export const useGlobalHotkeys = () => { useHotkeys('4', () => { dispatch(setActiveTab('nodes')); }); + + return null; }; + +export default memo(GlobalHotkeys); diff --git a/invokeai/frontend/web/src/app/components/Toaster.ts b/invokeai/frontend/web/src/app/components/Toaster.ts new file mode 100644 index 0000000000..66ba1d4925 --- /dev/null +++ b/invokeai/frontend/web/src/app/components/Toaster.ts @@ -0,0 +1,65 @@ +import { useToast, UseToastOptions } from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { toastQueueSelector } from 'features/system/store/systemSelectors'; +import { addToast, clearToastQueue } from 'features/system/store/systemSlice'; +import { useCallback, useEffect } from 'react'; + +export type MakeToastArg = string | UseToastOptions; + +/** + * Makes a toast from a string or a UseToastOptions object. + * If a string is passed, the toast will have the status 'info' and will be closable with a duration of 2500ms. + */ +export const makeToast = (arg: MakeToastArg): UseToastOptions => { + if (typeof arg === 'string') { + return { + title: arg, + status: 'info', + isClosable: true, + duration: 2500, + }; + } + + return { status: 'info', isClosable: true, duration: 2500, ...arg }; +}; + +/** + * Logical component. Watches the toast queue and makes toasts when the queue is not empty. + * @returns null + */ +const Toaster = () => { + const dispatch = useAppDispatch(); + const toastQueue = useAppSelector(toastQueueSelector); + const toast = useToast(); + useEffect(() => { + toastQueue.forEach((t) => { + toast(t); + }); + toastQueue.length > 0 && dispatch(clearToastQueue()); + }, [dispatch, toast, toastQueue]); + + return null; +}; + +/** + * Returns a function that can be used to make a toast. + * @example + * const toaster = useAppToaster(); + * toaster('Hello world!'); + * toaster({ title: 'Hello world!', status: 'success' }); + * @returns A function that can be used to make a toast. + * @see makeToast + * @see MakeToastArg + * @see UseToastOptions + */ +export const useAppToaster = () => { + const dispatch = useAppDispatch(); + const toaster = useCallback( + (arg: MakeToastArg) => dispatch(addToast(makeToast(arg))), + [dispatch] + ); + + return toaster; +}; + +export default Toaster; diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index b74c67befd..d312d725ba 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -1,6 +1,6 @@ // TODO: use Enums? -export const SCHEDULERS: Array = [ +export const SCHEDULERS = [ 'ddim', 'lms', 'euler', @@ -17,7 +17,12 @@ export const SCHEDULERS: Array = [ 'heun', 'heun_k', 'unipc', -]; +] as const; + +export type Scheduler = (typeof SCHEDULERS)[number]; + +export const isScheduler = (x: string): x is Scheduler => + SCHEDULERS.includes(x as Scheduler); // Valid image widths export const WIDTHS: Array = Array.from(Array(64)).map( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 36bf6adfe7..f23e83a191 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -15,6 +15,10 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; +import { addCanvasSavedToGalleryListener } from './listeners/canvasSavedToGallery'; +import { addCanvasDownloadedAsImageListener } from './listeners/canvasDownloadedAsImage'; +import { addCanvasCopiedToClipboardListener } from './listeners/canvasCopiedToClipboard'; +import { addCanvasMergedListener } from './listeners/canvasMerged'; export const listenerMiddleware = createListenerMiddleware(); @@ -43,3 +47,8 @@ addUserInvokedCanvasListener(); addUserInvokedNodesListener(); addUserInvokedTextToImageListener(); addUserInvokedImageToImageListener(); + +addCanvasSavedToGalleryListener(); +addCanvasDownloadedAsImageListener(); +addCanvasCopiedToClipboardListener(); +addCanvasMergedListener(); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts new file mode 100644 index 0000000000..16642f1f32 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasCopiedToClipboard.ts @@ -0,0 +1,33 @@ +import { canvasCopiedToClipboard } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; +import { addToast } from 'features/system/store/systemSlice'; +import { copyBlobToClipboard } from 'features/canvas/util/copyBlobToClipboard'; + +const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' }); + +export const addCanvasCopiedToClipboardListener = () => { + startAppListening({ + actionCreator: canvasCopiedToClipboard, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + + const blob = await getBaseLayerBlob(state); + + if (!blob) { + moduleLog.error('Problem getting base layer blob'); + dispatch( + addToast({ + title: 'Problem Copying Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + copyBlobToClipboard(blob); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts new file mode 100644 index 0000000000..ef4c63b31c --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasDownloadedAsImage.ts @@ -0,0 +1,33 @@ +import { canvasDownloadedAsImage } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { downloadBlob } from 'features/canvas/util/downloadBlob'; +import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; +import { addToast } from 'features/system/store/systemSlice'; + +const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); + +export const addCanvasDownloadedAsImageListener = () => { + startAppListening({ + actionCreator: canvasDownloadedAsImage, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + + const blob = await getBaseLayerBlob(state); + + if (!blob) { + moduleLog.error('Problem getting base layer blob'); + dispatch( + addToast({ + title: 'Problem Downloading Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + downloadBlob(blob, 'mergedCanvas.png'); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasGraphBuilt.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasGraphBuilt.ts deleted file mode 100644 index 532bac3eee..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasGraphBuilt.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { canvasGraphBuilt } from 'features/nodes/store/actions'; -import { startAppListening } from '..'; -import { - canvasSessionIdChanged, - stagingAreaInitialized, -} from 'features/canvas/store/canvasSlice'; -import { sessionInvoked } from 'services/thunks/session'; - -export const addCanvasGraphBuiltListener = () => - startAppListening({ - actionCreator: canvasGraphBuilt, - effect: async (action, { dispatch, getState, take }) => { - const [{ meta }] = await take(sessionInvoked.fulfilled.match); - const { sessionId } = meta.arg; - const state = getState(); - - if (!state.canvas.layerState.stagingArea.boundingBox) { - dispatch( - stagingAreaInitialized({ - sessionId, - boundingBox: { - ...state.canvas.boundingBoxCoordinates, - ...state.canvas.boundingBoxDimensions, - }, - }) - ); - } - - dispatch(canvasSessionIdChanged(sessionId)); - }, - }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts new file mode 100644 index 0000000000..d7a58c2050 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasMerged.ts @@ -0,0 +1,88 @@ +import { canvasMerged } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; +import { addToast } from 'features/system/store/systemSlice'; +import { imageUploaded } from 'services/thunks/image'; +import { v4 as uuidv4 } from 'uuid'; +import { deserializeImageResponse } from 'services/util/deserializeImageResponse'; +import { setMergedCanvas } from 'features/canvas/store/canvasSlice'; +import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider'; + +const moduleLog = log.child({ namespace: 'canvasCopiedToClipboardListener' }); + +export const addCanvasMergedListener = () => { + startAppListening({ + actionCreator: canvasMerged, + effect: async (action, { dispatch, getState, take }) => { + const state = getState(); + + const blob = await getBaseLayerBlob(state, true); + + if (!blob) { + moduleLog.error('Problem getting base layer blob'); + dispatch( + addToast({ + title: 'Problem Merging Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + const canvasBaseLayer = getCanvasBaseLayer(); + + if (!canvasBaseLayer) { + moduleLog.error('Problem getting canvas base layer'); + dispatch( + addToast({ + title: 'Problem Merging Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + const baseLayerRect = canvasBaseLayer.getClientRect({ + relativeTo: canvasBaseLayer.getParent(), + }); + + const filename = `mergedCanvas_${uuidv4()}.png`; + + dispatch( + imageUploaded({ + imageType: 'intermediates', + formData: { + file: new File([blob], filename, { type: 'image/png' }), + }, + }) + ); + + const [{ payload }] = await take( + (action): action is ReturnType => + imageUploaded.fulfilled.match(action) && + action.meta.arg.formData.file.name === filename + ); + + const mergedCanvasImage = deserializeImageResponse(payload.response); + + dispatch( + setMergedCanvas({ + kind: 'image', + layer: 'base', + image: mergedCanvasImage, + ...baseLayerRect, + }) + ); + + dispatch( + addToast({ + title: 'Canvas Merged', + status: 'success', + }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts new file mode 100644 index 0000000000..d8237d1d5c --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -0,0 +1,40 @@ +import { canvasSavedToGallery } from 'features/canvas/store/actions'; +import { startAppListening } from '..'; +import { log } from 'app/logging/useLogger'; +import { imageUploaded } from 'services/thunks/image'; +import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; +import { addToast } from 'features/system/store/systemSlice'; + +const moduleLog = log.child({ namespace: 'canvasSavedToGalleryListener' }); + +export const addCanvasSavedToGalleryListener = () => { + startAppListening({ + actionCreator: canvasSavedToGallery, + effect: async (action, { dispatch, getState }) => { + const state = getState(); + + const blob = await getBaseLayerBlob(state); + + if (!blob) { + moduleLog.error('Problem getting base layer blob'); + dispatch( + addToast({ + title: 'Problem Saving Canvas', + description: 'Unable to export base layer', + status: 'error', + }) + ); + return; + } + + dispatch( + imageUploaded({ + imageType: 'results', + formData: { + file: new File([blob], 'mergedCanvas.png', { type: 'image/png' }), + }, + }) + ); + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index c32da2e710..de06220ecd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -3,6 +3,10 @@ import { startAppListening } from '..'; import { uploadAdded } from 'features/gallery/store/uploadsSlice'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageUploaded } from 'services/thunks/image'; +import { addToast } from 'features/system/store/systemSlice'; +import { initialImageSelected } from 'features/parameters/store/actions'; +import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; +import { resultAdded } from 'features/gallery/store/resultsSlice'; export const addImageUploadedListener = () => { startAppListening({ @@ -11,14 +15,31 @@ export const addImageUploadedListener = () => { action.payload.response.image_type !== 'intermediates', effect: (action, { dispatch, getState }) => { const { response } = action.payload; + const { imageType } = action.meta.arg; const state = getState(); const image = deserializeImageResponse(response); - dispatch(uploadAdded(image)); + if (imageType === 'uploads') { + dispatch(uploadAdded(image)); - if (state.gallery.shouldAutoSwitchToNewImages) { - dispatch(imageSelected(image)); + dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); + + if (state.gallery.shouldAutoSwitchToNewImages) { + dispatch(imageSelected(image)); + } + + if (action.meta.arg.activeTabName === 'img2img') { + dispatch(initialImageSelected(image)); + } + + if (action.meta.arg.activeTabName === 'unifiedCanvas') { + dispatch(setInitialCanvasImage(image)); + } + } + + if (imageType === 'results') { + dispatch(resultAdded(image)); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts index 6bc2f9e9bc..ae3a35f537 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/initialImageSelected.ts @@ -2,11 +2,11 @@ import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { Image, isInvokeAIImage } from 'app/types/invokeai'; import { selectResultsById } from 'features/gallery/store/resultsSlice'; import { selectUploadsById } from 'features/gallery/store/uploadsSlice'; -import { makeToast } from 'features/system/hooks/useToastWatcher'; import { t } from 'i18next'; import { addToast } from 'features/system/store/systemSlice'; import { startAppListening } from '..'; import { initialImageSelected } from 'features/parameters/store/actions'; +import { makeToast } from 'app/components/Toaster'; export const addInitialImageSelectedListener = () => { startAppListening({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index cdb2c83e12..2ebd3684e9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -1,6 +1,6 @@ import { startAppListening } from '..'; import { sessionCreated, sessionInvoked } from 'services/thunks/session'; -import { buildCanvasGraphAndBlobs } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; +import { buildCanvasGraphComponents } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { log } from 'app/logging/useLogger'; import { canvasGraphBuilt } from 'features/nodes/store/actions'; import { imageUploaded } from 'services/thunks/image'; @@ -11,9 +11,17 @@ import { stagingAreaInitialized, } from 'features/canvas/store/canvasSlice'; import { userInvoked } from 'app/store/actions'; +import { getCanvasData } from 'features/canvas/util/getCanvasData'; +import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; +import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; +import openBase64ImageInTab from 'common/util/openBase64ImageInTab'; const moduleLog = log.child({ namespace: 'invoke' }); +/** + * This listener is responsible for building the canvas graph and blobs when the user invokes the canvas. + * It is also responsible for uploading the base and mask layers to the server. + */ export const addUserInvokedCanvasListener = () => { startAppListening({ predicate: (action): action is ReturnType => @@ -21,25 +29,49 @@ export const addUserInvokedCanvasListener = () => { effect: async (action, { getState, dispatch, take }) => { const state = getState(); - const data = await buildCanvasGraphAndBlobs(state); + // Build canvas blobs + const canvasBlobsAndImageData = await getCanvasData(state); - if (!data) { + if (!canvasBlobsAndImageData) { + moduleLog.error('Unable to create canvas data'); + return; + } + + const { baseBlob, baseImageData, maskBlob, maskImageData } = + canvasBlobsAndImageData; + + // Determine the generation mode + const generationMode = getCanvasGenerationMode( + baseImageData, + maskImageData + ); + + if (state.system.enableImageDebugging) { + const baseDataURL = await blobToDataURL(baseBlob); + const maskDataURL = await blobToDataURL(maskBlob); + openBase64ImageInTab([ + { base64: maskDataURL, caption: 'mask b64' }, + { base64: baseDataURL, caption: 'image b64' }, + ]); + } + + moduleLog.debug(`Generation mode: ${generationMode}`); + + // Build the canvas graph + const graphComponents = await buildCanvasGraphComponents( + state, + generationMode + ); + + if (!graphComponents) { moduleLog.error('Problem building graph'); return; } - const { - rangeNode, - iterateNode, - baseNode, - edges, - baseBlob, - maskBlob, - generationMode, - } = data; + const { rangeNode, iterateNode, baseNode, edges } = graphComponents; + // Upload the base layer, to be used as init image const baseFilename = `${uuidv4()}.png`; - const maskFilename = `${uuidv4()}.png`; dispatch( imageUploaded({ @@ -66,6 +98,9 @@ export const addUserInvokedCanvasListener = () => { }; } + // Upload the mask layer image + const maskFilename = `${uuidv4()}.png`; + if (baseNode.type === 'inpaint') { dispatch( imageUploaded({ @@ -103,9 +138,12 @@ export const addUserInvokedCanvasListener = () => { dispatch(canvasGraphBuilt(graph)); moduleLog({ data: graph }, 'Canvas graph built'); + // Actually create the session dispatch(sessionCreated({ graph })); + // Wait for the session to be invoked (this is just the HTTP request to start processing) const [{ meta }] = await take(sessionInvoked.fulfilled.match); + const { sessionId } = meta.arg; if (!state.canvas.layerState.stagingArea.boundingBox) { diff --git a/invokeai/frontend/web/src/common/components/IAIInput.tsx b/invokeai/frontend/web/src/common/components/IAIInput.tsx index 3e90dca83a..3cba36d2c9 100644 --- a/invokeai/frontend/web/src/common/components/IAIInput.tsx +++ b/invokeai/frontend/web/src/common/components/IAIInput.tsx @@ -5,6 +5,7 @@ import { Input, InputProps, } from '@chakra-ui/react'; +import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { ChangeEvent, memo } from 'react'; interface IAIInputProps extends InputProps { @@ -31,7 +32,7 @@ const IAIInput = (props: IAIInputProps) => { {...formControlProps} > {label !== '' && {label}} - + ); }; diff --git a/invokeai/frontend/web/src/common/components/IAINumberInput.tsx b/invokeai/frontend/web/src/common/components/IAINumberInput.tsx index 762182eb47..bf598f3b12 100644 --- a/invokeai/frontend/web/src/common/components/IAINumberInput.tsx +++ b/invokeai/frontend/web/src/common/components/IAINumberInput.tsx @@ -14,6 +14,7 @@ import { Tooltip, TooltipProps, } from '@chakra-ui/react'; +import { stopPastePropagation } from 'common/util/stopPastePropagation'; import { clamp } from 'lodash-es'; import { FocusEvent, memo, useEffect, useState } from 'react'; @@ -125,6 +126,7 @@ const IAINumberInput = (props: Props) => { onChange={handleOnChange} onBlur={handleBlur} {...rest} + onPaste={stopPastePropagation} > {showStepper && ( diff --git a/invokeai/frontend/web/src/common/components/IAITextarea.tsx b/invokeai/frontend/web/src/common/components/IAITextarea.tsx new file mode 100644 index 0000000000..b5247887bb --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAITextarea.tsx @@ -0,0 +1,9 @@ +import { Textarea, TextareaProps, forwardRef } from '@chakra-ui/react'; +import { stopPastePropagation } from 'common/util/stopPastePropagation'; +import { memo } from 'react'; + +const IAITextarea = forwardRef((props: TextareaProps, ref) => { + return