mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
less lines, and typing found a bug
This commit is contained in:
@@ -148,11 +148,9 @@ class GPUBuffer:
|
||||
}"""
|
||||
elif ret.shape != bufs[0][1].shape: # this is a reduce
|
||||
# reverse operation of expand, this validates inputs
|
||||
st = ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape)
|
||||
|
||||
# generate loops with combined adjacent reduce axis
|
||||
acc = 1
|
||||
for shp,stride in st.views[-1].shape_strides[::-1]:
|
||||
for shp,stride in ShapeTracker(ret.shape).movement_op(MovementOps.EXPAND, bufs[0][1].shape).views[-1].shape_strides[::-1]:
|
||||
if stride == 0: loop.append((f"for (int axis_{len(loop)} = 0; axis_{len(loop)} < {shp}; axis_{len(loop)}++) {{", f"idx += {acc}; }} idx -= {shp*acc};"))
|
||||
acc *= shp
|
||||
|
||||
|
||||
@@ -61,22 +61,20 @@ class Tensor:
|
||||
self.lazydata = x.lazydata
|
||||
return x
|
||||
|
||||
def detach(self):
|
||||
return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
|
||||
def numpy(self):
|
||||
return np.array(self.lazydata.toCPU())
|
||||
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
def numpy(self): return np.array(self.lazydata.toCPU())
|
||||
|
||||
# TOOD: this keeps the legacy behavior working, remove it after refactor
|
||||
@property
|
||||
def data(self):
|
||||
return self.numpy()
|
||||
def data(self): return self.numpy()
|
||||
|
||||
def to_(self, device):
|
||||
self.device = device
|
||||
if self.grad: self.grad.device = device
|
||||
# TODO: if things are realized this won't work
|
||||
def to_(self, device:str):
|
||||
assert self.lazydata.realized is None
|
||||
self.lazydata.device = device
|
||||
if self.grad: self.grad.lazydata.device = device
|
||||
|
||||
def to(self, device):
|
||||
def to(self, device:str):
|
||||
ret = Tensor(self.lazydata, device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
@@ -86,28 +84,22 @@ class Tensor:
|
||||
# TODO: remove use of numpy here
|
||||
|
||||
@classmethod
|
||||
def zeros(cls, *shape, **kwargs):
|
||||
return cls(np.zeros(shape, dtype=np.float32), **kwargs)
|
||||
def zeros(cls, *shape, **kwargs): return cls(np.zeros(shape, dtype=np.float32), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def ones(cls, *shape, **kwargs):
|
||||
return cls(np.ones(shape, dtype=np.float32), **kwargs)
|
||||
def ones(cls, *shape, **kwargs): return cls(np.ones(shape, dtype=np.float32), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def randn(cls, *shape, **kwargs):
|
||||
return cls(np.random.randn(*shape).astype(np.float32), **kwargs)
|
||||
def randn(cls, *shape, **kwargs): return cls(np.random.randn(*shape).astype(np.float32), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def arange(cls, stop, start=0, **kwargs):
|
||||
return cls(np.arange(start=start, stop=stop).astype(np.float32), **kwargs)
|
||||
def arange(cls, stop, start=0, **kwargs): return cls(np.arange(start=start, stop=stop).astype(np.float32), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def uniform(cls, *shape, **kwargs):
|
||||
return cls((np.random.uniform(-1., 1., size=shape)/np.sqrt(prod(shape))).astype(np.float32), **kwargs)
|
||||
def uniform(cls, *shape, **kwargs): return cls((np.random.uniform(-1., 1., size=shape)/np.sqrt(prod(shape))).astype(np.float32), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def eye(cls, dim, **kwargs):
|
||||
return cls(np.eye(dim).astype(np.float32), **kwargs)
|
||||
def eye(cls, dim, **kwargs): return cls(np.eye(dim).astype(np.float32), **kwargs)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
@@ -159,6 +151,7 @@ class Tensor:
|
||||
ret = self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
|
||||
return ret.reshape(shape=new_shape) if tuple(ret.shape) != tuple(new_shape) else ret
|
||||
|
||||
# TODO: there has to be a cleaner way to write this
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
for y in args: assert len(self.shape) == len(y.shape)
|
||||
@@ -181,10 +174,6 @@ class Tensor:
|
||||
ret += y.slice(arg=ts)
|
||||
return ret
|
||||
|
||||
def pad2d(self, padding:Tuple[int, ...]):
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
|
||||
|
||||
def matmul(x:Tensor, w:Tensor):
|
||||
# NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
|
||||
bs, groups = prod(x.shape[0:-2]), prod(w.shape[0:-2])
|
||||
@@ -204,11 +193,10 @@ class Tensor:
|
||||
# TODO: what's the difference between dot and matmul?
|
||||
dot = matmul
|
||||
|
||||
def transpose(self, order=(1,0)):
|
||||
return self.permute(order=order)
|
||||
|
||||
def flatten(self, start_dim=0):
|
||||
return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
|
||||
# (padding_left, padding_right, padding_top, padding_bottom)
|
||||
def pad2d(self, padding:Tuple[int, ...]): return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
|
||||
def transpose(self, order=(1,0)): return self.permute(order=order)
|
||||
def flatten(self, start_dim=0): return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
|
||||
|
||||
def _canonicalize_reduce_axis(self, axis):
|
||||
if axis is None: axis = range(len(self.shape))
|
||||
@@ -255,11 +243,8 @@ class Tensor:
|
||||
xup = self[:, :, :self.shape[2]-self.shape[2]%py, :self.shape[3]-self.shape[3]%px] if (self.shape[2]%py != 0) or (self.shape[3]%px != 0) else self
|
||||
return xup.reshape(shape=(xup.shape[0], xup.shape[1], xup.shape[2]//py, py, xup.shape[3]//px, px))
|
||||
|
||||
def avg_pool2d(self, kernel_size=(2,2)):
|
||||
return self._pool2d(*kernel_size).mean(axis=(3,5))
|
||||
|
||||
def max_pool2d(self, kernel_size=(2,2)):
|
||||
return self._pool2d(*kernel_size).max(axis=(3,5))
|
||||
def avg_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).mean(axis=(3,5))
|
||||
def max_pool2d(self, kernel_size=(2,2)): return self._pool2d(*kernel_size).max(axis=(3,5))
|
||||
|
||||
def conv2d(self, weight, bias=None, **kwargs):
|
||||
ret = self._conv2d(weight, **kwargs)
|
||||
@@ -331,25 +316,23 @@ class Tensor:
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
def __init__(self, device, *tensors:Tensor):
|
||||
self.device = device
|
||||
self.parents = tensors
|
||||
self.needs_input_grad = [t.requires_grad for t in tensors]
|
||||
self.requires_grad = any(self.needs_input_grad) and not Tensor.no_grad
|
||||
def __init__(self, device:str, *tensors:Tensor):
|
||||
self.device, self.parents = device, tensors
|
||||
self.needs_input_grad = [t.requires_grad for t in self.parents]
|
||||
self.requires_grad = any(self.needs_input_grad)
|
||||
self.saved_tensors : List[Tensor] = []
|
||||
|
||||
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
||||
def backward(self, *args, **kwargs): raise NotImplementedError(f"backward not implemented for {type(self)}")
|
||||
|
||||
def save_for_backward(self, *x):
|
||||
# NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad
|
||||
self.saved_tensors.extend(x)
|
||||
# NOTE: it doesn't hurt to save this since the ctx will be freed fast without grad
|
||||
def save_for_backward(self, *x): self.saved_tensors.extend(x)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *x:Tensor, **kwargs):
|
||||
ctx = cls(x[0].device, *x)
|
||||
ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad)
|
||||
if ctx.requires_grad: ret._ctx = ctx # used by autograd engine
|
||||
if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine
|
||||
return ret
|
||||
|
||||
# register functions to move between devices
|
||||
@@ -358,7 +341,7 @@ for device in [device for device in Device.__dict__.keys() if device[0] != "_"]:
|
||||
setattr(Tensor, f"{device.lower()}_", functools.partialmethod(Tensor.to_, Device.__dict__[device]))
|
||||
|
||||
# register all the mlops "math" operations
|
||||
def register(name, fxn):
|
||||
def register(name:str, fxn:Function):
|
||||
def dispatch(*x, **kwargs): return fxn.apply(*x, **kwargs) # TODO: there's probably a very pythonic thing to replace this with
|
||||
setattr(Tensor, "_"+name if (getattr(Tensor, name, None) is not None) else name, dispatch)
|
||||
for name, cls in inspect.getmembers(importlib.import_module('tinygrad.mlops'), inspect.isclass):
|
||||
|
||||
Reference in New Issue
Block a user