Tune the working memory estimate for FLUX VAE decoding.

This commit is contained in:
Ryan Dick
2025-01-02 12:31:26 -05:00
parent bd8017ecd5
commit 6a5cee61be

View File

@@ -42,12 +42,11 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
def _estimate_working_memory(self, latents: torch.Tensor, vae: AutoEncoder) -> int:
"""Estimate the working memory required by the invocation in bytes."""
# It was found experimentally that the peak working memory scales linearly with the number of pixels and the
# element size (precision). This estimate is accurate for both SD1 and SDXL.
# element size (precision).
out_h = LATENT_SCALE_FACTOR * latents.shape[-2]
out_w = LATENT_SCALE_FACTOR * latents.shape[-1]
element_size = next(vae.parameters()).element_size()
# TODO(ryand): Need to tune this value, it was copied from the SD1 implementation.
scaling_constant = 960 # Determined experimentally.
scaling_constant = 1090 # Determined experimentally.
working_memory = out_h * out_w * element_size * scaling_constant
return working_memory