mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 13:57:54 -05:00
[Vicuna] Revert the formatting for Brevitas op (#1626)
-- This commit reverts the formatting for Brevitas op. -- It also excludes vicuna.py script from `black` formatter. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
2
.flake8
2
.flake8
@@ -2,4 +2,4 @@
|
||||
count = 1
|
||||
show-source = 1
|
||||
select = E9,F63,F7,F82
|
||||
exclude = lit.cfg.py
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -189,6 +189,3 @@ apps/stable_diffusion/web/models/
|
||||
|
||||
# Stencil annotators.
|
||||
stencil_annotator/
|
||||
|
||||
# brevitas custom op lib
|
||||
apps/language_models/scripts/vicuna.py
|
||||
|
||||
@@ -103,14 +103,7 @@ parser.add_argument(
|
||||
)
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
@@ -119,30 +112,20 @@ def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
|
||||
|
||||
class ShardedVicuna(SharkLLMBase):
|
||||
|
||||
@@ -45,6 +45,7 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
new_shape.append(width * mul_val)
|
||||
elif "/" in shape[i]:
|
||||
import math
|
||||
|
||||
div_val = int(shape[i].split("/")[1])
|
||||
if "batch_size" in shape[i]:
|
||||
new_shape.append(math.ceil(batch_size / div_val))
|
||||
@@ -59,7 +60,9 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
|
||||
def check_compilation(model, model_name):
|
||||
if not model:
|
||||
raise Exception(f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues")
|
||||
raise Exception(
|
||||
f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
|
||||
|
||||
class SharkifyStableDiffusionModel:
|
||||
@@ -97,16 +100,22 @@ class SharkifyStableDiffusionModel:
|
||||
if "civitai" in custom_weights:
|
||||
weights_id = custom_weights.split("/")[-1]
|
||||
# TODO: use model name and identify file type by civitai rest api
|
||||
weights_path = str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
weights_path = (
|
||||
str(Path.cwd()) + "/models/" + weights_id + ".safetensors"
|
||||
)
|
||||
if not os.path.isfile(weights_path):
|
||||
subprocess.run(["wget", custom_weights, "-O", weights_path])
|
||||
subprocess.run(
|
||||
["wget", custom_weights, "-O", weights_path]
|
||||
)
|
||||
custom_weights = get_path_to_diffusers_checkpoint(weights_path)
|
||||
self.custom_weights = weights_path
|
||||
else:
|
||||
assert custom_weights.lower().endswith(
|
||||
(".ckpt", ".safetensors")
|
||||
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
|
||||
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
|
||||
custom_weights = get_path_to_diffusers_checkpoint(
|
||||
custom_weights
|
||||
)
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
@@ -126,7 +135,7 @@ class SharkifyStableDiffusionModel:
|
||||
+ "_"
|
||||
+ precision
|
||||
)
|
||||
print(f'use_tuned? sharkify: {use_tuned}')
|
||||
print(f"use_tuned? sharkify: {use_tuned}")
|
||||
self.use_tuned = use_tuned
|
||||
if use_tuned:
|
||||
self.model_name = self.model_name + "_tuned"
|
||||
@@ -163,14 +172,24 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
model_name = {}
|
||||
sub_model_list = ["clip", "unet", "unet512", "stencil_unet", "vae", "vae_encode", "stencil_adaptor"]
|
||||
sub_model_list = [
|
||||
"clip",
|
||||
"unet",
|
||||
"unet512",
|
||||
"stencil_unet",
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
]
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
model_config = self.model_name
|
||||
if "vae" == model:
|
||||
if self.custom_vae != "":
|
||||
model_config = model_config + get_path_stem(self.custom_vae)
|
||||
model_config = model_config + get_path_stem(
|
||||
self.custom_vae
|
||||
)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
@@ -197,7 +216,11 @@ class SharkifyStableDiffusionModel:
|
||||
tensor = None
|
||||
if isinstance(shape, list):
|
||||
clean_shape = replace_shape_str(
|
||||
shape, self.max_len, self.width, self.height, self.batch_size
|
||||
shape,
|
||||
self.max_len,
|
||||
self.width,
|
||||
self.height,
|
||||
self.batch_size,
|
||||
)
|
||||
if dtype == torch.int64:
|
||||
tensor = torch.randint(1, 3, tuple(clean_shape))
|
||||
@@ -209,10 +232,12 @@ class SharkifyStableDiffusionModel:
|
||||
sys.exit("shape isn't specified correctly.")
|
||||
input_map.append(tensor)
|
||||
return input_map
|
||||
|
||||
|
||||
def get_vae_encode(self):
|
||||
class VaeEncodeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
@@ -226,7 +251,11 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
vae_encode = VaeEncodeModel()
|
||||
inputs = tuple(self.inputs["vae_encode"])
|
||||
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
|
||||
is_f16 = (
|
||||
True
|
||||
if not self.is_upscaler and self.precision == "fp16"
|
||||
else False
|
||||
)
|
||||
shark_vae_encode, vae_encode_mlir = compile_through_fx(
|
||||
vae_encode,
|
||||
inputs,
|
||||
@@ -243,7 +272,13 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_vae(self):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, base_vae=self.base_vae, custom_vae=self.custom_vae, low_cpu_mem_usage=False):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
base_vae=self.base_vae,
|
||||
custom_vae=self.custom_vae,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
@@ -279,7 +314,11 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
is_f16 = True if not self.is_upscaler and self.precision == "fp16" else False
|
||||
is_f16 = (
|
||||
True
|
||||
if not self.is_upscaler and self.precision == "fp16"
|
||||
else False
|
||||
)
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -303,7 +342,10 @@ class SharkifyStableDiffusionModel:
|
||||
def get_controlled_unet(self):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -316,12 +358,43 @@ class SharkifyStableDiffusionModel:
|
||||
self.in_channels = self.unet.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward( self, latent, timestep, text_embedding, guidance_scale, control1,
|
||||
control2, control3, control4, control5, control6, control7,
|
||||
control8, control9, control10, control11, control12, control13,
|
||||
def forward(
|
||||
self,
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
guidance_scale,
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
control13,
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
db_res_samples = tuple([ control1, control2, control3, control4, control5, control6, control7, control8, control9, control10, control11, control12,])
|
||||
db_res_samples = tuple(
|
||||
[
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
]
|
||||
)
|
||||
mb_res_samples = control13
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
@@ -342,7 +415,25 @@ class SharkifyStableDiffusionModel:
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
input_mask = [True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True,]
|
||||
input_mask = [
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
]
|
||||
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
@@ -386,16 +477,23 @@ class SharkifyStableDiffusionModel:
|
||||
stencil_image = torch.cat(
|
||||
[stencil_image_input] * 2
|
||||
) # needs to be same as controlledUNET latents
|
||||
down_block_res_samples, mid_block_res_sample = self.cnet.forward(
|
||||
(
|
||||
down_block_res_samples,
|
||||
mid_block_res_sample,
|
||||
) = self.cnet.forward(
|
||||
latents,
|
||||
timestep,
|
||||
encoder_hidden_states=text_embedding,
|
||||
controlnet_cond=stencil_image,
|
||||
return_dict=False,
|
||||
)
|
||||
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
|
||||
return tuple(
|
||||
list(down_block_res_samples) + [mid_block_res_sample]
|
||||
)
|
||||
|
||||
scnet = StencilControlNetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
scnet = StencilControlNetModel(
|
||||
low_cpu_mem_usage=self.low_cpu_mem_usage
|
||||
)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
@@ -417,7 +515,12 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_unet(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
@@ -428,15 +531,24 @@ class SharkifyStableDiffusionModel:
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
if(args.attention_slicing is not None and args.attention_slicing != "none"):
|
||||
if(args.attention_slicing.isdigit()):
|
||||
self.unet.set_attention_slice(int(args.attention_slicing))
|
||||
if (
|
||||
args.attention_slicing is not None
|
||||
and args.attention_slicing != "none"
|
||||
):
|
||||
if args.attention_slicing.isdigit():
|
||||
self.unet.set_attention_slice(
|
||||
int(args.attention_slicing)
|
||||
)
|
||||
else:
|
||||
self.unet.set_attention_slice(args.attention_slicing)
|
||||
|
||||
# TODO: Instead of flattening the `control` try to use the list.
|
||||
def forward(
|
||||
self, latent, timestep, text_embedding, guidance_scale,
|
||||
self,
|
||||
latent,
|
||||
timestep,
|
||||
text_embedding,
|
||||
guidance_scale,
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
latents = torch.cat([latent] * 2)
|
||||
@@ -452,16 +564,22 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if(use_large):
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (inputs[0],
|
||||
inputs = (
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet512"])
|
||||
inputs[3],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["unet512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["unet"]
|
||||
)
|
||||
input_mask = [True, True, True, False]
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
@@ -489,7 +607,9 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_unet_upscaler(self, use_large=False):
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False):
|
||||
def __init__(
|
||||
self, model_id=self.model_id, low_cpu_mem_usage=False
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
@@ -512,13 +632,15 @@ class SharkifyStableDiffusionModel:
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
if(use_large):
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
inputs = (inputs[0],
|
||||
inputs = (
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3])
|
||||
inputs[3],
|
||||
)
|
||||
input_mask = [True, True, True, False]
|
||||
model_name = "unet512" if use_large else "unet"
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
@@ -538,7 +660,12 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self, model_id=self.model_id, low_cpu_mem_usage=False, use_lora=self.use_lora):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
@@ -546,7 +673,9 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.text_encoder, use_lora, "text_encoder")
|
||||
update_lora_weight(
|
||||
self.text_encoder, use_lora, "text_encoder"
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
@@ -585,16 +714,24 @@ class SharkifyStableDiffusionModel:
|
||||
vae_checkpoint = None
|
||||
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
|
||||
if custom_vae.endswith(".ckpt"):
|
||||
vae_checkpoint = torch.load(self.custom_vae, map_location="cpu")
|
||||
vae_checkpoint = torch.load(
|
||||
self.custom_vae, map_location="cpu"
|
||||
)
|
||||
else:
|
||||
vae_checkpoint = safetensors.torch.load_file(self.custom_vae, device="cpu")
|
||||
vae_checkpoint = safetensors.torch.load_file(
|
||||
self.custom_vae, device="cpu"
|
||||
)
|
||||
if "state_dict" in vae_checkpoint:
|
||||
vae_checkpoint = vae_checkpoint["state_dict"]
|
||||
|
||||
try:
|
||||
vae_checkpoint = convert_original_vae(vae_checkpoint)
|
||||
finally:
|
||||
vae_dict = {k: v for k, v in vae_checkpoint.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
|
||||
vae_dict = {
|
||||
k: v
|
||||
for k, v in vae_checkpoint.items()
|
||||
if k[0:4] != "loss" and k not in vae_ignore_keys
|
||||
}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model, use_large=False):
|
||||
@@ -603,7 +740,10 @@ class SharkifyStableDiffusionModel:
|
||||
return self.get_unet_upscaler(use_large=use_large)
|
||||
# TODO: Plug the experimental "int8" support at right place.
|
||||
elif self.use_quantize == "int8":
|
||||
from apps.stable_diffusion.src.models.opt_params import get_unet
|
||||
from apps.stable_diffusion.src.models.opt_params import (
|
||||
get_unet,
|
||||
)
|
||||
|
||||
return get_unet()
|
||||
else:
|
||||
return self.get_unet(use_large=use_large)
|
||||
@@ -612,7 +752,9 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def vae_encode(self):
|
||||
try:
|
||||
self.inputs["vae_encode"] = self.get_input_info_for(base_models["vae_encode"])
|
||||
self.inputs["vae_encode"] = self.get_input_info_for(
|
||||
base_models["vae_encode"]
|
||||
)
|
||||
compiled_vae_encode, vae_encode_mlir = self.get_vae_encode()
|
||||
|
||||
check_compilation(compiled_vae_encode, "Vae Encode")
|
||||
@@ -641,18 +783,28 @@ class SharkifyStableDiffusionModel:
|
||||
unet_inputs = base_models[model]
|
||||
|
||||
if self.base_model_id != "":
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[self.base_model_id])
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
|
||||
self.inputs["unet"] = self.get_input_info_for(
|
||||
unet_inputs[self.base_model_id]
|
||||
)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
self.base_model_id = model_id
|
||||
self.inputs["unet"] = self.get_input_info_for(unet_inputs[model_id])
|
||||
self.inputs["unet"] = self.get_input_info_for(
|
||||
unet_inputs[model_id]
|
||||
)
|
||||
|
||||
try:
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(model, use_large=use_large)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("Retrying with a different base model configuration")
|
||||
print(
|
||||
"Retrying with a different base model configuration"
|
||||
)
|
||||
continue
|
||||
|
||||
# -- Once a successful compilation has taken place we'd want to store
|
||||
@@ -675,7 +827,11 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def vae(self):
|
||||
try:
|
||||
vae_input = base_models["vae"]["vae_upscaler"] if self.is_upscaler else base_models["vae"]["vae"]
|
||||
vae_input = (
|
||||
base_models["vae"]["vae_upscaler"]
|
||||
if self.is_upscaler
|
||||
else base_models["vae"]["vae"]
|
||||
)
|
||||
self.inputs["vae"] = self.get_input_info_for(vae_input)
|
||||
|
||||
is_base_vae = self.base_vae
|
||||
@@ -693,7 +849,9 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
def controlnet(self):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(base_models["stencil_adaptor"])
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net()
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
|
||||
@@ -17,9 +17,13 @@ hf_model_variant_map = {
|
||||
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
|
||||
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
|
||||
"runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": ["stablediffusion", "inpaint_v2"],
|
||||
"stabilityai/stable-diffusion-2-inpainting": [
|
||||
"stablediffusion",
|
||||
"inpaint_v2",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# TODO: Add the quantized model as a part model_db.json.
|
||||
# This is currently in experimental phase.
|
||||
def get_quantize_model():
|
||||
@@ -27,9 +31,12 @@ def get_quantize_model():
|
||||
model_key = "unet_int8"
|
||||
iree_flags = get_opt_flags("unet", precision="fp16")
|
||||
if args.height != 512 and args.width != 512 and args.max_length != 77:
|
||||
sys.exit("The int8 quantized model currently requires the height and width to be 512, and max_length to be 77")
|
||||
sys.exit(
|
||||
"The int8 quantized model currently requires the height and width to be 512, and max_length to be 77"
|
||||
)
|
||||
return bucket_key, model_key, iree_flags
|
||||
|
||||
|
||||
def get_variant_version(hf_model_id):
|
||||
return hf_model_variant_map[hf_model_id]
|
||||
|
||||
|
||||
@@ -14,4 +14,4 @@ build-backend = "setuptools.build_meta"
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
include = '\.pyi?$'
|
||||
|
||||
exclude = "apps/language_models/scripts/vicuna.py"
|
||||
|
||||
Reference in New Issue
Block a user