mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
@@ -155,8 +155,8 @@ class Tensor:
|
||||
return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step)
|
||||
|
||||
@staticmethod
|
||||
def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs):
|
||||
return Tensor.full(tensor.shape, fill_value=fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs)
|
||||
def full_like(tensor, fill_value, **kwargs):
|
||||
return Tensor.full(tensor.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", tensor.dtype), device=kwargs.pop("device", tensor.device), **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def zeros_like(tensor, **kwargs): return Tensor.full_like(tensor, 0, **kwargs)
|
||||
@@ -327,11 +327,11 @@ class Tensor:
|
||||
idx = [i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
|
||||
sum_dim = [d+max_dim-n for n,d in enumerate(dim)]
|
||||
new_idx = idx[0].reshape(*[1]*dim[0], 1,*idx[0].shape, *[1]*(ret.ndim-dim[0]-1))
|
||||
arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1))
|
||||
arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1))
|
||||
ret = (ret.reshape(*ret.shape[:dim[0]+1], *[1]*idx[0].ndim, *ret.shape[dim[0]+1:]) * (arange == new_idx)).sum(dim[0])
|
||||
for idx_,d in zip(idx[1:],sum_dim[1:]):
|
||||
new_idx = idx_.reshape(*[1]*dim[0], *idx_.shape, *[1]*(ret.ndim-dim[0]-idx_.ndim))
|
||||
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
|
||||
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
|
||||
ret = ((new_idx == arange) * ret).sum(d)
|
||||
if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case
|
||||
order = list(range(ret.ndim))
|
||||
@@ -352,7 +352,7 @@ class Tensor:
|
||||
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
|
||||
permarg = list(range(self.ndim))
|
||||
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
|
||||
return ((idx == Tensor.arange(self.shape[dim])) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
|
||||
return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim)
|
||||
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
@@ -443,11 +443,11 @@ class Tensor:
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
if axis is None:
|
||||
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1).reshape(self.shape)
|
||||
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
|
||||
return math.prod(self.shape) - idx.max() - 1
|
||||
axis = axis + len(self.shape) if axis < 0 else axis
|
||||
m = self == self.max(axis=axis, keepdim=True)
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
|
||||
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
@@ -535,8 +535,8 @@ class Tensor:
|
||||
|
||||
@staticmethod
|
||||
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self)
|
||||
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
|
||||
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
|
||||
@@ -669,17 +669,17 @@ class Tensor:
|
||||
|
||||
def dropout(self, p=0.5) -> Tensor:
|
||||
if not Tensor.training: return self
|
||||
mask = (Tensor.rand(*self.shape, requires_grad=False) >= p).cast(dtypes.bool)
|
||||
mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool)
|
||||
return self * mask * (1/(1.0 - p))
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False).tril(0).cast(dtypes.bool)
|
||||
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
||||
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
|
||||
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
|
||||
|
||||
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
|
||||
loss_mask = Y != ignore_index
|
||||
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
return self.log_softmax().mul(y).sum() / loss_mask.sum()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user