control which revision of a diffusers model is downloaded

- Previously the user's preferred precision was used to select which
  version branch of a diffusers model would be downloaded. Half-precision
  would try to download the 'fp16' branch if it existed.

- Turns out that with waifu-diffusion this logic doesn't work, as
  'fp16' gets you waifu-diffusion v1.3, while 'main' gets you
  waifu-diffusion v1.4. Who knew?

- This PR adds a new optional "revision" field to `models.yaml`. This
  can be used to override the diffusers branch version. In the case of
  Waifu diffusion, INITIAL_MODELS.yaml now specifies the "main" branch.

- This PR also quenches the NSFW nag that downloading diffusers sometimes
  triggers.

- Closes #3160
This commit is contained in:
Lincoln Stein
2023-04-09 22:04:00 -04:00
parent 2af511c98a
commit 16ccc807cc
3 changed files with 20 additions and 8 deletions

View File

@@ -80,7 +80,8 @@ trinart-2.0:
repo_id: stabilityai/sd-vae-ft-mse
recommended: False
waifu-diffusion-1.4:
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
description: An SD-2.1 model trained on 5.4M anime/manga-style images (4.27 GB)
revision: main
repo_id: hakurei/waifu-diffusion
format: diffusers
vae:

View File

@@ -11,6 +11,7 @@ from tempfile import TemporaryFile
import requests
from diffusers import AutoencoderKL
from diffusers import logging as dlogging
from huggingface_hub import hf_hub_url
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@@ -296,13 +297,21 @@ def _download_diffusion_weights(
mconfig: DictConfig, access_token: str, precision: str = "float32"
):
repo_id = mconfig["repo_id"]
revision = mconfig.get('revision',None)
model_class = (
StableDiffusionGeneratorPipeline
if mconfig.get("format", None) == "diffusers"
else AutoencoderKL
)
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
extra_arg_list = [{"revision": revision}] if revision \
else [{"revision": "fp16"}, {}] if precision == "float16" \
else [{}]
path = None
# quench safety checker warnings
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
for extra_args in extra_arg_list:
try:
path = download_from_hf(
@@ -318,6 +327,7 @@ def _download_diffusion_weights(
print(f"An unexpected error occurred while downloading the model: {e})")
if path:
break
dlogging.set_verbosity(verbosity)
return path
@@ -448,6 +458,8 @@ def new_config_file_contents(
stanza["description"] = mod["description"]
stanza["repo_id"] = mod["repo_id"]
stanza["format"] = mod["format"]
if "revision" in mod:
stanza["revision"] = mod["revision"]
# diffusers don't need width and height (probably .ckpt doesn't either)
# so we no longer require these in INITIAL_MODELS.yaml
if "width" in mod:
@@ -472,10 +484,9 @@ def new_config_file_contents(
conf[model] = stanza
# if no default model was chosen, then we select the first
# one in the list
# if no default model was chosen, then we select the first one in the list
if not default_selected:
conf[list(successfully_downloaded.keys())[0]]["default"] = True
conf[list(conf.keys())[0]]["default"] = True
return OmegaConf.to_yaml(conf)

View File

@@ -423,9 +423,9 @@ class ModelManager(object):
pipeline_args.update(cache_dir=global_cache_dir("hub"))
if using_fp16:
pipeline_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
else:
fp_args_list = [{}]
revision = mconfig.get('revision') or ('fp16' if using_fp16 else None)
fp_args_list = [{"revision": revision}] if revision else []
fp_args_list.append({})
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()