mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-04 14:05:18 -05:00
tweaks in response to psychedelicious review of PR
This commit is contained in:
|
Before Width: | Height: | Size: 33 KiB After Width: | Height: | Size: 33 KiB |
@@ -20,7 +20,7 @@ from ...backend.model_management import BaseModelType, ModelType
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||
InvocationConfig, InvocationContext)
|
||||
from .image_defs import ImageOutput, PILInvocationConfig
|
||||
from ..models.image import ImageOutput, PILInvocationConfig
|
||||
|
||||
CONTROLNET_DEFAULT_MODELS = [
|
||||
###########################################
|
||||
|
||||
@@ -4,24 +4,21 @@ from typing import Literal, Optional
|
||||
|
||||
import numpy
|
||||
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||
from ..models.image import (
|
||||
ImageCategory, ImageField, ResourceOrigin,
|
||||
PILInvocationConfig, ImageOutput, MaskOutput,
|
||||
)
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InvocationContext,
|
||||
InvocationConfig,
|
||||
)
|
||||
from .image_defs import (
|
||||
PILInvocationConfig,
|
||||
ImageOutput,
|
||||
MaskOutput,
|
||||
)
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend import SilenceWarnings
|
||||
@@ -644,7 +641,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
device = choose_torch_device()
|
||||
|
||||
if self.enabled:
|
||||
logger.info("Running NSFW checker")
|
||||
logger.debug("Running NSFW checker")
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
|
||||
|
||||
@@ -681,8 +678,8 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||
)
|
||||
|
||||
def _get_caution_img(self)->Image:
|
||||
import invokeai.assets.web as web_assets
|
||||
caution = Image.open(Path(web_assets.__path__[0]) / 'caution.png')
|
||||
import invokeai.app.assets.images as image_assets
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / 'caution.png')
|
||||
return caution.resize((caution.width // 2, caution.height //2))
|
||||
|
||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
@@ -716,7 +713,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||
logger = context.services.logger
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
if self.enabled:
|
||||
logger.info("Running invisible watermarker")
|
||||
logger.debug("Running invisible watermarker")
|
||||
bgr = cv2.cvtColor(numpy.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
wm = self.text
|
||||
encoder = WatermarkEncoder()
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
""" Common classes used by .image and .controlnet; avoids circular import issues """
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Literal
|
||||
from ..models.image import ImageField
|
||||
from .baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,80 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from ..invocations.baseinvocation import (
|
||||
BaseInvocationOutput,
|
||||
InvocationConfig,
|
||||
)
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
class PILInvocationConfig(BaseModel):
|
||||
"""Helper class to provide all PIL invocations with additional config"""
|
||||
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["PIL", "image"],
|
||||
},
|
||||
}
|
||||
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output an image"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["image_output"] = "image_output"
|
||||
image: ImageField = Field(default=None, description="The output image")
|
||||
width: int = Field(description="The width of the image in pixels")
|
||||
height: int = Field(description="The height of the image in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||
|
||||
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""Base class for invocations that output a mask"""
|
||||
|
||||
# fmt: off
|
||||
type: Literal["mask"] = "mask"
|
||||
mask: ImageField = Field(default=None, description="The output mask")
|
||||
width: int = Field(description="The width of the mask in pixels")
|
||||
height: int = Field(description="The height of the mask in pixels")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"required": [
|
||||
"type",
|
||||
"mask",
|
||||
]
|
||||
}
|
||||
|
||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||
"""The origin of a resource (eg image).
|
||||
@@ -63,28 +134,3 @@ class InvalidImageCategoryException(ValueError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image field used for passing image objects between invocations"""
|
||||
|
||||
image_name: Optional[str] = Field(default=None, description="The name of the image")
|
||||
|
||||
class Config:
|
||||
schema_extra = {"required": ["image_name"]}
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
@@ -135,7 +135,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
board_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
print(f'DEBUG: board_id={board_id}, image_name={image_name}')
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
@@ -147,7 +146,6 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
(board_id, image_name, board_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
print('got here')
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
|
||||
@@ -48,7 +48,7 @@ export const buildLinearTextToImageGraph = (
|
||||
}
|
||||
|
||||
/**
|
||||
v * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
|
||||
* full graph here as a template. Then use the parameters from app state and set friendlier node
|
||||
* ids.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user