tweaks in response to psychedelicious review of PR

This commit is contained in:
Lincoln Stein
2023-07-24 09:23:51 -04:00
parent 7ce5b6504f
commit 4194a0ed99
8 changed files with 96 additions and 130 deletions

View File

Before

Width:  |  Height:  |  Size: 33 KiB

After

Width:  |  Height:  |  Size: 33 KiB

View File

@@ -20,7 +20,7 @@ from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .image_defs import ImageOutput, PILInvocationConfig
from ..models.image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################

View File

@@ -4,24 +4,21 @@ from typing import Literal, Optional
import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
from pydantic import Field
from pathlib import Path
from typing import Union
from invokeai.app.invocations.metadata import CoreMetadata
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ..models.image import (
ImageCategory, ImageField, ResourceOrigin,
PILInvocationConfig, ImageOutput, MaskOutput,
)
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from .image_defs import (
PILInvocationConfig,
ImageOutput,
MaskOutput,
)
from ..services.config import InvokeAIAppConfig
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend import SilenceWarnings
@@ -644,7 +641,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
device = choose_torch_device()
if self.enabled:
logger.info("Running NSFW checker")
logger.debug("Running NSFW checker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
@@ -681,8 +678,8 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
)
def _get_caution_img(self)->Image:
import invokeai.assets.web as web_assets
caution = Image.open(Path(web_assets.__path__[0]) / 'caution.png')
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / 'caution.png')
return caution.resize((caution.width // 2, caution.height //2))
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
@@ -716,7 +713,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
logger = context.services.logger
image = context.services.images.get_pil_image(self.image.image_name)
if self.enabled:
logger.info("Running invisible watermarker")
logger.debug("Running invisible watermarker")
bgr = cv2.cvtColor(numpy.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
wm = self.text
encoder = WatermarkEncoder()

View File

@@ -1,54 +0,0 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
""" Common classes used by .image and .controlnet; avoids circular import issues """
from pydantic import BaseModel, Field
from typing import Literal
from ..models.image import ImageField
from .baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}

View File

@@ -1,9 +1,80 @@
from enum import Enum
from typing import Optional, Tuple
from typing import Optional, Tuple, Literal
from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The origin of a resource (eg image).
@@ -63,28 +134,3 @@ class InvalidImageCategoryException(ValueError):
super().__init__(message)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@@ -135,7 +135,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str,
image_name: str,
) -> None:
print(f'DEBUG: board_id={board_id}, image_name={image_name}')
try:
self._lock.acquire()
self._cursor.execute(
@@ -147,7 +146,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id, image_name, board_id),
)
self._conn.commit()
print('got here')
except sqlite3.Error as e:
self._conn.rollback()
raise e

View File

@@ -48,7 +48,7 @@ export const buildLinearTextToImageGraph = (
}
/**
v * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*