diff --git a/docs/features/CLI.md b/docs/features/CLI.md index 85524f6fa9..67a187fb3b 100644 --- a/docs/features/CLI.md +++ b/docs/features/CLI.md @@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt | `--model ` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" | | `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. | | `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | +| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery | | `--web` | | `False` | Start in web server mode | | `--host ` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. | | `--port ` | | `9090` | Which port web server should listen for requests on. | @@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt | `--embedding_path ` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models | | `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. | | `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. | -| `--device ` | `-d` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" | | `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions | | `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast | diff --git a/environment-mac.yml b/environment-mac.yml index 14509ccff7..16fcccb67f 100644 --- a/environment-mac.yml +++ b/environment-mac.yml @@ -19,6 +19,7 @@ dependencies: # ``` - albumentations==1.2.1 - coloredlogs==15.0.1 + - diffusers==0.6.0 - einops==0.4.1 - grpcio==1.46.4 - humanfriendly==10.0 diff --git a/environment.yml b/environment.yml index 820f940608..fe9eb37768 100644 --- a/environment.yml +++ b/environment.yml @@ -26,6 +26,7 @@ dependencies: - pyreadline3 - torch-fidelity==0.3.0 - transformers==4.21.3 + - diffusers==0.6.0 - torchmetrics==0.7.0 - flask==2.1.3 - flask_socketio==5.3.0 diff --git a/ldm/generate.py b/ldm/generate.py index ce2331806c..4030e16e41 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -132,20 +132,21 @@ class Generate: def __init__( self, - model = None, - conf = 'configs/models.yaml', - embedding_path = None, - sampler_name = 'k_lms', - ddim_eta = 0.0, # deterministic - full_precision = False, - precision = 'auto', - # these are deprecated; if present they override values in the conf file - weights = None, - config = None, + model = None, + conf = 'configs/models.yaml', + embedding_path = None, + sampler_name = 'k_lms', + ddim_eta = 0.0, # deterministic + full_precision = False, + precision = 'auto', gfpgan=None, codeformer=None, esrgan=None, free_gpu_mem=False, + safety_checker:bool=False, + # these are deprecated; if present they override values in the conf file + weights = None, + config = None, ): mconfig = OmegaConf.load(conf) self.height = None @@ -176,6 +177,7 @@ class Generate: self.free_gpu_mem = free_gpu_mem self.size_matters = True # used to warn once about large image sizes and VRAM self.txt2mask = None + self.safety_checker = None # Note that in previous versions, there was an option to pass the # device to Generate(). However the device was then ignored, so @@ -203,6 +205,19 @@ class Generate: # gets rid of annoying messages about random seed logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) + # load safety checker if requested + if safety_checker: + try: + print('>> Initializing safety checker') + from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + from transformers import AutoFeatureExtractor + safety_model_id = "CompVis/stable-diffusion-safety-checker" + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True) + self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True) + except Exception: + print('** An error was encountered while installing the safety checker:') + print(traceback.format_exc()) + def prompt2png(self, prompt, outdir, **kwargs): """ Takes a prompt and an output directory, writes out the requested number @@ -418,6 +433,11 @@ class Generate: self.seed, variation_amount, with_variations ) + checker = { + 'checker':self.safety_checker, + 'extractor':self.safety_feature_extractor + } if self.safety_checker else None + results = generator.generate( prompt, iterations=iterations, @@ -428,10 +448,10 @@ class Generate: conditioning=(uc, c), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated - step_callback=step_callback, # called after each intermediate image is generated + step_callback=step_callback, # called after each intermediate image is generated width=width, height=height, - init_img=init_img, # embiggen needs to manipulate from the unmodified init_img + init_img=init_img, # embiggen needs to manipulate from the unmodified init_img init_image=init_image, # notice that init_image is different from init_img mask_image=mask_image, strength=strength, @@ -440,7 +460,8 @@ class Generate: embiggen=embiggen, embiggen_tiles=embiggen_tiles, inpaint_replace=inpaint_replace, - mask_blur_radius=mask_blur_radius + mask_blur_radius=mask_blur_radius, + safety_checker=checker ) if init_color: diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index e2302e4452..dca3e49780 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -418,6 +418,11 @@ class Args(object): help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', default='auto', ) + model_group.add_argument( + '--safety_checker', + action='store_true', + help='Check for and blur potentially NSFW images', + ) file_group.add_argument( '--from_file', dest='infile', diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 89476cd216..cab550d76f 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -7,25 +7,27 @@ import numpy as np import random import os from tqdm import tqdm, trange -from PIL import Image +from PIL import Image, ImageFilter from einops import rearrange, repeat from pytorch_lightning import seed_everything from ldm.invoke.devices import choose_autocast from ldm.util import rand_perlin_2d downsampling = 8 +CAUTION_IMG = 'assets/caution.png' class Generator(): def __init__(self, model, precision): - self.model = model - self.precision = precision - self.seed = None - self.latent_channels = model.channels + self.model = model + self.precision = precision + self.seed = None + self.latent_channels = model.channels self.downsampling_factor = downsampling # BUG: should come from model or config - self.perlin = 0.0 - self.threshold = 0 - self.variation_amount = 0 - self.with_variations = [] + self.safety_checker = None + self.perlin = 0.0 + self.threshold = 0 + self.variation_amount = 0 + self.with_variations = [] # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self,prompt,**kwargs): @@ -42,8 +44,10 @@ class Generator(): def generate(self,prompt,init_image,width,height,iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, + safety_checker:dict=None, **kwargs): scope = choose_autocast(self.precision) + self.safety_checker = safety_checker make_image = self.get_make_image( prompt, init_image = init_image, @@ -77,10 +81,17 @@ class Generator(): pass image = make_image(x_T) + + if self.safety_checker is not None: + image = self.safety_check(image) + results.append([image, seed]) + if image_callback is not None: image_callback(image, seed, first_seed=first_seed) + seed = self.new_seed() + return results def sample_to_image(self,samples): @@ -169,6 +180,39 @@ class Generator(): return v2 + def safety_check(self,image:Image.Image): + ''' + If the CompViz safety checker flags an NSFW image, we + blur it out. + ''' + import diffusers + + checker = self.safety_checker['checker'] + extractor = self.safety_checker['extractor'] + features = extractor([image], return_tensors="pt") + + # unfortunately checker requires the numpy version, so we have to convert back + x_image = np.array(image).astype(np.float32) / 255.0 + x_image = x_image[None].transpose(0, 3, 1, 2) + + diffusers.logging.set_verbosity_error() + checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values) + if has_nsfw_concept[0]: + print('** An image with potential non-safe content has been detected. A blurred image will be returned. **') + return self.blur(image) + else: + return image + + def blur(self,input): + blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) + try: + caution = Image.open(CAUTION_IMG) + caution = caution.resize((caution.width // 2, caution.height //2)) + blurry.paste(caution,(0,0),caution) + except FileNotFoundError: + pass + return blurry + # this is a handy routine for debugging use. Given a generated sample, # convert it into a PNG image and store it at the indicated path def save_sample(self, sample, filepath): diff --git a/requirements-linux-arm64.txt b/requirements-linux-arm64.txt index 5ee4df2399..a0be77057b 100644 --- a/requirements-linux-arm64.txt +++ b/requirements-linux-arm64.txt @@ -1,5 +1,6 @@ albumentations==0.4.3 einops==0.3.0 +diffusers==0.6.0 huggingface-hub==0.8.1 imageio==2.9.0 imageio-ffmpeg==0.4.2 diff --git a/requirements.txt b/requirements.txt index 2e85166841..5671b596b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,6 +32,7 @@ send2trash dependency_injector==4.40.0 eventlet realesrgan +diffusers git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan diff --git a/scripts/invoke.py b/scripts/invoke.py index f4d4f3c4c0..2247f64219 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -69,16 +69,17 @@ def main(): # creating a Generate object: try: gen = Generate( - conf = opt.conf, - model = opt.model, - sampler_name = opt.sampler_name, + conf = opt.conf, + model = opt.model, + sampler_name = opt.sampler_name, embedding_path = opt.embedding_path, full_precision = opt.full_precision, - precision = opt.precision, + precision = opt.precision, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan, free_gpu_mem=opt.free_gpu_mem, + safety_checker=opt.safety_checker, ) except (FileNotFoundError, IOError, KeyError) as e: print(f'{e}. Aborting.') diff --git a/scripts/preload_models.py b/scripts/preload_models.py index 1b0ad80e5c..bf0a5ffb99 100644 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -5,7 +5,7 @@ # two machines must share a common .cache directory. from transformers import CLIPTokenizer, CLIPTextModel import clip -from transformers import BertTokenizerFast +from transformers import BertTokenizerFast, AutoFeatureExtractor import sys import transformers import os @@ -17,41 +17,39 @@ import traceback transformers.logging.set_verbosity_error() +#--------------------------------------------- # this will preload the Bert tokenizer fles -print('Loading bert tokenizer (ignore deprecation errors)...', end='') -with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=DeprecationWarning) - tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') -print('...success') -sys.stdout.flush() +def download_bert(): + print('Installing bert tokenizer (ignore deprecation errors)...', end='') + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + print('...success') + sys.stdout.flush() +#--------------------------------------------- # this will download requirements for Kornia -print('Loading Kornia requirements...', end='') -with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=DeprecationWarning) - import kornia -print('...success') +def download_kornia(): + print('Installing Kornia requirements...', end='') + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + import kornia + print('...success') -version = 'openai/clip-vit-large-patch14' -sys.stdout.flush() -print('Loading CLIP model...',end='') -tokenizer = CLIPTokenizer.from_pretrained(version) -transformer = CLIPTextModel.from_pretrained(version) -print('...success') +#--------------------------------------------- +def download_clip(): + version = 'openai/clip-vit-large-patch14' + sys.stdout.flush() + print('Loading CLIP model...',end='') + tokenizer = CLIPTokenizer.from_pretrained(version) + transformer = CLIPTextModel.from_pretrained(version) + print('...success') -# In the event that the user has installed GFPGAN and also elected to use -# RealESRGAN, this will attempt to download the model needed by RealESRGANer -gfpgan = False -try: - from realesrgan import RealESRGANer - - gfpgan = True -except ModuleNotFoundError: - pass - -if gfpgan: - print('Loading models from RealESRGAN and facexlib...',end='') +#--------------------------------------------- +def download_gfpgan(): + print('Installing models from RealESRGAN and facexlib...',end='') try: + from realesrgan import RealESRGANer from realesrgan.archs.srvgg_arch import SRVGGNetCompact from facexlib.utils.face_restoration_helper import FaceRestoreHelper @@ -94,44 +92,72 @@ if gfpgan: print('Error loading GFPGAN:') print(traceback.format_exc()) -print('preloading CodeFormer model file...',end='') -try: - model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' - model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth' - if not os.path.exists(model_dest): - print('Downloading codeformer model file...') +#--------------------------------------------- +def download_codeformer(): + print('Installing CodeFormer model file...',end='') + try: + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' + model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth' + if not os.path.exists(model_dest): + print('Downloading codeformer model file...') + os.makedirs(os.path.dirname(model_dest), exist_ok=True) + urllib.request.urlretrieve(model_url,model_dest) + except Exception: + print('Error loading CodeFormer:') + print(traceback.format_exc()) + print('...success') + +#--------------------------------------------- +def download_clipseg(): + print('Installing clipseg model for text-based masking...',end='') + try: + model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' + model_dest = 'src/clipseg/clipseg_weights.zip' + weights_dir = 'src/clipseg/weights' + if not os.path.exists(weights_dir): os.makedirs(os.path.dirname(model_dest), exist_ok=True) urllib.request.urlretrieve(model_url,model_dest) -except Exception: - print('Error loading CodeFormer:') - print(traceback.format_exc()) -print('...success') + with zipfile.ZipFile(model_dest,'r') as zip: + zip.extractall('src/clipseg') + os.rename('src/clipseg/clipseg_weights','src/clipseg/weights') + os.remove(model_dest) + from clipseg_models.clipseg import CLIPDensePredT + model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, ) + model.eval() + model.load_state_dict( + torch.load( + 'src/clipseg/weights/rd64-uni-refined.pth', + map_location=torch.device('cpu') + ), + strict=False, + ) + except Exception: + print('Error installing clipseg model:') + print(traceback.format_exc()) + print('...success') -print('Loading clipseg model for text-based masking...',end='') -try: - model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' - model_dest = 'src/clipseg/clipseg_weights.zip' - weights_dir = 'src/clipseg/weights' - if not os.path.exists(weights_dir): - os.makedirs(os.path.dirname(model_dest), exist_ok=True) - urllib.request.urlretrieve(model_url,model_dest) - with zipfile.ZipFile(model_dest,'r') as zip: - zip.extractall('src/clipseg') - os.rename('src/clipseg/clipseg_weights','src/clipseg/weights') - os.remove(model_dest) - from clipseg_models.clipseg import CLIPDensePredT - model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, ) - model.eval() - model.load_state_dict( - torch.load( - 'src/clipseg/weights/rd64-uni-refined.pth', - map_location=torch.device('cpu') - ), - strict=False, - ) -except Exception: - print('Error installing clipseg model:') - print(traceback.format_exc()) -print('...success') +#------------------------------------- +def download_safety_checker(): + print('Installing safety model for NSFW content detection...',end='') + try: + from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + except ModuleNotFoundError: + print('Error installing safety checker model:') + print(traceback.format_exc()) + return + safety_model_id = "CompVis/stable-diffusion-safety-checker" + safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) + safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) + print('...success') + +#------------------------------------- +if __name__ == '__main__': + download_bert() + download_kornia() + download_clip() + download_gfpgan() + download_codeformer() + download_clipseg() + download_safety_checker() - +