sparse cat cross entropy (#1591)

* add sparse cat cross entropy

* minor fix

* add log_softmax into loss function

* add test

* update docs
This commit is contained in:
Yixiang Gao
2023-08-21 11:56:41 -05:00
committed by GitHub
parent 8d6662a741
commit f0ee850e98
5 changed files with 35 additions and 28 deletions

View File

@@ -708,6 +708,12 @@ class Tensor:
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
loss_mask = Y != ignore_index
y_counter = Tensor.arange(self.shape[-1], requires_grad=False).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum()
# ***** cast ops *****
def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self