mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
only set DiskDevice.size if it opens successfully (#13962)
This commit is contained in:
@@ -383,6 +383,38 @@ class TestDiskTensor(unittest.TestCase):
|
|||||||
np.testing.assert_equal(t1.numpy(), np.arange(128, dtype=np.uint8))
|
np.testing.assert_equal(t1.numpy(), np.arange(128, dtype=np.uint8))
|
||||||
np.testing.assert_equal(t2.numpy(), np.arange(64, dtype=np.uint8))
|
np.testing.assert_equal(t2.numpy(), np.arange(64, dtype=np.uint8))
|
||||||
|
|
||||||
|
def test_disk_open_failure_state(self):
|
||||||
|
from tinygrad.runtime.ops_disk import DiskDevice
|
||||||
|
fn = pathlib.Path(temp("dt_open_failure"))
|
||||||
|
fn.unlink(missing_ok=True)
|
||||||
|
fn.write_bytes(bytes(range(256)))
|
||||||
|
os.chmod(fn, 0o000)
|
||||||
|
try:
|
||||||
|
t = Tensor.empty(100, device=f"disk:{fn}", dtype=dtypes.uint8)
|
||||||
|
t.numpy()
|
||||||
|
except PermissionError: pass
|
||||||
|
# device state should be clean after failed open
|
||||||
|
disk_device = Device[f"DISK:{fn}"]
|
||||||
|
assert isinstance(disk_device, DiskDevice)
|
||||||
|
assert disk_device.size is None, "size should be None after failed open"
|
||||||
|
assert not hasattr(disk_device, "mem"), "mem should not exist after failed open"
|
||||||
|
# should be able to open with any size after failure
|
||||||
|
os.chmod(fn, 0o644)
|
||||||
|
t2 = Tensor.empty(200, device=f"disk:{fn}", dtype=dtypes.uint8)
|
||||||
|
t2.to("CPU").realize()
|
||||||
|
assert disk_device.size == 200
|
||||||
|
|
||||||
|
def test_disk_permission_error(self):
|
||||||
|
fn = pathlib.Path(temp("dt_permission"))
|
||||||
|
fn.unlink(missing_ok=True)
|
||||||
|
fn.write_bytes(bytes(range(256)))
|
||||||
|
os.chmod(fn, 0o000)
|
||||||
|
try:
|
||||||
|
with self.assertRaises(PermissionError):
|
||||||
|
Tensor.empty(100, device=f"disk:{fn}", dtype=dtypes.uint8).numpy()
|
||||||
|
finally:
|
||||||
|
os.chmod(fn, 0o644)
|
||||||
|
|
||||||
class TestPathTensor(unittest.TestCase):
|
class TestPathTensor(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.temp_dir = tempfile.TemporaryDirectory()
|
self.temp_dir = tempfile.TemporaryDirectory()
|
||||||
|
|||||||
@@ -22,17 +22,17 @@ class DiskDevice(Compiled):
|
|||||||
self.count += 1
|
self.count += 1
|
||||||
return
|
return
|
||||||
filename = self.device[len("disk:"):]
|
filename = self.device[len("disk:"):]
|
||||||
self.size = size
|
|
||||||
|
|
||||||
if sys.platform != "win32" and filename.startswith("shm:"):
|
if sys.platform != "win32" and filename.startswith("shm:"):
|
||||||
fd = _posixshmem.shm_open("/"+filename[4:].lstrip("/"), os.O_RDWR, 0o600)
|
fd = _posixshmem.shm_open("/"+filename[4:].lstrip("/"), os.O_RDWR, 0o600)
|
||||||
self.mem = mmap.mmap(fd, self.size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED)
|
self.mem = mmap.mmap(fd, size, mmap.MAP_SHARED | MAP_POPULATE | MAP_LOCKED)
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
else:
|
else:
|
||||||
try: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT|getattr(os, "O_DIRECT", 0))
|
try: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT|getattr(os, "O_DIRECT", 0))
|
||||||
except OSError: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT)
|
except OSError: self.fd = os.open(filename, os.O_RDWR|os.O_CREAT)
|
||||||
if not pathlib.Path(filename).is_block_device() and os.fstat(self.fd).st_size < self.size: os.ftruncate(self.fd, self.size)
|
if not pathlib.Path(filename).is_block_device() and os.fstat(self.fd).st_size < size: os.ftruncate(self.fd, size)
|
||||||
self.mem = mmap.mmap(self.fd, self.size)
|
self.mem = mmap.mmap(self.fd, size)
|
||||||
|
self.size = size
|
||||||
if hasattr(self.mem, 'madvise') and (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
|
if hasattr(self.mem, 'madvise') and (hp := getattr(mmap, "MADV_HUGEPAGE", None)) is not None:
|
||||||
with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled
|
with contextlib.suppress(OSError): self.mem.madvise(hp) # some systems have transparent_hugepage disabled
|
||||||
self.count += 1
|
self.count += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user