mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
More readable torch_load ext check (#2853)
* more readable extension check * enable tarfile test * detach tensor if requires grad in torch
This commit is contained in:
@@ -15,6 +15,7 @@ def compare_weights_both(url):
|
||||
for k in tg_weights:
|
||||
if tg_weights[k].dtype == dtypes.bfloat16: tg_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
|
||||
if torch_weights[k].dtype == torch.bfloat16: torch_weights[k] = torch_weights[k].float() # numpy doesn't support bfloat16
|
||||
if torch_weights[k].requires_grad: torch_weights[k] = torch_weights[k].detach()
|
||||
np.testing.assert_equal(tg_weights[k].numpy(), torch_weights[k].numpy(), err_msg=f"mismatch at {k}, {tg_weights[k].shape}")
|
||||
print(f"compared {len(tg_weights)} weights")
|
||||
|
||||
@@ -34,7 +35,7 @@ class TestTorchLoad(unittest.TestCase):
|
||||
def test_load_llama2bfloat(self): compare_weights_both("https://huggingface.co/qazalin/bf16-lightweight/resolve/main/consolidated.00.pth?download=true")
|
||||
|
||||
# TODO: support pytorch tar format with minimal lines
|
||||
#def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
||||
def test_load_resnet(self): compare_weights_both('https://download.pytorch.org/models/resnet50-19c8e357.pth')
|
||||
|
||||
test_fn = pathlib.Path(__file__).parents[2] / "weights/LLaMA/7B/consolidated.00.pth"
|
||||
#test_size = test_fn.stat().st_size
|
||||
|
||||
@@ -111,7 +111,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
||||
def persistent_load(self, pid): return deserialized_objects[pid] if pid in deserialized_objects else pid
|
||||
|
||||
if tuple(t[0:2].numpy()) == (0x50, 0x4b):
|
||||
if zipfile.is_zipfile(fn):
|
||||
myzip = zipfile.ZipFile(fn, 'r')
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
for n in myzip.namelist():
|
||||
@@ -120,7 +120,7 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
return TorchPickle(myfile).load()
|
||||
elif bytes(t[0:0xe].numpy()) == b"././@PaxHeader": # TODO: is this how you detect a tarfile?
|
||||
elif tarfile.is_tarfile(fn):
|
||||
with tarfile.open(fn, "r") as tar:
|
||||
storages_offset = tar.getmember('storages').offset_data
|
||||
f = unwrap(tar.extractfile('storages'))
|
||||
|
||||
Reference in New Issue
Block a user