diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index 36f087221c..41b91299ba 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -383,6 +383,38 @@ class TestDiskTensor(unittest.TestCase): np.testing.assert_equal(t1.numpy(), np.arange(128, dtype=np.uint8)) np.testing.assert_equal(t2.numpy(), np.arange(64, dtype=np.uint8)) + def test_disk_open_failure_state(self): + from tinygrad.runtime.ops_disk import DiskDevice + fn = pathlib.Path(temp("dt_open_failure")) + fn.unlink(missing_ok=True) + fn.write_bytes(bytes(range(256))) + os.chmod(fn, 0o000) + try: + t = Tensor.empty(100, device=f"disk:{fn}", dtype=dtypes.uint8) + t.numpy() + except PermissionError: pass + # device state should be clean after failed open + disk_device = Device[f"DISK:{fn}"] + assert isinstance(disk_device, DiskDevice) + assert disk_device.size is None, "size should be None after failed open" + assert not hasattr(disk_device, "mem"), "mem should not exist after failed open" + # should be able to open with any size after failure + os.chmod(fn, 0o644) + t2 = Tensor.empty(200, device=f"disk:{fn}", dtype=dtypes.uint8) + t2.to("CPU").realize() + assert disk_device.size == 200 + + def test_disk_permission_error(self): + fn = pathlib.Path(temp("dt_permission")) + fn.unlink(missing_ok=True) + fn.write_bytes(bytes(range(256))) + os.chmod(fn, 0o000) + try: + with self.assertRaises(PermissionError): + Tensor.empty(100, device=f"disk:{fn}", dtype=dtypes.uint8).numpy() + finally: + os.chmod(fn, 0o644) + class TestPathTensor(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 2725b6d103..89475a1957 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -22,17 +22,17 @@ class DiskDevice(Compiled): self.count += 1 return filename = self.device[len("disk:"):] - self.size = size if sys.platform != "win32" and filename.startswith("shm:"): fd = _posixshmem.shm_open("/"+filename[4:].lstrip("/"), os.O_RDWR, 0o600) - self.mem = mmap.mmap(fd, self.size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED) + self.mem = mmap.mmap(fd, size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED) os.close(fd) else: try: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT|getattr(os, "O_DIRECT", 0)) except OSError: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT) - if not pathlib.Path(filename).is_block_device() and os.fstat(self.fd).st_size < self.size: os.ftruncate(self.fd, self.size) - self.mem = mmap.mmap(self.fd, self.size) + if not pathlib.Path(filename).is_block_device() and os.fstat(self.fd).st_size < size: os.ftruncate(self.fd, size) + self.mem = mmap.mmap(self.fd, size) + self.size = size if hasattr(self.mem, 'madvise') and (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None: with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled self.count += 1