mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
control which revision of a diffusers model is downloaded
- Previously the user's preferred precision was used to select which version branch of a diffusers model would be downloaded. Half-precision would try to download the 'fp16' branch if it existed. - Turns out that with waifu-diffusion this logic doesn't work, as 'fp16' gets you waifu-diffusion v1.3, while 'main' gets you waifu-diffusion v1.4. Who knew? - This PR adds a new optional "revision" field to `models.yaml`. This can be used to override the diffusers branch version. In the case of Waifu diffusion, INITIAL_MODELS.yaml now specifies the "main" branch. - This PR also quenches the NSFW nag that downloading diffusers sometimes triggers. - Closes #3160
This commit is contained in:
@@ -80,7 +80,8 @@ trinart-2.0:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
recommended: False
|
||||
waifu-diffusion-1.4:
|
||||
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
|
||||
description: An SD-2.1 model trained on 5.4M anime/manga-style images (4.27 GB)
|
||||
revision: main
|
||||
repo_id: hakurei/waifu-diffusion
|
||||
format: diffusers
|
||||
vae:
|
||||
|
||||
@@ -11,6 +11,7 @@ from tempfile import TemporaryFile
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@@ -296,13 +297,21 @@ def _download_diffusion_weights(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
):
|
||||
repo_id = mconfig["repo_id"]
|
||||
revision = mconfig.get('revision',None)
|
||||
model_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if mconfig.get("format", None) == "diffusers"
|
||||
else AutoencoderKL
|
||||
)
|
||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
||||
extra_arg_list = [{"revision": revision}] if revision \
|
||||
else [{"revision": "fp16"}, {}] if precision == "float16" \
|
||||
else [{}]
|
||||
path = None
|
||||
|
||||
# quench safety checker warnings
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
@@ -318,6 +327,7 @@ def _download_diffusion_weights(
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return path
|
||||
|
||||
|
||||
@@ -448,6 +458,8 @@ def new_config_file_contents(
|
||||
stanza["description"] = mod["description"]
|
||||
stanza["repo_id"] = mod["repo_id"]
|
||||
stanza["format"] = mod["format"]
|
||||
if "revision" in mod:
|
||||
stanza["revision"] = mod["revision"]
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if "width" in mod:
|
||||
@@ -472,10 +484,9 @@ def new_config_file_contents(
|
||||
|
||||
conf[model] = stanza
|
||||
|
||||
# if no default model was chosen, then we select the first
|
||||
# one in the list
|
||||
# if no default model was chosen, then we select the first one in the list
|
||||
if not default_selected:
|
||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
||||
conf[list(conf.keys())[0]]["default"] = True
|
||||
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
|
||||
@@ -423,9 +423,9 @@ class ModelManager(object):
|
||||
pipeline_args.update(cache_dir=global_cache_dir("hub"))
|
||||
if using_fp16:
|
||||
pipeline_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
fp_args_list = [{}]
|
||||
revision = mconfig.get('revision') or ('fp16' if using_fp16 else None)
|
||||
fp_args_list = [{"revision": revision}] if revision else []
|
||||
fp_args_list.append({})
|
||||
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
Reference in New Issue
Block a user