mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add the clip text shark_model. (#458)
This commit is contained in:
@@ -10,6 +10,7 @@ from model_wrappers import (
|
||||
get_vae16,
|
||||
get_unet16_wrapped,
|
||||
get_unet32_wrapped,
|
||||
get_clipped_text,
|
||||
)
|
||||
from utils import get_shark_model
|
||||
import time
|
||||
@@ -132,6 +133,13 @@ if __name__ == "__main__":
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
clip_model = "clip_text"
|
||||
clip_extra_args = [
|
||||
"--iree-flow-linalg-ops-padding-size=16",
|
||||
"--iree-flow-enable-padding-linalg-ops",
|
||||
]
|
||||
clip = get_shark_model(GCLOUD_BUCKET, clip_model, clip_extra_args)
|
||||
|
||||
prompt = args.prompts
|
||||
height = 512 # default height of Stable Diffusion
|
||||
width = 512 # default width of Stable Diffusion
|
||||
@@ -149,9 +157,6 @@ if __name__ == "__main__":
|
||||
vae, unet = get_models()
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
scheduler = LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
@@ -170,7 +175,8 @@ if __name__ == "__main__":
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_embeddings = text_encoder(text_input.input_ids)[0].to(dtype)
|
||||
text_embeddings = clip.forward((text_input.input_ids,))
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
max_length = text_input.input_ids.shape[-1]
|
||||
uncond_input = tokenizer(
|
||||
[""] * batch_size,
|
||||
@@ -178,7 +184,8 @@ if __name__ == "__main__":
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = text_encoder(uncond_input.input_ids)[0].to(dtype)
|
||||
uncond_embeddings = clip.forward((uncond_input.input_ids,))
|
||||
uncond_embeddings = torch.from_numpy(uncond_embeddings).to(dtype)
|
||||
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx
|
||||
from stable_args import args
|
||||
import torch
|
||||
@@ -9,6 +10,27 @@ YOUR_TOKEN = "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"
|
||||
BATCH_SIZE = len(args.prompts)
|
||||
|
||||
|
||||
def get_clipped_text(model_name="clip_text"):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
"openai/clip-vit-large-patch14"
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return self.text_encoder(input)[0]
|
||||
|
||||
clip_model = CLIPText()
|
||||
clip_input = torch.randint(1, 2, (BATCH_SIZE, 77))
|
||||
shark_clip = compile_through_fx(
|
||||
clip_model,
|
||||
(clip_input,),
|
||||
model_name=model_name,
|
||||
)
|
||||
return shark_clip
|
||||
|
||||
|
||||
def get_vae32(model_name="vae_fp32"):
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user