From 1df02f2b936ddaa06e145b606edd72861dffd636 Mon Sep 17 00:00:00 2001 From: lvmin Date: Sat, 11 Feb 2023 21:35:10 -0800 Subject: [PATCH] i --- .gitignore | 3 ++- tool_transfer_control.py | 46 ---------------------------------------- 2 files changed, 2 insertions(+), 47 deletions(-) delete mode 100644 tool_transfer_control.py diff --git a/.gitignore b/.gitignore index 5ae854a..330172c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,8 @@ training/ *.ckpt *.safetensors -my_fix.py +gradio_pose2image_private.py +tool_transfer_control.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/tool_transfer_control.py b/tool_transfer_control.py deleted file mode 100644 index df0a273..0000000 --- a/tool_transfer_control.py +++ /dev/null @@ -1,46 +0,0 @@ -path_sd15 = './models/v1-5-pruned.ckpt' -path_sd15_with_control = './models/control_sd15_openpose.pth' -path_input = './models/anything-v3-full.safetensors' -path_output = './models/control_any3_openpose.ckpt' - - -import os - - -assert os.path.exists(path_sd15), 'Input path_sd15 does not exists!' -assert os.path.exists(path_sd15_with_control), 'Input path_sd15_with_control does not exists!' -assert os.path.exists(path_input), 'Input path_input does not exists!' -assert os.path.exists(os.path.dirname(path_output)), 'Output folder not exists!' - - -import torch -from share import * -from cldm.model import load_state_dict - - -sd15_state_dict = load_state_dict(path_sd15) -sd15_with_control_state_dict = load_state_dict(path_sd15_with_control) -input_state_dict = load_state_dict(path_input) - - -def get_node_name(name, parent_name): - if len(name) <= len(parent_name): - return False, '' - p = name[:len(parent_name)] - if p != parent_name: - return False, '' - return True, name[len(parent_name):] - - -keys = sd15_with_control_state_dict.keys() - -final_state_dict = {} -for key in keys: - is_control, node_name = get_node_name(key, 'control_') - if is_control: - pass - else: - pass - -torch.save(final_state_dict, path_output) -print('Transferred model saved at ' + path_output)