mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
Add functionality to params prefixer.
This commit is contained in:
@@ -64,11 +64,14 @@ def preprocessCKPT(custom_weights, precision="fp16", is_inpaint=False):
|
||||
return path_to_diffusers
|
||||
|
||||
|
||||
def save_irpa(weights_path, prepend_str):
|
||||
def save_irpa(weights_path, prepend_str, remove_str=None):
|
||||
weights = safetensors.torch.load_file(weights_path)
|
||||
archive = ParameterArchiveBuilder()
|
||||
for key in weights.keys():
|
||||
new_key = prepend_str + key
|
||||
if remove_str:
|
||||
new_key = key.replace(remove_str, prepend_str)
|
||||
else:
|
||||
new_key = prepend_str + key
|
||||
archive.add_tensor(new_key, weights[key])
|
||||
|
||||
if "safetensors" in weights_path:
|
||||
|
||||
@@ -15,6 +15,12 @@ parser.add_argument(
|
||||
default="",
|
||||
help="prefix to add to all the keys in the irpa",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--replace",
|
||||
type=str,
|
||||
default=None,
|
||||
help="prefix to be removed"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
output_file = save_irpa(args.input, args.prefix)
|
||||
output_file = save_irpa(args.input, args.prefix, args.replace)
|
||||
print("saved irpa to", output_file, "with prefix", args.prefix)
|
||||
|
||||
Reference in New Issue
Block a user