This commit is contained in:
geohotstan
2023-08-23 11:51:05 -07:00
committed by GitHub
parent b57c374164
commit 484708da87

View File

@@ -155,8 +155,8 @@ class Tensor:
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
@staticmethod
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
return Tensor.full(tensor.shape, fill_value=fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs)
def full_like(tensor, fill_value, **kwargs):
return Tensor.full(tensor.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", tensor.dtype), device=kwargs.pop("device", tensor.device), **kwargs)
@staticmethod
def zeros_like(tensor, **kwargs): return Tensor.full_like(tensor, 0, **kwargs)
@@ -327,11 +327,11 @@ class Tensor:
idx = [i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
sum_dim = [d+max_dim-n for n,d in enumerate(dim)]
new_idx = idx[0].reshape(*[1]*dim[0], 1,*idx[0].shape, *[1]*(ret.ndim-dim[0]-1))
arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1))
arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1))
ret = (ret.reshape(*ret.shape[:dim[0]+1], *[1]*idx[0].ndim, *ret.shape[dim[0]+1:]) * (arange == new_idx)).sum(dim[0])
for idx_,d in zip(idx[1:],sum_dim[1:]):
new_idx = idx_.reshape(*[1]*dim[0], *idx_.shape, *[1]*(ret.ndim-dim[0]-idx_.ndim))
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
ret = ((new_idx == arange) * ret).sum(d)
if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case
order = list(range(ret.ndim))
@@ -352,7 +352,7 @@ class Tensor:
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
permarg = list(range(self.ndim))
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
return ((idx == Tensor.arange(self.shape[dim])) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
@@ -443,11 +443,11 @@ class Tensor:
def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1).reshape(self.shape)
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
return math.prod(self.shape) - idx.max() - 1
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
@@ -535,8 +535,8 @@ class Tensor:
@staticmethod
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
# ***** math functions (unary) *****
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
@@ -669,17 +669,17 @@ class Tensor:
def dropout(self, p=0.5) -> Tensor:
if not Tensor.training: return self
mask = (Tensor.rand(*self.shape, requires_grad=False) >= p).cast(dtypes.bool)
mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool)
return self * mask * (1/(1.0 - p))
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False).tril(0).cast(dtypes.bool)
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
loss_mask = Y != ignore_index
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum()