Compare commits

...

1 Commits

View File

@@ -219,9 +219,6 @@ class FluxCheckpointModel(ModelLoader):
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
with accelerate.init_empty_weights():
model = Flux(params[config.config_path])
sd = load_file(model_path)
if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd:
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
@@ -230,6 +227,11 @@ class FluxCheckpointModel(ModelLoader):
for k in sd.keys():
# We need to cast to bfloat16 due to it being the only currently supported dtype for inference
sd[k] = sd[k].to(torch.bfloat16)
flux_params = infer_flux_params_from_state_dict(sd)
with accelerate.init_empty_weights():
model = Flux(flux_params)
model.load_state_dict(sd, assign=True)
return model