Add a comment for why we're converting scale tensors in flux models to bfloat16

This commit is contained in:
Brandon Rising
2024-09-03 16:55:53 -04:00
committed by Brandon
parent d57ba1ed8b
commit d10d258213

View File

@@ -143,6 +143,8 @@ def convert_bundle_to_flux_transformer_checkpoint(
if not k.startswith("model.diffusion_model"):
continue
if k.endswith("scale"):
# Scale math must be done at bfloat16 due to our current flux model
# support limitations at inference time
v = v.to(dtype=torch.bfloat16)
original_state_dict[k.replace("model.diffusion_model.", "")] = v
return original_state_dict