diff --git a/examples/flux1.py b/examples/flux1.py index 66d44823cb..b0fcff13b1 100644 --- a/examples/flux1.py +++ b/examples/flux1.py @@ -326,11 +326,12 @@ class Flux: return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) # https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py -def load_flow_model(name:str): +def load_flow_model(name:str, model_path:str): # Loading Flux print("Init model") model = Flux(guidance_embed=(name != "flux-schnell")) - state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(fetch(urls[name])).items()} + if not model_path: model_path = fetch(urls[name]) + state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(model_path).items()} load_state_dict(model, state_dict) return model @@ -420,6 +421,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load") + parser.add_argument("--model_path", type=str, default="", help="path of the model file") parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)") parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)") parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling") @@ -461,7 +463,7 @@ if __name__ == "__main__": del T5, clip # load model - model = load_flow_model(args.name) + model = load_flow_model(args.name, args.model_path) # denoise initial noise x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)