don't cast before view on shape changing bitcast (#4833)

* don't cast before view on shape changing bitcast

* make sure cast before view triggers
This commit is contained in:
David Hou
2024-06-04 13:04:52 -07:00
committed by GitHub
parent 0c3a996e64
commit cddce0e168
2 changed files with 9 additions and 3 deletions

View File

@@ -284,6 +284,12 @@ class TestDiskTensor(unittest.TestCase):
ret = t.to("CLANG").bitcast(dtypes.uint16) + 1
assert ret.tolist() == [2827, 3341, 3855, 4369, 4883]
def test_bitcast_view(self):
with open(temp('range_1020'), "wb") as f: f.write(bytes(range(10, 24)))
t = Tensor.empty(3, dtype=dtypes.uint, device=f"disk:{temp('range_1020')}").shrink([(0, 2)])
ret = t.bitcast(dtypes.uint16).to("CLANG") + 1
assert ret.tolist() == [2827, 3341, 3855, 4369]
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()

View File

@@ -97,9 +97,6 @@ class LazyBuffer:
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
if self.is_unrealized_unmasked_const() and not bitcast:
return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
# TODO: applying this makes gpt2 slower
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
return self.base.cast(dtype, bitcast)._view(self.st)
new_shape = self.shape
if bitcast and self.dtype.itemsize != dtype.itemsize:
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
@@ -107,6 +104,9 @@ class LazyBuffer:
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
# TODO: applying this makes gpt2 slower
return self.base.cast(dtype, bitcast)._view(self.st)
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))