Delete all flux bundle state dict keys when extracting the transformer state dict

This commit is contained in:
Brandon Rising
2024-09-03 22:20:38 -04:00
committed by Brandon
parent d20335dabc
commit 33edee1ba6

View File

@@ -143,6 +143,7 @@ def convert_bundle_to_flux_transformer_checkpoint(
for k, v in transformer_state_dict.items():
if not k.startswith("model.diffusion_model"):
keys_to_remove.append(k) # This can be removed in the future if we only want to delete transformer keys
continue
if k.endswith("scale"):
# Scale math must be done at bfloat16 due to our current flux model