mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD][CLI] Add a warmup phase (#670)
This commit is contained in:
@@ -118,6 +118,19 @@ if __name__ == "__main__":
|
||||
subfolder="scheduler",
|
||||
)
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, 4, height // 8, width // 8),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
# Warmup phase to improve performance.
|
||||
if args.warmup_count >= 1:
|
||||
vae_warmup_input = torch.clone(latents).detach().numpy()
|
||||
clip_warmup_input = torch.randint(1, 2, (2, 77))
|
||||
for i in range(args.warmup_count):
|
||||
vae.forward((vae_warmup_input,))
|
||||
clip.forward((clip_warmup_input,))
|
||||
|
||||
start = time.time()
|
||||
|
||||
text_input = tokenizer(
|
||||
|
||||
@@ -178,4 +178,11 @@ p.add_argument(
|
||||
help="flag for hiding the details of iteration/sec for each step.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--warmup_count",
|
||||
type=int,
|
||||
default=0,
|
||||
help="flag setting warmup count for clip and vae [>= 0].",
|
||||
)
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user