Compare commits

...

6 Commits

Author SHA1 Message Date
psychedelicious
60d16fde2c feat(nodes): raise on NSFW 2024-05-13 17:35:35 +10:00
psychedelicious
8bddcce2b0 fix(backend): nsfw checker on non-rgb images 2024-05-13 17:18:05 +10:00
psychedelicious
1ca583e407 chore(ui): rip out refs to nsfw checker/watermark from UI 2024-05-13 17:18:05 +10:00
psychedelicious
b46cd5dcd0 feat(nodes): check/watermark all images in invocation api
Note: The NSFW checker model fails when the image isn't RGB.
2024-05-13 17:16:26 +10:00
psychedelicious
28d81a4372 feat(config): add nsfw_check and watermark to config 2024-05-13 17:16:26 +10:00
psychedelicious
e4f0c8a5d0 fix(nodes): fix nsfw checker model download 2024-05-13 17:16:26 +10:00
23 changed files with 269 additions and 321 deletions

View File

@@ -13,7 +13,6 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.logging import logging
from invokeai.version import __version__
@@ -109,9 +108,7 @@ async def get_config() -> AppConfig:
upscaling_models.append(str(Path(model).stem))
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
nsfw_methods = []
if SafetyChecker.safety_checker_available():
nsfw_methods.append("nsfw_checker")
nsfw_methods = ["nsfw_checker"]
watermarking_methods = ["invisible_watermark"]

View File

@@ -112,6 +112,8 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
nsfw_check: Enable NSFW checking for images. NSFW images will be blurred.
watermark: Watermark all images with `invisible-watermark`.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@@ -184,6 +186,8 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. If an NSFW image is encountered during generation, execution will immediately stop. If disabled, the NSFW model is never loaded.")
watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")

View File

@@ -15,6 +15,8 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -191,6 +193,12 @@ class ImagesInterface(InvocationContextInterface):
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
board_id_ = self._data.invocation.board.board_id
if self._services.configuration.nsfw_check:
SafetyChecker.raise_if_nsfw(image)
if self._services.configuration.watermark:
image = InvisibleWatermark.add_watermark(image, "InvokeAI")
return self._services.images.create(
image=image,
is_intermediate=self._data.invocation.is_intermediate,

View File

@@ -4,8 +4,6 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
from pathlib import Path
import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image
@@ -16,38 +14,44 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings
repo_id = "CompVis/stable-diffusion-safety-checker"
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
class SafetyChecker:
"""
Wrapper around SafetyChecker model.
"""
class NSFWImageException(Exception):
"""Raised when a NSFW image is detected."""
def __init__(self):
super().__init__("A potentially NSFW image has been detected.")
class SafetyChecker:
"""Wrapper around SafetyChecker model."""
safety_checker = None
feature_extractor = None
tried_load: bool = False
safety_checker = None
@classmethod
def _load_safety_checker(cls):
if cls.tried_load:
if cls.safety_checker is not None and cls.feature_extractor is not None:
return
try:
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
model_path = get_config().models_path / CHECKER_PATH
if model_path.exists():
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_path)
else:
model_path.mkdir(parents=True, exist_ok=True)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
cls.feature_extractor.save_pretrained(model_path, safe_serialization=True)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id)
cls.safety_checker.save_pretrained(model_path, safe_serialization=True)
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
cls.tried_load = True
@classmethod
def safety_checker_available(cls) -> bool:
return Path(get_config().models_path, CHECKER_PATH).exists()
@classmethod
def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available() and cls.tried_load:
return False
cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None:
return False
@@ -55,8 +59,25 @@ class SafetyChecker:
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
cls.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
# Only RGB(A) images are supported, so to prevent an error when a NSFW concept is detect in an image with
# a different image mode, we _must_ convert it to RGB and then to a normalized, batched np array.
rgb_image = image.convert("RGB")
# Convert to normalized (0-1) np array
x_image = np.array(rgb_image).astype(np.float32) / 255.0
# Add batch dimension and transpose to NCHW
x_image = x_image[None].transpose(0, 3, 1, 2)
# A warning is logged if a NSFW concept is detected - silence those, so we can handle it ourselves.
with SilenceWarnings():
# `clip_input` (features) is used to check for NSFW concepts. `images` is required, but it isn't actually
# checked for NSFW concepts. If a NSFW concept is detected, the the image is replaced with a black image.
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]
@classmethod
def raise_if_nsfw(cls, image: Image.Image) -> Image.Image:
"""Raises an exception if the image contains NSFW content."""
if cls.has_nsfw_concept(image):
raise NSFWImageException()
return image

View File

@@ -1,13 +1,12 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { setInfillMethod } from 'features/parameters/store/generationSlice';
import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged } from 'features/system/store/systemSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
export const addAppConfigReceivedListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getAppConfig.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
const { infill_methods = [], nsfw_methods = [], watermarking_methods = [] } = action.payload;
const { infill_methods = [] } = action.payload;
const infillMethod = getState().generation.infillMethod;
if (!infill_methods.includes(infillMethod)) {
@@ -15,14 +14,6 @@ export const addAppConfigReceivedListener = (startAppListening: AppStartListenin
// if there is no first one... god help us
dispatch(setInfillMethod(infill_methods[0] as string));
}
if (!nsfw_methods.includes('nsfw_checker')) {
dispatch(shouldUseNSFWCheckerChanged(false));
}
if (!watermarking_methods.includes('invisible_watermark')) {
dispatch(shouldUseWatermarkerChanged(false));
}
},
});
};

View File

@@ -1,40 +0,0 @@
import type { RootState } from 'app/store/store';
import type { ImageNSFWBlurInvocation, LatentsToImageInvocation, NonNullableGraph } from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
export const addNSFWCheckerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true;
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;
graph.edges.push({
source: {
node_id: nodeIdToAddTo,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
});
};

View File

@@ -1,66 +0,0 @@
import type { RootState } from 'app/store/store';
import type {
ImageNSFWBlurInvocation,
ImageWatermarkInvocation,
LatentsToImageInvocation,
NonNullableGraph,
} from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants';
import { getBoardField, getIsIntermediate } from './graphBuilderUtils';
export const addWatermarkerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as LatentsToImageInvocation | undefined;
const nsfwCheckerNode = graph.nodes[NSFW_CHECKER] as ImageNSFWBlurInvocation | undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
const watermarkerNode: ImageWatermarkInvocation = {
id: WATERMARKER,
type: 'img_watermark',
is_intermediate: getIsIntermediate(state),
board: getBoardField(state),
};
graph.nodes[WATERMARKER] = watermarkerNode;
// no matter the situation, we want the l2i node to be intermediate
nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true;
if (nsfwCheckerNode) {
// if we are using NSFW checker, we need to "disable" it output by marking it intermediate,
// then connect it to the watermark node
nsfwCheckerNode.is_intermediate = true;
graph.edges.push({
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
});
} else {
// otherwise we just connect to the watermark node
graph.edges.push({
source: {
node_id: nodeIdToAddTo,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
});
}
};

View File

@@ -12,11 +12,9 @@ import {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_IMAGE_TO_IMAGE_GRAPH,
CANVAS_OUTPUT,
@@ -357,16 +355,5 @@ export const buildCanvasImageToImageGraph = async (
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};

View File

@@ -12,11 +12,9 @@ import type {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_INPAINT_GRAPH,
CANVAS_OUTPUT,
@@ -446,16 +444,6 @@ export const buildCanvasInpaintGraph = async (
// Add IP Adapter
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};

View File

@@ -12,11 +12,9 @@ import type {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_OUTPAINT_GRAPH,
CANVAS_OUTPUT,
@@ -606,16 +604,5 @@ export const buildCanvasOutpaintGraph = async (
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};

View File

@@ -10,13 +10,11 @@ import {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_OUTPUT,
IMAGE_TO_LATENTS,
@@ -367,16 +365,5 @@ export const buildCanvasSDXLImageToImageGraph = async (
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};

View File

@@ -11,13 +11,11 @@ import type {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_OUTPUT,
INPAINT_CREATE_MASK,
@@ -466,16 +464,5 @@ export const buildCanvasSDXLInpaintGraph = async (
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};

View File

@@ -11,13 +11,11 @@ import type {
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_OUTPUT,
INPAINT_CREATE_MASK,
@@ -623,16 +621,5 @@ export const buildCanvasSDXLOutpaintGraph = async (
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};

View File

@@ -5,13 +5,11 @@ import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,
@@ -322,16 +320,5 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise
await addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};

View File

@@ -7,11 +7,9 @@ import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CANVAS_OUTPUT,
CANVAS_TEXT_TO_IMAGE_GRAPH,
@@ -303,16 +301,5 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise<Non
await addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
await addT2IAdaptersToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph, CANVAS_OUTPUT);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
return graph;
};

View File

@@ -8,10 +8,8 @@ import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api
import { addHrfToGraph } from './addHrfToGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
CLIP_SKIP,
CONTROL_LAYERS_GRAPH,
@@ -252,16 +250,5 @@ export const buildGenerationTabGraph = async (state: RootState): Promise<NonNull
addHrfToGraph(state, graph);
}
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@@ -4,12 +4,10 @@ import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetch
import { addControlLayersToGraph } from 'features/nodes/util/graph/addControlLayersToGraph';
import { isNonRefinerMainModelConfig, type NonNullableGraph } from 'services/api/types';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
import {
LATENTS_TO_IMAGE,
NEGATIVE_CONDITIONING,
@@ -263,16 +261,5 @@ export const buildGenerationTabSDXLGraph = async (state: RootState): Promise<Non
// optionally add custom VAE
await addVAEToGraph(state, graph, modelLoaderNodeId);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
addNSFWCheckerToGraph(state, graph);
}
if (state.system.shouldUseWatermarker) {
// must add after nsfw checker!
addWatermarkerToGraph(state, graph);
}
return graph;
};

View File

@@ -9,8 +9,6 @@ export const LATENTS_TO_IMAGE_HRF_LR = 'latents_to_image_hrf_lr';
export const IMAGE_TO_LATENTS_HRF = 'image_to_latents_hrf';
export const RESIZE_HRF = 'resize_hrf';
export const ESRGAN_HRF = 'esrgan_hrf';
export const NSFW_CHECKER = 'nsfw_checker';
export const WATERMARKER = 'invisible_watermark';
export const NOISE = 'noise';
export const NOISE_HRF = 'noise_hrf';
export const MAIN_MODEL_LOADER = 'main_model_loader';

View File

@@ -28,14 +28,11 @@ import {
setShouldEnableInformationalPopovers,
shouldAntialiasProgressImageChanged,
shouldLogToConsoleChanged,
shouldUseNSFWCheckerChanged,
shouldUseWatermarkerChanged,
} from 'features/system/store/systemSlice';
import { setShouldShowProgressInViewer } from 'features/ui/store/uiSlice';
import type { ChangeEvent, ReactElement } from 'react';
import { cloneElement, memo, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetAppConfigQuery } from 'services/api/endpoints/appInfo';
import { SettingsLanguageSelect } from './SettingsLanguageSelect';
import { SettingsLogLevelSelect } from './SettingsLogLevelSelect';
@@ -69,13 +66,6 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
}
}, [shouldShowDeveloperSettings, dispatch]);
const { isNSFWCheckerAvailable, isWatermarkerAvailable } = useGetAppConfigQuery(undefined, {
selectFromResult: ({ data }) => ({
isNSFWCheckerAvailable: data?.nsfw_methods.includes('nsfw_checker') ?? false,
isWatermarkerAvailable: data?.watermarking_methods.includes('invisible_watermark') ?? false,
}),
});
const {
clearIntermediates,
hasPendingItems,
@@ -94,8 +84,6 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
const shouldShowProgressInViewer = useAppSelector((s) => s.ui.shouldShowProgressInViewer);
const shouldLogToConsole = useAppSelector((s) => s.system.shouldLogToConsole);
const shouldAntialiasProgressImage = useAppSelector((s) => s.system.shouldAntialiasProgressImage);
const shouldUseNSFWChecker = useAppSelector((s) => s.system.shouldUseNSFWChecker);
const shouldUseWatermarker = useAppSelector((s) => s.system.shouldUseWatermarker);
const shouldEnableInformationalPopovers = useAppSelector((s) => s.system.shouldEnableInformationalPopovers);
const clearStorage = useClearStorage();
@@ -133,18 +121,6 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
},
[dispatch]
);
const handleChangeShouldUseNSFWChecker = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldUseNSFWCheckerChanged(e.target.checked));
},
[dispatch]
);
const handleChangeShouldUseWatermarker = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldUseWatermarkerChanged(e.target.checked));
},
[dispatch]
);
const handleChangeShouldShowProgressInViewer = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldShowProgressInViewer(e.target.checked));
@@ -198,17 +174,6 @@ const SettingsModal = ({ children, config }: SettingsModalProps) => {
</FormControl>
</StickyScrollable>
<StickyScrollable title={t('settings.generation')}>
<FormControl isDisabled={!isNSFWCheckerAvailable}>
<FormLabel>{t('settings.enableNSFWChecker')}</FormLabel>
<Switch isChecked={shouldUseNSFWChecker} onChange={handleChangeShouldUseNSFWChecker} />
</FormControl>
<FormControl isDisabled={!isWatermarkerAvailable}>
<FormLabel>{t('settings.enableInvisibleWatermark')}</FormLabel>
<Switch isChecked={shouldUseWatermarker} onChange={handleChangeShouldUseWatermarker} />
</FormControl>
</StickyScrollable>
<StickyScrollable title={t('settings.ui')}>
<FormControl>
<FormLabel>{t('settings.showProgressInViewer')}</FormLabel>

View File

@@ -35,8 +35,6 @@ const initialSystemState: SystemState = {
consoleLogLevel: 'debug',
shouldLogToConsole: true,
language: 'en',
shouldUseNSFWChecker: false,
shouldUseWatermarker: false,
shouldEnableInformationalPopovers: false,
status: 'DISCONNECTED',
};
@@ -69,12 +67,6 @@ export const systemSlice = createSlice({
languageChanged: (state, action: PayloadAction<Language>) => {
state.language = action.payload;
},
shouldUseNSFWCheckerChanged(state, action: PayloadAction<boolean>) {
state.shouldUseNSFWChecker = action.payload;
},
shouldUseWatermarkerChanged(state, action: PayloadAction<boolean>) {
state.shouldUseWatermarker = action.payload;
},
setShouldEnableInformationalPopovers(state, action: PayloadAction<boolean>) {
state.shouldEnableInformationalPopovers = action.payload;
},
@@ -189,8 +181,6 @@ export const {
shouldLogToConsoleChanged,
shouldAntialiasProgressImageChanged,
languageChanged,
shouldUseNSFWCheckerChanged,
shouldUseWatermarkerChanged,
setShouldEnableInformationalPopovers,
} = systemSlice.actions;

View File

@@ -53,8 +53,6 @@ export interface SystemState {
shouldLogToConsole: boolean;
shouldAntialiasProgressImage: boolean;
language: Language;
shouldUseNSFWChecker: boolean;
shouldUseWatermarker: boolean;
status: SystemStatus;
shouldEnableInformationalPopovers: boolean;
}

File diff suppressed because one or more lines are too long

View File

@@ -143,8 +143,6 @@ export type ImageToLatentsInvocation = S['ImageToLatentsInvocation'];
export type LatentsToImageInvocation = S['LatentsToImageInvocation'];
export type LoRALoaderInvocation = S['LoRALoaderInvocation'];
export type ESRGANInvocation = S['ESRGANInvocation'];
export type ImageNSFWBlurInvocation = S['ImageNSFWBlurInvocation'];
export type ImageWatermarkInvocation = S['ImageWatermarkInvocation'];
export type SeamlessModeInvocation = S['SeamlessModeInvocation'];
export type CoreMetadataInvocation = S['CoreMetadataInvocation'];