mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
flux set model path in args (#7660)
in addition to default downloading through fetch, add an arg to pass model path directly
This commit is contained in:
@@ -326,11 +326,12 @@ class Flux:
|
||||
return self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
# https://github.com/black-forest-labs/flux/blob/main/src/flux/util.py
|
||||
def load_flow_model(name:str):
|
||||
def load_flow_model(name:str, model_path:str):
|
||||
# Loading Flux
|
||||
print("Init model")
|
||||
model = Flux(guidance_embed=(name != "flux-schnell"))
|
||||
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(fetch(urls[name])).items()}
|
||||
if not model_path: model_path = fetch(urls[name])
|
||||
state_dict = {k.replace("scale", "weight"): v for k, v in safe_load(model_path).items()}
|
||||
load_state_dict(model, state_dict)
|
||||
return model
|
||||
|
||||
@@ -420,6 +421,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run Flux.1", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
|
||||
parser.add_argument("--name", type=str, default="flux-schnell", help="Name of the model to load")
|
||||
parser.add_argument("--model_path", type=str, default="", help="path of the model file")
|
||||
parser.add_argument("--width", type=int, default=512, help="width of the sample in pixels (should be a multiple of 16)")
|
||||
parser.add_argument("--height", type=int, default=512, help="height of the sample in pixels (should be a multiple of 16)")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Set a seed for sampling")
|
||||
@@ -461,7 +463,7 @@ if __name__ == "__main__":
|
||||
del T5, clip
|
||||
|
||||
# load model
|
||||
model = load_flow_model(args.name)
|
||||
model = load_flow_model(args.name, args.model_path)
|
||||
|
||||
# denoise initial noise
|
||||
x = denoise(model, **inp, timesteps=timesteps, guidance=args.guidance)
|
||||
|
||||
Reference in New Issue
Block a user