flux set model path in args (#7660)

in addition to default downloading through fetch, add an arg to pass model path directly
This commit is contained in:
chenyu
2024-11-12 22:11:40 -05:00
committed by GitHub
parent 08706c2ea4
commit 4c5f7ddf1f

View File

@@ -326,11 +326,12 @@ class Flux:
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py
def load_flow_model(name:str):
def load_flow_model(name:str, model_path:str):
# Loading Flux
print("Init model")
model = Flux(guidance_embed=(name != "flux-schnell"))
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(fetch(urls[name])).items()}
if not model_path: model_path = fetch(urls[name])
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(model_path).items()}
load_state_dict(model, state_dict)
return model
@@ -420,6 +421,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load")
parser.add_argument("--model_path", type=str, default="", help="path of the model file")
parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)")
parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)")
parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling")
@@ -461,7 +463,7 @@ if __name__ == "__main__":
del T5, clip
# load model
model = load_flow_model(args.name)
model = load_flow_model(args.name, args.model_path)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)