mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add width and height support for the scheduler.
This commit is contained in:
@@ -17,8 +17,8 @@ SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(1, 4, 64, 64),
|
||||
"output": torch.randn(1, 4, 64, 64),
|
||||
"latent": torch.randn(1, 4, args.height // 8, args.width // 8),
|
||||
"output": torch.randn(1, 4, args.height // 8, args.width // 8),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
},
|
||||
@@ -84,7 +84,8 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
self.scaling_model = compile_through_fx(
|
||||
scaling_model,
|
||||
(example_latent, example_sigma),
|
||||
model_name="euler_scale_model_input_" + args.precision,
|
||||
model_name=f"euler_scale_model_input_{args.height}_{args.width}"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
@@ -92,7 +93,8 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
self.step_model = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
model_name="euler_step_" + args.precision,
|
||||
model_name=f"euler_step_{args.height}_{args.width}"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user