mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 10:07:54 -05:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3207822738 | ||
|
|
8d86fabf4b | ||
|
|
af3e910ad3 | ||
|
|
af25d00964 | ||
|
|
d4a30d08ef | ||
|
|
bd8a33e824 | ||
|
|
b425646b7b | ||
|
|
293e11cfa6 | ||
|
|
c73aabdfbf | ||
|
|
ca989c54b0 | ||
|
|
260e24733f | ||
|
|
bb6e3e726d | ||
|
|
6b394554e2 | ||
|
|
ae1955a1a8 | ||
|
|
1bef13db37 | ||
|
|
a461537087 | ||
|
|
99e28da19b | ||
|
|
42a159beaa | ||
|
|
0aa5aadfe8 | ||
|
|
2537d260e3 | ||
|
|
bbf919a933 | ||
|
|
01897ec576 | ||
|
|
bc12d6654e | ||
|
|
6d7c8d5f57 | ||
|
|
38604aa408 | ||
|
|
781de914f4 | ||
|
|
c094bad233 | ||
|
|
0063014f2b | ||
|
|
d7b5ad02e8 | ||
|
|
2cee436ecf | ||
|
|
e6386d969f | ||
|
|
4b2b983646 | ||
|
|
53808149fb | ||
|
|
21ba55d0a6 | ||
|
|
28c28b2fc0 |
@@ -51,6 +51,7 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
@@ -185,7 +186,7 @@ class GradientMaskOutput(BaseInvocationOutput):
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@@ -198,6 +199,32 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
minimum_denoise: float = InputField(
|
||||
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||
)
|
||||
image: Optional[ImageField] = InputField(
|
||||
default=None,
|
||||
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||
title="[OPTIONAL] Image",
|
||||
ui_order=6,
|
||||
)
|
||||
unet: Optional[UNetField] = InputField(
|
||||
description="OPTIONAL: If the Unet is a specialized Inpainting model, masked_latents will be generated from the image with the VAE",
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
title="[OPTIONAL] UNet",
|
||||
ui_order=5,
|
||||
)
|
||||
vae: Optional[VAEField] = InputField(
|
||||
default=None,
|
||||
description="OPTIONAL: Only connect for specialized Inpainting models, masked_latents will be generated from the image with the VAE",
|
||||
title="[OPTIONAL] VAE",
|
||||
input=Input.Connection,
|
||||
ui_order=7,
|
||||
)
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
|
||||
fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == "float32",
|
||||
description=FieldDescriptions.fp32,
|
||||
ui_order=9,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> GradientMaskOutput:
|
||||
@@ -233,8 +260,27 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
expanded_mask_image = Image.fromarray((expanded_mask.squeeze(0).numpy() * 255).astype(np.uint8), mode="L")
|
||||
expanded_image_dto = context.images.save(expanded_mask_image)
|
||||
|
||||
masked_latents_name = None
|
||||
if self.unet is not None and self.vae is not None and self.image is not None:
|
||||
# all three fields must be present at the same time
|
||||
main_model_config = context.models.get_config(self.unet.unet.key)
|
||||
assert isinstance(main_model_config, MainConfigBase)
|
||||
if main_model_config.variant is ModelVariantType.Inpaint:
|
||||
mask = blur_tensor
|
||||
vae_info: LoadedModel = context.models.load(self.vae.vae)
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(
|
||||
vae_info, self.fp32, self.tiled, masked_image.clone()
|
||||
)
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
|
||||
return GradientMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=None, gradient=True),
|
||||
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=True),
|
||||
expanded_mask_area=ImageField(image_name=expanded_image_dto.image_name),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
@@ -17,12 +17,6 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteAllResult:
|
||||
deleted_count: int
|
||||
freed_space_bytes: float
|
||||
|
||||
|
||||
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||
|
||||
@@ -35,6 +29,12 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
self._ephemeral = ephemeral
|
||||
self._base_output_dir = output_dir
|
||||
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if self._ephemeral:
|
||||
# Remove dangling tempdirs that might have been left over from an earlier unplanned shutdown.
|
||||
for temp_dir in filter(Path.is_dir, self._base_output_dir.glob("tmp*")):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
||||
self._tempdir = (
|
||||
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
||||
|
||||
@@ -301,12 +301,12 @@ class MainConfigBase(ModelConfigBase):
|
||||
default_settings: Optional[MainModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
|
||||
|
||||
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
||||
"""Model config for main checkpoint models."""
|
||||
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
@@ -155,7 +155,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="IP Adapter",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_sd15",
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_sd15/resolve/main/ip-adapter_sd15.safetensors",
|
||||
description="IP-Adapter for SD 1.5 models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
@@ -163,7 +163,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="IP Adapter Plus",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_plus_sd15",
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_plus_sd15/resolve/main/ip-adapter-plus_sd15.safetensors",
|
||||
description="Refined IP-Adapter for SD 1.5 models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
@@ -171,7 +171,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="IP Adapter Plus Face",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
source="InvokeAI/ip_adapter_plus_face_sd15",
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15/resolve/main/ip-adapter-plus-face_sd15.safetensors",
|
||||
description="Refined IP-Adapter for SD 1.5 models, adapted for faces",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sd_image_encoder],
|
||||
@@ -179,7 +179,7 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
StarterModel(
|
||||
name="IP Adapter SDXL",
|
||||
base=BaseModelType.StableDiffusionXL,
|
||||
source="InvokeAI/ip_adapter_sdxl",
|
||||
source="https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h/resolve/main/ip-adapter_sdxl_vit-h.safetensors",
|
||||
description="IP-Adapter for SDXL models",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[ip_adapter_sdxl_image_encoder],
|
||||
|
||||
@@ -51,6 +51,7 @@
|
||||
}
|
||||
},
|
||||
"dependencies": {
|
||||
"@chakra-ui/react-use-size": "^2.1.0",
|
||||
"@dagrejs/dagre": "^1.1.1",
|
||||
"@dagrejs/graphlib": "^2.2.1",
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
|
||||
3
invokeai/frontend/web/pnpm-lock.yaml
generated
3
invokeai/frontend/web/pnpm-lock.yaml
generated
@@ -8,6 +8,9 @@ dependencies:
|
||||
'@chakra-ui/react':
|
||||
specifier: ^2.8.2
|
||||
version: 2.8.2(@emotion/react@11.11.3)(@emotion/styled@11.11.0)(@types/react@18.2.59)(framer-motion@11.0.6)(react-dom@18.2.0)(react@18.2.0)
|
||||
'@chakra-ui/react-use-size':
|
||||
specifier: ^2.1.0
|
||||
version: 2.1.0(react@18.2.0)
|
||||
'@dagrejs/dagre':
|
||||
specifier: ^1.1.1
|
||||
version: 1.1.1
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/hooks/useControlNetOrT2IAdapterDefaultSettings';
|
||||
import { DefaultPreprocessor } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/DefaultPreprocessor';
|
||||
@@ -92,13 +92,9 @@ export const ControlNetOrT2IAdapterDefaultSettings = () => {
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultPreprocessor control={control} name="preprocessor" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<SimpleGrid columns={2} gap={8}>
|
||||
<DefaultPreprocessor control={control} name="preprocessor" />
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Button, Flex, Heading, Text } from '@invoke-ai/ui-library';
|
||||
import { Button, Flex, Heading, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useMainModelDefaultSettings } from 'features/modelManagerV2/hooks/useMainModelDefaultSettings';
|
||||
import { DefaultHeight } from 'features/modelManagerV2/subpanels/ModelPanel/MainModelDefaultSettings/DefaultHeight';
|
||||
@@ -122,40 +122,16 @@ export const MainModelDefaultSettings = () => {
|
||||
</Button>
|
||||
</Flex>
|
||||
|
||||
<Flex flexDir="column" gap={8}>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultVae control={control} name="vae" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex gap={8}>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultWidth control={control} optimalDimension={optimalDimension} />
|
||||
</Flex>
|
||||
<Flex gap={4} w="full">
|
||||
<DefaultHeight control={control} optimalDimension={optimalDimension} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<SimpleGrid columns={2} gap={8}>
|
||||
<DefaultVae control={control} name="vae" />
|
||||
<DefaultVaePrecision control={control} name="vaePrecision" />
|
||||
<DefaultScheduler control={control} name="scheduler" />
|
||||
<DefaultSteps control={control} name="steps" />
|
||||
<DefaultCfgScale control={control} name="cfgScale" />
|
||||
<DefaultCfgRescaleMultiplier control={control} name="cfgRescaleMultiplier" />
|
||||
<DefaultWidth control={control} optimalDimension={optimalDimension} />
|
||||
<DefaultHeight control={control} optimalDimension={optimalDimension} />
|
||||
</SimpleGrid>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
FormLabel,
|
||||
Heading,
|
||||
Input,
|
||||
SimpleGrid,
|
||||
Text,
|
||||
Textarea,
|
||||
} from '@invoke-ai/ui-library';
|
||||
@@ -66,25 +67,21 @@ export const ModelEdit = ({ form }: Props) => {
|
||||
<Heading as="h3" fontSize="md" mt="4">
|
||||
{t('modelManager.modelSettings')}
|
||||
</Heading>
|
||||
<Flex gap={4}>
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.pathToConfig')}</FormLabel>
|
||||
<Input {...form.register('config_path', stringFieldOptions)} />
|
||||
</FormControl>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
<Flex gap={4}>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.predictionType')}</FormLabel>
|
||||
<PredictionTypeSelect control={form.control} />
|
||||
@@ -93,9 +90,9 @@ export const ModelEdit = ({ form }: Props) => {
|
||||
<FormLabel>{t('modelManager.upcastAttention')}</FormLabel>
|
||||
<Checkbox {...form.register('upcast_attention')} />
|
||||
</FormControl>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</SimpleGrid>
|
||||
</Flex>
|
||||
</form>
|
||||
</Flex>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { Box, Flex, Text } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, SimpleGrid, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ControlNetOrT2IAdapterDefaultSettings } from 'features/modelManagerV2/subpanels/ModelPanel/ControlNetOrT2IAdapterDefaultSettings/ControlNetOrT2IAdapterDefaultSettings';
|
||||
@@ -24,57 +24,32 @@ export const ModelView = () => {
|
||||
return (
|
||||
<Flex flexDir="column" h="full" gap={4}>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
||||
</Flex>
|
||||
|
||||
<SimpleGrid columns={2} gap={4}>
|
||||
<ModelAttrView label={t('modelManager.baseModel')} value={data.base} />
|
||||
<ModelAttrView label={t('modelManager.modelType')} value={data.type} />
|
||||
<ModelAttrView label={t('common.format')} value={data.format} />
|
||||
<ModelAttrView label={t('modelManager.path')} value={data.path} />
|
||||
{data.type === 'main' && <ModelAttrView label={t('modelManager.variant')} value={data.variant} />}
|
||||
{data.type === 'main' && data.format === 'diffusers' && data.repo_variant && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
||||
</Flex>
|
||||
<ModelAttrView label={t('modelManager.repoVariant')} value={data.repo_variant} />
|
||||
)}
|
||||
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
||||
<ModelAttrView label={t('modelManager.variant')} value={data.variant} />
|
||||
</Flex>
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
||||
</Flex>
|
||||
<ModelAttrView label={t('modelManager.pathToConfig')} value={data.config_path} />
|
||||
<ModelAttrView label={t('modelManager.predictionType')} value={data.prediction_type} />
|
||||
<ModelAttrView label={t('modelManager.upcastAttention')} value={`${data.upcast_attention}`} />
|
||||
</>
|
||||
)}
|
||||
|
||||
{data.type === 'ip_adapter' && data.format === 'invokeai' && (
|
||||
<Flex gap={2}>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||
</Flex>
|
||||
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
|
||||
)}
|
||||
</Flex>
|
||||
</SimpleGrid>
|
||||
</Box>
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && <MainModelDefaultSettings />}
|
||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && <ControlNetOrT2IAdapterDefaultSettings />}
|
||||
{(data.type === 'main' || data.type === 'lora') && <TriggerPhrases />}
|
||||
</Box>
|
||||
{data.type === 'main' && data.base !== 'sdxl-refiner' && (
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<MainModelDefaultSettings />
|
||||
</Box>
|
||||
)}
|
||||
{(data.type === 'controlnet' || data.type === 't2i_adapter') && (
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<ControlNetOrT2IAdapterDefaultSettings />
|
||||
</Box>
|
||||
)}
|
||||
{(data.type === 'main' || data.type === 'lora') && (
|
||||
<Box layerStyle="second" borderRadius="base" p={4}>
|
||||
<TriggerPhrases />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -77,9 +77,17 @@ export const TriggerPhrases = () => {
|
||||
[updateModel, selectedModelKey, triggerPhrases]
|
||||
);
|
||||
|
||||
const onTriggerPhraseAddFormSubmit = useCallback(
|
||||
(e: React.FormEvent<HTMLFormElement>) => {
|
||||
e.preventDefault();
|
||||
addTriggerPhrase();
|
||||
},
|
||||
[addTriggerPhrase]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" w="full" gap="5">
|
||||
<form>
|
||||
<form onSubmit={onTriggerPhraseAddFormSubmit}>
|
||||
<FormControl w="full" isInvalid={Boolean(errors.length)} orientation="vertical">
|
||||
<FormLabel>{t('modelManager.triggerPhrases')}</FormLabel>
|
||||
<Flex flexDir="column" w="full">
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
CANVAS_TEXT_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_IMAGE_GRAPH,
|
||||
IMAGE_TO_LATENTS,
|
||||
INPAINT_CREATE_MASK,
|
||||
INPAINT_IMAGE,
|
||||
LATENTS_TO_IMAGE,
|
||||
MAIN_MODEL_LOADER,
|
||||
@@ -145,6 +146,16 @@ export const addVAEToGraph = async (
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: isSeamlessEnabled ? SEAMLESS : isAutoVae ? modelLoaderNodeId : VAE_LOADER,
|
||||
field: 'vae',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'vae',
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
source: {
|
||||
|
||||
@@ -133,6 +133,8 @@ export const buildCanvasInpaintGraph = async (
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
minimum_denoise: canvasCoherenceMinDenoise,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
tiled: false,
|
||||
fp32: fp32,
|
||||
},
|
||||
[DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@@ -182,6 +184,16 @@ export const buildCanvasInpaintGraph = async (
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
// Connect CLIP Skip to Conditioning
|
||||
{
|
||||
source: {
|
||||
@@ -331,6 +343,16 @@ export const buildCanvasInpaintGraph = async (
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Resize Down
|
||||
{
|
||||
source: {
|
||||
|
||||
@@ -157,6 +157,8 @@ export const buildCanvasOutpaintGraph = async (
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
minimum_denoise: canvasCoherenceMinDenoise,
|
||||
tiled: false,
|
||||
fp32: fp32,
|
||||
},
|
||||
[DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@@ -207,6 +209,16 @@ export const buildCanvasOutpaintGraph = async (
|
||||
field: 'clip',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
// Connect CLIP Skip to Conditioning
|
||||
{
|
||||
source: {
|
||||
@@ -453,6 +465,16 @@ export const buildCanvasOutpaintGraph = async (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Resize Results Down
|
||||
{
|
||||
source: {
|
||||
|
||||
@@ -135,6 +135,8 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
tiled: false,
|
||||
fp32: fp32,
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@@ -214,6 +216,16 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
// Connect Everything To Inpaint Node
|
||||
{
|
||||
source: {
|
||||
@@ -342,6 +354,16 @@ export const buildCanvasSDXLInpaintGraph = async (
|
||||
field: 'mask',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Resize Down
|
||||
{
|
||||
source: {
|
||||
|
||||
@@ -157,6 +157,8 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
coherence_mode: canvasCoherenceMode,
|
||||
edge_radius: canvasCoherenceEdgeSize,
|
||||
minimum_denoise: refinerModel ? Math.max(0.2, canvasCoherenceMinDenoise) : canvasCoherenceMinDenoise,
|
||||
tiled: false,
|
||||
fp32: fp32,
|
||||
},
|
||||
[SDXL_DENOISE_LATENTS]: {
|
||||
type: 'denoise_latents',
|
||||
@@ -237,6 +239,16 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
field: 'clip2',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: modelLoaderNodeId,
|
||||
field: 'unet',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'unet',
|
||||
},
|
||||
},
|
||||
// Connect Infill Result To Inpaint Image
|
||||
{
|
||||
source: {
|
||||
@@ -451,6 +463,16 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: INPAINT_IMAGE_RESIZE_UP,
|
||||
field: 'image',
|
||||
},
|
||||
destination: {
|
||||
node_id: INPAINT_CREATE_MASK,
|
||||
field: 'image',
|
||||
},
|
||||
},
|
||||
// Take combined mask and resize
|
||||
{
|
||||
source: {
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Flex } from '@invoke-ai/ui-library';
|
||||
import { StageComponent } from 'features/regionalPrompts/components/StageComponent';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const AspectRatioPreview = memo(() => {
|
||||
export const AspectRatioCanvasPreview = memo(() => {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center" position="relative">
|
||||
<StageComponent asPreview />
|
||||
@@ -10,4 +10,4 @@ export const AspectRatioPreview = memo(() => {
|
||||
);
|
||||
});
|
||||
|
||||
AspectRatioPreview.displayName = 'AspectRatioPreview';
|
||||
AspectRatioCanvasPreview.displayName = 'AspectRatioCanvasPreview';
|
||||
@@ -0,0 +1,75 @@
|
||||
import { useSize } from '@chakra-ui/react-use-size';
|
||||
import { Flex, Icon } from '@invoke-ai/ui-library';
|
||||
import { useImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||
import { AnimatePresence, motion } from 'framer-motion';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
import { PiFrameCorners } from 'react-icons/pi';
|
||||
|
||||
import {
|
||||
BOX_SIZE_CSS_CALC,
|
||||
ICON_CONTAINER_STYLES,
|
||||
ICON_HIGH_CUTOFF,
|
||||
ICON_LOW_CUTOFF,
|
||||
MOTION_ICON_ANIMATE,
|
||||
MOTION_ICON_EXIT,
|
||||
MOTION_ICON_INITIAL,
|
||||
} from './constants';
|
||||
|
||||
export const AspectRatioIconPreview = memo(() => {
|
||||
const ctx = useImageSizeContext();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const containerSize = useSize(containerRef);
|
||||
|
||||
const shouldShowIcon = useMemo(
|
||||
() => ctx.aspectRatioState.value < ICON_HIGH_CUTOFF && ctx.aspectRatioState.value > ICON_LOW_CUTOFF,
|
||||
[ctx.aspectRatioState.value]
|
||||
);
|
||||
|
||||
const { width, height } = useMemo(() => {
|
||||
if (!containerSize) {
|
||||
return { width: 0, height: 0 };
|
||||
}
|
||||
|
||||
let width = ctx.width;
|
||||
let height = ctx.height;
|
||||
|
||||
if (ctx.width > ctx.height) {
|
||||
width = containerSize.width;
|
||||
height = width / ctx.aspectRatioState.value;
|
||||
} else {
|
||||
height = containerSize.height;
|
||||
width = height * ctx.aspectRatioState.value;
|
||||
}
|
||||
|
||||
return { width, height };
|
||||
}, [containerSize, ctx.width, ctx.height, ctx.aspectRatioState.value]);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center" ref={containerRef}>
|
||||
<Flex
|
||||
bg="blackAlpha.400"
|
||||
borderRadius="base"
|
||||
width={`${width}px`}
|
||||
height={`${height}px`}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<AnimatePresence>
|
||||
{shouldShowIcon && (
|
||||
<Flex
|
||||
as={motion.div}
|
||||
initial={MOTION_ICON_INITIAL}
|
||||
animate={MOTION_ICON_ANIMATE}
|
||||
exit={MOTION_ICON_EXIT}
|
||||
style={ICON_CONTAINER_STYLES}
|
||||
>
|
||||
<Icon as={PiFrameCorners} color="base.700" boxSize={BOX_SIZE_CSS_CALC} />
|
||||
</Flex>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
AspectRatioIconPreview.displayName = 'AspectRatioIconPreview';
|
||||
@@ -1,6 +1,5 @@
|
||||
import type { FormLabelProps } from '@invoke-ai/ui-library';
|
||||
import { Flex, FormControlGroup } from '@invoke-ai/ui-library';
|
||||
import { AspectRatioPreview } from 'features/parameters/components/ImageSize/AspectRatioPreview';
|
||||
import { AspectRatioSelect } from 'features/parameters/components/ImageSize/AspectRatioSelect';
|
||||
import type { ImageSizeContextInnerValue } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||
import { ImageSizeContext } from 'features/parameters/components/ImageSize/ImageSizeContext';
|
||||
@@ -13,10 +12,11 @@ import { memo } from 'react';
|
||||
type ImageSizeProps = ImageSizeContextInnerValue & {
|
||||
widthComponent: ReactNode;
|
||||
heightComponent: ReactNode;
|
||||
previewComponent: ReactNode;
|
||||
};
|
||||
|
||||
export const ImageSize = memo((props: ImageSizeProps) => {
|
||||
const { widthComponent, heightComponent, ...ctx } = props;
|
||||
const { widthComponent, heightComponent, previewComponent, ...ctx } = props;
|
||||
return (
|
||||
<ImageSizeContext.Provider value={ctx}>
|
||||
<Flex gap={4} alignItems="center">
|
||||
@@ -33,7 +33,7 @@ export const ImageSize = memo((props: ImageSizeProps) => {
|
||||
</FormControlGroup>
|
||||
</Flex>
|
||||
<Flex w="108px" h="108px" flexShrink={0} flexGrow={0}>
|
||||
<AspectRatioPreview />
|
||||
{previewComponent}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</ImageSizeContext.Provider>
|
||||
|
||||
@@ -1,7 +1,29 @@
|
||||
import type { ComboboxOption } from '@invoke-ai/ui-library';
|
||||
|
||||
import type { AspectRatioID, AspectRatioState } from './types';
|
||||
|
||||
// When the aspect ratio is between these two values, we show the icon (experimentally determined)
|
||||
export const ICON_LOW_CUTOFF = 0.23;
|
||||
export const ICON_HIGH_CUTOFF = 1 / ICON_LOW_CUTOFF;
|
||||
const ICON_SIZE_PX = 64;
|
||||
const ICON_PADDING_PX = 16;
|
||||
export const BOX_SIZE_CSS_CALC = `min(${ICON_SIZE_PX}px, calc(100% - ${ICON_PADDING_PX}px))`;
|
||||
export const MOTION_ICON_INITIAL = {
|
||||
opacity: 0,
|
||||
};
|
||||
export const MOTION_ICON_ANIMATE = {
|
||||
opacity: 1,
|
||||
transition: { duration: 0.1 },
|
||||
};
|
||||
export const MOTION_ICON_EXIT = {
|
||||
opacity: 0,
|
||||
transition: { duration: 0.1 },
|
||||
};
|
||||
export const ICON_CONTAINER_STYLES = {
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
};
|
||||
export const ASPECT_RATIO_OPTIONS: ComboboxOption[] = [
|
||||
{ label: 'Free' as const, value: 'Free' },
|
||||
{ label: '16:9' as const, value: '16:9' },
|
||||
|
||||
@@ -75,7 +75,7 @@ export const RPLayerListItem = memo(({ layerId }: Props) => {
|
||||
<RPLayerSettingsPopover layerId={layerId} />
|
||||
<RPLayerMenu layerId={layerId} />
|
||||
</Flex>
|
||||
<AddPromptButtons layerId={layerId} />
|
||||
{!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons layerId={layerId} />}
|
||||
{hasPositivePrompt && <RPLayerPositivePrompt layerId={layerId} />}
|
||||
{hasNegativePrompt && <RPLayerNegativePrompt layerId={layerId} />}
|
||||
{hasIPAdapters && <RPLayerIPAdapterList layerId={layerId} />}
|
||||
|
||||
@@ -6,6 +6,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useMouseEvents } from 'features/regionalPrompts/hooks/mouseEventHooks';
|
||||
import {
|
||||
$cursorPosition,
|
||||
$isMouseOver,
|
||||
$lastMouseDownPos,
|
||||
$tool,
|
||||
isVectorMaskLayer,
|
||||
@@ -14,7 +15,7 @@ import {
|
||||
layerTranslated,
|
||||
selectRegionalPromptsSlice,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import { renderers } from 'features/regionalPrompts/util/renderers';
|
||||
import { debouncedRenderers, renderers as normalRenderers } from 'features/regionalPrompts/util/renderers';
|
||||
import Konva from 'konva';
|
||||
import type { IRect } from 'konva/lib/types';
|
||||
import type { MutableRefObject } from 'react';
|
||||
@@ -49,18 +50,10 @@ const useStageRenderer = (
|
||||
const { onMouseDown, onMouseUp, onMouseMove, onMouseEnter, onMouseLeave, onMouseWheel } = useMouseEvents();
|
||||
const cursorPosition = useStore($cursorPosition);
|
||||
const lastMouseDownPos = useStore($lastMouseDownPos);
|
||||
const isMouseOver = useStore($isMouseOver);
|
||||
const selectedLayerIdColor = useAppSelector(selectSelectedLayerColor);
|
||||
|
||||
const renderLayers = useMemo(() => (asPreview ? renderers.layersDebounced : renderers.layers), [asPreview]);
|
||||
const renderToolPreview = useMemo(
|
||||
() => (asPreview ? renderers.toolPreviewDebounced : renderers.toolPreview),
|
||||
[asPreview]
|
||||
);
|
||||
const renderBbox = useMemo(() => (asPreview ? renderers.bboxDebounced : renderers.bbox), [asPreview]);
|
||||
const renderBackground = useMemo(
|
||||
() => (asPreview ? renderers.backgroundDebounced : renderers.background),
|
||||
[asPreview]
|
||||
);
|
||||
const layerIds = useMemo(() => state.layers.map((l) => l.id), [state.layers]);
|
||||
const renderers = useMemo(() => (asPreview ? debouncedRenderers : normalRenderers), [asPreview]);
|
||||
|
||||
const onLayerPosChanged = useCallback(
|
||||
(layerId: string, x: number, y: number) => {
|
||||
@@ -147,17 +140,19 @@ const useStageRenderer = (
|
||||
}, [stageRef, width, height, wrapper]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering brush preview');
|
||||
log.trace('Rendering tool preview');
|
||||
if (asPreview) {
|
||||
// Preview should not display tool
|
||||
return;
|
||||
}
|
||||
renderToolPreview(
|
||||
renderers.renderToolPreview(
|
||||
stageRef.current,
|
||||
tool,
|
||||
selectedLayerIdColor,
|
||||
state.globalMaskLayerOpacity,
|
||||
cursorPosition,
|
||||
lastMouseDownPos,
|
||||
isMouseOver,
|
||||
state.brushSize
|
||||
);
|
||||
}, [
|
||||
@@ -168,30 +163,38 @@ const useStageRenderer = (
|
||||
state.globalMaskLayerOpacity,
|
||||
cursorPosition,
|
||||
lastMouseDownPos,
|
||||
isMouseOver,
|
||||
state.brushSize,
|
||||
renderToolPreview,
|
||||
renderers,
|
||||
]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering layers');
|
||||
renderLayers(stageRef.current, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
}, [stageRef, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged, renderLayers]);
|
||||
renderers.renderLayers(stageRef.current, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
}, [stageRef, state.layers, state.globalMaskLayerOpacity, tool, onLayerPosChanged, renderers]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering bbox');
|
||||
if (asPreview) {
|
||||
// Preview should not display bboxes
|
||||
return;
|
||||
}
|
||||
renderBbox(stageRef.current, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown);
|
||||
}, [stageRef, asPreview, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown, renderBbox]);
|
||||
renderers.renderBbox(stageRef.current, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown);
|
||||
}, [stageRef, asPreview, state.layers, state.selectedLayerId, tool, onBboxChanged, onBboxMouseDown, renderers]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Rendering background');
|
||||
if (asPreview) {
|
||||
// The preview should not have a background
|
||||
return;
|
||||
}
|
||||
renderBackground(stageRef.current, width, height);
|
||||
}, [stageRef, asPreview, width, height, renderBackground]);
|
||||
renderers.renderBackground(stageRef.current, width, height);
|
||||
}, [stageRef, asPreview, width, height, renderers]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
log.trace('Arranging layers');
|
||||
renderers.arrangeLayers(stageRef.current, layerIds);
|
||||
}, [stageRef, layerIds, renderers]);
|
||||
};
|
||||
|
||||
type Props = {
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import type Konva from 'konva';
|
||||
import type { KonvaEventObject } from 'konva/lib/Node';
|
||||
import type { Vector2d } from 'konva/lib/types';
|
||||
import { useCallback, useRef } from 'react';
|
||||
|
||||
const getIsFocused = (stage: Konva.Stage) => {
|
||||
@@ -23,21 +24,28 @@ const getIsFocused = (stage: Konva.Stage) => {
|
||||
|
||||
export const getScaledFlooredCursorPosition = (stage: Konva.Stage) => {
|
||||
const pointerPosition = stage.getPointerPosition();
|
||||
|
||||
const stageTransform = stage.getAbsoluteTransform().copy();
|
||||
|
||||
if (!pointerPosition || !stageTransform) {
|
||||
return;
|
||||
}
|
||||
|
||||
const scaledCursorPosition = stageTransform.invert().point(pointerPosition);
|
||||
|
||||
return {
|
||||
x: Math.floor(scaledCursorPosition.x),
|
||||
y: Math.floor(scaledCursorPosition.y),
|
||||
};
|
||||
};
|
||||
|
||||
const syncCursorPos = (stage: Konva.Stage): Vector2d | null => {
|
||||
const pos = getScaledFlooredCursorPosition(stage);
|
||||
if (!pos) {
|
||||
return null;
|
||||
}
|
||||
$cursorPosition.set(pos);
|
||||
return pos;
|
||||
};
|
||||
|
||||
const BRUSH_SPACING = 20;
|
||||
|
||||
export const useMouseEvents = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectedLayerId = useAppSelector((s) => s.regionalPrompts.present.selectedLayerId);
|
||||
@@ -52,7 +60,7 @@ export const useMouseEvents = () => {
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = $cursorPosition.get();
|
||||
const pos = syncCursorPos(stage);
|
||||
if (!pos) {
|
||||
return;
|
||||
}
|
||||
@@ -61,12 +69,11 @@ export const useMouseEvents = () => {
|
||||
if (!selectedLayerId) {
|
||||
return;
|
||||
}
|
||||
// const tool = getTool();
|
||||
if (tool === 'brush' || tool === 'eraser') {
|
||||
dispatch(
|
||||
maskLayerLineAdded({
|
||||
layerId: selectedLayerId,
|
||||
points: [Math.floor(pos.x), Math.floor(pos.y), Math.floor(pos.x), Math.floor(pos.y)],
|
||||
points: [pos.x, pos.y, pos.x, pos.y],
|
||||
tool,
|
||||
})
|
||||
);
|
||||
@@ -109,33 +116,47 @@ export const useMouseEvents = () => {
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = getScaledFlooredCursorPosition(stage);
|
||||
const pos = syncCursorPos(stage);
|
||||
if (!pos || !selectedLayerId) {
|
||||
return;
|
||||
}
|
||||
$cursorPosition.set(pos);
|
||||
if (getIsFocused(stage) && $isMouseOver.get() && $isMouseDown.get() && (tool === 'brush' || tool === 'eraser')) {
|
||||
if (lastCursorPosRef.current) {
|
||||
if (Math.hypot(lastCursorPosRef.current[0] - pos.x, lastCursorPosRef.current[1] - pos.y) < 20) {
|
||||
// Dispatching redux events impacts perf substantially - using brush spacing keeps dispatches to a reasonable number
|
||||
if (Math.hypot(lastCursorPosRef.current[0] - pos.x, lastCursorPosRef.current[1] - pos.y) < BRUSH_SPACING) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
lastCursorPosRef.current = [Math.floor(pos.x), Math.floor(pos.y)];
|
||||
lastCursorPosRef.current = [pos.x, pos.y];
|
||||
dispatch(maskLayerPointsAdded({ layerId: selectedLayerId, point: lastCursorPosRef.current }));
|
||||
}
|
||||
},
|
||||
[dispatch, selectedLayerId, tool]
|
||||
);
|
||||
|
||||
const onMouseLeave = useCallback((e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
$isMouseOver.set(false);
|
||||
$isMouseDown.set(false);
|
||||
$cursorPosition.set(null);
|
||||
}, []);
|
||||
const onMouseLeave = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent | TouchEvent>) => {
|
||||
const stage = e.target.getStage();
|
||||
if (!stage) {
|
||||
return;
|
||||
}
|
||||
const pos = syncCursorPos(stage);
|
||||
if (
|
||||
pos &&
|
||||
selectedLayerId &&
|
||||
getIsFocused(stage) &&
|
||||
$isMouseOver.get() &&
|
||||
$isMouseDown.get() &&
|
||||
(tool === 'brush' || tool === 'eraser')
|
||||
) {
|
||||
dispatch(maskLayerPointsAdded({ layerId: selectedLayerId, point: [pos.x, pos.y] }));
|
||||
}
|
||||
$isMouseOver.set(false);
|
||||
$isMouseDown.set(false);
|
||||
$cursorPosition.set(null);
|
||||
},
|
||||
[selectedLayerId, tool, dispatch]
|
||||
);
|
||||
|
||||
const onMouseEnter = useCallback(
|
||||
(e: KonvaEventObject<MouseEvent>) => {
|
||||
@@ -144,7 +165,7 @@ export const useMouseEvents = () => {
|
||||
return;
|
||||
}
|
||||
$isMouseOver.set(true);
|
||||
const pos = $cursorPosition.get();
|
||||
const pos = syncCursorPos(stage);
|
||||
if (!pos) {
|
||||
return;
|
||||
}
|
||||
@@ -162,7 +183,7 @@ export const useMouseEvents = () => {
|
||||
dispatch(
|
||||
maskLayerLineAdded({
|
||||
layerId: selectedLayerId,
|
||||
points: [Math.floor(pos.x), Math.floor(pos.y), Math.floor(pos.x), Math.floor(pos.y)],
|
||||
points: [pos.x, pos.y, pos.x, pos.y],
|
||||
tool,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -16,7 +16,7 @@ type DrawingTool = 'brush' | 'eraser';
|
||||
|
||||
export type Tool = DrawingTool | 'move' | 'rect';
|
||||
|
||||
type VectorMaskLine = {
|
||||
export type VectorMaskLine = {
|
||||
id: string;
|
||||
type: 'vector_mask_line';
|
||||
tool: DrawingTool;
|
||||
@@ -24,7 +24,7 @@ type VectorMaskLine = {
|
||||
points: number[];
|
||||
};
|
||||
|
||||
type VectorMaskRect = {
|
||||
export type VectorMaskRect = {
|
||||
id: string;
|
||||
type: 'vector_mask_rect';
|
||||
x: number;
|
||||
@@ -109,7 +109,7 @@ export const regionalPromptsSlice = createSlice({
|
||||
y: 0,
|
||||
autoNegative: 'invert',
|
||||
needsPixelBbox: false,
|
||||
positivePrompt: null,
|
||||
positivePrompt: '',
|
||||
negativePrompt: null,
|
||||
ipAdapterIds: [],
|
||||
};
|
||||
|
||||
@@ -20,7 +20,7 @@ export const getRegionalPromptLayerBlobs = async (
|
||||
const reduxLayers = state.regionalPrompts.present.layers;
|
||||
const container = document.createElement('div');
|
||||
const stage = new Konva.Stage({ container, width: state.generation.width, height: state.generation.height });
|
||||
renderers.layers(stage, reduxLayers, 1, 'brush');
|
||||
renderers.renderLayers(stage, reduxLayers, 1, 'brush');
|
||||
|
||||
const konvaLayers = stage.find<Konva.Layer>(`.${VECTOR_MASK_LAYER_NAME}`);
|
||||
const blobs: Record<string, Blob> = {};
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { rgbaColorToString, rgbColorToString } from 'features/canvas/util/colorToString';
|
||||
import { getScaledFlooredCursorPosition } from 'features/regionalPrompts/hooks/mouseEventHooks';
|
||||
import type { Layer, Tool, VectorMaskLayer } from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import type {
|
||||
Layer,
|
||||
Tool,
|
||||
VectorMaskLayer,
|
||||
VectorMaskLine,
|
||||
VectorMaskRect,
|
||||
} from 'features/regionalPrompts/store/regionalPromptsSlice';
|
||||
import {
|
||||
$isMouseOver,
|
||||
$tool,
|
||||
BACKGROUND_LAYER_ID,
|
||||
BACKGROUND_RECT_ID,
|
||||
@@ -35,6 +40,7 @@ const BBOX_NOT_SELECTED_STROKE = 'rgba(255, 255, 255, 0.353)';
|
||||
const BBOX_NOT_SELECTED_MOUSEOVER_STROKE = 'rgba(255, 255, 255, 0.661)';
|
||||
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
||||
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
||||
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||
const STAGE_BG_DATAURL =
|
||||
'';
|
||||
|
||||
@@ -51,6 +57,68 @@ const selectVectorMaskObjects = (node: Konva.Node) => {
|
||||
return node.name() === VECTOR_MASK_LAYER_LINE_NAME || node.name() === VECTOR_MASK_LAYER_RECT_NAME;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates the brush preview layer.
|
||||
* @param stage The konva stage to render on.
|
||||
* @returns The brush preview layer.
|
||||
*/
|
||||
const createToolPreviewLayer = (stage: Konva.Stage) => {
|
||||
// Initialize the brush preview layer & add to the stage
|
||||
const toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: false, listening: false });
|
||||
stage.add(toolPreviewLayer);
|
||||
|
||||
// Add handlers to show/hide the brush preview layer
|
||||
stage.on('mousemove', (e) => {
|
||||
const tool = $tool.get();
|
||||
e.target
|
||||
.getStage()
|
||||
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||
?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
stage.on('mouseleave', (e) => {
|
||||
e.target.getStage()?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
|
||||
});
|
||||
stage.on('mouseenter', (e) => {
|
||||
const tool = $tool.get();
|
||||
e.target
|
||||
.getStage()
|
||||
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||
?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
|
||||
// Create the brush preview group & circles
|
||||
const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID });
|
||||
const brushPreviewFill = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_FILL_ID,
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewFill);
|
||||
const brushPreviewBorderInner = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_BORDER_INNER_ID,
|
||||
listening: false,
|
||||
stroke: BRUSH_BORDER_INNER_COLOR,
|
||||
strokeWidth: 1,
|
||||
strokeEnabled: true,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewBorderInner);
|
||||
const brushPreviewBorderOuter = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID,
|
||||
listening: false,
|
||||
stroke: BRUSH_BORDER_OUTER_COLOR,
|
||||
strokeWidth: 1,
|
||||
strokeEnabled: true,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewBorderOuter);
|
||||
toolPreviewLayer.add(brushPreviewGroup);
|
||||
|
||||
// Create the rect preview
|
||||
const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 });
|
||||
toolPreviewLayer.add(rectPreview);
|
||||
|
||||
return toolPreviewLayer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the brush preview for the selected tool.
|
||||
* @param stage The konva stage to render on.
|
||||
@@ -60,13 +128,14 @@ const selectVectorMaskObjects = (node: Konva.Node) => {
|
||||
* @param lastMouseDownPos The position of the last mouse down event - used for the rect tool.
|
||||
* @param brushSize The brush size.
|
||||
*/
|
||||
const toolPreview = (
|
||||
const renderToolPreview = (
|
||||
stage: Konva.Stage,
|
||||
tool: Tool,
|
||||
color: RgbColor | null,
|
||||
globalMaskLayerOpacity: number,
|
||||
cursorPos: Vector2d | null,
|
||||
lastMouseDownPos: Vector2d | null,
|
||||
isMouseOver: boolean,
|
||||
brushSize: number
|
||||
) => {
|
||||
const layerCount = stage.find(`.${VECTOR_MASK_LAYER_NAME}`).length;
|
||||
@@ -85,65 +154,9 @@ const toolPreview = (
|
||||
stage.container().style.cursor = 'none';
|
||||
}
|
||||
|
||||
let toolPreviewLayer = stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`);
|
||||
const toolPreviewLayer = stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`) ?? createToolPreviewLayer(stage);
|
||||
|
||||
// Create the layer if it doesn't exist
|
||||
if (!toolPreviewLayer) {
|
||||
// Initialize the brush preview layer & add to the stage
|
||||
toolPreviewLayer = new Konva.Layer({ id: TOOL_PREVIEW_LAYER_ID, visible: tool !== 'move', listening: false });
|
||||
stage.add(toolPreviewLayer);
|
||||
|
||||
// Add handlers to show/hide the brush preview layer
|
||||
stage.on('mousemove', (e) => {
|
||||
const tool = $tool.get();
|
||||
e.target
|
||||
.getStage()
|
||||
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||
?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
stage.on('mouseleave', (e) => {
|
||||
e.target.getStage()?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.visible(false);
|
||||
});
|
||||
stage.on('mouseenter', (e) => {
|
||||
const tool = $tool.get();
|
||||
e.target
|
||||
.getStage()
|
||||
?.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)
|
||||
?.visible(tool === 'brush' || tool === 'eraser');
|
||||
});
|
||||
|
||||
// Create the brush preview group & circles
|
||||
const brushPreviewGroup = new Konva.Group({ id: TOOL_PREVIEW_BRUSH_GROUP_ID });
|
||||
const brushPreviewFill = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_FILL_ID,
|
||||
listening: false,
|
||||
strokeEnabled: false,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewFill);
|
||||
const brushPreviewBorderInner = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_BORDER_INNER_ID,
|
||||
listening: false,
|
||||
stroke: BRUSH_BORDER_INNER_COLOR,
|
||||
strokeWidth: 1,
|
||||
strokeEnabled: true,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewBorderInner);
|
||||
const brushPreviewBorderOuter = new Konva.Circle({
|
||||
id: TOOL_PREVIEW_BRUSH_BORDER_OUTER_ID,
|
||||
listening: false,
|
||||
stroke: BRUSH_BORDER_OUTER_COLOR,
|
||||
strokeWidth: 1,
|
||||
strokeEnabled: true,
|
||||
});
|
||||
brushPreviewGroup.add(brushPreviewBorderOuter);
|
||||
toolPreviewLayer.add(brushPreviewGroup);
|
||||
|
||||
// Create the rect preview
|
||||
const rectPreview = new Konva.Rect({ id: TOOL_PREVIEW_RECT_ID, listening: false, stroke: 'white', strokeWidth: 1 });
|
||||
toolPreviewLayer.add(rectPreview);
|
||||
}
|
||||
|
||||
if (!$isMouseOver.get() || layerCount === 0) {
|
||||
if (!isMouseOver || layerCount === 0) {
|
||||
// We can bail early if the mouse isn't over the stage or there are no layers
|
||||
toolPreviewLayer.visible(false);
|
||||
return;
|
||||
@@ -200,85 +213,140 @@ const toolPreview = (
|
||||
}
|
||||
};
|
||||
|
||||
const vectorMaskLayer = (
|
||||
/**
|
||||
* Creates a vector mask layer.
|
||||
* @param stage The konva stage to attach the layer to.
|
||||
* @param reduxLayer The redux layer to create the konva layer from.
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes.
|
||||
*/
|
||||
const createVectorMaskLayer = (
|
||||
stage: Konva.Stage,
|
||||
vmLayer: VectorMaskLayer,
|
||||
vmLayerIndex: number,
|
||||
reduxLayer: VectorMaskLayer,
|
||||
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||
) => {
|
||||
// This layer hasn't been added to the konva state yet
|
||||
const konvaLayer = new Konva.Layer({
|
||||
id: reduxLayer.id,
|
||||
name: VECTOR_MASK_LAYER_NAME,
|
||||
draggable: true,
|
||||
dragDistance: 0,
|
||||
});
|
||||
|
||||
// Create a `dragmove` listener for this layer
|
||||
if (onLayerPosChanged) {
|
||||
konvaLayer.on('dragend', function (e) {
|
||||
onLayerPosChanged(reduxLayer.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
|
||||
});
|
||||
}
|
||||
|
||||
// The dragBoundFunc limits how far the layer can be dragged
|
||||
konvaLayer.dragBoundFunc(function (pos) {
|
||||
const cursorPos = getScaledFlooredCursorPosition(stage);
|
||||
if (!cursorPos) {
|
||||
return this.getAbsolutePosition();
|
||||
}
|
||||
// Prevent the user from dragging the layer out of the stage bounds.
|
||||
if (
|
||||
cursorPos.x < 0 ||
|
||||
cursorPos.x > stage.width() / stage.scaleX() ||
|
||||
cursorPos.y < 0 ||
|
||||
cursorPos.y > stage.height() / stage.scaleY()
|
||||
) {
|
||||
return this.getAbsolutePosition();
|
||||
}
|
||||
return pos;
|
||||
});
|
||||
|
||||
// The object group holds all of the layer's objects (e.g. lines and rects)
|
||||
const konvaObjectGroup = new Konva.Group({
|
||||
id: getVectorMaskLayerObjectGroupId(reduxLayer.id, uuidv4()),
|
||||
name: VECTOR_MASK_LAYER_OBJECT_GROUP_NAME,
|
||||
listening: false,
|
||||
});
|
||||
konvaLayer.add(konvaObjectGroup);
|
||||
|
||||
stage.add(konvaLayer);
|
||||
|
||||
return konvaLayer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a konva line from a redux vector mask line.
|
||||
* @param reduxObject The redux object to create the konva line from.
|
||||
* @param konvaGroup The konva group to add the line to.
|
||||
*/
|
||||
const createVectorMaskLine = (reduxObject: VectorMaskLine, konvaGroup: Konva.Group): Konva.Line => {
|
||||
const vectorMaskLine = new Konva.Line({
|
||||
id: reduxObject.id,
|
||||
key: reduxObject.id,
|
||||
name: VECTOR_MASK_LAYER_LINE_NAME,
|
||||
strokeWidth: reduxObject.strokeWidth,
|
||||
tension: 0,
|
||||
lineCap: 'round',
|
||||
lineJoin: 'round',
|
||||
shadowForStrokeEnabled: false,
|
||||
globalCompositeOperation: reduxObject.tool === 'brush' ? 'source-over' : 'destination-out',
|
||||
listening: false,
|
||||
});
|
||||
konvaGroup.add(vectorMaskLine);
|
||||
return vectorMaskLine;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates a konva rect from a redux vector mask rect.
|
||||
* @param reduxObject The redux object to create the konva rect from.
|
||||
* @param konvaGroup The konva group to add the rect to.
|
||||
*/
|
||||
const createVectorMaskRect = (reduxObject: VectorMaskRect, konvaGroup: Konva.Group): Konva.Rect => {
|
||||
const vectorMaskRect = new Konva.Rect({
|
||||
id: reduxObject.id,
|
||||
key: reduxObject.id,
|
||||
name: VECTOR_MASK_LAYER_RECT_NAME,
|
||||
x: reduxObject.x,
|
||||
y: reduxObject.y,
|
||||
width: reduxObject.width,
|
||||
height: reduxObject.height,
|
||||
listening: false,
|
||||
});
|
||||
konvaGroup.add(vectorMaskRect);
|
||||
return vectorMaskRect;
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders a vector mask layer.
|
||||
* @param stage The konva stage to render on.
|
||||
* @param reduxLayer The redux vector mask layer to render.
|
||||
* @param reduxLayerIndex The index of the layer in the redux store.
|
||||
* @param globalMaskLayerOpacity The opacity of the global mask layer.
|
||||
* @param tool The current tool.
|
||||
*/
|
||||
const renderVectorMaskLayer = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayer: VectorMaskLayer,
|
||||
globalMaskLayerOpacity: number,
|
||||
tool: Tool,
|
||||
onLayerPosChanged?: (layerId: string, x: number, y: number) => void
|
||||
) => {
|
||||
let konvaLayer = stage.findOne<Konva.Layer>(`#${vmLayer.id}`);
|
||||
|
||||
if (!konvaLayer) {
|
||||
// This layer hasn't been added to the konva state yet
|
||||
konvaLayer = new Konva.Layer({
|
||||
id: vmLayer.id,
|
||||
name: VECTOR_MASK_LAYER_NAME,
|
||||
draggable: true,
|
||||
dragDistance: 0,
|
||||
});
|
||||
|
||||
// Create a `dragmove` listener for this layer
|
||||
if (onLayerPosChanged) {
|
||||
konvaLayer.on('dragend', function (e) {
|
||||
onLayerPosChanged(vmLayer.id, Math.floor(e.target.x()), Math.floor(e.target.y()));
|
||||
});
|
||||
}
|
||||
|
||||
// The dragBoundFunc limits how far the layer can be dragged
|
||||
konvaLayer.dragBoundFunc(function (pos) {
|
||||
const cursorPos = getScaledFlooredCursorPosition(stage);
|
||||
if (!cursorPos) {
|
||||
return this.getAbsolutePosition();
|
||||
}
|
||||
// Prevent the user from dragging the layer out of the stage bounds.
|
||||
if (
|
||||
cursorPos.x < 0 ||
|
||||
cursorPos.x > stage.width() / stage.scaleX() ||
|
||||
cursorPos.y < 0 ||
|
||||
cursorPos.y > stage.height() / stage.scaleY()
|
||||
) {
|
||||
return this.getAbsolutePosition();
|
||||
}
|
||||
return pos;
|
||||
});
|
||||
|
||||
// The object group holds all of the layer's objects (e.g. lines and rects)
|
||||
const konvaObjectGroup = new Konva.Group({
|
||||
id: getVectorMaskLayerObjectGroupId(vmLayer.id, uuidv4()),
|
||||
name: VECTOR_MASK_LAYER_OBJECT_GROUP_NAME,
|
||||
listening: false,
|
||||
});
|
||||
konvaLayer.add(konvaObjectGroup);
|
||||
|
||||
stage.add(konvaLayer);
|
||||
|
||||
// When a layer is added, it ends up on top of the brush preview - we need to move the preview back to the top.
|
||||
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.moveToTop();
|
||||
}
|
||||
): void => {
|
||||
const konvaLayer =
|
||||
stage.findOne<Konva.Layer>(`#${reduxLayer.id}`) ?? createVectorMaskLayer(stage, reduxLayer, onLayerPosChanged);
|
||||
|
||||
// Update the layer's position and listening state
|
||||
konvaLayer.setAttrs({
|
||||
listening: tool === 'move', // The layer only listens when using the move tool - otherwise the stage is handling mouse events
|
||||
x: Math.floor(vmLayer.x),
|
||||
y: Math.floor(vmLayer.y),
|
||||
// We have a konva layer for each redux layer, plus a brush preview layer, which should always be on top. We can
|
||||
// therefore use the index of the redux layer as the zIndex for konva layers. If more layers are added to the
|
||||
// stage, this may no longer be work.
|
||||
zIndex: vmLayerIndex,
|
||||
x: Math.floor(reduxLayer.x),
|
||||
y: Math.floor(reduxLayer.y),
|
||||
});
|
||||
|
||||
// Convert the color to a string, stripping the alpha - the object group will handle opacity.
|
||||
const rgbColor = rgbColorToString(vmLayer.previewColor);
|
||||
const rgbColor = rgbColorToString(reduxLayer.previewColor);
|
||||
|
||||
const konvaObjectGroup = konvaLayer.findOne<Konva.Group>(`.${VECTOR_MASK_LAYER_OBJECT_GROUP_NAME}`);
|
||||
assert(konvaObjectGroup, `Object group not found for layer ${vmLayer.id}`);
|
||||
assert(konvaObjectGroup, `Object group not found for layer ${reduxLayer.id}`);
|
||||
|
||||
// We use caching to handle "global" layer opacity, but caching is expensive and we should only do it when required.
|
||||
let groupNeedsCache = false;
|
||||
|
||||
const objectIds = vmLayer.objects.map(mapId);
|
||||
const objectIds = reduxLayer.objects.map(mapId);
|
||||
for (const objectNode of konvaObjectGroup.find(selectVectorMaskObjects)) {
|
||||
if (!objectIds.includes(objectNode.id())) {
|
||||
objectNode.destroy();
|
||||
@@ -286,26 +354,10 @@ const vectorMaskLayer = (
|
||||
}
|
||||
}
|
||||
|
||||
for (const reduxObject of vmLayer.objects) {
|
||||
for (const reduxObject of reduxLayer.objects) {
|
||||
if (reduxObject.type === 'vector_mask_line') {
|
||||
let vectorMaskLine = stage.findOne<Konva.Line>(`#${reduxObject.id}`);
|
||||
|
||||
// Create the line if it doesn't exist
|
||||
if (!vectorMaskLine) {
|
||||
vectorMaskLine = new Konva.Line({
|
||||
id: reduxObject.id,
|
||||
key: reduxObject.id,
|
||||
name: VECTOR_MASK_LAYER_LINE_NAME,
|
||||
strokeWidth: reduxObject.strokeWidth,
|
||||
tension: 0,
|
||||
lineCap: 'round',
|
||||
lineJoin: 'round',
|
||||
shadowForStrokeEnabled: false,
|
||||
globalCompositeOperation: reduxObject.tool === 'brush' ? 'source-over' : 'destination-out',
|
||||
listening: false,
|
||||
});
|
||||
konvaObjectGroup.add(vectorMaskLine);
|
||||
}
|
||||
const vectorMaskLine =
|
||||
stage.findOne<Konva.Line>(`#${reduxObject.id}`) ?? createVectorMaskLine(reduxObject, konvaObjectGroup);
|
||||
|
||||
// Only update the points if they have changed. The point values are never mutated, they are only added to the
|
||||
// array, so checking the length is sufficient to determine if we need to re-cache.
|
||||
@@ -319,20 +371,9 @@ const vectorMaskLayer = (
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
} else if (reduxObject.type === 'vector_mask_rect') {
|
||||
let konvaObject = stage.findOne<Konva.Rect>(`#${reduxObject.id}`);
|
||||
if (!konvaObject) {
|
||||
konvaObject = new Konva.Rect({
|
||||
id: reduxObject.id,
|
||||
key: reduxObject.id,
|
||||
name: VECTOR_MASK_LAYER_RECT_NAME,
|
||||
x: reduxObject.x,
|
||||
y: reduxObject.y,
|
||||
width: reduxObject.width,
|
||||
height: reduxObject.height,
|
||||
listening: false,
|
||||
});
|
||||
konvaObjectGroup.add(konvaObject);
|
||||
}
|
||||
const konvaObject =
|
||||
stage.findOne<Konva.Rect>(`#${reduxObject.id}`) ?? createVectorMaskRect(reduxObject, konvaObjectGroup);
|
||||
|
||||
// Only update the color if it has changed.
|
||||
if (konvaObject.fill() !== rgbColor) {
|
||||
konvaObject.fill(rgbColor);
|
||||
@@ -342,20 +383,16 @@ const vectorMaskLayer = (
|
||||
}
|
||||
|
||||
// Only update layer visibility if it has changed.
|
||||
if (konvaLayer.visible() !== vmLayer.isVisible) {
|
||||
konvaLayer.visible(vmLayer.isVisible);
|
||||
if (konvaLayer.visible() !== reduxLayer.isVisible) {
|
||||
konvaLayer.visible(reduxLayer.isVisible);
|
||||
groupNeedsCache = true;
|
||||
}
|
||||
|
||||
if (konvaObjectGroup.children.length > 0) {
|
||||
// If we have objects, we need to cache the group to apply the layer opacity...
|
||||
if (groupNeedsCache) {
|
||||
// ...but only if we've done something that needs the cache.
|
||||
konvaObjectGroup.cache();
|
||||
}
|
||||
} else {
|
||||
// No children - clear the cache to reset the previous pixel data
|
||||
if (konvaObjectGroup.children.length === 0) {
|
||||
// No objects - clear the cache to reset the previous pixel data
|
||||
konvaObjectGroup.clearCache();
|
||||
} else if (groupNeedsCache) {
|
||||
konvaObjectGroup.cache();
|
||||
}
|
||||
|
||||
// Updating group opacity does not require re-caching
|
||||
@@ -372,7 +409,7 @@ const vectorMaskLayer = (
|
||||
* @param onLayerPosChanged Callback for when the layer's position changes. This is optional to allow for offscreen rendering.
|
||||
* @returns
|
||||
*/
|
||||
const layers = (
|
||||
const renderLayers = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayers: Layer[],
|
||||
globalMaskLayerOpacity: number,
|
||||
@@ -388,24 +425,57 @@ const layers = (
|
||||
}
|
||||
}
|
||||
|
||||
for (let layerIndex = 0; layerIndex < reduxLayers.length; layerIndex++) {
|
||||
const reduxLayer = reduxLayers[layerIndex];
|
||||
assert(reduxLayer, `Layer at index ${layerIndex} is undefined`);
|
||||
for (const reduxLayer of reduxLayers) {
|
||||
if (isVectorMaskLayer(reduxLayer)) {
|
||||
vectorMaskLayer(stage, reduxLayer, layerIndex, globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
renderVectorMaskLayer(stage, reduxLayer, globalMaskLayerOpacity, tool, onLayerPosChanged);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
*
|
||||
* @param stage The konva stage to render on.
|
||||
* @param tool The current tool.
|
||||
* @param selectedLayerIdId The currently selected layer id.
|
||||
* @param onBboxChanged A callback to be called when the bounding box changes.
|
||||
* Creates a bounding box rect for a layer.
|
||||
* @param reduxLayer The redux layer to create the bounding box for.
|
||||
* @param konvaLayer The konva layer to attach the bounding box to.
|
||||
* @param onBboxMouseDown Callback for when the bounding box is clicked.
|
||||
*/
|
||||
const createBboxRect = (reduxLayer: Layer, konvaLayer: Konva.Layer, onBboxMouseDown: (layerId: string) => void) => {
|
||||
const rect = new Konva.Rect({
|
||||
id: getLayerBboxId(reduxLayer.id),
|
||||
name: LAYER_BBOX_NAME,
|
||||
strokeWidth: 1,
|
||||
});
|
||||
rect.on('mousedown', function () {
|
||||
onBboxMouseDown(reduxLayer.id);
|
||||
});
|
||||
rect.on('mouseover', function (e) {
|
||||
if (getIsSelected(e.target.getLayer()?.id())) {
|
||||
this.stroke(BBOX_SELECTED_STROKE);
|
||||
} else {
|
||||
this.stroke(BBOX_NOT_SELECTED_MOUSEOVER_STROKE);
|
||||
}
|
||||
});
|
||||
rect.on('mouseout', function (e) {
|
||||
if (getIsSelected(e.target.getLayer()?.id())) {
|
||||
this.stroke(BBOX_SELECTED_STROKE);
|
||||
} else {
|
||||
this.stroke(BBOX_NOT_SELECTED_STROKE);
|
||||
}
|
||||
});
|
||||
konvaLayer.add(rect);
|
||||
return rect;
|
||||
};
|
||||
|
||||
/**
|
||||
* Renders the bounding boxes for the layers.
|
||||
* @param stage The konva stage to render on
|
||||
* @param reduxLayers An array of all redux layers to draw bboxes for
|
||||
* @param selectedLayerId The selected layer's id
|
||||
* @param tool The current tool
|
||||
* @param onBboxChanged Callback for when the bbox is changed
|
||||
* @param onBboxMouseDown Callback for when the bbox is clicked
|
||||
* @returns
|
||||
*/
|
||||
const bbox = (
|
||||
const renderBbox = (
|
||||
stage: Konva.Stage,
|
||||
reduxLayers: Layer[],
|
||||
selectedLayerId: string | null,
|
||||
@@ -433,7 +503,6 @@ const bbox = (
|
||||
if (reduxLayer.bboxNeedsUpdate && reduxLayer.objects.length) {
|
||||
// We only need to use the pixel-perfect bounding box if the layer has eraser strokes
|
||||
bbox = reduxLayer.needsPixelBbox ? getLayerBboxPixels(konvaLayer) : getLayerBboxFast(konvaLayer);
|
||||
|
||||
// Update the layer's bbox in the redux store
|
||||
onBboxChanged(reduxLayer.id, bbox);
|
||||
}
|
||||
@@ -442,32 +511,8 @@ const bbox = (
|
||||
continue;
|
||||
}
|
||||
|
||||
let rect = konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`);
|
||||
if (!rect) {
|
||||
rect = new Konva.Rect({
|
||||
id: getLayerBboxId(reduxLayer.id),
|
||||
name: LAYER_BBOX_NAME,
|
||||
strokeWidth: 1,
|
||||
});
|
||||
rect.on('mousedown', function () {
|
||||
onBboxMouseDown(reduxLayer.id);
|
||||
});
|
||||
rect.on('mouseover', function (e) {
|
||||
if (getIsSelected(e.target.getLayer()?.id())) {
|
||||
this.stroke(BBOX_SELECTED_STROKE);
|
||||
} else {
|
||||
this.stroke(BBOX_NOT_SELECTED_MOUSEOVER_STROKE);
|
||||
}
|
||||
});
|
||||
rect.on('mouseout', function (e) {
|
||||
if (getIsSelected(e.target.getLayer()?.id())) {
|
||||
this.stroke(BBOX_SELECTED_STROKE);
|
||||
} else {
|
||||
this.stroke(BBOX_NOT_SELECTED_STROKE);
|
||||
}
|
||||
});
|
||||
konvaLayer.add(rect);
|
||||
}
|
||||
const rect =
|
||||
konvaLayer.findOne<Konva.Rect>(`.${LAYER_BBOX_NAME}`) ?? createBboxRect(reduxLayer, konvaLayer, onBboxMouseDown);
|
||||
|
||||
rect.setAttrs({
|
||||
visible: true,
|
||||
@@ -481,31 +526,41 @@ const bbox = (
|
||||
}
|
||||
};
|
||||
|
||||
const background = (stage: Konva.Stage, width: number, height: number) => {
|
||||
let layer = stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`);
|
||||
/**
|
||||
* Creates the background layer for the stage.
|
||||
* @param stage The konva stage to render on
|
||||
*/
|
||||
const createBackgroundLayer = (stage: Konva.Stage): Konva.Layer => {
|
||||
const layer = new Konva.Layer({
|
||||
id: BACKGROUND_LAYER_ID,
|
||||
});
|
||||
const background = new Konva.Rect({
|
||||
id: BACKGROUND_RECT_ID,
|
||||
x: stage.x(),
|
||||
y: 0,
|
||||
width: stage.width() / stage.scaleX(),
|
||||
height: stage.height() / stage.scaleY(),
|
||||
listening: false,
|
||||
opacity: 0.2,
|
||||
});
|
||||
layer.add(background);
|
||||
stage.add(layer);
|
||||
const image = new Image();
|
||||
image.onload = () => {
|
||||
background.fillPatternImage(image);
|
||||
};
|
||||
image.src = STAGE_BG_DATAURL;
|
||||
return layer;
|
||||
};
|
||||
|
||||
if (!layer) {
|
||||
layer = new Konva.Layer({
|
||||
id: BACKGROUND_LAYER_ID,
|
||||
});
|
||||
const background = new Konva.Rect({
|
||||
id: BACKGROUND_RECT_ID,
|
||||
x: stage.x(),
|
||||
y: 0,
|
||||
width: stage.width() / stage.scaleX(),
|
||||
height: stage.height() / stage.scaleY(),
|
||||
listening: false,
|
||||
opacity: 0.2,
|
||||
});
|
||||
layer.add(background);
|
||||
stage.add(layer);
|
||||
const image = new Image();
|
||||
image.onload = () => {
|
||||
background.fillPatternImage(image);
|
||||
};
|
||||
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||
image.src = STAGE_BG_DATAURL;
|
||||
}
|
||||
/**
|
||||
* Renders the background layer for the stage.
|
||||
* @param stage The konva stage to render on
|
||||
* @param width The unscaled width of the canvas
|
||||
* @param height The unscaled height of the canvas
|
||||
*/
|
||||
const renderBackground = (stage: Konva.Stage, width: number, height: number) => {
|
||||
const layer = stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`) ?? createBackgroundLayer(stage);
|
||||
|
||||
const background = layer.findOne<Konva.Rect>(`#${BACKGROUND_RECT_ID}`);
|
||||
assert(background, 'Background rect not found');
|
||||
@@ -528,15 +583,37 @@ const background = (stage: Konva.Stage, width: number, height: number) => {
|
||||
background.fillPatternOffset(stagePos);
|
||||
};
|
||||
|
||||
const DEBOUNCE_MS = 300;
|
||||
/**
|
||||
* Arranges all layers in the z-axis by updating their z-indices.
|
||||
* @param stage The konva stage
|
||||
* @param layerIds An array of redux layer ids, in their z-index order
|
||||
*/
|
||||
const arrangeLayers = (stage: Konva.Stage, layerIds: string[]): void => {
|
||||
let nextZIndex = 0;
|
||||
// Background is the first layer
|
||||
stage.findOne<Konva.Layer>(`#${BACKGROUND_LAYER_ID}`)?.zIndex(nextZIndex++);
|
||||
// Then arrange the redux layers in order
|
||||
for (const layerId of layerIds) {
|
||||
stage.findOne<Konva.Layer>(`#${layerId}`)?.zIndex(nextZIndex++);
|
||||
}
|
||||
// Finally, the tool preview layer is always on top
|
||||
stage.findOne<Konva.Layer>(`#${TOOL_PREVIEW_LAYER_ID}`)?.zIndex(nextZIndex++);
|
||||
};
|
||||
|
||||
export const renderers = {
|
||||
toolPreview,
|
||||
toolPreviewDebounced: debounce(toolPreview, DEBOUNCE_MS),
|
||||
layers,
|
||||
layersDebounced: debounce(layers, DEBOUNCE_MS),
|
||||
bbox,
|
||||
bboxDebounced: debounce(bbox, DEBOUNCE_MS),
|
||||
background,
|
||||
backgroundDebounced: debounce(background, DEBOUNCE_MS),
|
||||
renderToolPreview,
|
||||
renderLayers,
|
||||
renderBbox,
|
||||
renderBackground,
|
||||
arrangeLayers,
|
||||
};
|
||||
|
||||
const DEBOUNCE_MS = 300;
|
||||
|
||||
export const debouncedRenderers = {
|
||||
renderToolPreview: debounce(renderToolPreview, DEBOUNCE_MS),
|
||||
renderLayers: debounce(renderLayers, DEBOUNCE_MS),
|
||||
renderBbox: debounce(renderBbox, DEBOUNCE_MS),
|
||||
renderBackground: debounce(renderBackground, DEBOUNCE_MS),
|
||||
arrangeLayers: debounce(arrangeLayers, DEBOUNCE_MS),
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { aspectRatioChanged, setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
|
||||
import ParamBoundingBoxHeight from 'features/parameters/components/Canvas/BoundingBox/ParamBoundingBoxHeight';
|
||||
import ParamBoundingBoxWidth from 'features/parameters/components/Canvas/BoundingBox/ParamBoundingBoxWidth';
|
||||
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
|
||||
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
@@ -41,6 +42,7 @@ export const ImageSizeCanvas = memo(() => {
|
||||
aspectRatioState={aspectRatioState}
|
||||
heightComponent={<ParamBoundingBoxHeight />}
|
||||
widthComponent={<ParamBoundingBoxWidth />}
|
||||
previewComponent={<AspectRatioIconPreview />}
|
||||
onChangeAspectRatioState={onChangeAspectRatioState}
|
||||
onChangeWidth={onChangeWidth}
|
||||
onChangeHeight={onChangeHeight}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { ParamHeight } from 'features/parameters/components/Core/ParamHeight';
|
||||
import { ParamWidth } from 'features/parameters/components/Core/ParamWidth';
|
||||
import { AspectRatioCanvasPreview } from 'features/parameters/components/ImageSize/AspectRatioCanvasPreview';
|
||||
import { AspectRatioIconPreview } from 'features/parameters/components/ImageSize/AspectRatioIconPreview';
|
||||
import { ImageSize } from 'features/parameters/components/ImageSize/ImageSize';
|
||||
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
|
||||
import { aspectRatioChanged, heightChanged, widthChanged } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
export const ImageSizeLinear = memo(() => {
|
||||
const dispatch = useAppDispatch();
|
||||
const tab = useAppSelector(activeTabNameSelector);
|
||||
const width = useAppSelector((s) => s.generation.width);
|
||||
const height = useAppSelector((s) => s.generation.height);
|
||||
const aspectRatioState = useAppSelector((s) => s.generation.aspectRatio);
|
||||
@@ -40,6 +44,7 @@ export const ImageSizeLinear = memo(() => {
|
||||
aspectRatioState={aspectRatioState}
|
||||
heightComponent={<ParamHeight />}
|
||||
widthComponent={<ParamWidth />}
|
||||
previewComponent={tab === 'txt2img' ? <AspectRatioCanvasPreview /> : <AspectRatioIconPreview />}
|
||||
onChangeAspectRatioState={onChangeAspectRatioState}
|
||||
onChangeWidth={onChangeWidth}
|
||||
onChangeHeight={onChangeHeight}
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
||||
__version__ = "4.2.0a2"
|
||||
__version__ = "4.2.0a3"
|
||||
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
from torch import tensor
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import InvalidModelConfigException
|
||||
from invokeai.backend.model_manager.config import InvalidModelConfigException, MainDiffusersConfig, ModelVariantType
|
||||
from invokeai.backend.model_manager.probe import (
|
||||
CkptType,
|
||||
ModelProbe,
|
||||
@@ -78,3 +78,11 @@ def test_probe_handles_state_dict_with_integer_keys():
|
||||
}
|
||||
with pytest.raises(InvalidModelConfigException):
|
||||
ModelProbe.get_model_type_from_checkpoint(Path("embedding.pt"), state_dict_with_integer_keys)
|
||||
|
||||
|
||||
def test_probe_sd1_diffusers_inpainting(datadir: Path):
|
||||
config = ModelProbe.probe(datadir / "sd-1/main/dreamshaper-8-inpainting")
|
||||
assert isinstance(config, MainDiffusersConfig)
|
||||
assert config.base is BaseModelType.StableDiffusion1
|
||||
assert config.variant is ModelVariantType.Inpaint
|
||||
assert config.repo_variant is ModelRepoVariant.FP16
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
This folder contains config files copied from [Lykon/dreamshaper-8-inpainting](https://huggingface.co/Lykon/dreamshaper-8-inpainting).
|
||||
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"_class_name": "StableDiffusionInpaintPipeline",
|
||||
"_diffusers_version": "0.21.0.dev0",
|
||||
"_name_or_path": "lykon-models/dreamshaper-8-inpainting",
|
||||
"feature_extractor": [
|
||||
"transformers",
|
||||
"CLIPFeatureExtractor"
|
||||
],
|
||||
"requires_safety_checker": true,
|
||||
"safety_checker": [
|
||||
"stable_diffusion",
|
||||
"StableDiffusionSafetyChecker"
|
||||
],
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"DEISMultistepScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"_class_name": "DEISMultistepScheduler",
|
||||
"_diffusers_version": "0.21.0.dev0",
|
||||
"algorithm_type": "deis",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"dynamic_thresholding_ratio": 0.995,
|
||||
"lower_order_final": true,
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"solver_order": 2,
|
||||
"solver_type": "logrho",
|
||||
"steps_offset": 1,
|
||||
"thresholding": false,
|
||||
"timestep_spacing": "leading",
|
||||
"trained_betas": null,
|
||||
"use_karras_sigmas": false
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
{
|
||||
"_class_name": "UNet2DConditionModel",
|
||||
"_diffusers_version": "0.21.0.dev0",
|
||||
"_name_or_path": "/home/patrick/.cache/huggingface/hub/models--lykon-models--dreamshaper-8-inpainting/snapshots/15dcb9dec91a39ee498e3917c9ef6174b103862d/unet",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": null,
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": null,
|
||||
"attention_head_dim": 8,
|
||||
"attention_type": "default",
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280,
|
||||
1280
|
||||
],
|
||||
"center_input_sample": false,
|
||||
"class_embed_type": null,
|
||||
"class_embeddings_concat": false,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 768,
|
||||
"cross_attention_norm": null,
|
||||
"down_block_types": [
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"dual_cross_attention": false,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 9,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_only_cross_attention": null,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"out_channels": 4,
|
||||
"projection_class_embeddings_input_dim": null,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": false,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"sample_size": 64,
|
||||
"time_cond_proj_dim": null,
|
||||
"time_embedding_act_fn": null,
|
||||
"time_embedding_dim": null,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": null,
|
||||
"transformer_layers_per_block": 1,
|
||||
"up_block_types": [
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D"
|
||||
],
|
||||
"upcast_attention": null,
|
||||
"use_linear_projection": false
|
||||
}
|
||||
@@ -99,6 +99,20 @@ def test_obj_serializer_ephemeral_writes_to_tempdir(tmp_path: Path):
|
||||
assert not Path(tmp_path, obj_1_name).exists()
|
||||
|
||||
|
||||
def test_obj_serializer_ephemeral_deletes_dangling_tempdirs_on_init(tmp_path: Path):
|
||||
tempdir = tmp_path / "tmpdir"
|
||||
tempdir.mkdir()
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=True)
|
||||
assert not tempdir.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_does_not_delete_tempdirs_on_init(tmp_path: Path):
|
||||
tempdir = tmp_path / "tmpdir"
|
||||
tempdir.mkdir()
|
||||
ObjectSerializerDisk[MockDataclass](tmp_path, ephemeral=False)
|
||||
assert tempdir.exists()
|
||||
|
||||
|
||||
def test_obj_serializer_disk_different_types(tmp_path: Path):
|
||||
obj_serializer_1 = ObjectSerializerDisk[MockDataclass](tmp_path)
|
||||
obj_1 = MockDataclass(foo="bar")
|
||||
|
||||
Reference in New Issue
Block a user