SD - Fix civitai download on Windows +improvements (#1907)

This commit is contained in:
Stefan Kapusniak
2023-10-21 19:17:41 +01:00
committed by GitHub
parent 7cd14fdc47
commit 134441957d
4 changed files with 56 additions and 12 deletions

2
.gitignore vendored
View File

@@ -182,7 +182,7 @@ generated_imgs/
# Custom model related artefacts
variants.json
models/
/models/
# models folder
apps/stable_diffusion/web/models/

View File

@@ -8,6 +8,7 @@ import traceback
import subprocess
import sys
import os
import requests
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
@@ -16,6 +17,7 @@ from apps.stable_diffusion.src.utils import (
preprocessCKPT,
convert_original_vae,
get_path_to_diffusers_checkpoint,
get_civitai_checkpoint,
fetch_and_update_base_model_id,
get_path_stem,
get_extended_name,
@@ -94,21 +96,19 @@ class SharkifyStableDiffusionModel:
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.custom_weights = custom_weights.strip()
self.use_quantize = use_quantize
if custom_weights != "":
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
if custom_weights.startswith("https://civitai.com/api/"):
# download the checkpoint from civitai if we don't already have it
weights_path = get_civitai_checkpoint(custom_weights)
# act as if we were given the local file as custom_weights originally
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
# needed to ensure webui sets the correct model name metadata
args.ckpt_loc = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
@@ -116,6 +116,7 @@ class SharkifyStableDiffusionModel:
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":

View File

@@ -41,3 +41,4 @@ from apps.stable_diffusion.src.utils.utils import (
resize_stencil,
_compile_module,
)
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint

View File

@@ -0,0 +1,42 @@
import re
import requests
from apps.stable_diffusion.src.utils.stable_args import args
from pathlib import Path
from tqdm import tqdm
def get_civitai_checkpoint(url: str):
with requests.get(url, allow_redirects=True, stream=True) as response:
response.raise_for_status()
# civitai api returns the filename in the content disposition
base_filename = re.findall(
'"([^"]*)"', response.headers["Content-Disposition"]
)[0]
destination_path = (
Path.cwd() / (args.ckpt_dir or "models") / base_filename
)
# we don't have this model downloaded yet
if not destination_path.is_file():
print(
f"downloading civitai model from {url} to {destination_path}"
)
size = int(response.headers["content-length"], 0)
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=65536):
f.write(chunk)
progress_bar.update(len(chunk))
progress_bar.close()
# we already have this model downloaded
else:
print(f"civitai model already downloaded to {destination_path}")
response.close()
return destination_path.as_posix()