mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
changes for TYPED=1
This commit is contained in:
@@ -267,7 +267,7 @@ class TestOps(unittest.TestCase):
|
||||
for tor_i, ten_i in zip(tor, ten):
|
||||
helper_test_op([], lambda: tor_i, lambda: ten_i)
|
||||
|
||||
self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError)
|
||||
self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=Exception)
|
||||
|
||||
def test_arange(self):
|
||||
helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True)
|
||||
@@ -587,7 +587,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x,y: x.div(y, rounding_mode="trunc"), forward_only=True, vals=[[numerator], [denominator]])
|
||||
helper_test_op(None, lambda x,y: x.div(y, rounding_mode="floor"), forward_only=True, vals=[[numerator], [denominator]])
|
||||
|
||||
self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, vals=[[5], [0]], expected=RuntimeError)
|
||||
self.helper_test_exception(None, lambda x,y: x.div(y, rounding_mode="typo"), forward_only=True, vals=[[5], [0]], expected=Exception)
|
||||
|
||||
def test_div_int(self):
|
||||
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5, 6, 7],[1, 2, 3]])
|
||||
@@ -2989,7 +2989,7 @@ class TestOps(unittest.TestCase):
|
||||
self.helper_test_exception([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.scatter_reduce(dim=0, index=b, src=src, reduce="INVALID"),
|
||||
lambda x,src: x.scatter_reduce(dim=0, index=a, src=src, reduce="INVALID"),
|
||||
RuntimeError)
|
||||
Exception)
|
||||
# dtype mismatch
|
||||
self.helper_test_exception([(4,5,6), (4,5,6)],
|
||||
lambda x,src: x.half().scatter_reduce(dim=0, index=b, src=src, reduce="sum"),
|
||||
@@ -3068,7 +3068,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction=r),
|
||||
lambda x,y: x.cross_entropy(y, reduction=r))
|
||||
self.helper_test_exception([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, reduction="typo"),
|
||||
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=ValueError)
|
||||
lambda x,y: x.cross_entropy(y, reduction="typo"), expected=Exception)
|
||||
|
||||
def test_cross_entropy_smoothing(self):
|
||||
for ls in (0., 0.3, 0.7, 1.):
|
||||
@@ -3131,7 +3131,7 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: x.log_softmax(axis=1).nll_loss(Tensor(target), reduction=r))
|
||||
self.helper_test_exception([(32,10)],
|
||||
lambda x: torch.nn.functional.nll_loss(x, torch.tensor(target), reduction="typo"),
|
||||
lambda x: x.nll_loss(Tensor(target), reduction="typo"), expected=ValueError)
|
||||
lambda x: x.nll_loss(Tensor(target), reduction="typo"), expected=Exception)
|
||||
|
||||
def test_nll_loss_weight(self):
|
||||
target = np.random.randint(0, 10, (32,), dtype=np.int32).tolist()
|
||||
|
||||
@@ -456,7 +456,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
|
||||
def test_tensor_dtype_errors(self):
|
||||
with self.assertRaises(AttributeError): Tensor([3], dtype="typo")
|
||||
with self.assertRaises(AttributeError): Tensor([3], dtype=(dtypes.int,))
|
||||
with self.assertRaises(Exception): Tensor([3], dtype=(dtypes.int,)) # AttributeError or TypeCheckError with TYPED=1
|
||||
|
||||
def test_tensor_bytes(self):
|
||||
data = b"abc123"
|
||||
|
||||
@@ -340,7 +340,8 @@ class Compiled:
|
||||
# override this in your device implementation
|
||||
|
||||
# TODO: move this to each Device
|
||||
def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
def is_dtype_supported(dtype:DType|None, device:str|None=None) -> bool:
|
||||
if dtype is None: return True
|
||||
if dtype == dtypes.index: return False
|
||||
if device is None: device = Device.DEFAULT
|
||||
if dtype == dtypes.bfloat16:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Final, ClassVar, Callable, Literal
|
||||
from typing import Final, ClassVar, Callable, Literal, TYPE_CHECKING, Any
|
||||
if TYPE_CHECKING: import numpy as np
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass, fields
|
||||
from tinygrad.helpers import getenv, prod
|
||||
@@ -123,7 +124,7 @@ class dtypes:
|
||||
if x.__class__ is list or x.__class__ is tuple: return max(dtypes.from_py(xi) for xi in x) if x else dtypes.default_float
|
||||
raise RuntimeError(f"Could not infer dtype of {x} with type {type(x)}")
|
||||
@staticmethod
|
||||
def as_const(val: tuple[ConstType|InvalidType, ...]|ConstType|InvalidType, dtype:DType):
|
||||
def as_const(val:Any, dtype:DType):
|
||||
if isinstance(val, tuple):
|
||||
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
||||
return tuple(dtypes.as_const(x, dtype) for x in val)
|
||||
|
||||
@@ -13,7 +13,7 @@ from tinygrad.codegen.opt import Opt
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@track_rewrites(name=lambda *args,ret,**kwargs: TracingKey(ret.name, (ret.function_name, ret.ast), ret=ret), replay=True)
|
||||
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> ProgramSpec:
|
||||
def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|tuple[Opt, ...]|None=None) -> ProgramSpec:
|
||||
"""
|
||||
Transform an AST into a ProgramSpec. May trigger BEAM search.
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ class LayerNorm:
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __init__(self, normalized_shape:int|tuple[int, ...], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
def __init__(self, normalized_shape:int|tuple[int, ...]|list[int], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
self.normalized_shape: tuple[int, ...] = make_tuple(normalized_shape, 1)
|
||||
self.axis, self.eps = tuple(-1-i for i in range(len(self.normalized_shape))), eps
|
||||
self.weight: Tensor|None = Tensor.ones(*self.normalized_shape) if elementwise_affine else None
|
||||
|
||||
@@ -84,7 +84,7 @@ def safe_save(tensors:dict[str, Tensor], fn:str, metadata:dict[str, Any]|None=No
|
||||
|
||||
# state dict
|
||||
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Tensor]:
|
||||
def get_state_dict(obj, prefix:str='', tensor_type=Tensor) -> dict[str, Any]:
|
||||
"""
|
||||
Returns a `state_dict` of the object, with optional prefix.
|
||||
|
||||
@@ -203,7 +203,7 @@ def tar_extract(t: Tensor) -> dict[str, Tensor]:
|
||||
|
||||
# TODO: this should use tar_extract and zip_extract
|
||||
@accept_filename
|
||||
def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
def torch_load(t:Tensor) -> dict[str, Any]:
|
||||
"""
|
||||
```python
|
||||
torch_load(fn: Tensor | str | Path) -> dict[str, Tensor]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import cast, Callable, Type, TypeVar, Generic, Any
|
||||
from typing import cast, Callable, Type, TypeVar, Generic, Any, Sequence
|
||||
import contextlib, decimal, statistics, time, ctypes, array, os, struct, collections, functools
|
||||
try: import fcntl # windows misses that
|
||||
except ImportError: fcntl = None #type:ignore[assignment]
|
||||
@@ -279,14 +279,14 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Callable[[], HWQueue]
|
||||
if enabled and PROFILE: dev.sig_prof_records.append((unwrap(st), unwrap(en), desc, (queue_type or type(queue)) is dev.hw_copy_queue_t))
|
||||
|
||||
class HCQArgsState(Generic[ProgramType]):
|
||||
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=()):
|
||||
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=()):
|
||||
self.buf, self.prg, self.bufs, self.vals = buf, prg, bufs, vals
|
||||
self.bind_data:list[tuple[tuple[sint, ...], MMIOInterface, str]] = []
|
||||
|
||||
def bind_sints_to_buf(self, *vals:sint, buf:HCQBuffer, fmt, offset=0): self.bind_data.append((vals, buf.cpu_view().view(offset=offset), fmt))
|
||||
|
||||
class CLikeArgsState(HCQArgsState[ProgramType]):
|
||||
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:tuple[HCQBuffer, ...], vals:tuple[sint, ...]=(), prefix:list[int]|None=None):
|
||||
def __init__(self, buf:HCQBuffer, prg:ProgramType, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=(), prefix:list[int]|None=None):
|
||||
super().__init__(buf, prg, bufs, vals=vals)
|
||||
|
||||
if prefix is not None: self.buf.cpu_view().view(size=len(prefix) * 4, fmt='I')[:] = array.array('I', prefix)
|
||||
@@ -302,7 +302,7 @@ class HCQProgram(Generic[HCQDeviceType]):
|
||||
@staticmethod
|
||||
def _fini(dev, buf, spec): dev.allocator.free(buf, buf.size, spec)
|
||||
|
||||
def fill_kernargs(self, bufs:tuple[HCQBuffer, ...], vals:tuple[int, ...]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
|
||||
def fill_kernargs(self, bufs:Sequence[HCQBuffer], vals:Sequence[sint]=(), kernargs:HCQBuffer|None=None) -> HCQArgsState:
|
||||
"""
|
||||
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
|
||||
Args:
|
||||
|
||||
@@ -326,9 +326,7 @@ class Tensor(OpMixin):
|
||||
assert self.numel() == 1, "must have one element for item"
|
||||
return self.data()[(0,) * len(self.shape)]
|
||||
|
||||
# TODO: should be Tensor.tolist() -> Union[list[ConstType], ConstType]. The list is Sequence because mypy expects memoryview.tolist() -> list[int]
|
||||
# src: https://github.com/python/mypy/blob/release-1.6/mypy/typeshed/stdlib/builtins.pyi#L803
|
||||
def tolist(self) -> Sequence[ConstType]|ConstType:
|
||||
def tolist(self) -> list|ConstType:
|
||||
"""
|
||||
Returns the value of this tensor as a nested list.
|
||||
Returns single value for const tensor.
|
||||
@@ -612,7 +610,7 @@ class Tensor(OpMixin):
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@staticmethod
|
||||
def full(shape:tuple[sint, ...], fill_value:ConstType, **kwargs) -> Tensor:
|
||||
def full(shape:tuple[sint, ...]|int, fill_value:ConstType, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates a tensor with the given shape, filled with the given value.
|
||||
|
||||
@@ -1249,7 +1247,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return self._getitem(indices)
|
||||
|
||||
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
|
||||
def __setitem__(self, indices, v:Tensor|ConstType|list) -> None:
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
self.realize()._getitem(indices).assign(v)
|
||||
return
|
||||
@@ -1289,7 +1287,7 @@ class Tensor(OpMixin):
|
||||
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
|
||||
|
||||
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
def cat(self:Tensor|tuple[Tensor, ...]|list[Tensor], *args:Tensor, dim:int=0) -> Tensor:
|
||||
"""
|
||||
Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
|
||||
All tensors must have the same shape except in the concatenating dimension.
|
||||
@@ -1309,7 +1307,7 @@ class Tensor(OpMixin):
|
||||
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
|
||||
return functools.reduce(Tensor.add, tensors)
|
||||
|
||||
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
def stack(self:Tensor|tuple[Tensor, ...]|list[Tensor], *args:Tensor, dim:int=0) -> Tensor:
|
||||
"""
|
||||
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
|
||||
|
||||
@@ -2114,7 +2112,7 @@ class Tensor(OpMixin):
|
||||
return pads
|
||||
|
||||
# NOTE: these work for more than 2D
|
||||
def avg_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
def avg_pool2d(self, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0,
|
||||
ceil_mode=False, count_include_pad=True) -> Tensor:
|
||||
"""
|
||||
Applies average pooling over a tensor.
|
||||
@@ -2160,7 +2158,7 @@ class Tensor(OpMixin):
|
||||
if not ceil_mode: return pool(self, reg_pads).mean(axis)
|
||||
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
|
||||
|
||||
def max_pool2d(self, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
def max_pool2d(self, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0,
|
||||
ceil_mode=False, return_indices=False) -> Tensor | tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Applies max pooling over a tensor.
|
||||
@@ -2204,7 +2202,7 @@ class Tensor(OpMixin):
|
||||
idx = m * idx.pad(pads, value=dtypes.min(idx.dtype))._pool(k_, stride if stride is not None else k_, dilation)
|
||||
return pooled.max(axis), spatial_sz - idx.max(axis)
|
||||
|
||||
def max_unpool2d(self, indices:Tensor, kernel_size:tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]=0, output_size=None):
|
||||
def max_unpool2d(self, indices:Tensor, kernel_size:int|tuple[int, ...]=(2,2), stride=None, dilation=1, padding:int|tuple[int, ...]|list[int]=0, output_size=None):
|
||||
"""
|
||||
Performs a partial inverse of `max_pool2d` using the indices from the argmax.
|
||||
|
||||
@@ -2235,7 +2233,7 @@ class Tensor(OpMixin):
|
||||
ret = (indices.reshape(bs,c,1,-1)._one_hot_along_dim(prod(output_size), 2).where(self.reshape(bs,c,1,-1), 0)).sum(3)
|
||||
return ret.reshape(bs,c,*output_size)
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]=0,
|
||||
def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|tuple[int, ...]|list[int]=0,
|
||||
dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies a convolution over a tensor with a given `weight` and optional `bias`.
|
||||
@@ -3614,7 +3612,7 @@ class Tensor(OpMixin):
|
||||
nll = -self.gather(1, Y.unsqueeze(1)).squeeze(1) * masked_weight
|
||||
return nll.sum() / masked_weight.sum() if reduction == "mean" else nll._do_reduction(reduction)
|
||||
|
||||
def newton_schulz(self, steps:int, params:tuple[int, ...], eps:float=1.0e-7) -> Tensor:
|
||||
def newton_schulz(self, steps:int, params:tuple[float|int, ...], eps:float=1.0e-7) -> Tensor:
|
||||
"""
|
||||
Performs the newton-schulz algorithm for odd polynomials. The degree of the odd polynomial depends on the number of params.
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def smin(*lst) -> sint: return _suop(argfix(*lst), UOp.minimum, min)
|
||||
def srender(x:sint) -> str: return x.render() if isinstance(x, UOp) else str(x)
|
||||
|
||||
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
def sym_infer(uop: UOp|int|float, var_vals: dict[str, int]) -> int|float: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
|
||||
|
||||
def range_str(u:UOp, color=False) -> str:
|
||||
ret = '_'.join([str(x) if x >= 0 else "m"+str(-x) for x in u.arg[0:-1]])
|
||||
@@ -699,11 +699,11 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
uval = self.const_like(val) if isinstance(val, int) else val
|
||||
assert self.arg[1] <= uval.vmin and uval.vmax <= self.arg[2], f"bind {val} not in range [{self.arg[1]}, {self.arg[2]}]"
|
||||
return UOp(Ops.BIND, self.dtype, (self, uval))
|
||||
def unbind(self) -> tuple[Variable, int]:
|
||||
def unbind(self) -> tuple[UOp, int]:
|
||||
assert self.op is Ops.BIND and self.src[0].op is Ops.DEFINE_VAR and self.src[1].op is Ops.CONST, f"can't unbind {self}"
|
||||
return self.src[0], self.src[1].arg
|
||||
def unbind_all(self) -> tuple[UOp, dict[Variable, int]]:
|
||||
ret:dict[Variable, int] = {}
|
||||
def unbind_all(self) -> tuple[UOp, dict[UOp, int]]:
|
||||
ret:dict[UOp, int] = {}
|
||||
return graph_rewrite(self, pm_unbind, ctx=ret), ret
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
@@ -712,7 +712,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
bound_var_base = set(x.src[0] for x in bound_vars)
|
||||
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> list[Variable]:
|
||||
def variables(self) -> list[UOp]:
|
||||
return sorted(set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
@@ -1320,7 +1320,7 @@ def _index_to_concrete_int(u:UOp): return graph_rewrite(u.sink(), pm_lower_index
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
def do_unbind(ctx:dict[Variable, int], x:UOp):
|
||||
def do_unbind(ctx:dict[UOp, int], x:UOp):
|
||||
v,i = x.unbind()
|
||||
ctx[v] = i
|
||||
return v
|
||||
|
||||
Reference in New Issue
Block a user