Add width and height support for the scheduler.

This commit is contained in:
Prashant Kumar
2023-01-22 17:04:26 +00:00
parent f09f217478
commit b3ab0a1843

View File

@@ -17,8 +17,8 @@ SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
model_input = {
"euler": {
"latent": torch.randn(1, 4, 64, 64),
"output": torch.randn(1, 4, 64, 64),
"latent": torch.randn(1, 4, args.height // 8, args.width // 8),
"output": torch.randn(1, 4, args.height // 8, args.width // 8),
"sigma": torch.tensor(1).to(torch.float32),
"dt": torch.tensor(1).to(torch.float32),
},
@@ -84,7 +84,8 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.scaling_model = compile_through_fx(
scaling_model,
(example_latent, example_sigma),
model_name="euler_scale_model_input_" + args.precision,
model_name=f"euler_scale_model_input_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
@@ -92,7 +93,8 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
self.step_model = compile_through_fx(
step_model,
(example_output, example_sigma, example_latent, example_dt),
model_name="euler_step_" + args.precision,
model_name=f"euler_step_{args.height}_{args.width}"
+ args.precision,
extra_args=iree_flags,
)
else: