update type annotation of _broadcasted (#2753)

input can be Tensor, float, int.
also updated scaled_dot_product_attention that might add a None to a Tensor
This commit is contained in:
chenyu
2023-12-13 19:03:14 -05:00
committed by GitHub
parent bf4165ccac
commit fc6bca7ba8

View File

@@ -727,8 +727,7 @@ class Tensor:
# ***** broadcasted binary mlops *****
# TODO: y can be int here
def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]:
def _broadcasted(self, y:Union[Tensor, float, int], reverse:bool=False) -> Tuple[Tensor, Tensor]:
if not isinstance(y, Tensor):
# make y a Tensor
if 0 in self.shape: return self, self.full_like(y)
@@ -745,26 +744,25 @@ class Tensor:
broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape))
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
# TODO: x can be int here
def _to_float(self, x:Union[Tensor, float]):
def _to_float(self, x:Union[Tensor, float, int]):
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
def add(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
x = self._to_float(x)
return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor:
def sub(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
x = self._to_float(x)
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self)
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor:
def mul(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
x = self._to_float(x)
if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
if x.__class__ is not Tensor and x == -1.0: return -self
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor:
def div(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
x = self._to_float(x)
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
def pow(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
x = self._to_float(x)
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
@@ -857,7 +855,8 @@ class Tensor:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1])
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
def binary_crossentropy(self, y:Tensor) -> Tensor:
return (-y*self.log() - (1-y)*(1-self).log()).mean()