diff --git a/test/test_ops.py b/test/test_ops.py index 6d4349a35c..0074402c10 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -267,7 +267,7 @@ class TestOps(unittest.TestCase): for tor_i, ten_i in zip(tor, ten): helper_test_op([], lambda: tor_i, lambda: ten_i) - self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError) + self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=Exception) def test_arange(self): helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True) @@ -587,7 +587,7 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x,y: x.div(y, rounding_mode="trunc"), forward_only=True, vals=[[numerator], [denominator]]) helper_test_op(None, lambda x,y: x.div(y, rounding_mode="floor"), forward_only=True, vals=[[numerator], [denominator]]) - self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, vals=[[5], [0]], expected=RuntimeError) + self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, vals=[[5], [0]], expected=Exception) def test_div_int(self): helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5, 6, 7],[1, 2, 3]]) @@ -2989,7 +2989,7 @@ class TestOps(unittest.TestCase): self.helper_test_exception([(4,5,6), (4,5,6)], lambda x,src: x.scatter_reduce(dim=0, index=b, src=src, reduce="INVALID"), lambda x,src: x.scatter_reduce(dim=0, index=a, src=src, reduce="INVALID"), - RuntimeError) + Exception) # dtype mismatch self.helper_test_exception([(4,5,6), (4,5,6)], lambda x,src: x.half().scatter_reduce(dim=0, index=b, src=src, reduce="sum"), @@ -3068,7 +3068,7 @@ class TestOps(unittest.TestCase): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r), lambda x,y: x.cross_entropy(y, reduction=r)) self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"), - lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError) + lambda x,y: x.cross_entropy(y, reduction="typo"), expected=Exception) def test_cross_entropy_smoothing(self): for ls in (0., 0.3, 0.7, 1.): @@ -3131,7 +3131,7 @@ class TestOps(unittest.TestCase): lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), reduction=r)) self.helper_test_exception([(32,10)], lambda x: torch.nn.functional.nll_loss(x, torch.tensor(target), reduction="typo"), - lambda x: x.nll_loss(Tensor(target), reduction="typo"), expected=ValueError) + lambda x: x.nll_loss(Tensor(target), reduction="typo"), expected=Exception) def test_nll_loss_weight(self): target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist() diff --git a/test/test_tensor.py b/test/test_tensor.py index 8b995d568f..4f2680dac8 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -456,7 +456,7 @@ class TestTinygrad(unittest.TestCase): def test_tensor_dtype_errors(self): with self.assertRaises(AttributeError): Tensor([3], dtype="typo") - with self.assertRaises(AttributeError): Tensor([3], dtype=(dtypes.int,)) + with self.assertRaises(Exception): Tensor([3], dtype=(dtypes.int,)) # AttributeError or TypeCheckError with TYPED=1 def test_tensor_bytes(self): data = b"abc123" diff --git a/tinygrad/device.py b/tinygrad/device.py index f8e5b5e915..4e9508ded0 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -340,7 +340,8 @@ class Compiled: # override this in your device implementation # TODO: move this to each Device -def is_dtype_supported(dtype:DType, device:str|None=None) -> bool: +def is_dtype_supported(dtype:DType|None, device:str|None=None) -> bool: + if dtype is None: return True if dtype == dtypes.index: return False if device is None: device = Device.DEFAULT if dtype == dtypes.bfloat16: diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 26cbb7fe12..dd568e170c 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -1,5 +1,6 @@ from __future__ import annotations -from typing import Final, ClassVar, Callable, Literal +from typing import Final, ClassVar, Callable, Literal, TYPE_CHECKING, Any +if TYPE_CHECKING: import numpy as np import math, struct, ctypes, functools from dataclasses import dataclass, fields from tinygrad.helpers import getenv, prod @@ -123,7 +124,7 @@ class dtypes: if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}") @staticmethod - def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType): + def as_const(val:Any, dtype:DType): if isinstance(val, tuple): assert len(val) == dtype.count, f"mismatch {val} {dtype}" return tuple(dtypes.as_const(x, dtype) for x in val) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index d2451e1c8b..dcdb9c3965 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -13,7 +13,7 @@ from tinygrad.codegen.opt import Opt # **************** Program Creation **************** @track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True) -def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec: +def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|tuple[Opt, ...]|None=None) -> ProgramSpec: """ Transform an AST into a ProgramSpec. May trigger BEAM search. diff --git a/tinygrad/nn/__init__.py b/tinygrad/nn/__init__.py index 5d5ced5c32..3234ed9a1f 100644 --- a/tinygrad/nn/__init__.py +++ b/tinygrad/nn/__init__.py @@ -249,7 +249,7 @@ class LayerNorm: print(t.mean().item(), t.std().item()) ``` """ - def __init__(self, normalized_shape:int|tuple[int, ...], eps:float=1e-5, elementwise_affine:bool=True): + def __init__(self, normalized_shape:int|tuple[int, ...]|list[int], eps:float=1e-5, elementwise_affine:bool=True): self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1) self.axis, self.eps = tuple(-1-i for i in range(len(self.normalized_shape))), eps self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 454cb2b9b3..7206ee1853 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -84,7 +84,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No # state dict -def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]: +def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Any]: """ Returns a `state_dict` of the object, with optional prefix. @@ -203,7 +203,7 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]: # TODO: this should use tar_extract and zip_extract @accept_filename -def torch_load(t:Tensor) -> dict[str, Tensor]: +def torch_load(t:Tensor) -> dict[str, Any]: """ ```python torch_load(fn: Tensor | str | Path) -> dict[str, Tensor] diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index 43c73ddf7c..c4b6223d11 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import cast, Callable, Type, TypeVar, Generic, Any +from typing import cast, Callable, Type, TypeVar, Generic, Any, Sequence import contextlib, decimal, statistics, time, ctypes, array, os, struct, collections, functools try: import fcntl # windows misses that except ImportError: fcntl = None #type:ignore[assignment] @@ -279,14 +279,14 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue] if enabled and PROFILE: dev.sig_prof_records.append((unwrap(st), unwrap(en), desc, (queue_type or type(queue)) is dev.hw_copy_queue_t)) class HCQArgsState(Generic[ProgramType]): - def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()): + def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=()): self.buf, self.prg, self.bufs, self.vals = buf, prg, bufs, vals self.bind_data:list[tuple[tuple[sint, ...], MMIOInterface, str]] = [] def bind_sints_to_buf(self, *vals:sint, buf:HCQBuffer, fmt, offset=0): self.bind_data.append((vals, buf.cpu_view().view(offset=offset), fmt)) class CLikeArgsState(HCQArgsState[ProgramType]): - def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None): + def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=(), prefix:list[int]|None=None): super().__init__(buf, prg, bufs, vals=vals) if prefix is not None: self.buf.cpu_view().view(size=len(prefix) * 4, fmt='I')[:] = array.array('I', prefix) @@ -302,7 +302,7 @@ class HCQProgram(Generic[HCQDeviceType]): @staticmethod def _fini(dev, buf, spec): dev.allocator.free(buf, buf.size, spec) - def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState: + def fill_kernargs(self, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState: """ Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided. Args: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 290a36e568..b17fe54d22 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -326,9 +326,7 @@ class Tensor(OpMixin): assert self.numel() == 1, "must have one element for item" return self.data()[(0,) * len(self.shape)] - # TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int] - # src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803 - def tolist(self) -> Sequence[ConstType]|ConstType: + def tolist(self) -> list|ConstType: """ Returns the value of this tensor as a nested list. Returns single value for const tensor. @@ -612,7 +610,7 @@ class Tensor(OpMixin): # ***** creation helper functions ***** @staticmethod - def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor: + def full(shape:tuple[sint, ...]|int, fill_value:ConstType, **kwargs) -> Tensor: """ Creates a tensor with the given shape, filled with the given value. @@ -1249,7 +1247,7 @@ class Tensor(OpMixin): """ return self._getitem(indices) - def __setitem__(self, indices, v:Tensor|ConstType) -> None: + def __setitem__(self, indices, v:Tensor|ConstType|list) -> None: if isinstance(self.device, str) and self.device.startswith("DISK"): self.realize()._getitem(indices).assign(v) return @@ -1289,7 +1287,7 @@ class Tensor(OpMixin): x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype) - def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: + def cat(self:Tensor|tuple[Tensor, ...]|list[Tensor], *args:Tensor, dim:int=0) -> Tensor: """ Concatenates self with other `Tensor` in `args` along an axis specified by `dim`. All tensors must have the same shape except in the concatenating dimension. @@ -1309,7 +1307,7 @@ class Tensor(OpMixin): for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)]) return functools.reduce(Tensor.add, tensors) - def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: + def stack(self:Tensor|tuple[Tensor, ...]|list[Tensor], *args:Tensor, dim:int=0) -> Tensor: """ Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`. @@ -2114,7 +2112,7 @@ class Tensor(OpMixin): return pads # NOTE: these work for more than 2D - def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, + def avg_pool2d(self, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0, ceil_mode=False, count_include_pad=True) -> Tensor: """ Applies average pooling over a tensor. @@ -2160,7 +2158,7 @@ class Tensor(OpMixin): if not ceil_mode: return pool(self, reg_pads).mean(axis) return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis) - def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, + def max_pool2d(self, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0, ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]: """ Applies max pooling over a tensor. @@ -2204,7 +2202,7 @@ class Tensor(OpMixin): idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation) return pooled.max(axis), spatial_sz - idx.max(axis) - def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None): + def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0, output_size=None): """ Performs a partial inverse of `max_pool2d` using the indices from the argmax. @@ -2235,7 +2233,7 @@ class Tensor(OpMixin): ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2).where(self.reshape(bs,c,1,-1), 0)).sum(3) return ret.reshape(bs,c,*output_size) - def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0, + def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]|list[int]=0, dtype:DTypeLike|None=None) -> Tensor: """ Applies a convolution over a tensor with a given `weight` and optional `bias`. @@ -3614,7 +3612,7 @@ class Tensor(OpMixin): nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction) - def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Tensor: + def newton_schulz(self, steps:int, params:tuple[float|int, ...], eps:float=1.0e-7) -> Tensor: """ Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params. diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index b939fc6fa8..65749855ad 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -46,7 +46,7 @@ def smin(*lst) -> sint: return _suop(argfix(*lst), UOp.minimum, min) def srender(x:sint) -> str: return x.render() if isinstance(x, UOp) else str(x) def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop -def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop +def sym_infer(uop: UOp|int|float, var_vals: dict[str, int]) -> int|float: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop def range_str(u:UOp, color=False) -> str: ret = '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]]) @@ -699,11 +699,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass): uval = self.const_like(val) if isinstance(val, int) else val assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]" return UOp(Ops.BIND, self.dtype, (self, uval)) - def unbind(self) -> tuple[Variable, int]: + def unbind(self) -> tuple[UOp, int]: assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}" return self.src[0], self.src[1].arg - def unbind_all(self) -> tuple[UOp, dict[Variable, int]]: - ret:dict[Variable, int] = {} + def unbind_all(self) -> tuple[UOp, dict[UOp, int]]: + ret:dict[UOp, int] = {} return graph_rewrite(self, pm_unbind, ctx=ret), ret @property def val(self) -> int: return self.unbind()[1] @@ -712,7 +712,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass): bound_var_base = set(x.src[0] for x in bound_vars) all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR]) return bound_vars.union(set([x for x in all_vars if x not in bound_var_base])) - def variables(self) -> list[Variable]: + def variables(self) -> list[UOp]: return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg) # *** uop symbolic stuff *** @@ -1320,7 +1320,7 @@ def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) _remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) -def do_unbind(ctx:dict[Variable, int], x:UOp): +def do_unbind(ctx:dict[UOp, int], x:UOp): v,i = x.unbind() ctx[v] = i return v