add support for img1img

This commit is contained in:
PhaneeshB
2023-02-08 13:08:01 +05:30
committed by Gaurav Shukla
parent b2f3c96835
commit 3159a6f3e1
6 changed files with 491 additions and 11 deletions

View File

@@ -0,0 +1,320 @@
import os
if "AMD_ENABLE_LLPC" not in os.environ:
os.environ["AMD_ENABLE_LLPC"] = "1"
import sys
import json
import torch
import re
import time
from pathlib import Path
from PIL import Image, PngImagePlugin
from datetime import datetime as dt
from dataclasses import dataclass
from csv import DictWriter
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
get_schedulers,
set_init_device_flags,
)
@dataclass
class Config:
model_id: str
ckpt_loc: str
precision: str
batch_size: int
max_length: int
height: int
width: int
device: str
# This has to come before importing cache objects
if args.clear_all:
print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE")
from glob import glob
import shutil
vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb"))
for vmfb in vmfbs:
if os.path.exists(vmfb):
os.remove(vmfb)
# Temporary workaround of deleting yaml files to incorporate diffusers' pipeline.
# TODO: Remove this once we have better weight updation logic.
inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"]
for yaml in inference_yaml:
if os.path.exists(yaml):
os.remove(yaml)
home = os.path.expanduser("~")
if os.name == "nt": # Windows
appdata = os.getenv("LOCALAPPDATA")
shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True)
shutil.rmtree(os.path.join(home, "shark_tank"), ignore_errors=True)
elif os.name == "unix":
shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache"))
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs correspoding to it.
def save_output_img(output_img):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(output_path, "generated_imgs")
generated_imgs_path.mkdir(parents=True, exist_ok=True)
csv_path = Path(generated_imgs_path, "imgs_details.csv")
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
else:
out_img_path = Path(generated_imgs_path, f"{out_img_name}.png")
pngInfo = PngImagePlugin.PngInfo()
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
if args.output_img_format not in ["png", "jpg"]:
print(
f"[ERROR] Format {args.output_img_format} is not supported yet."
"Image saved as png instead. Supported formats: png / jpg"
)
new_entry = {
"VARIANT": args.hf_model_id,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"IMG_INPUT": args.img_path,
"SEED": args.seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
"HEIGHT": args.height,
"WIDTH": args.width,
"MAX_LENGTH": args.max_length,
"OUTPUT": out_img_path,
}
with open(csv_path, "a") as csv_obj:
dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys()))
dictwriter_obj.writerow(new_entry)
csv_obj.close()
if args.save_metadata_to_json:
del new_entry["OUTPUT"]
json_path = Path(generated_imgs_path, f"{out_img_name}.json")
with open(json_path, "w") as f:
json.dump(new_entry, f, indent=4)
img2img_obj = None
config_obj = None
schedulers = None
# Exposed to UI.
def image2image_inf(
prompt: str,
negative_prompt: str,
image: Image,
height: int,
width: int,
steps: int,
guidance_scale: float,
seed: int,
batch_size: int,
scheduler: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
save_metadata_to_json: bool,
save_metadata_to_png: bool,
):
global img2img_obj
global config_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
args.hf_model_id,
args.ckpt_loc,
precision,
batch_size,
max_length,
height,
width,
device,
)
if config_obj != new_config_obj:
config_obj = new_config_obj
args.precision = precision
args.batch_size = batch_size
args.max_length = max_length
args.height = height
args.width = width
args.device = device.split("=>", 1)[1].strip()
args.use_tuned = True
args.import_mlir = False
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "runwayml/stable-diffusion-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
if not img2img_obj:
sys.exit("text to image pipeline must not return a null value")
img2img_obj.scheduler = schedulers[scheduler]
start_time = time.time()
img2img_obj.log = ""
generated_imgs = img2img_obj.generate_images(
prompt,
negative_prompt,
image,
batch_size,
height,
width,
steps,
guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
save_output_img(generated_imgs[0])
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += img2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
return generated_imgs, text_output
if __name__ == "__main__":
if args.img_path is None:
print("Flag --img_path is required.")
exit()
dtype = torch.float32 if args.precision == "fp32" else torch.half
cpu_scheduling = not args.scheduler.startswith("Shark")
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
image = Image.open(args.img_path)
# Adjust for height and width based on model
img2img_obj = Image2ImagePipeline.from_pretrained(
scheduler_obj,
args.import_mlir,
args.hf_model_id,
args.ckpt_loc,
args.precision,
args.max_length,
args.batch_size,
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
start_time = time.time()
generated_imgs = img2img_obj.generate_images(
args.prompts,
args.negative_prompts,
image,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
args.seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += img2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0])
print(text_output)

View File

@@ -7,5 +7,6 @@ from apps.stable_diffusion.src.utils import (
from apps.stable_diffusion.src.pipelines import (
Text2ImagePipeline,
InpaintPipeline,
Image2ImagePipeline,
)
from apps.stable_diffusion.src.schedulers import get_schedulers

View File

@@ -249,7 +249,7 @@ class SharkifyStableDiffusionModel:
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id, if_inpaint):
def compile_all(self, base_model_id, need_vae_encode):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
@@ -260,7 +260,7 @@ class SharkifyStableDiffusionModel:
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
if if_inpaint:
if need_vae_encode:
compiled_vae_encode = self.get_vae_encode()
return compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode
@@ -269,9 +269,9 @@ class SharkifyStableDiffusionModel:
def __call__(self):
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
if_inpaint = "inpaint" in self.model_id
need_vae_encode = args.img_path is not None
vmfbs = fetch_or_delete_vmfbs(
self.model_name, self.base_vae, if_inpaint, self.precision
self.model_name, self.base_vae, need_vae_encode, self.precision
)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
@@ -300,7 +300,7 @@ class SharkifyStableDiffusionModel:
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched)
return self.compile_all(base_model_fetched, need_vae_encode)
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
@@ -308,10 +308,10 @@ class SharkifyStableDiffusionModel:
print("Inferring base model configuration.")
for model_id in base_models:
try:
if if_inpaint:
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, if_inpaint)
if need_vae_encode:
compiled_clip, compiled_unet, compiled_vae, compiled_vae_encode = self.compile_all(model_id, need_vae_encode)
else:
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, if_inpaint)
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id, need_vae_encode)
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
@@ -326,7 +326,7 @@ class SharkifyStableDiffusionModel:
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
if if_inpaint:
if need_vae_encode:
return (
compiled_clip,
compiled_unet,

View File

@@ -4,3 +4,6 @@ from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import (
InpaintPipeline,
)
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
Image2ImagePipeline,
)

View File

@@ -0,0 +1,156 @@
import torch
import time
import numpy as np
from tqdm.auto import tqdm
from random import randint
from PIL import Image
from transformers import CLIPTokenizer
from typing import Union
from shark.shark_inference import SharkInference
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
StableDiffusionPipeline,
)
class Image2ImagePipeline(StableDiffusionPipeline):
def __init__(
self,
vae_encode: SharkInference,
vae: SharkInference,
text_encoder: SharkInference,
tokenizer: CLIPTokenizer,
unet: SharkInference,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
SharkEulerDiscreteScheduler,
],
):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
self.vae_encode = vae_encode
def prepare_image_latents(
self,
image,
batch_size,
height,
width,
generator,
num_inference_steps,
dtype,
):
# Pre process image -> get image encoded -> process latents
# TODO: process with variable HxW combos
# Pre process image
image = image.resize((height, width)) # Current support for 512x512
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
image_arr = image_arr / 255.0
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
image_arr = 2 * (image_arr - 0.5)
# image encode
latents = self.encode_image((image_arr,))
latents = torch.from_numpy(latents).to(dtype)
# set scheduler steps
self.scheduler.set_timesteps(num_inference_steps)
# add noise to data
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_image(self, input_image):
vae_encode_start = time.time()
latents = self.vae_encode("forward", input_image)
vae_inf_time = (time.time() - vae_encode_start) * 1000
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
return latents
def generate_images(
self,
prompts,
neg_prompts,
image,
batch_size,
height,
width,
num_inference_steps,
guidance_scale,
seed,
max_length,
dtype,
use_base_vae,
cpu_scheduling,
):
# prompts and negative prompts must be a list.
if isinstance(prompts, str):
prompts = [prompts]
if isinstance(neg_prompts, str):
neg_prompts = [neg_prompts]
prompts = prompts * batch_size
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
generator = torch.manual_seed(seed)
# Get text embeddings from prompts
text_embeddings = self.encode_prompts(prompts, neg_prompts, max_length)
# guidance scale as a float32 tensor.
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
# Prepare input image latent
image_latents = self.prepare_image_latents(
image=image,
batch_size=batch_size,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
dtype=dtype,
)
# Get Image latents
latents = self.produce_img_latents(
latents=image_latents,
text_embeddings=text_embeddings,
guidance_scale=guidance_scale,
total_timesteps=self.scheduler.timesteps,
dtype=dtype,
cpu_scheduling=cpu_scheduling,
)
# Img latents -> PIL images
all_imgs = []
for i in tqdm(range(0, latents.shape[0], batch_size)):
imgs = self.decode_latents(
latents=latents[i : i + batch_size],
use_base_vae=use_base_vae,
cpu_scheduling=cpu_scheduling,
)
all_imgs.extend(imgs)
return all_imgs

View File

@@ -212,7 +212,7 @@ class StableDiffusionPipeline:
use_base_vae=use_base_vae,
use_tuned=use_tuned,
)
if "inpaint" in model_id:
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
clip, unet, vae, vae_encode = mlir_import()
return cls(
vae_encode, vae, clip, get_tokenizer(), unet, scheduler
@@ -220,7 +220,7 @@ class StableDiffusionPipeline:
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)
if "inpaint" in model_id:
if cls.__name__ in ["Image2ImagePipeline", "InpaintPipeline"]:
return cls(
get_vae_encode(),
get_vae(),