More readable torch_load ext check (#2853)

* more readable extension check

* enable tarfile test

* detach tensor if requires grad in torch
This commit is contained in:
Oleg Rybalko
2023-12-19 22:53:15 +03:00
committed by GitHub
parent 172a88e719
commit 42a038c83f
2 changed files with 4 additions and 3 deletions

View File

@@ -15,6 +15,7 @@ def compare_weights_both(url):
for k in tg_weights:
if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
if torch_weights[k].dtype == torch.bfloat16: torch_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
if torch_weights[k].requires_grad: torch_weights[k] = torch_weights[k].detach()
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
print(f"compared {len(tg_weights)} weights")
@@ -34,7 +35,7 @@ class TestTorchLoad(unittest.TestCase):
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
# TODO: support pytorch tar format with minimal lines
#def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth"
#test_size = test_fn.stat().st_size

View File

@@ -111,7 +111,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
return intercept[name] if module_root == "torch" else super().find_class(module, name)
def persistent_load(self, pid): return deserialized_objects[pid] if pid in deserialized_objects else pid
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
if zipfile.is_zipfile(fn):
myzip = zipfile.ZipFile(fn, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
for n in myzip.namelist():
@@ -120,7 +120,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
elif bytes(t[0:0xe].numpy()) == b"././@PaxHeader": # TODO: is this how you detect a tarfile?
elif tarfile.is_tarfile(fn):
with tarfile.open(fn, "r") as tar:
storages_offset = tar.getmember('storages').offset_data
f = unwrap(tar.extractfile('storages'))