mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
don't cast before view on shape changing bitcast (#4833)
* don't cast before view on shape changing bitcast * make sure cast before view triggers
This commit is contained in:
@@ -284,6 +284,12 @@ class TestDiskTensor(unittest.TestCase):
|
||||
ret = t.to("CLANG").bitcast(dtypes.uint16) + 1
|
||||
assert ret.tolist() == [2827, 3341, 3855, 4369, 4883]
|
||||
|
||||
def test_bitcast_view(self):
|
||||
with open(temp('range_1020'), "wb") as f: f.write(bytes(range(10, 24)))
|
||||
t = Tensor.empty(3, dtype=dtypes.uint, device=f"disk:{temp('range_1020')}").shrink([(0, 2)])
|
||||
ret = t.bitcast(dtypes.uint16).to("CLANG") + 1
|
||||
assert ret.tolist() == [2827, 3341, 3855, 4369]
|
||||
|
||||
def test_bf16_disk_write_read(self):
|
||||
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
|
||||
t.to(f"disk:{temp('f32')}").realize()
|
||||
|
||||
@@ -97,9 +97,6 @@ class LazyBuffer:
|
||||
if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)")
|
||||
if self.is_unrealized_unmasked_const() and not bitcast:
|
||||
return create_lazybuffer(self.device, self.st, dtype, LoadOps.CONST, dtypes.as_const(self.base.arg, dtype))
|
||||
# TODO: applying this makes gpt2 slower
|
||||
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
new_shape = self.shape
|
||||
if bitcast and self.dtype.itemsize != dtype.itemsize:
|
||||
if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now")
|
||||
@@ -107,6 +104,9 @@ class LazyBuffer:
|
||||
# https://pytorch.org/docs/stable/generated/torch.Tensor.view.html
|
||||
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
|
||||
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
||||
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self != self.base:
|
||||
# TODO: applying this makes gpt2 slower
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user