add InvalidType to ConstType [pr] (#14373)

* add InvalidType to ConstType [pr]

TYPED=1 python test/test_tiny.py passes.
added PyConst = float|int|bool for some Tensor level input types

* hcq
This commit is contained in:
chenyu
2026-01-27 14:09:34 -05:00
committed by GitHub
parent 5b42a1357b
commit cd22ee9ed0
10 changed files with 37 additions and 35 deletions

View File

@@ -241,7 +241,7 @@ jobs:
python -m mypy --lineprecision-report .
cat lineprecision.txt
- name: Run TYPED=1
run: TYPED=1 python -c "import tinygrad"
run: CHECK_OOB=0 DEV=CPU TYPED=1 python test/test_tiny.py
unittest:
name: Unit Tests

View File

@@ -33,7 +33,8 @@ class InvalidType:
Invalid = InvalidType()
ConstType = float|int|bool
PyConst = float|int|bool
ConstType = PyConst|InvalidType
FmtStr = Literal['?', 'b', 'B', 'h', 'H', 'i', 'I', 'q', 'Q', 'e', 'f', 'd']
@@ -148,7 +149,7 @@ class dtypes:
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
@staticmethod
def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType):
def as_const(val: tuple[ConstType, ...]|ConstType, dtype:DType):
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)

View File

@@ -136,7 +136,7 @@ class MathMixin:
return self._binop(Ops.MOD, x, reverse)
def sub(self, x: Self | ConstType, reverse: bool = False):
return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, -self.ufix(x))
def div(self, x: Self | ConstType, reverse: bool = False):
return (self.ufix(x) * self.alu(Ops.RECIPROCAL)) if reverse else (self * self.ufix(x).alu(Ops.RECIPROCAL))
@@ -243,7 +243,7 @@ class MathMixin:
return self.alu(Ops.MAX, self.ufix(x))
def minimum(self, x: Self | ConstType):
return -(-self).maximum(-x)
return -(-self).maximum(-self.ufix(x))
def where(self, x: Self | ConstType, y: Self | ConstType):
if isinstance(x, type(self)):

View File

@@ -593,7 +593,7 @@ class AMDProgram(HCQProgram):
base=self.lib_gpu.va_addr)
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
if self.dev.sqtt_enabled: cast(AMDComputeQueue, self.dev.hw_compute_queue_t()).sqtt_start(self.dev.sqtt_buffers).submit(self.dev)
res = super().__call__(*bufs, global_size=global_size, local_size=local_size, vals=vals, wait=wait)
if self.dev.pmc_enabled:

View File

@@ -73,7 +73,7 @@ class CPUProgram(HCQProgram):
try: rt_lib = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32') if OSX or WIN else 'libgcc_s.so.1')
except OSError: pass
def __init__(self, dev, name:str, lib:bytes, runtimevars:dict[str, tuple[int, int]]|None=None, **kwargs):
def __init__(self, dev, name:str, lib:bytes, runtimevars:dict[str, int]|None=None, **kwargs):
self.runtimevars = runtimevars or {}
LVP = isinstance(dev.renderer, LVPRenderer)

View File

@@ -307,7 +307,7 @@ class NVProgram(HCQProgram):
yield typ, param, sh.content[start_off+4:start_off+sz+4] if typ == 0x4 else sz
start_off += (sz if typ == 0x4 else 0) + 4
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
if prod(local_size) > 1024 or self.max_threads < prod(local_size) or self.lcmem_usage > cast(NVDevice, self.dev).slm_per_thread:
raise RuntimeError(f"Too many resources requested for launch, {prod(local_size)=}, {self.max_threads=}")
if any(cur > mx for cur,mx in zip(global_size, [2147483647, 65535, 65535])) or any(cur > mx for cur,mx in zip(local_size, [1024, 1024, 64])):

View File

@@ -257,7 +257,7 @@ class QCOMProgram(HCQProgram):
super().__init__(QCOMArgsState, self.dev, self.name, kernargs_alloc_size=kernargs_alloc_size)
weakref.finalize(self, self._fini, self.dev, self.lib_gpu, buf_spec)
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int, ...]=(), wait=False):
def __call__(self, *bufs, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1), vals:tuple[int|None, ...]=(), wait=False):
if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requested for launch")
if any(g*l>mx for g,l,mx in zip(global_size, local_size, [65536, 65536, 65536])) and any(l>mx for l,mx in zip(local_size, [1024, 1024, 1024])):
raise RuntimeError(f"Invalid global/local dims {global_size=}, {local_size=}")

View File

@@ -283,20 +283,21 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]
if enabled and PROFILE: dev.sig_prof_records.append((unwrap(st), unwrap(en), desc, (queue_type or type(queue)) is dev.hw_copy_queue_t))
class HCQArgsState(Generic[ProgramType]):
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint|None, ...]=()):
self.buf, self.prg, self.bufs, self.vals = buf, prg, bufs, vals
self.bind_data:list[tuple[tuple[sint, ...], MMIOInterface, str]] = []
def bind_sints_to_buf(self, *vals:sint, buf:HCQBuffer, fmt, offset=0): self.bind_data.append((vals, buf.cpu_view().view(offset=offset), fmt))
class CLikeArgsState(HCQArgsState[ProgramType]):
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint|None, ...]=(), prefix:list[int]|None=None):
super().__init__(buf, prg, bufs, vals=vals)
if prefix is not None: self.buf.cpu_view().view(size=len(prefix) * 4, fmt='I')[:] = array.array('I', prefix)
self.bind_sints_to_buf(*[b.va_addr for b in bufs], buf=self.buf, fmt='Q', offset=len(prefix or []) * 4)
self.bind_sints_to_buf(*vals, buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8)
assert None not in vals
self.bind_sints_to_buf(*cast(tuple[sint, ...], vals), buf=self.buf, fmt='I', offset=len(prefix or []) * 4 + len(bufs) * 8)
class HCQProgram(Generic[HCQDeviceType]):
def __init__(self, args_state_t:Type[HCQArgsState], dev:HCQDeviceType, name:str, kernargs_alloc_size:int, lib:bytes|None=None, base:int|None=None):
@@ -307,7 +308,7 @@ class HCQProgram(Generic[HCQDeviceType]):
@staticmethod
def _fini(dev, buf, spec): dev.allocator.free(buf, buf.size, spec)
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int|None, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
"""
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
Args:
@@ -322,7 +323,7 @@ class HCQProgram(Generic[HCQDeviceType]):
return self.args_state_t(argsbuf, self, bufs, vals=vals)
def __call__(self, *bufs:HCQBuffer, global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]=(1,1,1),
vals:tuple[int, ...]=(), wait:bool=False) -> float|None:
vals:tuple[int|None, ...]=(), wait:bool=False) -> float|None:
"""
Enqueues the program for execution with the given arguments and dimensions.

View File

@@ -5,7 +5,7 @@ from contextlib import ContextDecorator
from typing import Any, Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic, TYPE_CHECKING
if TYPE_CHECKING: import numpy
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, is_numpy_ndarray, TracingKey, cpu_profile
from tinygrad.helpers import suppress_finalizing, disable_gc
@@ -112,7 +112,7 @@ class Tensor(OpMixin):
__slots__ = "uop", "requires_grad", "grad"
training: ClassVar[bool] = False
def __init__(self, data:ConstType|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
def __init__(self, data:PyConst|bytes|list|tuple|UOp|'numpy.ndarray'|pathlib.Path|None,
device:str|tuple|list|None=None, dtype:DTypeLike|None=None, requires_grad:bool|None=None, _force_unique:bool=False):
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
_dtype:DType|None = to_dtype(dtype) if dtype is not None else None
@@ -138,7 +138,7 @@ class Tensor(OpMixin):
data = data.replace(src=(var.replace(src=const.src), const))
elif data is None:
data = Tensor(0, device=_device, dtype=_dtype or dtypes.default_float, requires_grad=requires_grad).uop
elif isinstance(data, get_args(ConstType)):
elif isinstance(data, get_args(PyConst)):
data = (UOp.unique_const if _force_unique or requires_grad else UOp.const)(_dtype or dtypes.from_py(data), data, _device)
elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if _dtype is None else _dtype)
elif isinstance(data, (list, tuple)):
@@ -321,7 +321,7 @@ class Tensor(OpMixin):
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
return self._buffer().as_typed_buffer(self.shape)
def item(self) -> ConstType:
def item(self) -> PyConst:
"""
Returns the value of this tensor as a standard Python number.
@@ -334,7 +334,7 @@ class Tensor(OpMixin):
return self.data()[(0,) * len(self.shape)]
# NOTE: list[Any] because return type is recursive (list[list[...]] for higher dimensions)
def tolist(self) -> ConstType|list[Any]:
def tolist(self) -> PyConst|list[Any]:
"""
Returns the value of this tensor as a nested list.
Returns single value for const tensor.
@@ -618,7 +618,7 @@ class Tensor(OpMixin):
# ***** creation helper functions *****
@staticmethod
def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
def full(shape:tuple[sint, ...], fill_value:PyConst, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with the given value.
@@ -748,7 +748,7 @@ class Tensor(OpMixin):
stacked = UOp(Ops.MSTACK, dtype=dtype, src=tuple([fxn(sharded_shape, *args, device=d, dtype=dtype, **kwargs).uop for d in self.device]))
return Tensor(UOp.multi(stacked, axis=self.uop.axis), device=self.device, dtype=dtype)
def full_like(self, fill_value:ConstType, **kwargs) -> Tensor:
def full_like(self, fill_value:PyConst, **kwargs) -> Tensor:
"""
Creates a tensor with the same shape as `self`, filled with the given value.
If `dtype` is not specified, the dtype of `self` is used.
@@ -1268,12 +1268,12 @@ class Tensor(OpMixin):
"""
return self._getitem(indices)
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
def __setitem__(self, indices, v:Tensor|PyConst) -> None:
if isinstance(self.device, str) and self.device.startswith("DISK"):
self.realize()._getitem(indices).assign(v)
return
# NOTE: check that setitem target is valid first
if isinstance(v, get_args(ConstType)): v = Tensor(v, device=self.device, dtype=self.dtype)
if isinstance(v, get_args(PyConst)): v = Tensor(v, device=self.device, dtype=self.dtype)
if not isinstance(v, Tensor): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
self.realize()
@@ -1544,7 +1544,7 @@ class Tensor(OpMixin):
for i, s in enumerate(self.shape)], dim=-1)
return indices.masked_select(mask.unsqueeze(-1).expand(*mask.shape, self.ndim)).reshape(-1, self.ndim)
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor:
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|PyConst) -> Tensor:
"""
Replaces `self` with `value` wherever the elements of `mask` are True.
@@ -1864,7 +1864,7 @@ class Tensor(OpMixin):
# https://keccak.team/keccak_specs_summary.html
def ctensor(l: Sequence[ConstType], dtype: DType = dtypes.uint64):
def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64):
# TODO: contiguous is here for compile speed
return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l)).contiguous()
rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2]
@@ -2621,7 +2621,7 @@ class Tensor(OpMixin):
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
return src, mask
def scatter(self, dim:int, index:Tensor, src:Tensor|ConstType, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
def scatter(self, dim:int, index:Tensor, src:Tensor|PyConst, reduce:Literal['multiply', 'add']|None=None) -> Tensor:
"""
Scatters `src` values along an axis specified by `dim`.
Apply `add` or `multiply` reduction operation with `reduce`.
@@ -2686,7 +2686,7 @@ class Tensor(OpMixin):
```
"""
src, mask = self._pre_scatter(dim, index, src)
def _inv_mask(a:Tensor|ConstType, b:Tensor|ConstType) -> Tensor: return mask.any(-1).logical_not().where(a, b)
def _inv_mask(a:Tensor|PyConst, b:Tensor|PyConst) -> Tensor: return mask.any(-1).logical_not().where(a, b)
if reduce == "sum": return mask.where(src, 0).sum(-1).add(self if include_self else _inv_mask(self, 0))
if reduce == "prod": return mask.where(src, 1).prod(-1).mul(self if include_self else _inv_mask(self, 1))
if reduce == "amax": return mask.where(src, m := dtypes.min(src.dtype)).max(-1).maximum(self if include_self else _inv_mask(self, m))

View File

@@ -4,7 +4,7 @@ import sys, time, functools, itertools, math, operator, hashlib, os, types, pick
from dataclasses import dataclass
from enum import Enum, auto
from tinygrad.uop import Ops, GroupOp
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, InvalidType, AddrSpace, ConstFloat
from tinygrad.dtype import ConstType, ImageDType, dtypes, DType, truncate, PtrDType, least_upper_dtype, Invalid, AddrSpace, ConstFloat, PyConst
from tinygrad.helpers import ContextVar, all_int, prod, getenv, all_same, Context, partition, temp, unwrap, T, argfix, Metadata, flatten, TRACEMETA
from tinygrad.helpers import PROFILE, dedup, cdiv, cmod, diskcache_put, to_function_name, cpu_profile, TracingKey, VIZ, SPEC
from tinygrad.helpers import strip_parens, colored, ansilen, printable, panic
@@ -29,7 +29,7 @@ axis_to_pos = {AxisType.LOOP: -1, AxisType.THREAD: 0, AxisType.GLOBAL: 0, AxisTy
range_start = {Ops.BUFFERIZE: 1, Ops.REDUCE: 1, Ops.STORE: 2, Ops.WMMA: 3, Ops.END: 1}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
def identity_element(op:Ops, dt:DType) -> PyConst: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
# With True as the default, this matches the old symbolic behavior
def resolve(x:UOp|bool, default:bool=True):
@@ -720,7 +720,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1]
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
def pop_const(self, op=Ops.ADD) -> tuple[UOp, ConstType]:
def pop_const(self, op=Ops.ADD) -> tuple[UOp, PyConst]: # NOTE: assume Invalid ALU is resolved
return (self.src[0], self.src[1].arg) if self.op is op and self.src[1].op is Ops.CONST else (self, identity_element(op, self.dtype))
@staticmethod
def gcd(*uops: UOp) -> UOp:
@@ -740,11 +740,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def sum(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.or_ if self.dtype is dtypes.bool else operator.add, uops, self)
def prod(self:UOp, *uops:UOp) -> UOp: return functools.reduce(operator.and_ if self.dtype is dtypes.bool else operator.mul, uops, self)
@property
def vmin(self) -> ConstType: return self._min_max[0]
def vmin(self) -> PyConst: return self._min_max[0]
@property
def vmax(self) -> ConstType: return self._min_max[1]
def vmax(self) -> PyConst: return self._min_max[1]
@functools.cached_property
def _min_max(self) -> tuple[ConstType, ConstType]:
def _min_max(self) -> tuple[PyConst, PyConst]:
if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype):
(s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
@@ -943,7 +943,7 @@ class UPat(OpMixin):
def cvar(name:str|None=None, dtype:DType|tuple[DType, ...]|None=None, vec=True, arg=None):
return UPat((Ops.CONST,Ops.VCONST) if vec else Ops.CONST, dtype, name=name, arg=arg)
@staticmethod
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType|InvalidType): return UPat(Ops.CONST, dtype=dtype, arg=b)
def const(dtype:DType|tuple[DType, ...]|None, b:ConstType): return UPat(Ops.CONST, dtype=dtype, arg=b)
# lil helper
def f(self, op, **kwargs): return UPat(op, src=(self,), **kwargs)
@@ -1451,4 +1451,4 @@ def pyrender(ast:UOp) -> str:
sint = int|UOp
Variable = UOp
ConstLike = ConstType|InvalidType|Variable|tuple[ConstType|InvalidType, ...]
ConstLike = ConstType|Variable|tuple[ConstType, ...]