mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
99 lines
2.6 KiB
Python
99 lines
2.6 KiB
Python
from diffusers import AutoencoderKL, UNet2DConditionModel
|
|
from transformers import CLIPTextModel
|
|
from utils import compile_through_fx
|
|
import torch
|
|
|
|
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
|
|
|
model_input = {
|
|
"clip": (torch.randint(1, 2, (1, 77)),),
|
|
"vae": (torch.randn(1, 4, 128, 128),),
|
|
"unet": (
|
|
torch.randn(2, 7, 128, 128), # latents
|
|
torch.tensor([1]).to(torch.float32), # timestep
|
|
torch.randn(2, 77, 1024), # embedding
|
|
torch.randn(2).to(torch.int64), # noise_level
|
|
),
|
|
}
|
|
|
|
|
|
def get_clip_mlir(model_name="clip_text", extra_args=[]):
|
|
text_encoder = CLIPTextModel.from_pretrained(
|
|
model_id,
|
|
subfolder="text_encoder",
|
|
)
|
|
|
|
class CLIPText(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.text_encoder = text_encoder
|
|
|
|
def forward(self, input):
|
|
return self.text_encoder(input)[0]
|
|
|
|
clip_model = CLIPText()
|
|
amdshark_clip = compile_through_fx(
|
|
clip_model,
|
|
model_input["clip"],
|
|
model_name=model_name,
|
|
extra_args=extra_args,
|
|
)
|
|
return amdshark_clip
|
|
|
|
|
|
def get_vae_mlir(model_name="vae", extra_args=[]):
|
|
class VaeModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.vae = AutoencoderKL.from_pretrained(
|
|
model_id,
|
|
subfolder="vae",
|
|
)
|
|
|
|
def forward(self, input):
|
|
x = self.vae.decode(input, return_dict=False)[0]
|
|
return x
|
|
|
|
vae = VaeModel()
|
|
amdshark_vae = compile_through_fx(
|
|
vae,
|
|
model_input["vae"],
|
|
model_name=model_name,
|
|
extra_args=extra_args,
|
|
)
|
|
return amdshark_vae
|
|
|
|
|
|
def get_unet_mlir(model_name="unet", extra_args=[]):
|
|
class UnetModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.unet = UNet2DConditionModel.from_pretrained(
|
|
model_id,
|
|
subfolder="unet",
|
|
)
|
|
self.in_channels = self.unet.in_channels
|
|
self.train(False)
|
|
|
|
def forward(self, latent, timestep, text_embedding, noise_level):
|
|
unet_out = self.unet.forward(
|
|
latent,
|
|
timestep,
|
|
text_embedding,
|
|
noise_level,
|
|
return_dict=False,
|
|
)[0]
|
|
return unet_out
|
|
|
|
unet = UnetModel()
|
|
f16_input_mask = (True, True, True, False)
|
|
amdshark_unet = compile_through_fx(
|
|
unet,
|
|
model_input["unet"],
|
|
model_name=model_name,
|
|
is_f16=True,
|
|
f16_input_mask=f16_input_mask,
|
|
extra_args=extra_args,
|
|
)
|
|
return amdshark_unet
|