delete unused cast/bitcast lines from ops.py [pr] (#8651)

* move cast and bitcast out

* more deletion of bitcast arg

* fix test_bitcast_fuses

* update tests

* work
This commit is contained in:
qazal
2025-01-17 03:04:18 -05:00
committed by GitHub
parent 4f0d1b4759
commit 2b7db9b45d
7 changed files with 14 additions and 16 deletions

View File

@@ -1041,5 +1041,9 @@ class TestTensorOps(unittest.TestCase):
def test_interpolate(self):
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
@unittest.expectedFailure # 'MultiLazyBuffer' object has no attribute 'bitcast'
def test_bitcast(self):
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
if __name__ == '__main__':
unittest.main()

View File

@@ -1344,8 +1344,8 @@ class TestSchedule(unittest.TestCase):
def test_bitcast_fuses(self):
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
a = x.alu(Ops.EXP2).cast(dtypes.int32, True)
b = x.cast(dtypes.int32, True)
a = x.alu(Ops.EXP2).bitcast(dtypes.int32)
b = x.bitcast(dtypes.int32)
b = a.alu(Ops.ADD, b)
check_schedule(b, 1) # this should fuse when it makes sense

View File

@@ -68,8 +68,8 @@ class TestRawDiskBuffer(unittest.TestCase):
_test_bitcasted(t, dtypes.float32, 3.1415927)
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
# doesn't suport normal cast
with self.assertRaises(RuntimeError):
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16)
with self.assertRaises(NotImplementedError):
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16).realize()
# Those two should be moved to test_dtype.py:test_shape_change_bitcast after bitcast works on non-disk
with self.assertRaises(RuntimeError):

View File

@@ -13,7 +13,7 @@ from tinygrad.device import Buffer
# creation can recurse a lot
sys.setrecursionlimit(10000)
# **** big graph spec
# **** Tensor UOp spec
tensor_uop_spec = PatternMatcher([
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),

View File

@@ -76,7 +76,7 @@ class MultiLazyBuffer(MathTrait):
# passthroughs
@property
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
def cast(self, dtype:DType): return MultiLazyBuffer([x.cast(dtype) for x in self.lbs], self.axis, self.real)
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)

View File

@@ -358,14 +358,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert self.dtype.count == 1
if count == 1: return self
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
def cast(self, dtype:DType, bitcast=False):
if bitcast: return self.bitcast(dtype)
if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK")
return UOp(Ops.CAST, dtype, (self,))
def bitcast(self, dtype:DType):
if self.st is not None and self.shape and ((self.shape[-1]*self.dtype.itemsize)%dtype.itemsize != 0):
raise RuntimeError(f"unsupported size in bitcast {dtype}")
return UOp(Ops.BITCAST, dtype, (self,))
def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
def gep(self, i:Union[tuple[int, ...], int]):
if isinstance(i, int):
# NOTE: these are just shortcuts to not have to create and fold later

View File

@@ -3813,8 +3813,8 @@ class Tensor(SimpleMathTrait):
"""
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
dt = to_dtype(dtype)
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and (ns:=dt.itemsize) != (os:=self.dtype.itemsize):
if (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
if (ns:=dt.itemsize) != (os:=self.dtype.itemsize) and (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os:
new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
tmp = self.bitcast(old_uint)
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)