mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
update type annotation of _broadcasted (#2753)
input can be Tensor, float, int. also updated scaled_dot_product_attention that might add a None to a Tensor
This commit is contained in:
@@ -727,8 +727,7 @@ class Tensor:
|
||||
|
||||
# ***** broadcasted binary mlops *****
|
||||
|
||||
# TODO: y can be int here
|
||||
def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]:
|
||||
def _broadcasted(self, y:Union[Tensor, float, int], reverse:bool=False) -> Tuple[Tensor, Tensor]:
|
||||
if not isinstance(y, Tensor):
|
||||
# make y a Tensor
|
||||
if 0 in self.shape: return self, self.full_like(y)
|
||||
@@ -745,26 +744,25 @@ class Tensor:
|
||||
broadcasted_shape = tuple(max(xi, yi) for xi, yi in zip(x.shape, y.shape))
|
||||
return x.expand(broadcasted_shape), y.expand(broadcasted_shape)
|
||||
|
||||
# TODO: x can be int here
|
||||
def _to_float(self, x:Union[Tensor, float]):
|
||||
def _to_float(self, x:Union[Tensor, float, int]):
|
||||
return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \
|
||||
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
def add(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
def add(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self
|
||||
def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
def sub(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self)
|
||||
def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
def mul(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self)
|
||||
if x.__class__ is not Tensor and x == -1.0: return -self
|
||||
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
def div(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) # noqa: E501
|
||||
def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
|
||||
def pow(self, x:Union[Tensor, float, int], reverse=False) -> Tensor:
|
||||
x = self._to_float(x)
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
# simple pow identities
|
||||
@@ -857,7 +855,8 @@ class Tensor:
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
|
||||
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0)
|
||||
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
|
||||
qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1])
|
||||
return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value
|
||||
|
||||
def binary_crossentropy(self, y:Tensor) -> Tensor:
|
||||
return (-y*self.log() - (1-y)*(1-self).log()).mean()
|
||||
|
||||
Reference in New Issue
Block a user