Test nllloss (#958)

* works but slow

* work with NC and NCd1 it still slow

* refactor

* support for k dimensions

* without numpy
This commit is contained in:
Steven Anderson
2023-06-09 09:00:29 -07:00
committed by GitHub
parent 6b1280f01c
commit c0e558b77c

View File

@@ -230,3 +230,21 @@ def MeanVarianceNormalization(input, axis=(0, 2, 3)):
data_mean = input.mean(axis=axis, keepdim=True)
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
return (input - data_mean) / (std + 1e-9)
def NegativeLogLikelihoodLoss(input, target, weight=None, ignore_index=None, reduction="mean"):
N, C, i_shape = input.shape[0], input.shape[1], input.shape
t_shape = target.shape
if len(input.shape) != 3:
input = input.reshape((N, C, -1))
target = target.reshape((N, -1))
if weight is not None:
mask = target.unsqueeze(-1) == Tensor.arange(C,dtype=dtypes.int64).repeat((N, 1, 1))
weight = (mask * weight).sum(axis=-1)
if ignore_index is not None:
cond = (target == ignore_index)
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(input.shape) -2))
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
elif reduction == "sum": return loss.sum()
return loss.reshape(t_shape) if len(i_shape) != 3 else loss