mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Test nllloss (#958)
* works but slow * work with NC and NCd1 it still slow * refactor * support for k dimensions * without numpy
This commit is contained in:
@@ -230,3 +230,21 @@ def MeanVarianceNormalization(input, axis=(0, 2, 3)):
|
||||
data_mean = input.mean(axis=axis, keepdim=True)
|
||||
std = ((input**2).mean(axis=axis, keepdim=True) - data_mean**2).sqrt()
|
||||
return (input - data_mean) / (std + 1e-9)
|
||||
|
||||
def NegativeLogLikelihoodLoss(input, target, weight=None, ignore_index=None, reduction="mean"):
|
||||
N, C, i_shape = input.shape[0], input.shape[1], input.shape
|
||||
t_shape = target.shape
|
||||
if len(input.shape) != 3:
|
||||
input = input.reshape((N, C, -1))
|
||||
target = target.reshape((N, -1))
|
||||
if weight is not None:
|
||||
mask = target.unsqueeze(-1) == Tensor.arange(C,dtype=dtypes.int64).repeat((N, 1, 1))
|
||||
weight = (mask * weight).sum(axis=-1)
|
||||
if ignore_index is not None:
|
||||
cond = (target == ignore_index)
|
||||
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
|
||||
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(input.shape) -2))
|
||||
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
|
||||
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
|
||||
elif reduction == "sum": return loss.sum()
|
||||
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
|
||||
Reference in New Issue
Block a user