From e94575d62e30740b022cd77ec54bd1bea7c1edf3 Mon Sep 17 00:00:00 2001 From: lvmin Date: Sat, 11 Feb 2023 19:04:11 -0800 Subject: [PATCH] safetensor --- .gitignore | 1 + cldm/model.py | 9 ++++++++- environment.yaml | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2df140f..5ae854a 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ training/ *.pth *.pt *.ckpt +*.safetensors my_fix.py diff --git a/cldm/model.py b/cldm/model.py index 2934ca6..51dbb5a 100644 --- a/cldm/model.py +++ b/cldm/model.py @@ -1,3 +1,4 @@ +import os import torch from omegaconf import OmegaConf @@ -9,7 +10,13 @@ def get_state_dict(d): def load_state_dict(ckpt_path, location='cpu'): - state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) + _, extension = os.path.splitext(ckpt_path) + if extension.lower() == ".safetensors": + import safetensors.torch + state_dict = safetensors.torch.load_file(ckpt_path, device=torch.device(location)) + else: + state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) + state_dict = get_state_dict(state_dict) print(f'Loaded state_dict from [{ckpt_path}]') return state_dict diff --git a/environment.yaml b/environment.yaml index 5bcaf15..2030a7c 100644 --- a/environment.yaml +++ b/environment.yaml @@ -31,3 +31,4 @@ dependencies: - addict==2.4.0 - yapf==0.32.0 - prettytable==3.6.0 + - safetensors==0.2.7