mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
Tensor.int and Tensor.bool (#5317)
This commit is contained in:
@@ -106,3 +106,5 @@
|
||||
::: tinygrad.Tensor.bitcast
|
||||
::: tinygrad.Tensor.float
|
||||
::: tinygrad.Tensor.half
|
||||
::: tinygrad.Tensor.int
|
||||
::: tinygrad.Tensor.bool
|
||||
|
||||
@@ -1866,6 +1866,13 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf))
|
||||
|
||||
def test_cast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.float())
|
||||
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True)
|
||||
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
|
||||
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -2837,12 +2837,23 @@ class Tensor:
|
||||
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask).sum()
|
||||
return -((1 - label_smoothing) * (log_probs * y).sum() + smoothing) / loss_mask.sum()
|
||||
|
||||
# ***** convenience stuff *****
|
||||
|
||||
@property
|
||||
def ndim(self) -> int: return len(self.shape)
|
||||
def numel(self) -> sint: return prod(self.shape)
|
||||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
def llvm_bf16_cast(self, dtype:DType):
|
||||
# hack for devices that don't support bfloat16
|
||||
assert self.dtype == dtypes.bfloat16
|
||||
return self.to("LLVM").bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype)
|
||||
|
||||
def cast(self, dtype:DType) -> Tensor:
|
||||
"""
|
||||
Casts `self` to the given `dtype`.
|
||||
@@ -2857,6 +2868,7 @@ class Tensor:
|
||||
```
|
||||
"""
|
||||
return self if self.dtype == dtype else F.Cast.apply(self, dtype=dtype)
|
||||
|
||||
def bitcast(self, dtype:DType) -> Tensor:
|
||||
"""
|
||||
Bitcasts `self` to the given `dtype` of the same itemsize.
|
||||
@@ -2874,6 +2886,7 @@ class Tensor:
|
||||
"""
|
||||
if self.requires_grad: raise RuntimeError("can't backprop through bitcast")
|
||||
return F.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self
|
||||
|
||||
def float(self) -> Tensor:
|
||||
"""
|
||||
Convenience method to cast `self` to a `float32` Tensor.
|
||||
@@ -2888,6 +2901,7 @@ class Tensor:
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.float32)
|
||||
|
||||
def half(self) -> Tensor:
|
||||
"""
|
||||
Convenience method to cast `self` to a `float16` Tensor.
|
||||
@@ -2903,15 +2917,35 @@ class Tensor:
|
||||
"""
|
||||
return self.cast(dtypes.float16)
|
||||
|
||||
# ***** convenience stuff *****
|
||||
def int(self) -> Tensor:
|
||||
"""
|
||||
Convenience method to cast `self` to a `int32` Tensor.
|
||||
|
||||
@property
|
||||
def ndim(self) -> int: return len(self.shape)
|
||||
def numel(self) -> sint: return prod(self.shape)
|
||||
def element_size(self) -> int: return self.dtype.itemsize
|
||||
def nbytes(self) -> int: return self.numel() * self.element_size()
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim]
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1.5, -0.5, 0.0, 0.5, 1.5])
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.int()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.int32)
|
||||
|
||||
def bool(self) -> Tensor:
|
||||
"""
|
||||
Convenience method to cast `self` to a `bool` Tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([-1, 0, 1])
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.bool()
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.bool)
|
||||
|
||||
# *** image Tensor function replacements ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user