fix DiskDevice reuse (#11039)

* fix DiskDevice reuse

* fix mypy and DiskDevice.count

* mypy

* add test

---------

Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
b1tg
2025-07-01 22:29:21 +08:00
committed by GitHub
parent 5628e2054c
commit fcbefde8f5
2 changed files with 14 additions and 4 deletions

View File

@@ -1,4 +1,4 @@
import pathlib, tempfile, unittest
import os, pathlib, tempfile, unittest
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
@@ -410,5 +410,13 @@ class TestPathTensor(unittest.TestCase):
self.assertEqual(t_cpu.device, "CPU")
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
def test_path_tensor_disk_device_bug(self):
test_file = pathlib.Path(self.temp_dir.name) / "disk_device_bug"
with open(test_file, "wb") as f: f.write(bytes(range(10)))
os.chmod(test_file, 0o000)
with self.assertRaises(PermissionError):
Tensor(pathlib.Path(test_file)).tolist()
os.chmod(test_file, 0o644)
assert Tensor(pathlib.Path(test_file)).tolist(), list(range(10))
if __name__ == "__main__":
unittest.main()

View File

@@ -16,10 +16,11 @@ class DiskDevice(Compiled):
self.fd: Optional[int] = None
self.count = 0
super().__init__(device, DiskAllocator(self), None, None, None)
def _might_open(self, size):
self.count += 1
def _might_open(self, size:int):
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
if self.size is not None: return
if self.size is not None and hasattr(self.device, "mem"):
self.count += 1
return
filename = self.device[len("disk:"):]
self.size = size
@@ -34,6 +35,7 @@ class DiskDevice(Compiled):
self.mem = mmap.mmap(self.fd, self.size)
if hasattr(self.mem, 'madvise') and (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled
self.count += 1
def _might_close(self):
self.count -= 1
if self.count == 0: