Add the clip text shark_model. (#458)

This commit is contained in:
Prashant Kumar
2022-11-02 12:38:33 +05:30
committed by GitHub
parent 06ccfb0533
commit a081733a42
2 changed files with 34 additions and 5 deletions

View File

@@ -10,6 +10,7 @@ from model_wrappers import (
get_vae16,
get_unet16_wrapped,
get_unet32_wrapped,
get_clipped_text,
)
from utils import get_shark_model
import time
@@ -132,6 +133,13 @@ if __name__ == "__main__":
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
clip_model = "clip_text"
clip_extra_args = [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",
]
clip = get_shark_model(GCLOUD_BUCKET, clip_model, clip_extra_args)
prompt = args.prompts
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
@@ -149,9 +157,6 @@ if __name__ == "__main__":
vae, unet = get_models()
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
scheduler = LMSDiscreteScheduler(
beta_start=0.00085,
@@ -170,7 +175,8 @@ if __name__ == "__main__":
return_tensors="pt",
)
text_embeddings = text_encoder(text_input.input_ids)[0].to(dtype)
text_embeddings = clip.forward((text_input.input_ids,))
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[""] * batch_size,
@@ -178,7 +184,8 @@ if __name__ == "__main__":
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = text_encoder(uncond_input.input_ids)[0].to(dtype)
uncond_embeddings = clip.forward((uncond_input.input_ids,))
uncond_embeddings = torch.from_numpy(uncond_embeddings).to(dtype)
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

View File

@@ -1,4 +1,5 @@
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from transformers import CLIPTextModel
from utils import compile_through_fx
from stable_args import args
import torch
@@ -9,6 +10,27 @@ YOUR_TOKEN = "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"
BATCH_SIZE = len(args.prompts)
def get_clipped_text(model_name="clip_text"):
class CLIPText(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14"
)
def forward(self, input):
return self.text_encoder(input)[0]
clip_model = CLIPText()
clip_input = torch.randint(1, 2, (BATCH_SIZE, 77))
shark_clip = compile_through_fx(
clip_model,
(clip_input,),
model_name=model_name,
)
return shark_clip
def get_vae32(model_name="vae_fp32"):
class VaeModel(torch.nn.Module):
def __init__(self):