changes for TYPED=1

This commit is contained in:
George Hotz
2025-12-20 04:35:44 +00:00
parent 86cd1e9e81
commit ca9d05efb7
10 changed files with 35 additions and 35 deletions

View File

@@ -267,7 +267,7 @@ class TestOps(unittest.TestCase):
for tor_i, ten_i in zip(tor, ten):
helper_test_op([], lambda: tor_i, lambda: ten_i)
self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError)
self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=Exception)
def test_arange(self):
helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True)
@@ -587,7 +587,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x,y: x.div(y, rounding_mode="trunc"), forward_only=True, vals=[[numerator], [denominator]])
helper_test_op(None, lambda x,y: x.div(y, rounding_mode="floor"), forward_only=True, vals=[[numerator], [denominator]])
self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, vals=[[5], [0]], expected=RuntimeError)
self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, vals=[[5], [0]], expected=Exception)
def test_div_int(self):
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5, 6, 7],[1, 2, 3]])
@@ -2989,7 +2989,7 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([(4,5,6), (4,5,6)],
lambda x,src: x.scatter_reduce(dim=0, index=b, src=src, reduce="INVALID"),
lambda x,src: x.scatter_reduce(dim=0, index=a, src=src, reduce="INVALID"),
RuntimeError)
Exception)
# dtype mismatch
self.helper_test_exception([(4,5,6), (4,5,6)],
lambda x,src: x.half().scatter_reduce(dim=0, index=b, src=src, reduce="sum"),
@@ -3068,7 +3068,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r),
lambda x,y: x.cross_entropy(y, reduction=r))
self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"),
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError)
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=Exception)
def test_cross_entropy_smoothing(self):
for ls in (0., 0.3, 0.7, 1.):
@@ -3131,7 +3131,7 @@ class TestOps(unittest.TestCase):
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), reduction=r))
self.helper_test_exception([(32,10)],
lambda x: torch.nn.functional.nll_loss(x, torch.tensor(target), reduction="typo"),
lambda x: x.nll_loss(Tensor(target), reduction="typo"), expected=ValueError)
lambda x: x.nll_loss(Tensor(target), reduction="typo"), expected=Exception)
def test_nll_loss_weight(self):
target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist()

View File

@@ -456,7 +456,7 @@ class TestTinygrad(unittest.TestCase):
def test_tensor_dtype_errors(self):
with self.assertRaises(AttributeError): Tensor([3], dtype="typo")
with self.assertRaises(AttributeError): Tensor([3], dtype=(dtypes.int,))
with self.assertRaises(Exception): Tensor([3], dtype=(dtypes.int,)) # AttributeError or TypeCheckError with TYPED=1
def test_tensor_bytes(self):
data = b"abc123"

View File

@@ -340,7 +340,8 @@ class Compiled:
# override this in your device implementation
# TODO: move this to each Device
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
def is_dtype_supported(dtype:DType|None, device:str|None=None) -> bool:
if dtype is None: return True
if dtype == dtypes.index: return False
if device is None: device = Device.DEFAULT
if dtype == dtypes.bfloat16:

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Final, ClassVar, Callable, Literal
from typing import Final, ClassVar, Callable, Literal, TYPE_CHECKING, Any
if TYPE_CHECKING: import numpy as np
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv, prod
@@ -123,7 +124,7 @@ class dtypes:
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
@staticmethod
def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType):
def as_const(val:Any, dtype:DType):
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)

View File

@@ -13,7 +13,7 @@ from tinygrad.codegen.opt import Opt
# **************** Program Creation ****************
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|tuple[Opt, ...]|None=None) -> ProgramSpec:
"""
Transform an AST into a ProgramSpec. May trigger BEAM search.

View File

@@ -249,7 +249,7 @@ class LayerNorm:
print(t.mean().item(), t.std().item())
```
"""
def __init__(self, normalized_shape:int|tuple[int, ...], eps:float=1e-5, elementwise_affine:bool=True):
def __init__(self, normalized_shape:int|tuple[int, ...]|list[int], eps:float=1e-5, elementwise_affine:bool=True):
self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1)
self.axis, self.eps = tuple(-1-i for i in range(len(self.normalized_shape))), eps
self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None

View File

@@ -84,7 +84,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No
# state dict
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Any]:
"""
Returns a `state_dict` of the object, with optional prefix.
@@ -203,7 +203,7 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
# TODO: this should use tar_extract and zip_extract
@accept_filename
def torch_load(t:Tensor) -> dict[str, Tensor]:
def torch_load(t:Tensor) -> dict[str, Any]:
"""
```python
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import cast, Callable, Type, TypeVar, Generic, Any
from typing import cast, Callable, Type, TypeVar, Generic, Any, Sequence
import contextlib, decimal, statistics, time, ctypes, array, os, struct, collections, functools
try: import fcntl # windows misses that
except ImportError: fcntl = None #type:ignore[assignment]
@@ -279,14 +279,14 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]
if enabled and PROFILE: dev.sig_prof_records.append((unwrap(st), unwrap(en), desc, (queue_type or type(queue)) is dev.hw_copy_queue_t))
class HCQArgsState(Generic[ProgramType]):
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=()):
self.buf, self.prg, self.bufs, self.vals = buf, prg, bufs, vals
self.bind_data:list[tuple[tuple[sint, ...], MMIOInterface, str]] = []
def bind_sints_to_buf(self, *vals:sint, buf:HCQBuffer, fmt, offset=0): self.bind_data.append((vals, buf.cpu_view().view(offset=offset), fmt))
class CLikeArgsState(HCQArgsState[ProgramType]):
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=(), prefix:list[int]|None=None):
super().__init__(buf, prg, bufs, vals=vals)
if prefix is not None: self.buf.cpu_view().view(size=len(prefix) * 4, fmt='I')[:] = array.array('I', prefix)
@@ -302,7 +302,7 @@ class HCQProgram(Generic[HCQDeviceType]):
@staticmethod
def _fini(dev, buf, spec): dev.allocator.free(buf, buf.size, spec)
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
def fill_kernargs(self, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
"""
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
Args:

View File

@@ -326,9 +326,7 @@ class Tensor(OpMixin):
assert self.numel() == 1, "must have one element for item"
return self.data()[(0,) * len(self.shape)]
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
def tolist(self) -> Sequence[ConstType]|ConstType:
def tolist(self) -> list|ConstType:
"""
Returns the value of this tensor as a nested list.
Returns single value for const tensor.
@@ -612,7 +610,7 @@ class Tensor(OpMixin):
# ***** creation helper functions *****
@staticmethod
def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
def full(shape:tuple[sint, ...]|int, fill_value:ConstType, **kwargs) -> Tensor:
"""
Creates a tensor with the given shape, filled with the given value.
@@ -1249,7 +1247,7 @@ class Tensor(OpMixin):
"""
return self._getitem(indices)
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
def __setitem__(self, indices, v:Tensor|ConstType|list) -> None:
if isinstance(self.device, str) and self.device.startswith("DISK"):
self.realize()._getitem(indices).assign(v)
return
@@ -1289,7 +1287,7 @@ class Tensor(OpMixin):
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
def cat(self:Tensor|tuple[Tensor, ...]|list[Tensor], *args:Tensor, dim:int=0) -> Tensor:
"""
Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
All tensors must have the same shape except in the concatenating dimension.
@@ -1309,7 +1307,7 @@ class Tensor(OpMixin):
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
return functools.reduce(Tensor.add, tensors)
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
def stack(self:Tensor|tuple[Tensor, ...]|list[Tensor], *args:Tensor, dim:int=0) -> Tensor:
"""
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
@@ -2114,7 +2112,7 @@ class Tensor(OpMixin):
return pads
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
def avg_pool2d(self, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0,
ceil_mode=False, count_include_pad=True) -> Tensor:
"""
Applies average pooling over a tensor.
@@ -2160,7 +2158,7 @@ class Tensor(OpMixin):
if not ceil_mode: return pool(self, reg_pads).mean(axis)
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
def max_pool2d(self, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0,
ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]:
"""
Applies max pooling over a tensor.
@@ -2204,7 +2202,7 @@ class Tensor(OpMixin):
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
return pooled.max(axis), spatial_sz - idx.max(axis)
def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None):
def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0, output_size=None):
"""
Performs a partial inverse of `max_pool2d` using the indices from the argmax.
@@ -2235,7 +2233,7 @@ class Tensor(OpMixin):
ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2).where(self.reshape(bs,c,1,-1), 0)).sum(3)
return ret.reshape(bs,c,*output_size)
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]|list[int]=0,
dtype:DTypeLike|None=None) -> Tensor:
"""
Applies a convolution over a tensor with a given `weight` and optional `bias`.
@@ -3614,7 +3612,7 @@ class Tensor(OpMixin):
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Tensor:
def newton_schulz(self, steps:int, params:tuple[float|int, ...], eps:float=1.0e-7) -> Tensor:
"""
Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params.

View File

@@ -46,7 +46,7 @@ def smin(*lst) -> sint: return _suop(argfix(*lst), UOp.minimum, min)
def srender(x:sint) -> str: return x.render() if isinstance(x, UOp) else str(x)
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int|float, var_vals: dict[str, int]) -> int|float: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
def range_str(u:UOp, color=False) -> str:
ret = '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]])
@@ -699,11 +699,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
uval = self.const_like(val) if isinstance(val, int) else val
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
return UOp(Ops.BIND, self.dtype, (self, uval))
def unbind(self) -> tuple[Variable, int]:
def unbind(self) -> tuple[UOp, int]:
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
return self.src[0], self.src[1].arg
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
ret:dict[Variable, int] = {}
def unbind_all(self) -> tuple[UOp, dict[UOp, int]]:
ret:dict[UOp, int] = {}
return graph_rewrite(self, pm_unbind, ctx=ret), ret
@property
def val(self) -> int: return self.unbind()[1]
@@ -712,7 +712,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
bound_var_base = set(x.src[0] for x in bound_vars)
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
def variables(self) -> list[Variable]:
def variables(self) -> list[UOp]:
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
# *** uop symbolic stuff ***
@@ -1320,7 +1320,7 @@ def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def do_unbind(ctx:dict[Variable, int], x:UOp):
def do_unbind(ctx:dict[UOp, int], x:UOp):
v,i = x.unbind()
ctx[v] = i
return v