mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -8,11 +8,26 @@ class DTypeMixin:
|
||||
def cast(self, dtype:DType) -> Self: raise NotImplementedError
|
||||
|
||||
def element_size(self) -> int:
|
||||
"""Returns the number of bytes of a single element in the tensor."""
|
||||
"""
|
||||
Returns the size in bytes of an individual element in the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([5], dtype=dtypes.int16)
|
||||
print(t.element_size())
|
||||
```
|
||||
"""
|
||||
return self.dtype.itemsize
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"""Returns `True` if the tensor contains floating point types, i.e. is one of `bool`, `float16`, `bfloat16`, `float32`, `float64`."""
|
||||
"""
|
||||
Returns `True` if the tensor contains floating point types, i.e. is one of `dtypes.float64`, `dtypes.float32`,
|
||||
`dtypes.float16`, `dtypes.bfloat16`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([8, 9], dtype=dtypes.float32)
|
||||
print(t.is_floating_point())
|
||||
```
|
||||
"""
|
||||
return dtypes.is_float(self.dtype)
|
||||
|
||||
def float(self) -> Self:
|
||||
|
||||
@@ -195,10 +195,10 @@ class MathMixin(DTypeMixin):
|
||||
return self.mod(x, True)
|
||||
|
||||
def __lt__(self, x: Self | ConstType) -> Self:
|
||||
return self.alu(Ops.CMPLT, self.ufix(x))
|
||||
return self._binop(Ops.CMPLT, x, False)
|
||||
|
||||
def __gt__(self, x: Self | ConstType) -> Self:
|
||||
return self.ufix(x).alu(Ops.CMPLT, self)
|
||||
return self._binop(Ops.CMPLT, x, True)
|
||||
|
||||
def __ge__(self, x: Self | ConstType) -> Self:
|
||||
return (self < x).logical_not()
|
||||
@@ -207,7 +207,7 @@ class MathMixin(DTypeMixin):
|
||||
return (self > x).logical_not()
|
||||
|
||||
def ne(self, x: Self | ConstType) -> Self:
|
||||
return self.alu(Ops.CMPNE, self.ufix(x))
|
||||
return self._binop(Ops.CMPNE, x, False)
|
||||
|
||||
def eq(self, x: Self | ConstType) -> Self:
|
||||
return self.ne(x).logical_not()
|
||||
@@ -236,7 +236,17 @@ class MathMixin(DTypeMixin):
|
||||
return self.rshift(x, True)
|
||||
|
||||
def maximum(self, x: Self | ConstType) -> Self:
|
||||
return self.alu(Ops.MAX, self.ufix(x))
|
||||
"""
|
||||
Computes element-wise maximum of `self` and `x`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1, 2, 3]).maximum(1).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
|
||||
```
|
||||
"""
|
||||
return self._binop(Ops.MAX, x, False)
|
||||
|
||||
def minimum(self, x: Self | ConstType) -> Self:
|
||||
return -(-self).maximum(-self.ufix(x))
|
||||
|
||||
@@ -189,12 +189,10 @@ class Tensor(OpMixin):
|
||||
all_tensors[weakref.ref(ret)] = None
|
||||
return ret
|
||||
|
||||
def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor:
|
||||
lhs,rhs = self._broadcasted(x, reverse)
|
||||
return lhs._apply_uop(fxn, rhs)
|
||||
|
||||
# _binop and alu are used by MathMixin
|
||||
def _binop(self, op, x, reverse): return self._apply_broadcasted_uop(lambda *u: UOp.alu(u[0], op, *u[1:]), x, reverse)
|
||||
def _binop(self, op, x, reverse):
|
||||
lhs,rhs = self._broadcasted(x, reverse)
|
||||
return lhs._apply_uop(lambda *u: u[0].alu(op, *u[1:]), rhs)
|
||||
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
|
||||
|
||||
def requires_grad_(self, requires_grad=True) -> Tensor:
|
||||
@@ -2822,7 +2820,7 @@ class Tensor(OpMixin):
|
||||
print(Tensor([False, True]).logical_not().numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
|
||||
return self.cast(dtypes.bool).ne(True)
|
||||
|
||||
def neg(self) -> Tensor:
|
||||
"""
|
||||
@@ -3197,7 +3195,7 @@ class Tensor(OpMixin):
|
||||
numerator, denominator = numerator.cast(dt), denominator.cast(dt)
|
||||
if rounding_mode == "trunc": return numerator.idiv(denominator)
|
||||
if rounding_mode == "floor":
|
||||
truncate_div, truncate_mod = numerator.idiv(denominator), numerator._apply_broadcasted_uop(UOp.mod, denominator)
|
||||
truncate_div, truncate_mod = numerator.idiv(denominator), numerator._binop(Ops.MOD, denominator, False)
|
||||
opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0))
|
||||
return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div)
|
||||
if rounding_mode == "trunc": return d.trunc().cast(output_dtype)
|
||||
@@ -3279,19 +3277,6 @@ class Tensor(OpMixin):
|
||||
# NOTE: pow(int, float) -> int
|
||||
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) and dtypes.is_float(exponent.dtype) else ret
|
||||
|
||||
def maximum(self, x:Tensor|ConstType) -> Tensor:
|
||||
"""
|
||||
Computes element-wise maximum of `self` and `x`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1, 2, 3]).maximum(1).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
|
||||
```
|
||||
"""
|
||||
return self._apply_broadcasted_uop(UOp.maximum, x)
|
||||
|
||||
def minimum(self, x:Tensor|ConstType) -> Tensor:
|
||||
"""
|
||||
Computes element-wise minimum of `self` and `x`.
|
||||
@@ -3376,10 +3361,6 @@ class Tensor(OpMixin):
|
||||
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x)) # type: ignore[misc]
|
||||
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x)) # type: ignore[misc]
|
||||
|
||||
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
|
||||
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
|
||||
def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
|
||||
|
||||
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
|
||||
|
||||
# ***** encoding/decoding ops *****
|
||||
@@ -3743,17 +3724,6 @@ class Tensor(OpMixin):
|
||||
|
||||
# ***** Tensor Properties *****
|
||||
|
||||
def element_size(self) -> int:
|
||||
"""
|
||||
Returns the size in bytes of an individual element in the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([5], dtype=dtypes.int16)
|
||||
print(t.element_size())
|
||||
```
|
||||
"""
|
||||
return self.dtype.itemsize
|
||||
|
||||
def nbytes(self) -> int:
|
||||
"""
|
||||
Returns the total number of bytes of all elements in the tensor.
|
||||
@@ -3765,18 +3735,6 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return int(self.numel()) * self.element_size()
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"""
|
||||
Returns `True` if the tensor contains floating point types, i.e. is one of `dtypes.float64`, `dtypes.float32`,
|
||||
`dtypes.float16`, `dtypes.bfloat16`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([8, 9], dtype=dtypes.float32)
|
||||
print(t.is_floating_point())
|
||||
```
|
||||
"""
|
||||
return dtypes.is_float(self.dtype)
|
||||
|
||||
def size(self, dim:int|None=None) -> sint|tuple[sint, ...]:
|
||||
"""
|
||||
Returns the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.
|
||||
|
||||
Reference in New Issue
Block a user