less lines, and typing found a bug

This commit is contained in:
George Hotz
2022-07-08 08:57:12 -07:00
parent 2035b89e54
commit ca9532ce29
2 changed files with 31 additions and 50 deletions

View File

@@ -148,11 +148,9 @@ class GPUBuffer:
}"""
elif ret.shape != bufs[0][1].shape: # this is a reduce
# reverse operation of expand, this validates inputs
st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape)
# generate loops with combined adjacent reduce axis
acc = 1
for shp,stride in st.views[-1].shape_strides[::-1]:
for shp,stride in ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape).views[-1].shape_strides[::-1]:
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
acc *= shp

View File

@@ -61,22 +61,20 @@ class Tensor:
self.lazydata = x.lazydata
return x
def detach(self):
return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self):
return np.array(self.lazydata.toCPU())
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self): return np.array(self.lazydata.toCPU())
# TOOD: this keeps the legacy behavior working, remove it after refactor
@property
def data(self):
return self.numpy()
def data(self): return self.numpy()
def to_(self, device):
self.device = device
if self.grad: self.grad.device = device
# TODO: if things are realized this won't work
def to_(self, device:str):
assert self.lazydata.realized is None
self.lazydata.device = device
if self.grad: self.grad.lazydata.device = device
def to(self, device):
def to(self, device:str):
ret = Tensor(self.lazydata, device)
if self.grad: ret.grad = self.grad.to(device)
return ret
@@ -86,28 +84,22 @@ class Tensor:
# TODO: remove use of numpy here
@classmethod
def zeros(cls, *shape, **kwargs):
return cls(np.zeros(shape, dtype=np.float32), **kwargs)
def zeros(cls, *shape, **kwargs): return cls(np.zeros(shape, dtype=np.float32), **kwargs)
@classmethod
def ones(cls, *shape, **kwargs):
return cls(np.ones(shape, dtype=np.float32), **kwargs)
def ones(cls, *shape, **kwargs): return cls(np.ones(shape, dtype=np.float32), **kwargs)
@classmethod
def randn(cls, *shape, **kwargs):
return cls(np.random.randn(*shape).astype(np.float32), **kwargs)
def randn(cls, *shape, **kwargs): return cls(np.random.randn(*shape).astype(np.float32), **kwargs)
@classmethod
def arange(cls, stop, start=0, **kwargs):
return cls(np.arange(start=start, stop=stop).astype(np.float32), **kwargs)
def arange(cls, stop, start=0, **kwargs): return cls(np.arange(start=start, stop=stop).astype(np.float32), **kwargs)
@classmethod
def uniform(cls, *shape, **kwargs):
return cls((np.random.uniform(-1., 1., size=shape)/np.sqrt(prod(shape))).astype(np.float32), **kwargs)
def uniform(cls, *shape, **kwargs): return cls((np.random.uniform(-1., 1., size=shape)/np.sqrt(prod(shape))).astype(np.float32), **kwargs)
@classmethod
def eye(cls, dim, **kwargs):
return cls(np.eye(dim).astype(np.float32), **kwargs)
def eye(cls, dim, **kwargs): return cls(np.eye(dim).astype(np.float32), **kwargs)
# ***** toposort and backward pass *****
@@ -159,6 +151,7 @@ class Tensor:
ret = self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
return ret.reshape(shape=new_shape) if tuple(ret.shape) != tuple(new_shape) else ret
# TODO: there has to be a cleaner way to write this
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
for y in args: assert len(self.shape) == len(y.shape)
@@ -181,10 +174,6 @@ class Tensor:
ret += y.slice(arg=ts)
return ret
def pad2d(self, padding:Tuple[int, ...]):
# (padding_left, padding_right, padding_top, padding_bottom)
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
def matmul(x:Tensor, w:Tensor):
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2])
@@ -204,11 +193,10 @@ class Tensor:
# TODO: what's the difference between dot and matmul?
dot = matmul
def transpose(self, order=(1,0)):
return self.permute(order=order)
def flatten(self, start_dim=0):
return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
# (padding_left, padding_right, padding_top, padding_bottom)
def pad2d(self, padding:Tuple[int, ...]): return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
def transpose(self, order=(1,0)): return self.permute(order=order)
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
def _canonicalize_reduce_axis(self, axis):
if axis is None: axis = range(len(self.shape))
@@ -255,11 +243,8 @@ class Tensor:
xup = self[:, :, :self.shape[2]-self.shape[2]%py, :self.shape[3]-self.shape[3]%px] if (self.shape[2]%py != 0) or (self.shape[3]%px != 0) else self
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
def avg_pool2d(self, kernel_size=(2,2)):
return self._pool2d(*kernel_size).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2)):
return self._pool2d(*kernel_size).max(axis=(3,5))
def avg_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).mean(axis=(3,5))
def max_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).max(axis=(3,5))
def conv2d(self, weight, bias=None, **kwargs):
ret = self._conv2d(weight, **kwargs)
@@ -331,25 +316,23 @@ class Tensor:
# An instantiation of the Function is the Context
class Function:
def __init__(self, device, *tensors:Tensor):
self.device = device
self.parents = tensors
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = any(self.needs_input_grad) and not Tensor.no_grad
def __init__(self, device:str, *tensors:Tensor):
self.device, self.parents = device, tensors
self.needs_input_grad = [t.requires_grad for t in self.parents]
self.requires_grad = any(self.needs_input_grad)
self.saved_tensors : List[Tensor] = []
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise NotImplementedError(f"backward not implemented for {type(self)}")
def save_for_backward(self, *x):
# NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad
self.saved_tensors.extend(x)
# NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad
def save_for_backward(self, *x): self.saved_tensors.extend(x)
@classmethod
def apply(cls, *x:Tensor, **kwargs):
ctx = cls(x[0].device, *x)
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
if ctx.requires_grad: ret._ctx = ctx # used by autograd engine
if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
return ret
# register functions to move between devices
@@ -358,7 +341,7 @@ for device in [device for device in Device.__dict__.keys() if device[0] != "_"]:
setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, Device.__dict__[device]))
# register all the mlops "math" operations
def register(name, fxn):
def register(name:str, fxn:Function):
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), inspect.isclass):