From ca9532ce2924fdae1d12ea0646e831205b46c839 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Fri, 8 Jul 2022 08:57:12 -0700 Subject: [PATCH] less lines, and typing found a bug --- tinygrad/llops/ops_gpu.py | 4 +- tinygrad/tensor.py | 77 +++++++++++++++------------------------ 2 files changed, 31 insertions(+), 50 deletions(-) diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index a8c182e4ed..4c8c0a4968 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -148,11 +148,9 @@ class GPUBuffer: }""" elif ret.shape != bufs[0][1].shape: # this is a reduce # reverse operation of expand, this validates inputs - st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape) - # generate loops with combined adjacent reduce axis acc = 1 - for shp,stride in st.views[-1].shape_strides[::-1]: + for shp,stride in ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape).views[-1].shape_strides[::-1]: if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};")) acc *= shp diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0294dbbc8e..1e61f1d581 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -61,22 +61,20 @@ class Tensor: self.lazydata = x.lazydata return x - def detach(self): - return Tensor(self.lazydata, device=self.device, requires_grad=False) - - def numpy(self): - return np.array(self.lazydata.toCPU()) + def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False) + def numpy(self): return np.array(self.lazydata.toCPU()) # TOOD: this keeps the legacy behavior working, remove it after refactor @property - def data(self): - return self.numpy() + def data(self): return self.numpy() - def to_(self, device): - self.device = device - if self.grad: self.grad.device = device + # TODO: if things are realized this won't work + def to_(self, device:str): + assert self.lazydata.realized is None + self.lazydata.device = device + if self.grad: self.grad.lazydata.device = device - def to(self, device): + def to(self, device:str): ret = Tensor(self.lazydata, device) if self.grad: ret.grad = self.grad.to(device) return ret @@ -86,28 +84,22 @@ class Tensor: # TODO: remove use of numpy here @classmethod - def zeros(cls, *shape, **kwargs): - return cls(np.zeros(shape, dtype=np.float32), **kwargs) + def zeros(cls, *shape, **kwargs): return cls(np.zeros(shape, dtype=np.float32), **kwargs) @classmethod - def ones(cls, *shape, **kwargs): - return cls(np.ones(shape, dtype=np.float32), **kwargs) + def ones(cls, *shape, **kwargs): return cls(np.ones(shape, dtype=np.float32), **kwargs) @classmethod - def randn(cls, *shape, **kwargs): - return cls(np.random.randn(*shape).astype(np.float32), **kwargs) + def randn(cls, *shape, **kwargs): return cls(np.random.randn(*shape).astype(np.float32), **kwargs) @classmethod - def arange(cls, stop, start=0, **kwargs): - return cls(np.arange(start=start, stop=stop).astype(np.float32), **kwargs) + def arange(cls, stop, start=0, **kwargs): return cls(np.arange(start=start, stop=stop).astype(np.float32), **kwargs) @classmethod - def uniform(cls, *shape, **kwargs): - return cls((np.random.uniform(-1., 1., size=shape)/np.sqrt(prod(shape))).astype(np.float32), **kwargs) + def uniform(cls, *shape, **kwargs): return cls((np.random.uniform(-1., 1., size=shape)/np.sqrt(prod(shape))).astype(np.float32), **kwargs) @classmethod - def eye(cls, dim, **kwargs): - return cls(np.eye(dim).astype(np.float32), **kwargs) + def eye(cls, dim, **kwargs): return cls(np.eye(dim).astype(np.float32), **kwargs) # ***** toposort and backward pass ***** @@ -159,6 +151,7 @@ class Tensor: ret = self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))]) return ret.reshape(shape=new_shape) if tuple(ret.shape) != tuple(new_shape) else ret + # TODO: there has to be a cleaner way to write this def cat(self, *args, dim=0): dim = (dim + len(self.shape)) if dim < 0 else dim for y in args: assert len(self.shape) == len(y.shape) @@ -181,10 +174,6 @@ class Tensor: ret += y.slice(arg=ts) return ret - def pad2d(self, padding:Tuple[int, ...]): - # (padding_left, padding_right, padding_top, padding_bottom) - return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]] - def matmul(x:Tensor, w:Tensor): # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2]) @@ -204,11 +193,10 @@ class Tensor: # TODO: what's the difference between dot and matmul? dot = matmul - def transpose(self, order=(1,0)): - return self.permute(order=order) - - def flatten(self, start_dim=0): - return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1])) + # (padding_left, padding_right, padding_top, padding_bottom) + def pad2d(self, padding:Tuple[int, ...]): return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]] + def transpose(self, order=(1,0)): return self.permute(order=order) + def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1])) def _canonicalize_reduce_axis(self, axis): if axis is None: axis = range(len(self.shape)) @@ -255,11 +243,8 @@ class Tensor: xup = self[:, :, :self.shape[2]-self.shape[2]%py, :self.shape[3]-self.shape[3]%px] if (self.shape[2]%py != 0) or (self.shape[3]%px != 0) else self return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px)) - def avg_pool2d(self, kernel_size=(2,2)): - return self._pool2d(*kernel_size).mean(axis=(3,5)) - - def max_pool2d(self, kernel_size=(2,2)): - return self._pool2d(*kernel_size).max(axis=(3,5)) + def avg_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).mean(axis=(3,5)) + def max_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).max(axis=(3,5)) def conv2d(self, weight, bias=None, **kwargs): ret = self._conv2d(weight, **kwargs) @@ -331,25 +316,23 @@ class Tensor: # An instantiation of the Function is the Context class Function: - def __init__(self, device, *tensors:Tensor): - self.device = device - self.parents = tensors - self.needs_input_grad = [t.requires_grad for t in tensors] - self.requires_grad = any(self.needs_input_grad) and not Tensor.no_grad + def __init__(self, device:str, *tensors:Tensor): + self.device, self.parents = device, tensors + self.needs_input_grad = [t.requires_grad for t in self.parents] + self.requires_grad = any(self.needs_input_grad) self.saved_tensors : List[Tensor] = [] def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") def backward(self, *args, **kwargs): raise NotImplementedError(f"backward not implemented for {type(self)}") - def save_for_backward(self, *x): - # NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad - self.saved_tensors.extend(x) + # NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad + def save_for_backward(self, *x): self.saved_tensors.extend(x) @classmethod def apply(cls, *x:Tensor, **kwargs): ctx = cls(x[0].device, *x) ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad) - if ctx.requires_grad: ret._ctx = ctx # used by autograd engine + if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine return ret # register functions to move between devices @@ -358,7 +341,7 @@ for device in [device for device in Device.__dict__.keys() if device[0] != "_"]: setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, Device.__dict__[device])) # register all the mlops "math" operations -def register(name, fxn): +def register(name:str, fxn:Function): def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch) for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), inspect.isclass):