mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
delete unused cast/bitcast lines from ops.py [pr] (#8651)
* move cast and bitcast out * more deletion of bitcast arg * fix test_bitcast_fuses * update tests * work
This commit is contained in:
@@ -1041,5 +1041,9 @@ class TestTensorOps(unittest.TestCase):
|
||||
def test_interpolate(self):
|
||||
helper_test_shard_op([(4,16,16),(4,24,24)], lambda x: Tensor.interpolate(x, (19,19)))
|
||||
|
||||
@unittest.expectedFailure # 'MultiLazyBuffer' object has no attribute 'bitcast'
|
||||
def test_bitcast(self):
|
||||
helper_test_shard_op([(256,), (256,)], lambda x: x.bitcast(dtypes.int))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1344,8 +1344,8 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_bitcast_fuses(self):
|
||||
x = cast(UOp, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata)
|
||||
a = x.alu(Ops.EXP2).cast(dtypes.int32, True)
|
||||
b = x.cast(dtypes.int32, True)
|
||||
a = x.alu(Ops.EXP2).bitcast(dtypes.int32)
|
||||
b = x.bitcast(dtypes.int32)
|
||||
b = a.alu(Ops.ADD, b)
|
||||
check_schedule(b, 1) # this should fuse when it makes sense
|
||||
|
||||
|
||||
@@ -68,8 +68,8 @@ class TestRawDiskBuffer(unittest.TestCase):
|
||||
_test_bitcasted(t, dtypes.float32, 3.1415927)
|
||||
_test_bitcasted(t, dtypes.uint32, 0x40490FDB)
|
||||
# doesn't suport normal cast
|
||||
with self.assertRaises(RuntimeError):
|
||||
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
Tensor.empty((4,), dtype=dtypes.int16, device=f"disk:{tmp}").cast(dtypes.float16).realize()
|
||||
|
||||
# Those two should be moved to test_dtype.py:test_shape_change_bitcast after bitcast works on non-disk
|
||||
with self.assertRaises(RuntimeError):
|
||||
|
||||
@@ -13,7 +13,7 @@ from tinygrad.device import Buffer
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# **** big graph spec
|
||||
# **** Tensor UOp spec
|
||||
|
||||
tensor_uop_spec = PatternMatcher([
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="device"), lambda device: isinstance(device.arg, str)),
|
||||
|
||||
@@ -76,7 +76,7 @@ class MultiLazyBuffer(MathTrait):
|
||||
# passthroughs
|
||||
@property
|
||||
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
|
||||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||
def cast(self, dtype:DType): return MultiLazyBuffer([x.cast(dtype) for x in self.lbs], self.axis, self.real)
|
||||
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
|
||||
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
||||
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
||||
|
||||
@@ -358,14 +358,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
return UOp(Ops.VECTORIZE, self.dtype.vec(count), (self,)*count)
|
||||
def cast(self, dtype:DType, bitcast=False):
|
||||
if bitcast: return self.bitcast(dtype)
|
||||
if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK")
|
||||
return UOp(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType):
|
||||
if self.st is not None and self.shape and ((self.shape[-1]*self.dtype.itemsize)%dtype.itemsize != 0):
|
||||
raise RuntimeError(f"unsupported size in bitcast {dtype}")
|
||||
return UOp(Ops.BITCAST, dtype, (self,))
|
||||
def cast(self, dtype:DType): return UOp(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType): return UOp(Ops.BITCAST, dtype, (self,))
|
||||
def gep(self, i:Union[tuple[int, ...], int]):
|
||||
if isinstance(i, int):
|
||||
# NOTE: these are just shortcuts to not have to create and fold later
|
||||
|
||||
@@ -3813,8 +3813,8 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
||||
dt = to_dtype(dtype)
|
||||
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and (ns:=dt.itemsize) != (os:=self.dtype.itemsize):
|
||||
if (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
|
||||
if (ns:=dt.itemsize) != (os:=self.dtype.itemsize) and (self.shape[-1]*os) % ns != 0: raise RuntimeError("unsupported size in bitcast")
|
||||
if (not isinstance(self.device, str) or not self.device.startswith("DISK")) and ns != os:
|
||||
new_uint, old_uint = to_dtype(f"uint{8*ns}"), to_dtype(f"uint{8*os}")
|
||||
tmp = self.bitcast(old_uint)
|
||||
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
|
||||
|
||||
Reference in New Issue
Block a user