From c1e330f302c917b0c783e3c5238d008158f94c05 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 7 Jul 2024 11:52:58 -0400 Subject: [PATCH] Tensor.int and Tensor.bool (#5317) --- docs/tensor/ops.md | 2 ++ test/test_ops.py | 7 +++++++ tinygrad/tensor.py | 50 ++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/docs/tensor/ops.md b/docs/tensor/ops.md index c23a29727f..933600fc2b 100644 --- a/docs/tensor/ops.md +++ b/docs/tensor/ops.md @@ -106,3 +106,5 @@ ::: tinygrad.Tensor.bitcast ::: tinygrad.Tensor.float ::: tinygrad.Tensor.half +::: tinygrad.Tensor.int +::: tinygrad.Tensor.bool diff --git a/test/test_ops.py b/test/test_ops.py index 45c2fa4044..5dd7f1c114 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1866,6 +1866,13 @@ class TestOps(unittest.TestCase): helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf)) helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf)) + def test_cast(self): + helper_test_op([(3, 3)], lambda x: x.float()) + helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True) + helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True) + helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True) + helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True) + if __name__ == '__main__': np.random.seed(1337) unittest.main(verbosity=2) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 340f41f7ae..3fef38be65 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2837,12 +2837,23 @@ class Tensor: smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum() return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum() + # ***** convenience stuff ***** + + @property + def ndim(self) -> int: return len(self.shape) + def numel(self) -> sint: return prod(self.shape) + def element_size(self) -> int: return self.dtype.itemsize + def nbytes(self) -> int: return self.numel() * self.element_size() + def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype) + def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim] + # ***** cast ops ***** def llvm_bf16_cast(self, dtype:DType): # hack for devices that don't support bfloat16 assert self.dtype == dtypes.bfloat16 return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype) + def cast(self, dtype:DType) -> Tensor: """ Casts `self` to the given `dtype`. @@ -2857,6 +2868,7 @@ class Tensor: ``` """ return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype) + def bitcast(self, dtype:DType) -> Tensor: """ Bitcasts `self` to the given `dtype` of the same itemsize. @@ -2874,6 +2886,7 @@ class Tensor: """ if self.requires_grad: raise RuntimeError("can't backprop through bitcast") return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self + def float(self) -> Tensor: """ Convenience method to cast `self` to a `float32` Tensor. @@ -2888,6 +2901,7 @@ class Tensor: ``` """ return self.cast(dtypes.float32) + def half(self) -> Tensor: """ Convenience method to cast `self` to a `float16` Tensor. @@ -2903,15 +2917,35 @@ class Tensor: """ return self.cast(dtypes.float16) - # ***** convenience stuff ***** + def int(self) -> Tensor: + """ + Convenience method to cast `self` to a `int32` Tensor. - @property - def ndim(self) -> int: return len(self.shape) - def numel(self) -> sint: return prod(self.shape) - def element_size(self) -> int: return self.dtype.itemsize - def nbytes(self) -> int: return self.numel() * self.element_size() - def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype) - def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim] + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5]) + print(t.dtype, t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = t.int() + print(t.dtype, t.numpy()) + ``` + """ + return self.cast(dtypes.int32) + + def bool(self) -> Tensor: + """ + Convenience method to cast `self` to a `bool` Tensor. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([-1, 0, 1]) + print(t.dtype, t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + t = t.bool() + print(t.dtype, t.numpy()) + ``` + """ + return self.cast(dtypes.bool) # *** image Tensor function replacements ***