[Vicuna] Revert the formatting for Brevitas op (#1626)

-- This commit reverts the formatting for Brevitas op.
-- It also excludes vicuna.py script from `black` formatter.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-07-06 19:26:17 +05:30
committed by GitHub
parent c5a47887f4
commit 1b62dc4529
6 changed files with 223 additions and 78 deletions

View File

@@ -2,4 +2,4 @@
count = 1
show-source = 1
select = E9,F63,F7,F82
exclude = lit.cfg.py
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py

3
.gitignore vendored
View File

@@ -189,6 +189,3 @@ apps/stable_diffusion/web/models/
# Stencil annotators.
stencil_annotator/
# brevitas custom op lib
apps/language_models/scripts/vicuna.py

View File

@@ -103,14 +103,7 @@ parser.add_argument(
)
def brevitasmatmul_rhs_group_quant〡shape(
lhs: List[int],
rhs: List[int],
rhs_scale: List[int],
rhs_zero_point: List[int],
rhs_bit_width: int,
rhs_group_size: int,
) -> List[int]:
def brevitasmatmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
if len(lhs) == 3 and len(rhs) == 2:
return [lhs[0], lhs[1], rhs[0]]
elif len(lhs) == 2 and len(rhs) == 2:
@@ -119,30 +112,20 @@ def brevitasmatmul_rhs_group_quant〡shape(
raise ValueError("Input shapes not supported.")
def brevitasmatmul_rhs_group_quant〡dtype(
lhs_rank_dtype: Tuple[int, int],
rhs_rank_dtype: Tuple[int, int],
rhs_scale_rank_dtype: Tuple[int, int],
rhs_zero_point_rank_dtype: Tuple[int, int],
rhs_bit_width: int,
rhs_group_size: int,
) -> int:
def brevitasmatmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
# output dtype is the dtype of the lhs float input
lhs_rank, lhs_dtype = lhs_rank_dtype
return lhs_dtype
def brevitasmatmul_rhs_group_quant〡has_value_semantics(
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
) -> None:
def brevitasmatmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
return
brevitas_matmul_rhs_group_quant_library = [
brevitasmatmul_rhs_group_quant〡shape,
brevitasmatmul_rhs_group_quant〡dtype,
brevitasmatmul_rhs_group_quant〡has_value_semantics,
]
brevitasmatmul_rhs_group_quant〡has_value_semantics]
class ShardedVicuna(SharkLLMBase):

View File

@@ -45,6 +45,7 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
new_shape.append(width * mul_val)
elif "/" in shape[i]:
import math
div_val = int(shape[i].split("/")[1])
if "batch_size" in shape[i]:
new_shape.append(math.ceil(batch_size / div_val))
@@ -59,7 +60,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
def check_compilation(model, model_name):
if not model:
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
raise Exception(
f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
)
class SharkifyStableDiffusionModel:
@@ -97,16 +100,22 @@ class SharkifyStableDiffusionModel:
if "civitai" in custom_weights:
weights_id = custom_weights.split("/")[-1]
# TODO: use model name and identify file type by civitai rest api
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
weights_path = (
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
)
if not os.path.isfile(weights_path):
subprocess.run(["wget", custom_weights, "-O", weights_path])
subprocess.run(
["wget", custom_weights, "-O", weights_path]
)
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
self.custom_weights = weights_path
else:
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
custom_weights = get_path_to_diffusers_checkpoint(
custom_weights
)
self.model_id = model_id if custom_weights == "" else custom_weights
# TODO: remove the following line when stable-diffusion-2-1 works
if self.model_id == "stabilityai/stable-diffusion-2-1":
@@ -126,7 +135,7 @@ class SharkifyStableDiffusionModel:
+ "_"
+ precision
)
print(f'use_tuned? sharkify: {use_tuned}')
print(f"use_tuned? sharkify: {use_tuned}")
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
@@ -163,14 +172,24 @@ class SharkifyStableDiffusionModel:
def get_extended_name_for_all_model(self):
model_name = {}
sub_model_list = ["clip", "unet", "unet512", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
sub_model_list = [
"clip",
"unet",
"unet512",
"stencil_unet",
"vae",
"vae_encode",
"stencil_adaptor",
]
index = 0
for model in sub_model_list:
sub_model = model
model_config = self.model_name
if "vae" == model:
if self.custom_vae != "":
model_config = model_config + get_path_stem(self.custom_vae)
model_config = model_config + get_path_stem(
self.custom_vae
)
if self.base_vae:
sub_model = "base_vae"
if "stencil_adaptor" == model and self.use_stencil is not None:
@@ -197,7 +216,11 @@ class SharkifyStableDiffusionModel:
tensor = None
if isinstance(shape, list):
clean_shape = replace_shape_str(
shape, self.max_len, self.width, self.height, self.batch_size
shape,
self.max_len,
self.width,
self.height,
self.batch_size,
)
if dtype == torch.int64:
tensor = torch.randint(1, 3, tuple(clean_shape))
@@ -209,10 +232,12 @@ class SharkifyStableDiffusionModel:
sys.exit("shape isn't specified correctly.")
input_map.append(tensor)
return input_map
def get_vae_encode(self):
class VaeEncodeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.vae = AutoencoderKL.from_pretrained(
model_id,
@@ -226,7 +251,11 @@ class SharkifyStableDiffusionModel:
vae_encode = VaeEncodeModel()
inputs = tuple(self.inputs["vae_encode"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
shark_vae_encode, vae_encode_mlir = compile_through_fx(
vae_encode,
inputs,
@@ -243,7 +272,13 @@ class SharkifyStableDiffusionModel:
def get_vae(self):
class VaeModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
def __init__(
self,
model_id=self.model_id,
base_vae=self.base_vae,
custom_vae=self.custom_vae,
low_cpu_mem_usage=False,
):
super().__init__()
self.vae = None
if custom_vae == "":
@@ -279,7 +314,11 @@ class SharkifyStableDiffusionModel:
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
inputs = tuple(self.inputs["vae"])
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
is_f16 = (
True
if not self.is_upscaler and self.precision == "fp16"
else False
)
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
if self.debug:
os.makedirs(save_dir, exist_ok=True)
@@ -303,7 +342,10 @@ class SharkifyStableDiffusionModel:
def get_controlled_unet(self):
class ControlledUnetModel(torch.nn.Module):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -316,12 +358,43 @@ class SharkifyStableDiffusionModel:
self.in_channels = self.unet.in_channels
self.train(False)
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
control2, control3, control4, control5, control6, control7,
control8, control9, control10, control11, control12, control13,
def forward(
self,
latent,
timestep,
text_embedding,
guidance_scale,
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
control13,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
db_res_samples = tuple(
[
control1,
control2,
control3,
control4,
control5,
control6,
control7,
control8,
control9,
control10,
control11,
control12,
]
)
mb_res_samples = control13
latents = torch.cat([latent] * 2)
unet_out = self.unet.forward(
@@ -342,7 +415,25 @@ class SharkifyStableDiffusionModel:
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
input_mask = [
True,
True,
True,
False,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
unet,
inputs,
@@ -386,16 +477,23 @@ class SharkifyStableDiffusionModel:
stencil_image = torch.cat(
[stencil_image_input] * 2
) # needs to be same as controlledUNET latents
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
(
down_block_res_samples,
mid_block_res_sample,
) = self.cnet.forward(
latents,
timestep,
encoder_hidden_states=text_embedding,
controlnet_cond=stencil_image,
return_dict=False,
)
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
return tuple(
list(down_block_res_samples) + [mid_block_res_sample]
)
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
scnet = StencilControlNetModel(
low_cpu_mem_usage=self.low_cpu_mem_usage
)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["stencil_adaptor"])
@@ -417,7 +515,12 @@ class SharkifyStableDiffusionModel:
def get_unet(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -428,15 +531,24 @@ class SharkifyStableDiffusionModel:
update_lora_weight(self.unet, use_lora, "unet")
self.in_channels = self.unet.config.in_channels
self.train(False)
if(args.attention_slicing is not None and args.attention_slicing != "none"):
if(args.attention_slicing.isdigit()):
self.unet.set_attention_slice(int(args.attention_slicing))
if (
args.attention_slicing is not None
and args.attention_slicing != "none"
):
if args.attention_slicing.isdigit():
self.unet.set_attention_slice(
int(args.attention_slicing)
)
else:
self.unet.set_attention_slice(args.attention_slicing)
# TODO: Instead of flattening the `control` try to use the list.
def forward(
self, latent, timestep, text_embedding, guidance_scale,
self,
latent,
timestep,
text_embedding,
guidance_scale,
):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latents = torch.cat([latent] * 2)
@@ -452,16 +564,22 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if(use_large):
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (inputs[0],
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet512"])
inputs[3],
)
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet512"]
)
else:
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
save_dir = os.path.join(
self.sharktank_dir, self.model_name["unet"]
)
input_mask = [True, True, True, False]
if self.debug:
os.makedirs(
@@ -489,7 +607,9 @@ class SharkifyStableDiffusionModel:
def get_unet_upscaler(self, use_large=False):
class UnetModel(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
def __init__(
self, model_id=self.model_id, low_cpu_mem_usage=False
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
model_id,
@@ -512,13 +632,15 @@ class SharkifyStableDiffusionModel:
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
is_f16 = True if self.precision == "fp16" else False
inputs = tuple(self.inputs["unet"])
if(use_large):
if use_large:
pad = (0, 0) * (len(inputs[2].shape) - 2)
pad = pad + (0, 512 - inputs[2].shape[1])
inputs = (inputs[0],
inputs = (
inputs[0],
inputs[1],
torch.nn.functional.pad(inputs[2], pad),
inputs[3])
inputs[3],
)
input_mask = [True, True, True, False]
model_name = "unet512" if use_large else "unet"
shark_unet, unet_mlir = compile_through_fx(
@@ -538,7 +660,12 @@ class SharkifyStableDiffusionModel:
def get_clip(self):
class CLIPText(torch.nn.Module):
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
def __init__(
self,
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
model_id,
@@ -546,7 +673,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
)
def forward(self, input):
return self.text_encoder(input)[0]
@@ -585,16 +714,24 @@ class SharkifyStableDiffusionModel:
vae_checkpoint = None
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
if custom_vae.endswith(".ckpt"):
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
vae_checkpoint = torch.load(
self.custom_vae, map_location="cpu"
)
else:
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
vae_checkpoint = safetensors.torch.load_file(
self.custom_vae, device="cpu"
)
if "state_dict" in vae_checkpoint:
vae_checkpoint = vae_checkpoint["state_dict"]
try:
vae_checkpoint = convert_original_vae(vae_checkpoint)
finally:
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
vae_dict = {
k: v
for k, v in vae_checkpoint.items()
if k[0:4] != "loss" and k not in vae_ignore_keys
}
return vae_dict
def compile_unet_variants(self, model, use_large=False):
@@ -603,7 +740,10 @@ class SharkifyStableDiffusionModel:
return self.get_unet_upscaler(use_large=use_large)
# TODO: Plug the experimental "int8" support at right place.
elif self.use_quantize == "int8":
from apps.stable_diffusion.src.models.opt_params import get_unet
from apps.stable_diffusion.src.models.opt_params import (
get_unet,
)
return get_unet()
else:
return self.get_unet(use_large=use_large)
@@ -612,7 +752,9 @@ class SharkifyStableDiffusionModel:
def vae_encode(self):
try:
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
self.inputs["vae_encode"] = self.get_input_info_for(
base_models["vae_encode"]
)
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
check_compilation(compiled_vae_encode, "Vae Encode")
@@ -641,18 +783,28 @@ class SharkifyStableDiffusionModel:
unet_inputs = base_models[model]
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
else:
for model_id in unet_inputs:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
)
try:
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
compiled_unet, unet_mlir = self.compile_unet_variants(
model, use_large=use_large
)
except Exception as e:
print(e)
print("Retrying with a different base model configuration")
print(
"Retrying with a different base model configuration"
)
continue
# -- Once a successful compilation has taken place we'd want to store
@@ -675,7 +827,11 @@ class SharkifyStableDiffusionModel:
def vae(self):
try:
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
vae_input = (
base_models["vae"]["vae_upscaler"]
if self.is_upscaler
else base_models["vae"]["vae"]
)
self.inputs["vae"] = self.get_input_info_for(vae_input)
is_base_vae = self.base_vae
@@ -693,7 +849,9 @@ class SharkifyStableDiffusionModel:
def controlnet(self):
try:
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
self.inputs["stencil_adaptor"] = self.get_input_info_for(
base_models["stencil_adaptor"]
)
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
check_compilation(compiled_stencil_adaptor, "Stencil")

View File

@@ -17,9 +17,13 @@ hf_model_variant_map = {
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
"stabilityai/stable-diffusion-2-inpainting": [
"stablediffusion",
"inpaint_v2",
],
}
# TODO: Add the quantized model as a part model_db.json.
# This is currently in experimental phase.
def get_quantize_model():
@@ -27,9 +31,12 @@ def get_quantize_model():
model_key = "unet_int8"
iree_flags = get_opt_flags("unet", precision="fp16")
if args.height != 512 and args.width != 512 and args.max_length != 77:
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
sys.exit(
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
)
return bucket_key, model_key, iree_flags
def get_variant_version(hf_model_id):
return hf_model_variant_map[hf_model_id]

View File

@@ -14,4 +14,4 @@ build-backend = "setuptools.build_meta"
[tool.black]
line-length = 79
include = '\.pyi?$'
exclude = "apps/language_models/scripts/vicuna.py"