Optimizations in tensor.py (#974)

* optimizations in tensor.py

* make mypy happy

* revert split of Function class
This commit is contained in:
Rayan Hatout
2023-06-14 16:44:35 +01:00
committed by GitHub
parent 0629791cbd
commit 2d567ef688

View File

@@ -1,9 +1,13 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import math, functools, itertools, operator, time
import time
from functools import partialmethod, reduce
from itertools import accumulate, filterfalse
import operator
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence
from tinygrad.helpers import prod, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, ImageDType
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from math import ceil, pi, prod, sqrt
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.ops import LoadOps
@@ -12,7 +16,7 @@ class Function:
def __init__(self, device:str, *tensors:Tensor):
self.device, self.parents = device, tensors
self.needs_input_grad = [t.requires_grad for t in self.parents]
self.requires_grad = True if any(self.needs_input_grad) else (None if any(x is None for x in self.needs_input_grad) else False)
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
@@ -29,6 +33,7 @@ import tinygrad.mlops as mlops
# **** start with two base classes, Tensor and Function ****
class Tensor:
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
@@ -37,22 +42,6 @@ class Tensor:
def __init__(self, data:Union[int, float, list, tuple, LazyBuffer, np.ndarray], device=Device.DEFAULT, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None):
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
device = Device.canonicalize(device)
if isinstance(data, (list, tuple)):
data = np.array(data, dtype=(dtype if dtype is not None else Tensor.default_type).np)
if isinstance(data, np.ndarray):
data = LazyBuffer.fromCPU(data)
if isinstance(data, LazyBuffer):
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
elif isinstance(data, (int, float)):
lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype if dtype is not None else Tensor.default_type, device, data)
else:
raise RuntimeError(f"can't create Tensor from {data}")
# this is set once we are here
self.lazydata: LazyBuffer = lazydata
# tensors have gradients, buffers do not
self.grad: Optional[Tensor] = None
@@ -62,6 +51,26 @@ class Tensor:
# internal variables used for autograd graph construction
self._ctx: Optional[Function] = None
if data.__class__ is LazyBuffer:
data = cast(LazyBuffer, data) # NOTE: this is a noop, it makes mypy happy
assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported"
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
return
if isinstance(data, (int, float)):
self.lazydata = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data)
return
if data.__class__ is list:
data = np.array(data, dtype=(dtype or Tensor.default_type).np)
if data.__class__ is np.ndarray:
data = cast(np.ndarray, data)
data = LazyBuffer.fromCPU(data)
self.lazydata = data if data.device == device else LazyBuffer.loadop(LoadOps.FROM, data.shape, data.dtype, device, src=data)
return
raise RuntimeError(f"can't create Tensor from {data}")
def __repr__(self):
return f"<Tensor {self.lazydata!r} on {self.device} with grad {(self.grad.lazydata if self.grad else None)!r}>"
@@ -87,11 +96,10 @@ class Tensor:
def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK
if self.device.startswith("DISK"):
if not isinstance(x, Tensor): x = Tensor(x, device="CPU", dtype=self.dtype)
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore
return self
if not isinstance(x, Tensor): x = Tensor(x, device=self.device, dtype=self.dtype)
# NOTE: we are currently allowing assignments from different dtypes
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
@@ -168,7 +176,7 @@ class Tensor:
def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor:
# https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
src = Tensor.rand(2, *shape, **kwargs)
return src[0].mul(2*math.pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
return src[0].mul(2*pi).cos().mul(src[1].log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype)
@staticmethod
def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor: return ((high-low) * Tensor.rand(*shape, **kwargs)) + low
@@ -183,7 +191,7 @@ class Tensor:
# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
bound = sqrt(3.0) * sqrt(2.0 / (1 + a ** 2)) / sqrt(prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)
# ***** toposort and backward pass *****
@@ -219,24 +227,23 @@ class Tensor:
del t0._ctx
# ***** movement mlops *****
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
return mlops.Reshape.apply(self, shape=tuple(-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any(x != (0,0) for x in arg) else self
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Pad.apply(self, arg=arg) if any([x != (0,0) for x in arg]) else self
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any([x != (0,s) for x,s in zip(arg, self.shape)]) else self
# ***** movement hlops *****
# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, int]]]) -> Tensor:
arg_ = tuple(a if a is not None else (0,s) for s,a in zip(self.shape, arg))
padding = tuple((max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_))
return self.pad(padding).shrink(tuple((p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)))
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
return self.pad(padding).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
# - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element
# - A slice i:j returns the elements with indices in [i, j)
@@ -271,7 +278,7 @@ class Tensor:
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
else:
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
valid_slices = list(itertools.filterfalse(lambda x: x is None, orig_slices))
valid_slices = list(filterfalse(lambda x: x is None, orig_slices))
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
@@ -289,11 +296,11 @@ class Tensor:
paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape))
padded_tensor = sliced_tensor.pad(paddings)
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
new_shape = functools.reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore
new_shape = reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore
reshaped_tensor = padded_tensor.reshape(new_shape)
# Shrink: do [:, 0]
new_shape = new_shape[::2]
final_slice = functools.reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
sliced_tensor = reshaped_tensor.shrink(final_slice)
final_shape = []
it_shape = iter(new_shape)
@@ -307,15 +314,14 @@ class Tensor:
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
for y in args:
assert len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim)
assert all([len(y.shape) == len(self.shape) and all([y.shape[i] == s for i,s in enumerate(self.shape) if i != dim]) for y in args])
catargs = [self] + list(args)
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
shape_cumsum = [0, *itertools.accumulate([y.shape[dim] for y in catargs])]
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
slc = [[(0, s) for s in self.shape] for _ in catargs]
for s,k in zip(slc, shape_cumsum):
s[dim] = (-k, shape_cumsum[-1]-k)
return functools.reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)])
return reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)])
@staticmethod
def stack(tensors, dim=0):
@@ -360,10 +366,10 @@ class Tensor:
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis))
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
ret = fxn.apply(self, new_shape=tuple(1 if i in axis_ else self.shape[i] for i in range(len(self.shape))))
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else self.shape[i] for i in range(len(self.shape))]))
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
@@ -398,7 +404,7 @@ class Tensor:
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_):
o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)]
e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
e_ = [ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding
xup = self.reshape(*prefix, *([1]*len(_insert_dims)), *flatten((1,i) for i in i_)).expand(*prefix, *_insert_dims, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *_insert_dims, *[e*i for e,i in zip(e_, i_)])
# NOTE: _insert_dims is required because reduces can't be merged (yet)
prefix += _insert_dims
@@ -432,7 +438,7 @@ class Tensor:
HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1))
x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing)
stride = make_pair(stride, len(HW))
if any(s>1 for s in stride):
if any([s>1 for s in stride]):
x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:]))
x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride)))
x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)])
@@ -479,7 +485,7 @@ class Tensor:
def exp(self): return mlops.Exp.apply(self)
def relu(self): return mlops.Relu.apply(self)
def sin(self): return mlops.Sin.apply(self)
def cos(self): return ((math.pi/2)-self).sin()
def cos(self): return ((pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
@staticmethod
@@ -517,23 +523,36 @@ class Tensor:
def softsign(self): return self / (1 + self.abs())
# ***** broadcasted binary mlops *****
def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor:
dtype = self.dtype if self.dtype != dtypes.bool and not isinstance(self.dtype,ImageDType) else dtypes.float32
x,y = [Tensor(t, device=self.device, requires_grad=False, dtype=dtype) if not isinstance(t, Tensor) else t for t in ([other,self] if reverse else [self,other])]
x,y = [t.reshape([1]*(max(len(x.shape), len(y.shape))-len(t.shape)) + list(t.shape)) for t in [x,y]]
shape_ret = tuple(max(sx, sy) for sx,sy in zip(x.shape, y.shape))
return fxn.apply(x.expand(shape_ret), y.expand(shape_ret))
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if isinstance(x, Tensor) or x != 0.0 else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if isinstance(x, Tensor) or x != 0.0 or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if isinstance(x, Tensor) or x != 1.0 else self
def _broadcasted(self, fxn:Type[Function], other:Union[Tensor, float], reverse:bool=False) -> Tensor:
dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32
x: Tensor = self
y: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other)
if reverse: x, y = y, x
if x.shape == y.shape: return fxn.apply(x, y)
len_x_shape, len_y_shape = len(x.shape), len(y.shape)
max_shape = max(len_x_shape, len_y_shape)
if len_x_shape != max_shape: x = x.reshape((1,) * (max_shape - len_x_shape) + x.shape)
if len_y_shape != max_shape: y = y.reshape((1,) * (max_shape - len_y_shape) + y.shape)
shape_ret = tuple([max(x, y) for x, y in zip(x.shape, y.shape)])
if x.shape != shape_ret: x = x.expand(shape_ret)
if y.shape != shape_ret: y = y.expand(shape_ret)
return fxn.apply(x, y)
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Add, x, reverse) if x.__class__ is Tensor or x else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Sub, x, reverse) if x.__class__ is Tensor or x or reverse else self
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Mul, x, reverse) if x.__class__ is Tensor or x != 1.0 else self
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
if not isinstance(x, Tensor) and not reverse:
if x.__class__ is not Tensor and not reverse:
# simple pow identities
if x == 2.0: return self*self
if x == -1.0: return 1/self
return self._broadcasted(mlops.Pow, x, reverse) if isinstance(x, Tensor) or x != 1.0 or reverse else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if isinstance(x, Tensor) or reverse or x == 0.0 else self.mul(1/x)
return self._broadcasted(mlops.Pow, x, reverse) if x.__class__ is Tensor or x != 1.0 or reverse else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return self._broadcasted(mlops.Div, x, reverse) if x.__class__ is Tensor or reverse or not x else self.mul(1/x)
def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x)
def maximum(self, x:Union[Tensor, float]) -> Tensor: return self._broadcasted(mlops.Maximum, x)
@@ -577,7 +596,7 @@ class Tensor:
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return functools.reduce(lambda x,f: f(x), ll, self)
def sequential(self, ll:List[Callable[[Tensor], Tensor]]): return reduce(lambda x,f: f(x), ll, self)
def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor:
y = (self - self.mean(axis, keepdim=True))
@@ -603,15 +622,15 @@ class Tensor:
# ***** Convenience stuff *****
@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> int: return math.prod(self.shape)
def numel(self) -> int: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
# register functions to move between devices
for device in Device._buffers:
setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device))
setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, device))
setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device))
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
from tinygrad.nn.image import image_conv2d, image_dot