mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix assign
This commit is contained in:
@@ -153,7 +153,7 @@ CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), Co
|
||||
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
|
||||
FUSE_ATTENTION = ContextVar("FUSE_ATTENTION", 0)
|
||||
EMULATE = ContextVar("EMULATE", "")
|
||||
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(aff(0)) if (aff:=getattr(os, "sched_getaffinity", None)) else (os.cpu_count() or 1)))
|
||||
CPU_COUNT = ContextVar("CPU_COUNT", max(1, len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else (os.cpu_count() or 1)))
|
||||
CPU_LLVM, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("AMD_LLVM", 1)
|
||||
VIZ = PROFILE = ContextVar("VIZ", 0)
|
||||
SPEC = ContextVar("SPEC", 0)
|
||||
|
||||
@@ -9,8 +9,8 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap, DEBUG, is_numpy_ndarray, FUSE_ATTENTION
|
||||
from tinygrad.helpers import suppress_finalizing
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, MathTrait, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, \
|
||||
srender
|
||||
from tinygrad.uop.mathtraits import MathTrait
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, srender
|
||||
from tinygrad.uop.spec import tensor_uop_spec, type_verify
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
@@ -2685,7 +2685,8 @@ class Tensor(MathTrait):
|
||||
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
|
||||
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
||||
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return {Ops.ADD: Tensor.__add__, Ops.MAX: Tensor.maximum, Ops.MUL: Tensor.__mul__}[op](fix(ret), fix(base))
|
||||
reduce_fxns: dict[Ops, Callable[[Tensor, Tensor], Tensor]] = {Ops.ADD: Tensor.__add__, Ops.MAX: Tensor.maximum, Ops.MUL: Tensor.__mul__}
|
||||
return reduce_fxns[op](fix(ret), fix(base))
|
||||
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
"""
|
||||
@@ -3852,18 +3853,20 @@ class Tensor(MathTrait):
|
||||
def __rpow__(self, x) -> Tensor: return self.pow(x, True)
|
||||
def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True)
|
||||
|
||||
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x))
|
||||
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x))
|
||||
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x))
|
||||
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
||||
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x))
|
||||
def __ifloordiv__(self, x) -> Tensor: return self.assign(self.__floordiv__(x))
|
||||
def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x))
|
||||
def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x))
|
||||
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x))
|
||||
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x))
|
||||
def __ixor__(self, x) -> Tensor: return self.assign(self.bitwise_xor(x))
|
||||
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
||||
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
||||
|
||||
# unlike Tensors, UOps are immutable, so these don't go in MathTraits
|
||||
def __iadd__(self, x) -> Tensor: return self.assign(self.add(x)) # type: ignore[misc]
|
||||
def __isub__(self, x) -> Tensor: return self.assign(self.sub(x)) # type: ignore[misc]
|
||||
def __imul__(self, x) -> Tensor: return self.assign(self.mul(x)) # type: ignore[misc]
|
||||
def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x)) # type: ignore[misc]
|
||||
def __iand__(self, x) -> Tensor: return self.assign(self.bitwise_and(x)) # type: ignore[misc]
|
||||
def __ior__(self, x) -> Tensor: return self.assign(self.bitwise_or(x)) # type: ignore[misc]
|
||||
def __ixor__(self, x) -> Tensor: return self.assign(self.bitwise_xor(x)) # type: ignore[misc]
|
||||
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x)) # type: ignore[misc]
|
||||
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x)) # type: ignore[misc]
|
||||
|
||||
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
|
||||
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
|
||||
|
||||
@@ -145,8 +145,7 @@ class MathTrait:
|
||||
|
||||
def ne(self:TMT, x:TMT|ConstType): return self.alu(Ops.CMPNE, self.ufix(x))
|
||||
def eq(self:TMT, x:TMT|ConstType): return self.ne(x).logical_not()
|
||||
# TODO: make typing of __ne__ work
|
||||
def __ne__(self, x): return self.ne(x)
|
||||
def __ne__(self:TMT, x:TMT|ConstType): return self.ne(x) # type: ignore[override]
|
||||
# NOTE: __eq__ isn't overridden, and means the same thing as is by default
|
||||
|
||||
def lshift(self:TMT, x:TMT|int, reverse:bool=False): return self._binop(Ops.SHL, x, reverse)
|
||||
|
||||
@@ -34,8 +34,8 @@ def resolve(x:UOp|bool, default:bool=True):
|
||||
def _suop(lst, uop_fxn, python_fxn):
|
||||
uops, nums = partition(lst, lambda x: isinstance(x, UOp))
|
||||
return ssimplify(functools.reduce(uop_fxn, uops + ([python_fxn(nums)] if nums else [])))
|
||||
def smax(*lst:sint): return _suop(argfix(*lst), UOp.maximum, max)
|
||||
def smin(*lst:sint): return _suop(argfix(*lst), UOp.minimum, min)
|
||||
def smax(*lst) -> sint: return _suop(argfix(*lst), UOp.maximum, max)
|
||||
def smin(*lst) -> sint: return _suop(argfix(*lst), UOp.minimum, min)
|
||||
def srender(x:sint) -> str: return x.render() if isinstance(x, UOp) else str(x)
|
||||
|
||||
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
|
||||
Reference in New Issue
Block a user