Tensor.int and Tensor.bool (#5317)

This commit is contained in:
chenyu
2024-07-07 11:52:58 -04:00
committed by GitHub
parent 778d1cdbee
commit c1e330f302
3 changed files with 51 additions and 8 deletions

View File

@@ -106,3 +106,5 @@
::: tinygrad.Tensor.bitcast
::: tinygrad.Tensor.float
::: tinygrad.Tensor.half
::: tinygrad.Tensor.int
::: tinygrad.Tensor.bool

View File

@@ -1866,6 +1866,13 @@ class TestOps(unittest.TestCase):
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf))
def test_cast(self):
helper_test_op([(3, 3)], lambda x: x.float())
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True)
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)

View File

@@ -2837,12 +2837,23 @@ class Tensor:
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
# ***** convenience stuff *****
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> sint: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
# ***** cast ops *****
def llvm_bf16_cast(self, dtype:DType):
# hack for devices that don't support bfloat16
assert self.dtype == dtypes.bfloat16
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
def cast(self, dtype:DType) -> Tensor:
"""
Casts `self` to the given `dtype`.
@@ -2857,6 +2868,7 @@ class Tensor:
```
"""
return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
def bitcast(self, dtype:DType) -> Tensor:
"""
Bitcasts `self` to the given `dtype` of the same itemsize.
@@ -2874,6 +2886,7 @@ class Tensor:
"""
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
def float(self) -> Tensor:
"""
Convenience method to cast `self` to a `float32` Tensor.
@@ -2888,6 +2901,7 @@ class Tensor:
```
"""
return self.cast(dtypes.float32)
def half(self) -> Tensor:
"""
Convenience method to cast `self` to a `float16` Tensor.
@@ -2903,15 +2917,35 @@ class Tensor:
"""
return self.cast(dtypes.float16)
# ***** convenience stuff *****
def int(self) -> Tensor:
"""
Convenience method to cast `self` to a `int32` Tensor.
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> sint: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.int()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.int32)
def bool(self) -> Tensor:
"""
Convenience method to cast `self` to a `bool` Tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([-1, 0, 1])
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.bool()
print(t.dtype, t.numpy())
```
"""
return self.cast(dtypes.bool)
# *** image Tensor function replacements ***