mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fix DiskDevice reuse (#11039)
* fix DiskDevice reuse * fix mypy and DiskDevice.count * mypy * add test --------- Co-authored-by: b1tg <b1tg@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import pathlib, tempfile, unittest
|
||||
import os, pathlib, tempfile, unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.dtype import DType
|
||||
@@ -410,5 +410,13 @@ class TestPathTensor(unittest.TestCase):
|
||||
self.assertEqual(t_cpu.device, "CPU")
|
||||
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_disk_device_bug(self):
|
||||
test_file = pathlib.Path(self.temp_dir.name) / "disk_device_bug"
|
||||
with open(test_file, "wb") as f: f.write(bytes(range(10)))
|
||||
os.chmod(test_file, 0o000)
|
||||
with self.assertRaises(PermissionError):
|
||||
Tensor(pathlib.Path(test_file)).tolist()
|
||||
os.chmod(test_file, 0o644)
|
||||
assert Tensor(pathlib.Path(test_file)).tolist(), list(range(10))
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -16,10 +16,11 @@ class DiskDevice(Compiled):
|
||||
self.fd: Optional[int] = None
|
||||
self.count = 0
|
||||
super().__init__(device, DiskAllocator(self), None, None, None)
|
||||
def _might_open(self, size):
|
||||
self.count += 1
|
||||
def _might_open(self, size:int):
|
||||
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
|
||||
if self.size is not None: return
|
||||
if self.size is not None and hasattr(self.device, "mem"):
|
||||
self.count += 1
|
||||
return
|
||||
filename = self.device[len("disk:"):]
|
||||
self.size = size
|
||||
|
||||
@@ -34,6 +35,7 @@ class DiskDevice(Compiled):
|
||||
self.mem = mmap.mmap(self.fd, self.size)
|
||||
if hasattr(self.mem, 'madvise') and (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
|
||||
with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled
|
||||
self.count += 1
|
||||
def _might_close(self):
|
||||
self.count -= 1
|
||||
if self.count == 0:
|
||||
|
||||
Reference in New Issue
Block a user