fix assign

This commit is contained in:
George Hotz
2025-10-14 11:15:12 +08:00
parent 1ecb99480e
commit 147fd0e2c6
4 changed files with 20 additions and 18 deletions

View File

@@ -153,7 +153,7 @@ CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), Co
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
FUSE_ATTENTION = ContextVar("FUSE_ATTENTION", 0)
EMULATE = ContextVar("EMULATE", "")
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(aff(0)) if (aff:=getattr(os, "sched_getaffinity", None)) else (os.cpu_count() or 1)))
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
CPU_LLVM, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("AMD_LLVM", 1)
VIZ = PROFILE = ContextVar("VIZ", 0)
SPEC = ContextVar("SPEC", 0)

View File

@@ -9,8 +9,8 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, FUSE_ATTENTION
from tinygrad.helpers import suppress_finalizing
from tinygrad.gradient import compute_gradient
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, \
srender
from tinygrad.uop.mathtraits import MathTrait
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender
from tinygrad.uop.spec import tensor_uop_spec, type_verify
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
@@ -2685,7 +2685,8 @@ class Tensor(MathTrait):
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
return {Ops.ADD: Tensor.__add__, Ops.MAX: Tensor.maximum, Ops.MUL: Tensor.__mul__}[op](fix(ret), fix(base))
reduce_fxns: dict[Ops, Callable[[Tensor, Tensor], Tensor]] = {Ops.ADD: Tensor.__add__, Ops.MAX: Tensor.maximum, Ops.MUL: Tensor.__mul__}
return reduce_fxns[op](fix(ret), fix(base))
def cumsum(self, axis:int=0) -> Tensor:
"""
@@ -3852,18 +3853,20 @@ class Tensor(MathTrait):
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
def __ifloordiv__(self, x) -> Tensor: return self.assign(self.__floordiv__(x))
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
def __ixor__(self, x) -> Tensor: return self.assign(self.bitwise_xor(x))
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
# unlike Tensors, UOps are immutable, so these don't go in MathTraits
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x)) # type: ignore[misc]
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x)) # type: ignore[misc]
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x)) # type: ignore[misc]
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x)) # type: ignore[misc]
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x)) # type: ignore[misc]
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x)) # type: ignore[misc]
def __ixor__(self, x) -> Tensor: return self.assign(self.bitwise_xor(x)) # type: ignore[misc]
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x)) # type: ignore[misc]
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x)) # type: ignore[misc]
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)

View File

@@ -145,8 +145,7 @@ class MathTrait:
def ne(self:TMT, x:TMT|ConstType): return self.alu(Ops.CMPNE, self.ufix(x))
def eq(self:TMT, x:TMT|ConstType): return self.ne(x).logical_not()
# TODO: make typing of __ne__ work
def __ne__(self, x): return self.ne(x)
def __ne__(self:TMT, x:TMT|ConstType): return self.ne(x) # type: ignore[override]
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
def lshift(self:TMT, x:TMT|int, reverse:bool=False): return self._binop(Ops.SHL, x, reverse)

View File

@@ -34,8 +34,8 @@ def resolve(x:UOp|bool, default:bool=True):
def _suop(lst, uop_fxn, python_fxn):
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
def smax(*lst:sint): return _suop(argfix(*lst), UOp.maximum, max)
def smin(*lst:sint): return _suop(argfix(*lst), UOp.minimum, min)
def smax(*lst) -> sint: return _suop(argfix(*lst), UOp.maximum, max)
def smin(*lst) -> sint: return _suop(argfix(*lst), UOp.minimum, min)
def srender(x:sint) -> str: return x.render() if isinstance(x, UOp) else str(x)
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop