mirror of
https://github.com/lllyasviel/ControlNet.git
synced 2026-04-24 03:00:54 -04:00
safetensor
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -5,6 +5,7 @@ training/
|
||||
*.pth
|
||||
*.pt
|
||||
*.ckpt
|
||||
*.safetensors
|
||||
|
||||
my_fix.py
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
@@ -9,7 +10,13 @@ def get_state_dict(d):
|
||||
|
||||
|
||||
def load_state_dict(ckpt_path, location='cpu'):
|
||||
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
||||
_, extension = os.path.splitext(ckpt_path)
|
||||
if extension.lower() == ".safetensors":
|
||||
import safetensors.torch
|
||||
state_dict = safetensors.torch.load_file(ckpt_path, device=torch.device(location))
|
||||
else:
|
||||
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
||||
state_dict = get_state_dict(state_dict)
|
||||
print(f'Loaded state_dict from [{ckpt_path}]')
|
||||
return state_dict
|
||||
|
||||
|
||||
@@ -31,3 +31,4 @@ dependencies:
|
||||
- addict==2.4.0
|
||||
- yapf==0.32.0
|
||||
- prettytable==3.6.0
|
||||
- safetensors==0.2.7
|
||||
|
||||
Reference in New Issue
Block a user