Merge branch 'main' into ean-gen-sharktank

This commit is contained in:
Ean Garvey
2023-03-22 11:26:03 -05:00
committed by GitHub
15 changed files with 88 additions and 78 deletions

View File

@@ -2,6 +2,7 @@ import sys
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -15,8 +16,6 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -85,8 +84,6 @@ def img2img_inf(
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -175,8 +172,8 @@ def img2img_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
if use_stencil is not None:
args.use_tuned = False
@@ -221,7 +218,7 @@ def img2img_inf(
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -13,8 +14,6 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -56,8 +55,6 @@ def inpaint_inf(
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -125,14 +122,15 @@ def inpaint_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
InpaintPipeline.from_pretrained(
scheduler=scheduler_obj,
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,
@@ -146,7 +144,7 @@ def inpaint_inf(
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -223,6 +221,7 @@ if __name__ == "__main__":
import_mlir=args.import_mlir,
model_id=args.hf_model_id,
ckpt_loc=args.ckpt_loc,
custom_vae=args.custom_vae,
precision=args.precision,
max_length=args.max_length,
batch_size=args.batch_size,

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
OutpaintPipeline,
@@ -13,8 +14,6 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -59,8 +58,6 @@ def outpaint_inf(
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -127,8 +124,8 @@ def outpaint_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-inpainting"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
OutpaintPipeline.from_pretrained(
scheduler_obj,
@@ -147,7 +144,7 @@ def outpaint_inf(
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""

View File

@@ -159,9 +159,6 @@ class LoraDataset(Dataset):
return example
schedulers = None
########## Setting up the model ##########
def lora_train(
prompt: str,
@@ -187,8 +184,6 @@ def lora_train(
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
print(
"Note LoRA training is not compatible with the latest torch-mlir branch"
)
@@ -227,7 +222,7 @@ def lora_train(
args.max_length = max_length
args.height = height
args.width = width
args.device = device
args.device = device.split("=>", 1)[1].strip()
# Load the Stable Diffusion model
text_encoder = CLIPTextModel.from_pretrained(

View File

@@ -1,4 +1,5 @@
import torch
import transformers
import time
from apps.stable_diffusion.src import (
args,
@@ -11,7 +12,6 @@ from apps.stable_diffusion.src import (
)
from apps.stable_diffusion.src.utils import get_generation_text_info
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
@@ -51,8 +51,6 @@ def txt2img_inf(
SD_STATE_CANCEL,
)
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -119,8 +117,8 @@ def txt2img_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
Text2ImagePipeline.from_pretrained(
scheduler=scheduler_obj,
@@ -141,7 +139,7 @@ def txt2img_inf(
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.set_sd_scheduler(scheduler)
start_time = time.time()
global_obj.get_sd_obj().log = ""
@@ -208,6 +206,7 @@ if __name__ == "__main__":
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
use_quantize=args.use_quantize,
)
for current_batch in range(args.batch_count):

View File

@@ -1,6 +1,7 @@
import torch
import time
from PIL import Image
import transformers
from apps.stable_diffusion.src import (
args,
UpscalerPipeline,
@@ -12,8 +13,6 @@ from apps.stable_diffusion.src import (
)
schedulers = None
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_use_tuned = args.use_tuned
@@ -51,8 +50,6 @@ def upscaler_inf(
)
import apps.stable_diffusion.web.utils.global_obj as global_obj
global schedulers
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
@@ -121,8 +118,8 @@ def upscaler_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
global_obj.set_sd_obj(
UpscalerPipeline.from_pretrained(
scheduler_obj,
@@ -142,8 +139,10 @@ def upscaler_inf(
)
)
global_obj.set_schedulers(schedulers[scheduler])
global_obj.get_sd_obj().low_res_scheduler = schedulers["DDPM"]
global_obj.set_sd_scheduler(scheduler)
global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler(
"DDPM"
)
start_time = time.time()
global_obj.get_sd_obj().log = ""

View File

@@ -100,7 +100,8 @@ class SharkifyStableDiffusionModel:
is_inpaint: bool = False,
is_upscaler: bool = False,
use_stencil: str = None,
use_lora: str = ""
use_lora: str = "",
use_quantize: str = None,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -108,6 +109,7 @@ class SharkifyStableDiffusionModel:
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
self.use_quantize = use_quantize
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
@@ -570,7 +572,12 @@ class SharkifyStableDiffusionModel:
compiled_controlnet = self.get_control_net()
compiled_controlled_unet = self.get_controlled_unet()
else:
compiled_unet = self.get_unet()
# TODO: Plug the experimental "int8" support at right place.
if self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
compiled_unet = get_unet()
else:
compiled_unet = self.get_unet()
if self.custom_vae != "":
print("Plugging in custom Vae")
compiled_vae = self.get_vae()

View File

@@ -20,6 +20,15 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
bucket_key = "gs://shark_tank/prashant_nod"
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]
@@ -41,6 +50,12 @@ def get_unet():
variant, version = get_variant_version(args.hf_model_id)
# Tuned model is present only for `fp16` precision.
is_tuned = "tuned" if args.use_tuned else "untuned"
# TODO: Get the quantize model from model_db.json
if args.use_quantize == "int8":
bk, mk, flags = get_quantize_model()
return get_shark_model(bk, mk, flags)
if "vulkan" not in args.device and args.use_tuned:
bucket_key = f"{variant}/{is_tuned}/{args.device}"
model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}"

View File

@@ -321,6 +321,7 @@ class StableDiffusionPipeline:
use_stencil: str = None,
use_lora: str = "",
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
is_inpaint = cls.__name__ in [
"InpaintPipeline",
@@ -349,6 +350,7 @@ class StableDiffusionPipeline:
is_upscaler=is_upscaler,
use_stencil=use_stencil,
use_lora=use_lora,
use_quantize=use_quantize,
)
if cls.__name__ in [
"Image2ImagePipeline",

View File

@@ -340,6 +340,14 @@ p.add_argument(
help="Use standalone LoRA weight using a HF ID or a checkpoint file (~3 MB)",
)
p.add_argument(
"--use_quantize",
type=str,
default="none",
help="""Runs the quantized version of stable diffusion model. This is currently in experimental phase.
Currently, only runs the stable-diffusion-2-1-base model in int8 quantization.""",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################

View File

@@ -80,7 +80,7 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
frontend="torch",
)
shark_module = SharkInference(
mlir_model, device=args.device, mlir_dialect="linalg"
mlir_model, device=args.device, mlir_dialect="tm_tensor"
)
return _compile_module(shark_module, model_name, extra_args)
@@ -126,14 +126,14 @@ def compile_through_fx(
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
if generate_vmfb:
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
mlir_dialect="tm_tensor",
)
del mlir_module
gc.collect()

View File

@@ -1,5 +1,6 @@
import os
import sys
import transformers
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"

View File

@@ -11,8 +11,10 @@ Also we could avoid memory leak when switching models by clearing the cache.
def _init():
global _sd_obj
global _config_obj
global _schedulers
_sd_obj = None
_config_obj = None
_schedulers = None
def set_sd_obj(value):
@@ -20,9 +22,9 @@ def set_sd_obj(value):
_sd_obj = value
def set_schedulers(value):
def set_sd_scheduler(key):
global _sd_obj
_sd_obj.scheduler = value
_sd_obj.scheduler = _schedulers[key]
def set_sd_status(value):
@@ -35,6 +37,11 @@ def set_cfg_obj(value):
_config_obj = value
def set_schedulers(value):
global _schedulers
_schedulers = value
def get_sd_obj():
return _sd_obj
@@ -47,6 +54,10 @@ def get_cfg_obj():
return _config_obj
def get_scheduler(key):
return _schedulers[key]
def clear_cache():
global _sd_obj
global _config_obj

View File

@@ -6,36 +6,16 @@ from distutils.sysconfig import get_python_lib
import fileinput
from pathlib import Path
# Diffusers 0.13.1 fails with transformers __init.py errros in BLIP. So remove it for now until we fork it
pix2pix_init = Path(get_python_lib() + "/diffusers/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(get_python_lib() + "/diffusers/pipelines/__init__.py")
for line in fileinput.input(pix2pix_init, inplace=True):
if "Pix2Pix" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
pix2pix_init = Path(
get_python_lib() + "/diffusers/pipelines/stable_diffusion/__init__.py"
# Temorary workaround for transformers/__init__.py.
path_to_tranformers_hook = Path(
get_python_lib()
+ "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py"
)
for line in fileinput.input(pix2pix_init, inplace=True):
if "StableDiffusionPix2PixZeroPipeline" in line:
if not line.startswith("#"):
print(f"#{line}", end="")
else:
print(f"{line[1:]}", end="")
else:
print(line, end="")
if path_to_tranformers_hook.is_file():
pass
else:
with open(path_to_tranformers_hook, "w") as f:
f.write("module_collection_mode = 'pyz+py'")
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")

View File

@@ -52,7 +52,7 @@ def get_iree_device_args(device, extra_args=[]):
# Get the iree-compiler arguments given frontend.
def get_iree_frontend_args(frontend):
if frontend in ["torch", "pytorch", "linalg"]:
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
return ["--iree-llvmcpu-target-cpu-features=host"]
elif frontend in ["tensorflow", "tf", "mhlo"]:
return [