mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Update the models to 8th Dec version.
This commit is contained in:
@@ -8,6 +8,7 @@ BATCH_SIZE = len(args.prompts)
|
||||
|
||||
model_config = {
|
||||
"v2": "stabilityai/stable-diffusion-2",
|
||||
"v2.1base": "stabilityai/stable-diffusion-2-1-base",
|
||||
"v1.4": "CompVis/stable-diffusion-v1-4",
|
||||
}
|
||||
|
||||
@@ -22,6 +23,16 @@ model_input = {
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v2.1base": {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
"unet": (
|
||||
torch.randn(1, 4, 64, 64), # latents
|
||||
torch.tensor([1]).to(torch.float32), # timestep
|
||||
torch.randn(2, 77, 1024), # embedding
|
||||
torch.tensor(1).to(torch.float32), # guidance_scale
|
||||
),
|
||||
},
|
||||
"v1.4": {
|
||||
"clip": (torch.randint(1, 2, (1, 77)),),
|
||||
"vae": (torch.randn(1, 4, 64, 64),),
|
||||
|
||||
@@ -27,7 +27,7 @@ def get_unet():
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
else:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "unet_1dec_fp16"
|
||||
model_name = "unet_8dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "unet2base_8dec_fp16"
|
||||
iree_flags += [
|
||||
@@ -76,7 +76,7 @@ def get_vae():
|
||||
)
|
||||
if args.precision in ["fp16", "int8"]:
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "vae_1dec_fp16"
|
||||
model_name = "vae_8dec_fp16"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "vae2base_8dec_fp16"
|
||||
iree_flags += [
|
||||
@@ -141,7 +141,7 @@ def get_clip():
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
bucket = "gs://shark_tank/stable_diffusion"
|
||||
model_name = "clip_1dec_fp32"
|
||||
model_name = "clip_8dec_fp32"
|
||||
if args.version == "v2.1base":
|
||||
model_name = "clip2base_8dec_fp32"
|
||||
iree_flags += [
|
||||
|
||||
@@ -265,6 +265,7 @@ def import_with_fx(model, inputs, debug=False):
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
]
|
||||
),
|
||||
)(*inputs)
|
||||
|
||||
Reference in New Issue
Block a user