more mixins pt 2 (#14765)

* more mixins pt 2

* lil cleanups
This commit is contained in:
George Hotz
2026-02-15 17:57:04 +08:00
committed by GitHub
parent 9da7f5e733
commit 713143a46a
3 changed files with 36 additions and 53 deletions

View File

@@ -8,11 +8,26 @@ class DTypeMixin:
def cast(self, dtype:DType) -> Self: raise NotImplementedError
def element_size(self) -> int:
"""Returns the number of bytes of a single element in the tensor."""
"""
Returns the size in bytes of an individual element in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([5], dtype=dtypes.int16)
print(t.element_size())
```
"""
return self.dtype.itemsize
def is_floating_point(self) -> bool:
"""Returns `True` if the tensor contains floating point types, i.e. is one of `bool`, `float16`, `bfloat16`, `float32`, `float64`."""
"""
Returns `True` if the tensor contains floating point types, i.e. is one of `dtypes.float64`, `dtypes.float32`,
`dtypes.float16`, `dtypes.bfloat16`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([8, 9], dtype=dtypes.float32)
print(t.is_floating_point())
```
"""
return dtypes.is_float(self.dtype)
def float(self) -> Self:

View File

@@ -195,10 +195,10 @@ class MathMixin(DTypeMixin):
return self.mod(x, True)
def __lt__(self, x: Self | ConstType) -> Self:
return self.alu(Ops.CMPLT, self.ufix(x))
return self._binop(Ops.CMPLT, x, False)
def __gt__(self, x: Self | ConstType) -> Self:
return self.ufix(x).alu(Ops.CMPLT, self)
return self._binop(Ops.CMPLT, x, True)
def __ge__(self, x: Self | ConstType) -> Self:
return (self < x).logical_not()
@@ -207,7 +207,7 @@ class MathMixin(DTypeMixin):
return (self > x).logical_not()
def ne(self, x: Self | ConstType) -> Self:
return self.alu(Ops.CMPNE, self.ufix(x))
return self._binop(Ops.CMPNE, x, False)
def eq(self, x: Self | ConstType) -> Self:
return self.ne(x).logical_not()
@@ -236,7 +236,17 @@ class MathMixin(DTypeMixin):
return self.rshift(x, True)
def maximum(self, x: Self | ConstType) -> Self:
return self.alu(Ops.MAX, self.ufix(x))
"""
Computes element-wise maximum of `self` and `x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).maximum(1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
```
"""
return self._binop(Ops.MAX, x, False)
def minimum(self, x: Self | ConstType) -> Self:
return -(-self).maximum(-self.ufix(x))

View File

@@ -189,12 +189,10 @@ class Tensor(OpMixin):
all_tensors[weakref.ref(ret)] = None
return ret
def _apply_broadcasted_uop(self, fxn:Callable, x:Tensor|ConstType, reverse=False) -> Tensor:
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(fxn, rhs)
# _binop and alu are used by MathMixin
def _binop(self, op, x, reverse): return self._apply_broadcasted_uop(lambda *u: UOp.alu(u[0], op, *u[1:]), x, reverse)
def _binop(self, op, x, reverse):
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(lambda *u: u[0].alu(op, *u[1:]), rhs)
def alu(self, op: Ops, *src: Tensor) -> Tensor: return self._apply_uop(lambda *u: u[0].alu(op, *u[1:]), *src)
def requires_grad_(self, requires_grad=True) -> Tensor:
@@ -2822,7 +2820,7 @@ class Tensor(OpMixin):
print(Tensor([False, True]).logical_not().numpy())
```
"""
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
return self.cast(dtypes.bool).ne(True)
def neg(self) -> Tensor:
"""
@@ -3197,7 +3195,7 @@ class Tensor(OpMixin):
numerator, denominator = numerator.cast(dt), denominator.cast(dt)
if rounding_mode == "trunc": return numerator.idiv(denominator)
if rounding_mode == "floor":
truncate_div, truncate_mod = numerator.idiv(denominator), numerator._apply_broadcasted_uop(UOp.mod, denominator)
truncate_div, truncate_mod = numerator.idiv(denominator), numerator._binop(Ops.MOD, denominator, False)
opposite_sign = ((numerator>0)&(denominator<0)) | ((numerator<0)&(denominator>0))
return (opposite_sign&(truncate_mod!=0)).where(truncate_div-1, truncate_div)
if rounding_mode == "trunc": return d.trunc().cast(output_dtype)
@@ -3279,19 +3277,6 @@ class Tensor(OpMixin):
# NOTE: pow(int, float) -> int
return ret.round().cast(self.dtype) if not reverse and not dtypes.is_float(self.dtype) and dtypes.is_float(exponent.dtype) else ret
def maximum(self, x:Tensor|ConstType) -> Tensor:
"""
Computes element-wise maximum of `self` and `x`.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).maximum(1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-1, 2, 3]).maximum(Tensor([-4, -2, 9])).numpy())
```
"""
return self._apply_broadcasted_uop(UOp.maximum, x)
def minimum(self, x:Tensor|ConstType) -> Tensor:
"""
Computes element-wise minimum of `self` and `x`.
@@ -3376,10 +3361,6 @@ class Tensor(OpMixin):
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x)) # type: ignore[misc]
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x)) # type: ignore[misc]
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
# ***** encoding/decoding ops *****
@@ -3743,17 +3724,6 @@ class Tensor(OpMixin):
# ***** Tensor Properties *****
def element_size(self) -> int:
"""
Returns the size in bytes of an individual element in the tensor.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([5], dtype=dtypes.int16)
print(t.element_size())
```
"""
return self.dtype.itemsize
def nbytes(self) -> int:
"""
Returns the total number of bytes of all elements in the tensor.
@@ -3765,18 +3735,6 @@ class Tensor(OpMixin):
"""
return int(self.numel()) * self.element_size()
def is_floating_point(self) -> bool:
"""
Returns `True` if the tensor contains floating point types, i.e. is one of `dtypes.float64`, `dtypes.float32`,
`dtypes.float16`, `dtypes.bfloat16`.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor([8, 9], dtype=dtypes.float32)
print(t.is_floating_point())
```
"""
return dtypes.is_float(self.dtype)
def size(self, dim:int|None=None) -> sint|tuple[sint, ...]:
"""
Returns the size of the tensor. If `dim` is specified, return the length along dimension `dim`. Otherwise return the shape of the tensor.