add support for safety checker (NSFW filter)

Now you can activate the Hugging Face `diffusers` library safety check
for NSFW and other potentially disturbing imagery.

To turn on the safety check, pass --safety_checker at the command
line. For developers, the flag is `safety_checker=True` passed to
ldm.generate.Generate(). Once the safety checker is turned on, it
cannot be turned off unless you reinitialize a new Generate object.

When the safety checker is active, suspect images will be blurred and
a warning icon is added. There is also a warning message printed in
the CLI, but it can be a little hard to see because of its positioning
in the output stream.

There is a slight but noticeable delay when the safety checker runs.

Note that invisible watermarking is *not* currently implemented. The
watermark code distributed by the CompViz distribution uses a library
that does not seem to be able to retrieve the watermarks it creates,
and it does not appear that Hugging Face `diffusers` or other SD
distributions are doing any watermarking.
This commit is contained in:
Lincoln Stein
2022-10-23 22:26:18 -04:00
parent b7ce5b4f1b
commit b159b2fe42
10 changed files with 195 additions and 94 deletions

View File

@@ -418,6 +418,11 @@ class Args(object):
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
)
model_group.add_argument(
'--safety_checker',
action='store_true',
help='Check for and blur potentially NSFW images',
)
file_group.add_argument(
'--from_file',
dest='infile',

View File

@@ -7,25 +7,27 @@ import numpy as np
import random
import os
from tqdm import tqdm, trange
from PIL import Image
from PIL import Image, ImageFilter
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d
downsampling = 8
CAUTION_IMG = 'assets/caution.png'
class Generator():
def __init__(self, model, precision):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
self.safety_checker = None
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):
@@ -42,8 +44,10 @@ class Generator():
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs):
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
make_image = self.get_make_image(
prompt,
init_image = init_image,
@@ -77,10 +81,17 @@ class Generator():
pass
image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed()
return results
def sample_to_image(self,samples):
@@ -169,6 +180,39 @@ class Generator():
return v2
def safety_check(self,image:Image.Image):
'''
If the CompViz safety checker flags an NSFW image, we
blur it out.
'''
import diffusers
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
return self.blur(image)
else:
return image
def blur(self,input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = Image.open(CAUTION_IMG)
caution = caution.resize((caution.width // 2, caution.height //2))
blurry.paste(caution,(0,0),caution)
except FileNotFoundError:
pass
return blurry
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):