diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 45dc28f06a..6b5d8720a7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -241,7 +241,7 @@ jobs: python -m mypy --lineprecision-report . cat lineprecision.txt - name: Run TYPED=1 - run: TYPED=1 python -c "import tinygrad" + run: CHECK_OOB=0 DEV=CPU TYPED=1 python test/test_tiny.py unittest: name: Unit Tests diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index a97f7b6487..8ddbe36f30 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -33,7 +33,8 @@ class InvalidType: Invalid = InvalidType() -ConstType = float|int|bool +PyConst = float|int|bool +ConstType = PyConst|InvalidType FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd'] @@ -148,7 +149,7 @@ class dtypes: if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}") @staticmethod - def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType): + def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType): if isinstance(val, tuple): assert len(val) == dtype.count, f"mismatch {val} {dtype}" return tuple(dtypes.as_const(x, dtype) for x in val) diff --git a/tinygrad/mixin/math.py b/tinygrad/mixin/math.py index ef30d883d7..0f4af537a4 100644 --- a/tinygrad/mixin/math.py +++ b/tinygrad/mixin/math.py @@ -136,7 +136,7 @@ class MathMixin: return self._binop(Ops.MOD, x, reverse) def sub(self, x: Self | ConstType, reverse: bool = False): - return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x)) + return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, -self.ufix(x)) def div(self, x: Self | ConstType, reverse: bool = False): return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL)) @@ -243,7 +243,7 @@ class MathMixin: return self.alu(Ops.MAX, self.ufix(x)) def minimum(self, x: Self | ConstType): - return -(-self).maximum(-x) + return -(-self).maximum(-self.ufix(x)) def where(self, x: Self | ConstType, y: Self | ConstType): if isinstance(x, type(self)): diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 8be30035ab..96da3a1dcc 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -593,7 +593,7 @@ class AMDProgram(HCQProgram): base=self.lib_gpu.va_addr) weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec) - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False): if self.dev.sqtt_enabled: cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).sqtt_start(self.dev.sqtt_buffers).submit(self.dev) res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait) if self.dev.pmc_enabled: diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index c0f2d2e8c3..3bf02c7c22 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -73,7 +73,7 @@ class CPUProgram(HCQProgram): try: rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or WIN else 'libgcc_s.so.1') except OSError: pass - def __init__(self, dev, name:str, lib:bytes, runtimevars:dict[str, tuple[int, int]]|None=None, **kwargs): + def __init__(self, dev, name:str, lib:bytes, runtimevars:dict[str, int]|None=None, **kwargs): self.runtimevars = runtimevars or {} LVP = isinstance(dev.renderer, LVPRenderer) diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index a971ee9720..69864f689e 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -307,7 +307,7 @@ class NVProgram(HCQProgram): yield typ, param, sh.content[start_off+4:start_off+sz+4] if typ == 0x4 else sz start_off += (sz if typ == 0x4 else 0) + 4 - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False): if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.dev).slm_per_thread: raise RuntimeError(f"Too many resources requested for launch, {prod(local_size)=}, {self.max_threads=}") if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])): diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 49ee54ae75..916a8f2546 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -257,7 +257,7 @@ class QCOMProgram(HCQProgram): super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size) weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec) - def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False): + def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False): if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch") if any(g*l>mx for g,l,mx in zip(global_size, local_size, [65536, 65536, 65536])) and any(l>mx for l,mx in zip(local_size, [1024, 1024, 1024])): raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}") diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index b941208f99..b99270fd09 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -283,20 +283,21 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue] if enabled and PROFILE: dev.sig_prof_records.append((unwrap(st), unwrap(en), desc, (queue_type or type(queue)) is dev.hw_copy_queue_t)) class HCQArgsState(Generic[ProgramType]): - def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()): + def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint|None, ...]=()): self.buf, self.prg, self.bufs, self.vals = buf, prg, bufs, vals self.bind_data:list[tuple[tuple[sint, ...], MMIOInterface, str]] = [] def bind_sints_to_buf(self, *vals:sint, buf:HCQBuffer, fmt, offset=0): self.bind_data.append((vals, buf.cpu_view().view(offset=offset), fmt)) class CLikeArgsState(HCQArgsState[ProgramType]): - def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None): + def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint|None, ...]=(), prefix:list[int]|None=None): super().__init__(buf, prg, bufs, vals=vals) if prefix is not None: self.buf.cpu_view().view(size=len(prefix) * 4, fmt='I')[:] = array.array('I', prefix) self.bind_sints_to_buf(*[b.va_addr for b in bufs], buf=self.buf, fmt='Q', offset=len(prefix or []) * 4) - self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8) + assert None not in vals + self.bind_sints_to_buf(*cast(tuple[sint, ...], vals), buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8) class HCQProgram(Generic[HCQDeviceType]): def __init__(self, args_state_t:Type[HCQArgsState], dev:HCQDeviceType, name:str, kernargs_alloc_size:int, lib:bytes|None=None, base:int|None=None): @@ -307,7 +308,7 @@ class HCQProgram(Generic[HCQDeviceType]): @staticmethod def _fini(dev, buf, spec): dev.allocator.free(buf, buf.size, spec) - def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState: + def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int|None, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState: """ Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided. Args: @@ -322,7 +323,7 @@ class HCQProgram(Generic[HCQDeviceType]): return self.args_state_t(argsbuf, self, bufs, vals=vals) def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), - vals:tuple[int, ...]=(), wait:bool=False) -> float|None: + vals:tuple[int|None, ...]=(), wait:bool=False) -> float|None: """ Enqueues the program for execution with the given arguments and dimensions. diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e23918abf9..aa86d54409 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -5,7 +5,7 @@ from contextlib import ContextDecorator from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING if TYPE_CHECKING: import numpy from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate -from tinygrad.dtype import _from_np_dtype, _to_np_dtype +from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import suppress_finalizing, disable_gc @@ -112,7 +112,7 @@ class Tensor(OpMixin): __slots__ = "uop", "requires_grad", "grad" training: ClassVar[bool] = False - def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None, + def __init__(self, data:PyConst|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None, device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False): if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None _dtype:DType|None = to_dtype(dtype) if dtype is not None else None @@ -138,7 +138,7 @@ class Tensor(OpMixin): data = data.replace(src=(var.replace(src=const.src), const)) elif data is None: data = Tensor(0, device=_device, dtype=_dtype or dtypes.default_float, requires_grad=requires_grad).uop - elif isinstance(data, get_args(ConstType)): + elif isinstance(data, get_args(PyConst)): data = (UOp.unique_const if _force_unique or requires_grad else UOp.const)(_dtype or dtypes.from_py(data), data, _device) elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype) elif isinstance(data, (list, tuple)): @@ -321,7 +321,7 @@ class Tensor(OpMixin): assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" return self._buffer().as_typed_buffer(self.shape) - def item(self) -> ConstType: + def item(self) -> PyConst: """ Returns the value of this tensor as a standard Python number. @@ -334,7 +334,7 @@ class Tensor(OpMixin): return self.data()[(0,) * len(self.shape)] # NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions) - def tolist(self) -> ConstType|list[Any]: + def tolist(self) -> PyConst|list[Any]: """ Returns the value of this tensor as a nested list. Returns single value for const tensor. @@ -618,7 +618,7 @@ class Tensor(OpMixin): # ***** creation helper functions ***** @staticmethod - def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor: + def full(shape:tuple[sint, ...], fill_value:PyConst, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with the given value. @@ -748,7 +748,7 @@ class Tensor(OpMixin): stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device])) return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype) - def full_like(self, fill_value:ConstType, **kwargs) -> Tensor: + def full_like(self, fill_value:PyConst, **kwargs) -> Tensor: """ Creates a tensor with the same shape as `self`, filled with the given value. If `dtype` is not specified, the dtype of `self` is used. @@ -1268,12 +1268,12 @@ class Tensor(OpMixin): """ return self._getitem(indices) - def __setitem__(self, indices, v:Tensor|ConstType) -> None: + def __setitem__(self, indices, v:Tensor|PyConst) -> None: if isinstance(self.device, str) and self.device.startswith("DISK"): self.realize()._getitem(indices).assign(v) return # NOTE: check that setitem target is valid first - if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype) + if isinstance(v, get_args(PyConst)): v = Tensor(v, device=self.device, dtype=self.dtype) if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") self.realize() @@ -1544,7 +1544,7 @@ class Tensor(OpMixin): for i, s in enumerate(self.shape)], dim=-1) return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim)).reshape(-1, self.ndim) - def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor: + def masked_fill(self:Tensor, mask:Tensor, value:Tensor|PyConst) -> Tensor: """ Replaces `self` with `value` wherever the elements of `mask` are True. @@ -1864,7 +1864,7 @@ class Tensor(OpMixin): # https://keccak.team/keccak_specs_summary.html - def ctensor(l: Sequence[ConstType], dtype: DType = dtypes.uint64): + def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64): # TODO: contiguous is here for compile speed return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l)).contiguous() rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2] @@ -2621,7 +2621,7 @@ class Tensor(OpMixin): src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask)) return src, mask - def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor: + def scatter(self, dim:int, index:Tensor, src:Tensor|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Tensor: """ Scatters `src` values along an axis specified by `dim`. Apply `add` or `multiply` reduction operation with `reduce`. @@ -2686,7 +2686,7 @@ class Tensor(OpMixin): ``` """ src, mask = self._pre_scatter(dim, index, src) - def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b) + def _inv_mask(a:Tensor|PyConst, b:Tensor|PyConst) -> Tensor: return mask.any(-1).logical_not().where(a, b) if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0)) if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1)) if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m)) diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 7d07ab762a..fca3ea2ad8 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -4,7 +4,7 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick from dataclasses import dataclass from enum import Enum, auto from tinygrad.uop import Ops, GroupOp -from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace, ConstFloat +from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace, ConstFloat, PyConst from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic @@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1} # https://en.wikipedia.org/wiki/Identity_element -def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) +def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) # With True as the default, this matches the old symbolic behavior def resolve(x:UOp|bool, default:bool=True): @@ -720,7 +720,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure - def pop_const(self, op=Ops.ADD) -> tuple[UOp, ConstType]: + def pop_const(self, op=Ops.ADD) -> tuple[UOp, PyConst]: # NOTE: assume Invalid ALU is resolved return (self.src[0], self.src[1].arg) if self.op is op and self.src[1].op is Ops.CONST else (self, identity_element(op, self.dtype)) @staticmethod def gcd(*uops: UOp) -> UOp: @@ -740,11 +740,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def sum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self) def prod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self) @property - def vmin(self) -> ConstType: return self._min_max[0] + def vmin(self) -> PyConst: return self._min_max[0] @property - def vmax(self) -> ConstType: return self._min_max[1] + def vmax(self) -> PyConst: return self._min_max[1] @functools.cached_property - def _min_max(self) -> tuple[ConstType, ConstType]: + def _min_max(self) -> tuple[PyConst, PyConst]: if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype): (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax @@ -943,7 +943,7 @@ class UPat(OpMixin): def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True, arg=None): return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name, arg=arg) @staticmethod - def const(dtype:DType|tuple[DType, ...]|None, b:ConstType|InvalidType): return UPat(Ops.CONST, dtype=dtype, arg=b) + def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b) # lil helper def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs) @@ -1451,4 +1451,4 @@ def pyrender(ast:UOp) -> str: sint = int|UOp Variable = UOp -ConstLike = ConstType|InvalidType|Variable|tuple[ConstType|InvalidType, ...] +ConstLike = ConstType|Variable|tuple[ConstType, ...]