From fc6bca7ba8fc2c8cfc6c7aecd3e7d0a6f72a785a Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 13 Dec 2023 19:03:14 -0500 Subject: [PATCH] update type annotation of _broadcasted (#2753) input can be Tensor, float, int. also updated scaled_dot_product_attention that might add a None to a Tensor --- tinygrad/tensor.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2ed79b8c57..762b64067e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -727,8 +727,7 @@ class Tensor: # ***** broadcasted binary mlops ***** - # TODO: y can be int here - def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]: + def _broadcasted(self, y:Union[Tensor, float, int], reverse:bool=False) -> Tuple[Tensor, Tensor]: if not isinstance(y, Tensor): # make y a Tensor if 0 in self.shape: return self, self.full_like(y) @@ -745,26 +744,25 @@ class Tensor: broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape)) return x.expand(broadcasted_shape), y.expand(broadcasted_shape) - # TODO: x can be int here - def _to_float(self, x:Union[Tensor, float]): + def _to_float(self, x:Union[Tensor, float, int]): return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \ and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x - def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: + def add(self, x:Union[Tensor, float, int], reverse=False) -> Tensor: x = self._to_float(x) return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self - def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: + def sub(self, x:Union[Tensor, float, int], reverse=False) -> Tensor: x = self._to_float(x) return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self) - def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: + def mul(self, x:Union[Tensor, float, int], reverse=False) -> Tensor: x = self._to_float(x) if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self) if x.__class__ is not Tensor and x == -1.0: return -self return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self - def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: + def div(self, x:Union[Tensor, float, int], reverse=False) -> Tensor: x = self._to_float(x) return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501 - def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: + def pow(self, x:Union[Tensor, float, int], reverse=False) -> Tensor: x = self._to_float(x) if not isinstance(x, Tensor) and not reverse: # simple pow identities @@ -857,7 +855,8 @@ class Tensor: assert all_int(self.shape), f"does not support symbolic shape {self.shape}" if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool) if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0) - return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value + qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value def binary_crossentropy(self, y:Tensor) -> Tensor: return (-y*self.log() - (1-y)*(1-self).log()).mean()