mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
properly fix DiskDevice reuse (#13961)
This commit is contained in:
@@ -360,6 +360,29 @@ class TestDiskTensor(unittest.TestCase):
|
||||
x = Tensor.empty(size + len(test), dtype=dtypes.uint8, device=f"disk:{fn}").to("CPU").realize()
|
||||
assert x[size:].data().tobytes() == test
|
||||
|
||||
def test_disk_device_reuse(self):
|
||||
from tinygrad.runtime.ops_disk import DiskDevice
|
||||
fn = pathlib.Path(temp("dt_device_reuse"))
|
||||
fn.unlink(missing_ok=True)
|
||||
fn.write_bytes(bytes(range(256)))
|
||||
# create first tensor and realize it
|
||||
t1 = Tensor.empty(128, device=f"disk:{fn}", dtype=dtypes.uint8)
|
||||
t1.to("CPU").realize()
|
||||
# get the DiskDevice and check internal state
|
||||
disk_device = Device[f"DISK:{fn}"]
|
||||
assert isinstance(disk_device, DiskDevice)
|
||||
assert disk_device.count == 1
|
||||
assert hasattr(disk_device, "mem")
|
||||
first_fd = disk_device.fd
|
||||
# create second tensor on same file - should reuse the device, not re-open
|
||||
t2 = Tensor.empty(64, device=f"disk:{fn}", dtype=dtypes.uint8)
|
||||
t2.to("CPU").realize()
|
||||
assert disk_device.count == 2
|
||||
assert disk_device.fd == first_fd, "file descriptor changed - file was unnecessarily re-opened"
|
||||
# verify data is correct
|
||||
np.testing.assert_equal(t1.numpy(), np.arange(128, dtype=np.uint8))
|
||||
np.testing.assert_equal(t2.numpy(), np.arange(64, dtype=np.uint8))
|
||||
|
||||
class TestPathTensor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
@@ -18,7 +18,7 @@ class DiskDevice(Compiled):
|
||||
super().__init__(device, DiskAllocator(self), None, None)
|
||||
def _might_open(self, size:int):
|
||||
assert self.size is None or size <= self.size, f"can't reopen Disk tensor with larger size, opened with {self.size}, tried to open with {size}"
|
||||
if self.size is not None and hasattr(self.device, "mem"):
|
||||
if self.size is not None and hasattr(self, "mem"):
|
||||
self.count += 1
|
||||
return
|
||||
filename = self.device[len("disk:"):]
|
||||
|
||||
Reference in New Issue
Block a user