mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Pipeline tweaks, add cmd_opts parsing to sd api
This commit is contained in:
@@ -34,8 +34,6 @@ from apps.shark_studio.modules.ckpt_processing import (
|
||||
)
|
||||
from transformers import CLIPTokenizer
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
|
||||
sd_model_map = {
|
||||
"clip": {
|
||||
@@ -166,35 +164,40 @@ class StableDiffusion(SharkPipelineBase):
|
||||
del static_kwargs
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(self, scheduler, custom_weights, adapters, embeddings, is_img2img):
|
||||
print(
|
||||
f"\n[LOG] Preparing pipeline with scheduler {scheduler}"
|
||||
f"\n[LOG] Custom embeddings currently unsupported."
|
||||
)
|
||||
def prepare_pipe(self, custom_weights, adapters, embeddings, is_img2img):
|
||||
print(f"\n[LOG] Preparing pipeline...")
|
||||
self.is_img2img = is_img2img
|
||||
schedulers = get_schedulers(self.base_model_id)
|
||||
self.scheduler = schedulers[scheduler]
|
||||
self.image_processor = VaeImageProcessor()#do_convert_rgb=True)
|
||||
self.weights_path = os.path.join(get_checkpoints_path(), self.safe_name(self.base_model_id))
|
||||
self.schedulers = get_schedulers(self.base_model_id)
|
||||
|
||||
self.weights_path = os.path.join(
|
||||
get_checkpoints_path(), self.safe_name(self.base_model_id)
|
||||
)
|
||||
if not os.path.exists(self.weights_path):
|
||||
os.mkdir(self.weights_path)
|
||||
print(f"[LOG] Loaded scheduler: {scheduler}")
|
||||
|
||||
for model in adapters:
|
||||
self.model_map[model] = adapters[model]
|
||||
|
||||
for submodel in self.static_kwargs:
|
||||
if custom_weights:
|
||||
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
|
||||
if submodel not in ["clip", "clip2"]:
|
||||
self.static_kwargs[submodel]["external_weight_file"] = custom_weights
|
||||
self.static_kwargs[submodel][
|
||||
"external_weight_file"
|
||||
] = custom_weights_params
|
||||
else:
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
|
||||
self.weights_path, submodel + ".safetensors"
|
||||
)
|
||||
else:
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(self.weights_path, submodel + ".safetensors")
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
|
||||
self.weights_path, submodel + ".safetensors"
|
||||
)
|
||||
|
||||
self.get_compiled_map(pipe_id=self.pipe_id)
|
||||
print("\n[LOG] Pipeline successfully prepared for runtime.")
|
||||
return
|
||||
|
||||
|
||||
def encode_prompts_weight(
|
||||
self,
|
||||
prompt,
|
||||
@@ -335,9 +338,9 @@ class StableDiffusion(SharkPipelineBase):
|
||||
|
||||
latent_history.append(latents)
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
# self.log += (
|
||||
# f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
# print(
|
||||
# f"\n [LOG] step = {i} | timestep = {t} | time = {step_time:.2f}ms"
|
||||
# )
|
||||
step_time_sum += step_time
|
||||
|
||||
# if self.status == SD_STATE_CANCEL:
|
||||
@@ -371,51 +374,52 @@ class StableDiffusion(SharkPipelineBase):
|
||||
pil_images = self.image_processor.numpy_to_pil(images)
|
||||
return pil_images
|
||||
|
||||
def process_sd_init_image(self, sd_init_image, resample_type):
|
||||
if isinstance(sd_init_image, list):
|
||||
images = []
|
||||
for img in sd_init_image:
|
||||
img, _ = self.process_sd_init_image(img, resample_type)
|
||||
images.append(img)
|
||||
is_img2img = True
|
||||
return images, is_img2img
|
||||
if isinstance(sd_init_image, str):
|
||||
if os.path.isfile(sd_init_image):
|
||||
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
|
||||
image, is_img2img = self.process_sd_init_image(
|
||||
sd_init_image, resample_type
|
||||
)
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
elif isinstance(sd_init_image, Image.Image):
|
||||
image = sd_init_image.convert("RGB")
|
||||
elif sd_init_image:
|
||||
image = sd_init_image["image"].convert("RGB")
|
||||
else:
|
||||
image = None
|
||||
is_img2img = False
|
||||
if image:
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
image = image.resize((self.width, self.height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
is_img2img = True
|
||||
image = image_arr
|
||||
return image, is_img2img
|
||||
|
||||
# def process_sd_init_image(self, sd_init_image, resample_type):
|
||||
# if isinstance(sd_init_image, list):
|
||||
# images = []
|
||||
# for img in sd_init_image:
|
||||
# img, _ = self.process_sd_init_image(img, resample_type)
|
||||
# images.append(img)
|
||||
# is_img2img = True
|
||||
# return images, is_img2img
|
||||
# if isinstance(sd_init_image, str):
|
||||
# if os.path.isfile(sd_init_image):
|
||||
# sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
|
||||
# image, is_img2img = self.process_sd_init_image(
|
||||
# sd_init_image, resample_type
|
||||
# )
|
||||
# else:
|
||||
# image = None
|
||||
# is_img2img = False
|
||||
# elif isinstance(sd_init_image, Image.Image):
|
||||
# image = sd_init_image.convert("RGB")
|
||||
# elif sd_init_image:
|
||||
# image = sd_init_image["image"].convert("RGB")
|
||||
# else:
|
||||
# image = None
|
||||
# is_img2img = False
|
||||
# if image:
|
||||
# resample_type = (
|
||||
# resamplers[resample_type]
|
||||
# if resample_type in resampler_list
|
||||
# # Fallback to Lanczos
|
||||
# else Image.Resampling.LANCZOS
|
||||
# )
|
||||
# image = image.resize((self.width, self.height), resample=resample_type)
|
||||
# image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
# image_arr = image_arr / 255.0
|
||||
# image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(self.dtype)
|
||||
# image_arr = 2 * (image_arr - 0.5)
|
||||
# is_img2img = True
|
||||
# image = image_arr
|
||||
# return image, is_img2img
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image,
|
||||
scheduler,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
@@ -427,9 +431,11 @@ class StableDiffusion(SharkPipelineBase):
|
||||
hints,
|
||||
):
|
||||
# TODO: Batched args
|
||||
self.image_processor = VaeImageProcessor(do_convert_rgb=True)
|
||||
self.scheduler = self.schedulers[scheduler]
|
||||
self.ondemand = ondemand
|
||||
if self.is_img2img:
|
||||
image, _ = self.process_sd_init_image(image, resample_type)
|
||||
image, _ = self.image_processor.preprocess(image, resample_type)
|
||||
else:
|
||||
image = None
|
||||
|
||||
@@ -532,6 +538,8 @@ def shark_sd_fn(
|
||||
embeddings: dict,
|
||||
):
|
||||
sd_kwargs = locals()
|
||||
if not isinstance(sd_init_image, list):
|
||||
sd_init_image = [sd_init_image]
|
||||
is_img2img = True if sd_init_image[0] is not None else False
|
||||
|
||||
print("\n[LOG] Performing Stable Diffusion Pipeline setup...")
|
||||
@@ -581,7 +589,6 @@ def shark_sd_fn(
|
||||
"is_controlled": is_controlled,
|
||||
}
|
||||
submit_prep_kwargs = {
|
||||
"scheduler": scheduler,
|
||||
"custom_weights": custom_weights,
|
||||
"adapters": adapters,
|
||||
"embeddings": embeddings,
|
||||
@@ -592,6 +599,7 @@ def shark_sd_fn(
|
||||
"negative_prompt": negative_prompt,
|
||||
"image": sd_init_image,
|
||||
"steps": steps,
|
||||
"scheduler": scheduler,
|
||||
"strength": strength,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
@@ -667,5 +675,8 @@ if __name__ == "__main__":
|
||||
|
||||
sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
|
||||
sd_kwargs = json.loads(sd_json)
|
||||
for arg in vars(cmd_opts):
|
||||
if arg in sd_kwargs:
|
||||
sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
for i in shark_sd_fn_dict_input(sd_kwargs):
|
||||
print(i)
|
||||
|
||||
@@ -102,7 +102,7 @@ class SharkPipelineBase:
|
||||
write_to=os.path.join(self.pipe_vmfb_path, submodel + ".vmfb"),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def get_io_params(self, submodel):
|
||||
if "external_weight_file" in self.static_kwargs[submodel]:
|
||||
# we are using custom weights
|
||||
@@ -114,7 +114,7 @@ class SharkPipelineBase:
|
||||
# assume the torch IR contains the weights.
|
||||
weights_path = None
|
||||
return weights_path
|
||||
|
||||
|
||||
def get_precompiled(self, pipe_id, submodel="None"):
|
||||
if submodel == "None":
|
||||
for model in self.model_map:
|
||||
@@ -125,7 +125,9 @@ class SharkPipelineBase:
|
||||
break
|
||||
for file in vmfbs:
|
||||
if submodel in file:
|
||||
self.pipe_map[submodel]["vmfb_path"] = os.path.join(self.pipe_vmfb_path, file)
|
||||
self.pipe_map[submodel]["vmfb_path"] = os.path.join(
|
||||
self.pipe_vmfb_path, file
|
||||
)
|
||||
return
|
||||
|
||||
def import_torch_ir(self, submodel, kwargs):
|
||||
@@ -153,9 +155,9 @@ class SharkPipelineBase:
|
||||
continue
|
||||
if "vmfb_path" in self.pipe_map[submodel]:
|
||||
weights_path = self.get_io_params(submodel)
|
||||
print(
|
||||
f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
|
||||
)
|
||||
# print(
|
||||
# f"\n[LOG] Loading .vmfb for {submodel} from {self.pipe_map[submodel]['vmfb_path']}"
|
||||
# )
|
||||
self.iree_module_dict[submodel] = {}
|
||||
(
|
||||
self.iree_module_dict[submodel]["vmfb"],
|
||||
|
||||
66
apps/shark_studio/modules/seed.py
Normal file
66
apps/shark_studio/modules/seed.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import numpy as np
|
||||
import json
|
||||
from random import (
|
||||
randint,
|
||||
seed as seed_random,
|
||||
getstate as random_getstate,
|
||||
setstate as random_setstate,
|
||||
)
|
||||
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed: int | str):
|
||||
seed = int(seed)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
return seed
|
||||
|
||||
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(type(seed) is int for seed in seed_input):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
|
||||
|
||||
# Generate a set of seeds from an input expression for batch_count batches,
|
||||
# optionally using that input as the rng seed for any randomly generated seeds.
|
||||
def batch_seeds(seed_input: str | list | int, batch_count: int, repeatable=False):
|
||||
# turn the input into a list if possible
|
||||
seeds = parse_seed_input(seed_input)
|
||||
|
||||
# slice or pad the list to be of batch_count length
|
||||
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
|
||||
|
||||
if repeatable:
|
||||
if all(seed < 0 for seed in seeds):
|
||||
seeds[0] = sanitize_seed(seeds[0])
|
||||
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
seed_random(str([n for n in seeds if n > -1]))
|
||||
|
||||
# generate any seeds that are unspecified
|
||||
seeds = [sanitize_seed(seed) for seed in seeds]
|
||||
|
||||
if repeatable:
|
||||
# reset the rng back to normal
|
||||
random_setstate(saved_random_state)
|
||||
|
||||
return seeds
|
||||
@@ -32,7 +32,7 @@ p.add_argument(
|
||||
)
|
||||
p.add_argument(
|
||||
"-p",
|
||||
"--prompts",
|
||||
"--prompt",
|
||||
nargs="+",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
@@ -44,7 +44,7 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--negative_prompts",
|
||||
"--negative_prompt",
|
||||
nargs="+",
|
||||
default=[
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), "
|
||||
@@ -54,7 +54,7 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--img_path",
|
||||
"--sd_init_image",
|
||||
type=str,
|
||||
help="Path to the image input for img2img/inpainting.",
|
||||
)
|
||||
@@ -320,7 +320,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--scheduler",
|
||||
type=str,
|
||||
default="SharkEulerDiscrete",
|
||||
default="DDIM",
|
||||
help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, "
|
||||
"DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, "
|
||||
"DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, "
|
||||
@@ -359,10 +359,10 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--ckpt_loc",
|
||||
"--custom_weights",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to SD's .ckpt file.",
|
||||
help="Path to a .safetensors or .ckpt file for SD pipeline weights.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
@@ -374,7 +374,7 @@ p.add_argument(
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--hf_model_id",
|
||||
"--base_model_id",
|
||||
type=str,
|
||||
default="stabilityai/stable-diffusion-2-1-base",
|
||||
help="The repo-id of hugging face.",
|
||||
|
||||
@@ -1,24 +1 @@
|
||||
{
|
||||
"prompt": [ "a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))" ],
|
||||
"negative_prompt": [ "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped" ],
|
||||
"sd_init_image": [ null ],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": 50,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": -1,
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "runwayml/stable-diffusion-v1-5",
|
||||
"custom_weights": "",
|
||||
"custom_vae": "",
|
||||
"precision": "fp16",
|
||||
"device": "vulkan",
|
||||
"ondemand": false,
|
||||
"repeatable_seeds": false,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}
|
||||
{"prompt": ["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"], "negative_prompt": ["watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"], "sd_init_image": [null], "height": 512, "width": 512, "steps": 50, "strength": 0.8, "guidance_scale": 7.5, "seed": "-1", "batch_count": 1, "batch_size": 1, "scheduler": "EulerDiscrete", "base_model_id": "stabilityai/stable-diffusion-2-1-base", "custom_weights": "None", "custom_vae": "None", "precision": "fp16", "device": "AMD Radeon RX 7900 XTX => vulkan://0", "ondemand": false, "repeatable_seeds": false, "resample_type": "Nearest Neighbor", "controlnets": {}, "embeddings": {}}
|
||||
@@ -300,13 +300,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=cmd_opts.prompts[0],
|
||||
value=cmd_opts.prompt[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=cmd_opts.negative_prompts[0],
|
||||
value=cmd_opts.negative_prompt[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
@@ -54,11 +54,11 @@ def set_prep_kwargs(value):
|
||||
_prep_kwargs = value
|
||||
|
||||
|
||||
|
||||
def set_gen_kwargs(value):
|
||||
global _gen_kwargs
|
||||
_gen_kwargs = value
|
||||
|
||||
|
||||
def set_schedulers(value):
|
||||
global _schedulers
|
||||
_schedulers = value
|
||||
|
||||
@@ -3,7 +3,6 @@ import gc
|
||||
|
||||
|
||||
def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1):
|
||||
print(f"Getting status label for {tab_name}")
|
||||
if batch_index < batch_count:
|
||||
bs = f"x{batch_size}" if batch_size > 1 else ""
|
||||
return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}"
|
||||
|
||||
Reference in New Issue
Block a user