Update the models to 8th Dec version.

This commit is contained in:
Prashant Kumar
2022-12-13 12:08:26 +00:00
parent b03038222d
commit 4cb50a3d06
3 changed files with 15 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ BATCH_SIZE = len(args.prompts)
model_config = {
"v2": "stabilityai/stable-diffusion-2",
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
"v1.4": "CompVis/stable-diffusion-v1-4",
}
@@ -22,6 +23,16 @@ model_input = {
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v2.1base": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 64, 64),),
"unet": (
torch.randn(1, 4, 64, 64), # latents
torch.tensor([1]).to(torch.float32), # timestep
torch.randn(2, 77, 1024), # embedding
torch.tensor(1).to(torch.float32), # guidance_scale
),
},
"v1.4": {
"clip": (torch.randint(1, 2, (1, 77)),),
"vae": (torch.randn(1, 4, 64, 64),),

View File

@@ -27,7 +27,7 @@ def get_unet():
return get_shark_model(bucket, model_name, iree_flags)
else:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "unet_1dec_fp16"
model_name = "unet_8dec_fp16"
if args.version == "v2.1base":
model_name = "unet2base_8dec_fp16"
iree_flags += [
@@ -76,7 +76,7 @@ def get_vae():
)
if args.precision in ["fp16", "int8"]:
bucket = "gs://shark_tank/stable_diffusion"
model_name = "vae_1dec_fp16"
model_name = "vae_8dec_fp16"
if args.version == "v2.1base":
model_name = "vae2base_8dec_fp16"
iree_flags += [
@@ -141,7 +141,7 @@ def get_clip():
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
bucket = "gs://shark_tank/stable_diffusion"
model_name = "clip_1dec_fp32"
model_name = "clip_8dec_fp32"
if args.version == "v2.1base":
model_name = "clip2base_8dec_fp32"
iree_flags += [

View File

@@ -265,6 +265,7 @@ def import_with_fx(model, inputs, debug=False):
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
torch.ops.aten.native_layer_norm,
]
),
)(*inputs)