diff --git a/configs/models.yaml b/configs/models.yaml new file mode 100644 index 0000000000..a3c929d29f --- /dev/null +++ b/configs/models.yaml @@ -0,0 +1,18 @@ +# This file describes the alternative machine learning models +# available to the dream script. +# +# To add a new model, follow the examples below. Each +# model requires a model config file, a weights file, +# and the width and height of the images it +# was trained on. + +laion400m: + config: configs/latent-diffusion/txt2img-1p4B-eval.yaml + weights: models/ldm/text2img-large/model.ckpt + width: 256 + height: 256 +stable-diffusion-1.4: + config: configs/stable-diffusion/v1-inference.yaml + weights: models/ldm/stable-diffusion-v1/model.ckpt + width: 512 + height: 512 diff --git a/scripts/dream.py b/scripts/dream.py index 1535ac386c..8901034eb1 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -9,31 +9,34 @@ import sys import copy import warnings import time -from ldm.dream.devices import choose_torch_device import ldm.dream.readline from ldm.dream.pngwriter import PngWriter, PromptFormatter from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.image_util import make_grid +from omegaconf import OmegaConf def main(): """Initialize command-line parsers and the diffusion model""" arg_parser = create_argv_parser() opt = arg_parser.parse_args() + if opt.laion400m: - # defaults suitable to the older latent diffusion weights - width = 256 - height = 256 - config = 'configs/latent-diffusion/txt2img-1p4B-eval.yaml' - weights = 'models/ldm/text2img-large/model.ckpt' - else: - # some defaults suitable for stable diffusion weights - width = 512 - height = 512 - config = 'configs/stable-diffusion/v1-inference.yaml' - if '.ckpt' in opt.weights: - weights = opt.weights - else: - weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt' + print('--laion400m flag has been deprecated. Please use --model laion400m instead.') + sys.exit(-1) + if opt.weights != 'model': + print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') + sys.exit(-1) + + try: + print(f'attempting to load {opt.config}') + models = OmegaConf.load(opt.config) + width = models[opt.model].width + height = models[opt.model].height + config = models[opt.model].config + weights = models[opt.model].weights + except (FileNotFoundError, IOError, KeyError) as e: + print(f'{e}. Aborting.') + sys.exit(-1) print('* Initializing, be patient...\n') sys.path.append('.') @@ -348,8 +351,6 @@ def create_argv_parser(): dest='full_precision', action='store_true', help='Use slower full precision math for calculations', - # MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237 - default=choose_torch_device() == 'mps', ) parser.add_argument( '-g', @@ -429,6 +430,16 @@ def create_argv_parser(): default='cuda', help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available" ) + parser.add_argument( + '--model', + default='stable-diffusion-1.4', + help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")', + ) + parser.add_argument( + '--config', + default ='configs/models.yaml', + help ='Path to configuration file for alternate models.', + ) return parser