mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
146 lines
5.4 KiB
Python
146 lines
5.4 KiB
Python
import os
|
|
import json
|
|
import re
|
|
import requests
|
|
import torch
|
|
import safetensors
|
|
from iree.turbine.aot.params import (
|
|
ParameterArchiveBuilder,
|
|
)
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from tqdm import tqdm
|
|
from omegaconf import OmegaConf
|
|
from diffusers import StableDiffusionPipeline
|
|
from apps.amdshark_studio.modules.shared_cmd_opts import cmd_opts
|
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
|
download_from_original_stable_diffusion_ckpt,
|
|
create_vae_diffusers_config,
|
|
convert_ldm_vae_checkpoint,
|
|
)
|
|
|
|
|
|
def get_path_to_diffusers_checkpoint(custom_weights, precision="fp16"):
|
|
path = Path(custom_weights)
|
|
diffusers_path = path.parent.absolute()
|
|
diffusers_directory_name = os.path.join("diffusers", path.stem + f"_{precision}")
|
|
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
|
|
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
|
|
path_to_diffusers = complete_path_to_diffusers.as_posix()
|
|
return path_to_diffusers
|
|
|
|
|
|
def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
|
|
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights, precision)
|
|
if next(Path(path_to_diffusers).iterdir(), None):
|
|
print("Checkpoint already loaded at : ", path_to_diffusers)
|
|
return path_to_diffusers
|
|
else:
|
|
print(
|
|
"Diffusers' checkpoint will be identified here : ",
|
|
path_to_diffusers,
|
|
)
|
|
from_safetensors = (
|
|
True if custom_weights.lower().endswith(".safetensors") else False
|
|
)
|
|
# EMA weights usually yield higher quality images for inference but
|
|
# non-EMA weights have been yielding better results in our case.
|
|
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if
|
|
# they want to go for EMA weight extraction or not.
|
|
extract_ema = False
|
|
print("Loading diffusers' pipeline from original stable diffusion checkpoint")
|
|
num_in_channels = 9 if is_inpaint else 4
|
|
pipe = download_from_original_stable_diffusion_ckpt(
|
|
checkpoint_path_or_dict=custom_weights,
|
|
extract_ema=extract_ema,
|
|
from_safetensors=from_safetensors,
|
|
num_in_channels=num_in_channels,
|
|
)
|
|
if precision == "fp16":
|
|
pipe.to(dtype=torch.float16)
|
|
pipe.save_pretrained(path_to_diffusers)
|
|
del pipe
|
|
print("Loading complete")
|
|
return path_to_diffusers
|
|
|
|
|
|
def save_irpa(weights_path, prepend_str):
|
|
weights = safetensors.torch.load_file(weights_path)
|
|
archive = ParameterArchiveBuilder()
|
|
for key in weights.keys():
|
|
new_key = prepend_str + key
|
|
archive.add_tensor(new_key, weights[key])
|
|
|
|
irpa_file = weights_path.replace(".safetensors", ".irpa")
|
|
archive.save(irpa_file)
|
|
return irpa_file
|
|
|
|
|
|
def convert_original_vae(vae_checkpoint):
|
|
vae_state_dict = {}
|
|
for key in list(vae_checkpoint.keys()):
|
|
vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key)
|
|
|
|
config_url = (
|
|
"https://raw.githubusercontent.com/CompVis/stable-diffusion/"
|
|
"main/configs/stable-diffusion/v1-inference.yaml"
|
|
)
|
|
original_config_file = BytesIO(requests.get(config_url).content)
|
|
original_config = OmegaConf.load(original_config_file)
|
|
vae_config = create_vae_diffusers_config(original_config, image_size=512)
|
|
|
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_state_dict, vae_config)
|
|
return converted_vae_checkpoint
|
|
|
|
|
|
def process_custom_pipe_weights(custom_weights):
|
|
if custom_weights != "":
|
|
if custom_weights.startswith("https://civitai.com/api/"):
|
|
# download the checkpoint from civitai if we don't already have it
|
|
weights_path = get_civitai_checkpoint(custom_weights)
|
|
|
|
# act as if we were given the local file as custom_weights originally
|
|
custom_weights_tgt = get_path_to_diffusers_checkpoint(weights_path)
|
|
custom_weights_params = weights_path
|
|
|
|
else:
|
|
assert custom_weights.lower().endswith(
|
|
(".ckpt", ".safetensors")
|
|
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
|
custom_weights_tgt = get_path_to_diffusers_checkpoint(custom_weights)
|
|
custom_weights_params = custom_weights
|
|
|
|
return custom_weights_params, custom_weights_tgt
|
|
|
|
|
|
def get_civitai_checkpoint(url: str):
|
|
with requests.get(url, allow_redirects=True, stream=True) as response:
|
|
response.raise_for_status()
|
|
|
|
# civitai api returns the filename in the content disposition
|
|
base_filename = re.findall(
|
|
'"([^"]*)"', response.headers["Content-Disposition"]
|
|
)[0]
|
|
destination_path = Path.cwd() / (cmd_opts.model_dir or "models") / base_filename
|
|
|
|
# we don't have this model downloaded yet
|
|
if not destination_path.is_file():
|
|
print(f"downloading civitai model from {url} to {destination_path}")
|
|
|
|
size = int(response.headers["content-length"], 0)
|
|
progress_bar = tqdm(total=size, unit="iB", unit_scale=True)
|
|
|
|
with open(destination_path, "wb") as f:
|
|
for chunk in response.iter_content(chunk_size=65536):
|
|
f.write(chunk)
|
|
progress_bar.update(len(chunk))
|
|
|
|
progress_bar.close()
|
|
|
|
# we already have this model downloaded
|
|
else:
|
|
print(f"civitai model already downloaded to {destination_path}")
|
|
|
|
response.close()
|
|
return destination_path.as_posix()
|