mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
SD - Fix civitai download on Windows +improvements (#1907)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -182,7 +182,7 @@ generated_imgs/
|
||||
|
||||
# Custom model related artefacts
|
||||
variants.json
|
||||
models/
|
||||
/models/
|
||||
|
||||
# models folder
|
||||
apps/stable_diffusion/web/models/
|
||||
|
||||
@@ -8,6 +8,7 @@ import traceback
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_opt_flags,
|
||||
@@ -16,6 +17,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
preprocessCKPT,
|
||||
convert_original_vae,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
get_civitai_checkpoint,
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
@@ -94,21 +96,19 @@ class SharkifyStableDiffusionModel:
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.custom_weights = custom_weights
|
||||
self.custom_weights = custom_weights.strip()
|
||||
self.use_quantize = use_quantize
|
||||
if custom_weights != "":
|
||||
if "civitai" in custom_weights:
|
||||
weights_id = custom_weights.split("/")[-1]
|
||||
# TODO: use model name and identify file type by civitai rest api
|
||||
weights_path = (
|
||||
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
)
|
||||
if not os.path.isfile(weights_path):
|
||||
subprocess.run(
|
||||
["wget", custom_weights, "-O", weights_path]
|
||||
)
|
||||
if custom_weights.startswith("https://civitai.com/api/"):
|
||||
# download the checkpoint from civitai if we don't already have it
|
||||
weights_path = get_civitai_checkpoint(custom_weights)
|
||||
|
||||
# act as if we were given the local file as custom_weights originally
|
||||
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
|
||||
self.custom_weights = weights_path
|
||||
|
||||
# needed to ensure webui sets the correct model name metadata
|
||||
args.ckpt_loc = weights_path
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
@@ -116,6 +116,7 @@ class SharkifyStableDiffusionModel:
|
||||
custom_weights = get_path_to_diffusers_checkpoint(
|
||||
custom_weights
|
||||
)
|
||||
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
|
||||
@@ -41,3 +41,4 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
resize_stencil,
|
||||
_compile_module,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
|
||||
|
||||
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
42
apps/stable_diffusion/src/utils/civitai.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import re
|
||||
import requests
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_civitai_checkpoint(url: str):
|
||||
with requests.get(url, allow_redirects=True, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
# civitai api returns the filename in the content disposition
|
||||
base_filename = re.findall(
|
||||
'"([^"]*)"', response.headers["Content-Disposition"]
|
||||
)[0]
|
||||
destination_path = (
|
||||
Path.cwd() / (args.ckpt_dir or "models") / base_filename
|
||||
)
|
||||
|
||||
# we don't have this model downloaded yet
|
||||
if not destination_path.is_file():
|
||||
print(
|
||||
f"downloading civitai model from {url} to {destination_path}"
|
||||
)
|
||||
|
||||
size = int(response.headers["content-length"], 0)
|
||||
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
||||
|
||||
with open(destination_path, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=65536):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
|
||||
progress_bar.close()
|
||||
|
||||
# we already have this model downloaded
|
||||
else:
|
||||
print(f"civitai model already downloaded to {destination_path}")
|
||||
|
||||
response.close()
|
||||
return destination_path.as_posix()
|
||||
Reference in New Issue
Block a user