diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ab3664a6c3..80d57f6d9b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,9 +1,13 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import math, functools, itertools, operator, time +import time +from functools import partialmethod, reduce +from itertools import accumulate, filterfalse +import operator import numpy as np -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence -from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, ImageDType +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast +from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes +from math import ceil, pi, prod, sqrt from tinygrad.lazy import Device, LazyBuffer from tinygrad.ops import LoadOps @@ -12,7 +16,7 @@ class Function: def __init__(self, device:str, *tensors:Tensor): self.device, self.parents = device, tensors self.needs_input_grad = [t.requires_grad for t in self.parents] - self.requires_grad = True if any(self.needs_input_grad) else (None if any(x is None for x in self.needs_input_grad) else False) + self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}") @@ -29,6 +33,7 @@ import tinygrad.mlops as mlops # **** start with two base classes, Tensor and Function **** class Tensor: + __slots__ = "lazydata", "requires_grad", "grad", "_ctx" __deletable__ = ('_ctx',) training: ClassVar[bool] = False no_grad: ClassVar[bool] = False @@ -37,22 +42,6 @@ class Tensor: def __init__(self, data:Union[int, float, list, tuple, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" device = Device.canonicalize(device) - if isinstance(data, (list, tuple)): - data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np) - if isinstance(data, np.ndarray): - data = LazyBuffer.fromCPU(data) - - if isinstance(data, LazyBuffer): - assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data) - elif isinstance(data, (int, float)): - lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype if dtype is not None else Tensor.default_type, device, data) - else: - raise RuntimeError(f"can't create Tensor from {data}") - - # this is set once we are here - self.lazydata: LazyBuffer = lazydata - # tensors have gradients, buffers do not self.grad: Optional[Tensor] = None @@ -62,6 +51,26 @@ class Tensor: # internal variables used for autograd graph construction self._ctx: Optional[Function] = None + if data.__class__ is LazyBuffer: + data = cast(LazyBuffer, data) # NOTE: this is a noop, it makes mypy happy + assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" + self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data) + return + + if isinstance(data, (int, float)): + self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data) + return + + if data.__class__ is list: + data = np.array(data, dtype=(dtype or Tensor.default_type).np) + + if data.__class__ is np.ndarray: + data = cast(np.ndarray, data) + data = LazyBuffer.fromCPU(data) + self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data) + return + + raise RuntimeError(f"can't create Tensor from {data}") def __repr__(self): return f"" @@ -87,11 +96,10 @@ class Tensor: def assign(self, x) -> Tensor: # TODO: this is a hack for writing to DISK if self.device.startswith("DISK"): - if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype) + if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype) self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore return self - if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype) - # NOTE: we are currently allowing assignments from different dtypes + if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}" assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") @@ -168,7 +176,7 @@ class Tensor: def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor: # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform src = Tensor.rand(2, *shape, **kwargs) - return src[0].mul(2*math.pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype) + return src[0].mul(2*pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype) @staticmethod def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low @@ -183,7 +191,7 @@ class Tensor: # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ @staticmethod def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor: - bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) + bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:])) return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) # ***** toposort and backward pass ***** @@ -219,24 +227,23 @@ class Tensor: del t0._ctx # ***** movement mlops ***** - def reshape(self, shape, *args) -> Tensor: new_shape = argfix(shape, *args) assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}" - return mlops.Reshape.apply(self, shape=tuple(-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape)) - def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args)))) + return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])) + def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))])) def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) - def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self - def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self + def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any([x != (0,0) for x in arg]) else self + def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any([x != (0,s) for x,s in zip(arg, self.shape)]) else self # ***** movement hlops ***** # NOTE: using slice is discouraged and things should migrate to pad and shrink def slice(self, arg:Sequence[Optional[Tuple[int, int]]]) -> Tensor: - arg_ = tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg)) - padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)) - return self.pad(padding).shrink(tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_))) + arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)]) + padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)]) + return self.pad(padding).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)])) # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element # - A slice i:j returns the elements with indices in [i, j) @@ -271,7 +278,7 @@ class Tensor: orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices) else: orig_slices += [slice(None)] * (len(self.shape) - num_slices) - valid_slices = list(itertools.filterfalse(lambda x: x is None, orig_slices)) + valid_slices = list(filterfalse(lambda x: x is None, orig_slices)) valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))] start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ()) new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides)) @@ -289,11 +296,11 @@ class Tensor: paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape)) padded_tensor = sliced_tensor.pad(paddings) # Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s] - new_shape = functools.reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore + new_shape = reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore reshaped_tensor = padded_tensor.reshape(new_shape) # Shrink: do [:, 0] new_shape = new_shape[::2] - final_slice = functools.reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ()) + final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ()) sliced_tensor = reshaped_tensor.shrink(final_slice) final_shape = [] it_shape = iter(new_shape) @@ -307,15 +314,14 @@ class Tensor: def cat(self, *args, dim=0): dim = (dim + len(self.shape)) if dim < 0 else dim - for y in args: - assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) + assert all([len(y.shape) == len(self.shape) and all([y.shape[i] == s for i,s in enumerate(self.shape) if i != dim]) for y in args]) catargs = [self] + list(args) assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated" - shape_cumsum = [0, *itertools.accumulate([y.shape[dim] for y in catargs])] + shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])] slc = [[(0, s) for s in self.shape] for _ in catargs] for s,k in zip(slc, shape_cumsum): s[dim] = (-k, shape_cumsum[-1]-k) - return functools.reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)]) + return reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)]) @staticmethod def stack(tensors, dim=0): @@ -360,10 +366,10 @@ class Tensor: # ***** reduce ops ***** def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False): - axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis)) + axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_] shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_] - ret = fxn.apply(self, new_shape=tuple(1 if i in axis_ else self.shape[i] for i in range(len(self.shape)))) + ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else self.shape[i] for i in range(len(self.shape))])) return ret if keepdim else ret.reshape(shape=shape) def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim) @@ -398,7 +404,7 @@ class Tensor: slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):] if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_): o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)] - e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding + e_ = [ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding xup = self.reshape(*prefix, *([1]*len(_insert_dims)), *flatten((1,i) for i in i_)).expand(*prefix, *_insert_dims, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *_insert_dims, *[e*i for e,i in zip(e_, i_)]) # NOTE: _insert_dims is required because reduces can't be merged (yet) prefix += _insert_dims @@ -432,7 +438,7 @@ class Tensor: HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1)) x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing) stride = make_pair(stride, len(HW)) - if any(s>1 for s in stride): + if any([s>1 for s in stride]): x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:])) x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride))) x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)]) @@ -479,7 +485,7 @@ class Tensor: def exp(self): return mlops.Exp.apply(self) def relu(self): return mlops.Relu.apply(self) def sin(self): return mlops.Sin.apply(self) - def cos(self): return ((math.pi/2)-self).sin() + def cos(self): return ((pi/2)-self).sin() def tan(self): return self.sin() / self.cos() @staticmethod @@ -517,23 +523,36 @@ class Tensor: def softsign(self): return self / (1 + self.abs()) # ***** broadcasted binary mlops ***** - def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor: - dtype = self.dtype if self.dtype != dtypes.bool and not isinstance(self.dtype,ImageDType) else dtypes.float32 - x,y = [Tensor(t, device=self.device, requires_grad=False, dtype=dtype) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])] - x,y = [t.reshape([1]*(max(len(x.shape), len(y.shape))-len(t.shape)) + list(t.shape)) for t in [x,y]] - shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape)) - return fxn.apply(x.expand(shape_ret), y.expand(shape_ret)) - def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self - def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self - def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self + def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor: + dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32 + x: Tensor = self + y: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other) + if reverse: x, y = y, x + if x.shape == y.shape: return fxn.apply(x, y) + + len_x_shape, len_y_shape = len(x.shape), len(y.shape) + max_shape = max(len_x_shape, len_y_shape) + + if len_x_shape != max_shape: x = x.reshape((1,) * (max_shape - len_x_shape) + x.shape) + if len_y_shape != max_shape: y = y.reshape((1,) * (max_shape - len_y_shape) + y.shape) + + shape_ret = tuple([max(x, y) for x, y in zip(x.shape, y.shape)]) + if x.shape != shape_ret: x = x.expand(shape_ret) + if y.shape != shape_ret: y = y.expand(shape_ret) + + return fxn.apply(x, y) + + def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self + def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if x.__class__ is Tensor or x or reverse else self + def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if x.__class__ is Tensor or x != 1.0 else self def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: - if not isinstance(x, Tensor) and not reverse: + if x.__class__ is not Tensor and not reverse: # simple pow identities if x == 2.0: return self*self if x == -1.0: return 1/self - return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self - def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse or x == 0.0 else self.mul(1/x) + return self._broadcasted(mlops.Pow, x, reverse) if x.__class__ is Tensor or x != 1.0 or reverse else self + def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x) def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x) @@ -577,7 +596,7 @@ class Tensor: x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight) return x.add(bias) if bias is not None else x - def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self) + def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self) def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor: y = (self - self.mean(axis, keepdim=True)) @@ -603,15 +622,15 @@ class Tensor: # ***** Convenience stuff ***** @property def ndim(self) -> int: return len(self.shape) - def numel(self) -> int: return math.prod(self.shape) + def numel(self) -> int: return prod(self.shape) def element_size(self) -> int: return self.dtype.itemsize def nbytes(self) -> int: return self.numel() * self.element_size() def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype) # register functions to move between devices for device in Device._buffers: - setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device)) - setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, device)) + setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device)) + setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device)) # if IMAGE>0 we install these replacement functions in Tensor (hack!) from tinygrad.nn.image import image_conv2d, image_dot