diff --git a/apps/shark_studio/modules/ckpt_processing.py b/apps/shark_studio/modules/ckpt_processing.py index b1d75806..9f75c8dd 100644 --- a/apps/shark_studio/modules/ckpt_processing.py +++ b/apps/shark_studio/modules/ckpt_processing.py @@ -64,11 +64,14 @@ def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False): return path_to_diffusers -def save_irpa(weights_path, prepend_str): +def save_irpa(weights_path, prepend_str, remove_str=None): weights = safetensors.torch.load_file(weights_path) archive = ParameterArchiveBuilder() for key in weights.keys(): - new_key = prepend_str + key + if remove_str: + new_key = key.replace(remove_str, prepend_str) + else: + new_key = prepend_str + key archive.add_tensor(new_key, weights[key]) if "safetensors" in weights_path: diff --git a/apps/shark_studio/tools/params_prefixer.py b/apps/shark_studio/tools/params_prefixer.py index 33bcfb89..313621e0 100644 --- a/apps/shark_studio/tools/params_prefixer.py +++ b/apps/shark_studio/tools/params_prefixer.py @@ -15,6 +15,12 @@ parser.add_argument( default="", help="prefix to add to all the keys in the irpa", ) +parser.add_argument( + "--replace", + type=str, + default=None, + help="prefix to be removed" +) args = parser.parse_args() -output_file = save_irpa(args.input, args.prefix) +output_file = save_irpa(args.input, args.prefix, args.replace) print("saved irpa to", output_file, "with prefix", args.prefix)