diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cdc36181d8..435620b1a8 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -155,8 +155,8 @@ class Tensor: return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) @staticmethod - def full_like(tensor, fill_value, dtype:Optional[DType]=None, **kwargs): - return Tensor.full(tensor.shape, fill_value=fill_value, dtype=tensor.dtype if dtype is None else dtype, **kwargs) + def full_like(tensor, fill_value, **kwargs): + return Tensor.full(tensor.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", tensor.dtype), device=kwargs.pop("device", tensor.device), **kwargs) @staticmethod def zeros_like(tensor, **kwargs): return Tensor.full_like(tensor, 0, **kwargs) @@ -327,11 +327,11 @@ class Tensor: idx = [i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx] sum_dim = [d+max_dim-n for n,d in enumerate(dim)] new_idx = idx[0].reshape(*[1]*dim[0], 1,*idx[0].shape, *[1]*(ret.ndim-dim[0]-1)) - arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1)) + arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1)) ret = (ret.reshape(*ret.shape[:dim[0]+1], *[1]*idx[0].ndim, *ret.shape[dim[0]+1:]) * (arange == new_idx)).sum(dim[0]) for idx_,d in zip(idx[1:],sum_dim[1:]): new_idx = idx_.reshape(*[1]*dim[0], *idx_.shape, *[1]*(ret.ndim-dim[0]-idx_.ndim)) - arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1)) + arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1)) ret = ((new_idx == arange) * ret).sum(d) if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case order = list(range(ret.ndim)) @@ -352,7 +352,7 @@ class Tensor: idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1) permarg = list(range(self.ndim)) permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]] - return ((idx == Tensor.arange(self.shape[dim])) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) + return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) def cat(self, *args, dim=0): dim = (dim + len(self.shape)) if dim < 0 else dim @@ -443,11 +443,11 @@ class Tensor: def argmax(self, axis=None, keepdim=False): if axis is None: - idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1).reshape(self.shape) + idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) return math.prod(self.shape) - idx.max() - 1 axis = axis + len(self.shape) if axis < 0 else axis m = self == self.max(axis=axis, keepdim=True) - idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) + idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1 def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim) @@ -535,8 +535,8 @@ class Tensor: @staticmethod def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) - def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype).where(self, Tensor.zeros_like(self)) - def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype).where(Tensor.zeros_like(self), self) + def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self)) + def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self) # ***** math functions (unary) ***** def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype) @@ -669,17 +669,17 @@ class Tensor: def dropout(self, p=0.5) -> Tensor: if not Tensor.training: return self - mask = (Tensor.rand(*self.shape, requires_grad=False) >= p).cast(dtypes.bool) + mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool) return self * mask * (1/(1.0 - p)) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: - if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False).tril(0).cast(dtypes.bool) + if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool) if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask) return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor: loss_mask = Y != ignore_index - y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) + y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) return self.log_softmax().mul(y).sum() / loss_mask.sum()