mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-04 18:45:05 -05:00
add support for safety checker (NSFW filter)
Now you can activate the Hugging Face `diffusers` library safety check for NSFW and other potentially disturbing imagery. To turn on the safety check, pass --safety_checker at the command line. For developers, the flag is `safety_checker=True` passed to ldm.generate.Generate(). Once the safety checker is turned on, it cannot be turned off unless you reinitialize a new Generate object. When the safety checker is active, suspect images will be blurred and a warning icon is added. There is also a warning message printed in the CLI, but it can be a little hard to see because of its positioning in the output stream. There is a slight but noticeable delay when the safety checker runs. Note that invisible watermarking is *not* currently implemented. The watermark code distributed by the CompViz distribution uses a library that does not seem to be able to retrieve the watermarks it creates, and it does not appear that Hugging Face `diffusers` or other SD distributions are doing any watermarking.
This commit is contained in:
@@ -418,6 +418,11 @@ class Args(object):
|
||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||
default='auto',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--safety_checker',
|
||||
action='store_true',
|
||||
help='Check for and blur potentially NSFW images',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--from_file',
|
||||
dest='infile',
|
||||
|
||||
@@ -7,25 +7,27 @@ import numpy as np
|
||||
import random
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFilter
|
||||
from einops import rearrange, repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'assets/caution.png'
|
||||
|
||||
class Generator():
|
||||
def __init__(self, model, precision):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
@@ -42,8 +44,10 @@ class Generator():
|
||||
|
||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
init_image = init_image,
|
||||
@@ -77,10 +81,17 @@ class Generator():
|
||||
pass
|
||||
|
||||
image = make_image(x_T)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_check(image)
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, first_seed=first_seed)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples):
|
||||
@@ -169,6 +180,39 @@ class Generator():
|
||||
|
||||
return v2
|
||||
|
||||
def safety_check(self,image:Image.Image):
|
||||
'''
|
||||
If the CompViz safety checker flags an NSFW image, we
|
||||
blur it out.
|
||||
'''
|
||||
import diffusers
|
||||
|
||||
checker = self.safety_checker['checker']
|
||||
extractor = self.safety_checker['extractor']
|
||||
features = extractor([image], return_tensors="pt")
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||
if has_nsfw_concept[0]:
|
||||
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self,input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = Image.open(CAUTION_IMG)
|
||||
caution = caution.resize((caution.width // 2, caution.height //2))
|
||||
blurry.paste(caution,(0,0),caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
|
||||
Reference in New Issue
Block a user