Pipeline tweaks, add cmd_opts parsing to sd api

This commit is contained in:
Ean Garvey
2023-12-20 13:42:39 -06:00
parent 12884591a5
commit a42ecb0be2
8 changed files with 157 additions and 102 deletions

View File

@@ -34,8 +34,6 @@ from apps.shark_studio.modules.ckpt_processing import (
)
from transformers import CLIPTokenizer
from diffusers.image_processor import VaeImageProcessor
from math import ceil
from PIL import Image
sd_model_map = {
"clip": {
@@ -166,35 +164,40 @@ class StableDiffusion(SharkPipelineBase):
del static_kwargs
gc.collect()
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img):
print(
f"\n[LOG] Preparing pipeline with scheduler {scheduler}"
f"\n[LOG] Custom embeddings currently unsupported."
)
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
print(f"\n[LOG] Preparing pipeline...")
self.is_img2img = is_img2img
schedulers = get_schedulers(self.base_model_id)
self.scheduler = schedulers[scheduler]
self.image_processor = VaeImageProcessor()#do_convert_rgb=True)
self.weights_path = os.path.join(get_checkpoints_path(), self.safe_name(self.base_model_id))
self.schedulers = get_schedulers(self.base_model_id)
self.weights_path = os.path.join(
get_checkpoints_path(), self.safe_name(self.base_model_id)
)
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)
print(f"[LOG] Loaded scheduler: {scheduler}")
for model in adapters:
self.model_map[model] = adapters[model]
for submodel in self.static_kwargs:
if custom_weights:
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
if submodel not in ["clip", "clip2"]:
self.static_kwargs[submodel]["external_weight_file"] = custom_weights
self.static_kwargs[submodel][
"external_weight_file"
] = custom_weights_params
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
self.weights_path, submodel + ".safetensors"
)
else:
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
self.weights_path, submodel + ".safetensors"
)
self.get_compiled_map(pipe_id=self.pipe_id)
print("\n[LOG] Pipeline successfully prepared for runtime.")
return
def encode_prompts_weight(
self,
prompt,
@@ -335,9 +338,9 @@ class StableDiffusion(SharkPipelineBase):
latent_history.append(latents)
step_time = (time.time() - step_start_time) * 1000
# self.log += (
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
# print(
# f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms"
# )
step_time_sum += step_time
# if self.status == SD_STATE_CANCEL:
@@ -371,51 +374,52 @@ class StableDiffusion(SharkPipelineBase):
pil_images = self.image_processor.numpy_to_pil(images)
return pil_images
def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, list):
images = []
for img in sd_init_image:
img, _ = self.process_sd_init_image(img, resample_type)
images.append(img)
is_img2img = True
return images, is_img2img
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
image, is_img2img = self.process_sd_init_image(
sd_init_image, resample_type
)
else:
image = None
is_img2img = False
elif isinstance(sd_init_image, Image.Image):
image = sd_init_image.convert("RGB")
elif sd_init_image:
image = sd_init_image["image"].convert("RGB")
else:
image = None
is_img2img = False
if image:
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((self.width, self.height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
image_arr = 2 * (image_arr - 0.5)
is_img2img = True
image = image_arr
return image, is_img2img
# def process_sd_init_image(self, sd_init_image, resample_type):
# if isinstance(sd_init_image, list):
# images = []
# for img in sd_init_image:
# img, _ = self.process_sd_init_image(img, resample_type)
# images.append(img)
# is_img2img = True
# return images, is_img2img
# if isinstance(sd_init_image, str):
# if os.path.isfile(sd_init_image):
# sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
# image, is_img2img = self.process_sd_init_image(
# sd_init_image, resample_type
# )
# else:
# image = None
# is_img2img = False
# elif isinstance(sd_init_image, Image.Image):
# image = sd_init_image.convert("RGB")
# elif sd_init_image:
# image = sd_init_image["image"].convert("RGB")
# else:
# image = None
# is_img2img = False
# if image:
# resample_type = (
# resamplers[resample_type]
# if resample_type in resampler_list
# # Fallback to Lanczos
# else Image.Resampling.LANCZOS
# )
# image = image.resize((self.width, self.height), resample=resample_type)
# image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
# image_arr = image_arr / 255.0
# image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
# image_arr = 2 * (image_arr - 0.5)
# is_img2img = True
# image = image_arr
# return image, is_img2img
def generate_images(
self,
prompt,
negative_prompt,
image,
scheduler,
steps,
strength,
guidance_scale,
@@ -427,9 +431,11 @@ class StableDiffusion(SharkPipelineBase):
hints,
):
# TODO: Batched args
self.image_processor = VaeImageProcessor(do_convert_rgb=True)
self.scheduler = self.schedulers[scheduler]
self.ondemand = ondemand
if self.is_img2img:
image, _ = self.process_sd_init_image(image, resample_type)
image, _ = self.image_processor.preprocess(image, resample_type)
else:
image = None
@@ -532,6 +538,8 @@ def shark_sd_fn(
embeddings: dict,
):
sd_kwargs = locals()
if not isinstance(sd_init_image, list):
sd_init_image = [sd_init_image]
is_img2img = True if sd_init_image[0] is not None else False
print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
@@ -581,7 +589,6 @@ def shark_sd_fn(
"is_controlled": is_controlled,
}
submit_prep_kwargs = {
"scheduler": scheduler,
"custom_weights": custom_weights,
"adapters": adapters,
"embeddings": embeddings,
@@ -592,6 +599,7 @@ def shark_sd_fn(
"negative_prompt": negative_prompt,
"image": sd_init_image,
"steps": steps,
"scheduler": scheduler,
"strength": strength,
"guidance_scale": guidance_scale,
"seed": seed,
@@ -667,5 +675,8 @@ if __name__ == "__main__":
sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
sd_kwargs = json.loads(sd_json)
for arg in vars(cmd_opts):
if arg in sd_kwargs:
sd_kwargs[arg] = getattr(cmd_opts, arg)
for i in shark_sd_fn_dict_input(sd_kwargs):
print(i)

View File

@@ -102,7 +102,7 @@ class SharkPipelineBase:
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
)
return
def get_io_params(self, submodel):
if "external_weight_file" in self.static_kwargs[submodel]:
# we are using custom weights
@@ -114,7 +114,7 @@ class SharkPipelineBase:
# assume the torch IR contains the weights.
weights_path = None
return weights_path
def get_precompiled(self, pipe_id, submodel="None"):
if submodel == "None":
for model in self.model_map:
@@ -125,7 +125,9 @@ class SharkPipelineBase:
break
for file in vmfbs:
if submodel in file:
self.pipe_map[submodel]["vmfb_path"] = os.path.join(self.pipe_vmfb_path, file)
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
self.pipe_vmfb_path, file
)
return
def import_torch_ir(self, submodel, kwargs):
@@ -153,9 +155,9 @@ class SharkPipelineBase:
continue
if "vmfb_path" in self.pipe_map[submodel]:
weights_path = self.get_io_params(submodel)
print(
f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
)
# print(
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
# )
self.iree_module_dict[submodel] = {}
(
self.iree_module_dict[submodel]["vmfb"],

View File

@@ -0,0 +1,66 @@
import numpy as np
import json
from random import (
randint,
seed as seed_random,
getstate as random_getstate,
setstate as random_setstate,
)
# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
seed = int(seed)
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed
# take a seed expression in an input format and convert it to
# a list of integers, where possible
def parse_seed_input(seed_input: str | list | int):
if isinstance(seed_input, str):
try:
seed_input = json.loads(seed_input)
except (ValueError, TypeError):
seed_input = None
if isinstance(seed_input, int):
return [seed_input]
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
return seed_input
raise TypeError(
"Seed input must be an integer or an array of integers in JSON format"
)
# Generate a set of seeds from an input expression for batch_count batches,
# optionally using that input as the rng seed for any randomly generated seeds.
def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False):
# turn the input into a list if possible
seeds = parse_seed_input(seed_input)
# slice or pad the list to be of batch_count length
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
if repeatable:
if all(seed < 0 for seed in seeds):
seeds[0] = sanitize_seed(seeds[0])
# set seed for the rng based on what we have so far
saved_random_state = random_getstate()
seed_random(str([n for n in seeds if n > -1]))
# generate any seeds that are unspecified
seeds = [sanitize_seed(seed) for seed in seeds]
if repeatable:
# reset the rng back to normal
random_setstate(saved_random_state)
return seeds

View File

@@ -32,7 +32,7 @@ p.add_argument(
)
p.add_argument(
"-p",
"--prompts",
"--prompt",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
@@ -44,7 +44,7 @@ p.add_argument(
)
p.add_argument(
"--negative_prompts",
"--negative_prompt",
nargs="+",
default=[
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
@@ -54,7 +54,7 @@ p.add_argument(
)
p.add_argument(
"--img_path",
"--sd_init_image",
type=str,
help="Path to the image input for img2img/inpainting.",
)
@@ -320,7 +320,7 @@ p.add_argument(
p.add_argument(
"--scheduler",
type=str,
default="SharkEulerDiscrete",
default="DDIM",
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
@@ -359,10 +359,10 @@ p.add_argument(
)
p.add_argument(
"--ckpt_loc",
"--custom_weights",
type=str,
default="",
help="Path to SD's .ckpt file.",
help="Path to a .safetensors or .ckpt file for SD pipeline weights.",
)
p.add_argument(
@@ -374,7 +374,7 @@ p.add_argument(
)
p.add_argument(
"--hf_model_id",
"--base_model_id",
type=str,
default="stabilityai/stable-diffusion-2-1-base",
help="The repo-id of hugging face.",

View File

@@ -1,24 +1 @@
{
"prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ],
"negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ],
"sd_init_image": [ null ],
"height": 512,
"width": 512,
"steps": 50,
"strength": 0.8,
"guidance_scale": 7.5,
"seed": -1,
"batch_count": 1,
"batch_size": 1,
"scheduler": "EulerDiscrete",
"base_model_id": "runwayml/stable-diffusion-v1-5",
"custom_weights": "",
"custom_vae": "",
"precision": "fp16",
"device": "vulkan",
"ondemand": false,
"repeatable_seeds": false,
"resample_type": "Nearest Neighbor",
"controlnets": {},
"embeddings": {}
}
{"prompt": ["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"], "negative_prompt": ["watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"], "sd_init_image": [null], "height": 512, "width": 512, "steps": 50, "strength": 0.8, "guidance_scale": 7.5, "seed": "-1", "batch_count": 1, "batch_size": 1, "scheduler": "EulerDiscrete", "base_model_id": "stabilityai/stable-diffusion-2-1-base", "custom_weights": "None", "custom_vae": "None", "precision": "fp16", "device": "AMD Radeon RX 7900 XTX => vulkan://0", "ondemand": false, "repeatable_seeds": false, "resample_type": "Nearest Neighbor", "controlnets": {}, "embeddings": {}}

View File

@@ -300,13 +300,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=cmd_opts.prompts[0],
value=cmd_opts.prompt[0],
lines=2,
elem_id="prompt_box",
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=cmd_opts.negative_prompts[0],
value=cmd_opts.negative_prompt[0],
lines=2,
elem_id="negative_prompt_box",
)

View File

@@ -54,11 +54,11 @@ def set_prep_kwargs(value):
_prep_kwargs = value
def set_gen_kwargs(value):
global _gen_kwargs
_gen_kwargs = value
def set_schedulers(value):
global _schedulers
_schedulers = value

View File

@@ -3,7 +3,6 @@ import gc
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
print(f"Getting status label for {tab_name}")
if batch_index < batch_count:
bs = f"x{batch_size}" if batch_size > 1 else ""
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"