log_perplexity metrics (#10912)

This commit is contained in:
chenyu
2025-06-21 10:44:47 -04:00
committed by GitHub
parent 0e7bd9fd03
commit 0480139def
2 changed files with 34 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
import re
import string
import re, string
from collections import Counter
from tinygrad import Tensor
def levenshtein(a, b):
n, m = len(a), len(b)
@@ -59,3 +59,11 @@ def f1_score(x, y):
p = ns / len(xt)
r = ns / len(yt)
return 2 * p * r / (p + r)
def log_perplexity(logit:Tensor, target:Tensor, ignore_index:int|None=None):
# logit has shape (n_samples, seq_len, vocab_size), target has shape (n_samples, seq_len)
assert logit.ndim == 3, logit.ndim
assert target.ndim == 2, target.ndim
assert logit.shape[:2] == target.shape, f"{logit.shape[:2]=}, {target.shape=}"
log_prob = logit.log_softmax(axis=-1)
return log_prob.transpose(1, 2).nll_loss(target, ignore_index=ignore_index)

View File

@@ -1,10 +1,10 @@
from tinygrad import Tensor
from test.external.mlperf_unet3d.dice import DiceScore
from examples.mlperf.metrics import dice_score
from examples.mlperf.metrics import dice_score, log_perplexity
import numpy as np
import torch
import unittest
import unittest, math
class ExternalTestMetrics(unittest.TestCase):
def _test_metrics(self, tinygrad_metrics, orig_metrics, pred, label, atol=1e-8, rtol=1e-7):
@@ -16,5 +16,27 @@ class ExternalTestMetrics(unittest.TestCase):
pred, label = np.random.rand(1, 3, 128, 128, 128).astype(np.float32), np.ones((1, 1, 128, 128, 128)).astype(np.uint8)
self._test_metrics(dice_score, DiceScore(), pred, label)
def test_log_perplexity(self):
# equally likely
np.testing.assert_allclose(log_perplexity(Tensor([[[1.0, 1, 1, 1]]]), Tensor([[2]])).numpy(), math.log(4))
np.testing.assert_allclose(log_perplexity(Tensor([[[1.0]*256]*32]), Tensor([[2]*32])).numpy(), math.log(256), rtol=1e-6)
# pretty correct and incorrect
np.testing.assert_allclose(log_perplexity(Tensor([[[10000., 0, 0, 0]]]), Tensor([[0]])).numpy(), 0)
np.testing.assert_allclose(log_perplexity(Tensor([[[0.0, 10000, 10000, 10000]]]), Tensor([[0]])).numpy(), 10000, rtol=1e-3)
# higher logit -> lower loss
x = Tensor([[[4.0, 3, 2, 1]]])
for i in range(x.numel()-1): self.assertLess(log_perplexity(x, Tensor([[i]])).item(), log_perplexity(x, Tensor([[i+1]])).item())
# torch eval examples
np.testing.assert_allclose(
log_perplexity(Tensor([[[0.3659, 0.7025, 0.3104], [0.0097, 0.6577, 0.1947]]]), Tensor([[2, 1]])).exp().numpy(),
2.7593, rtol=1e-5)
np.testing.assert_allclose(
log_perplexity(Tensor([[[0.3, 0.7, 0.3, 0.1], [0.5, 0.4, 0.1, 0.4],[0.1, 0.1, 0.2, 0.5]],
[[0.1, 0.6, 0.1, 0.5], [0.3, 0.7, 0.3, 0.4], [0.3, 0.7, 0.3, 0.4]]]), Tensor([[2, 1, 3], [1, 0, 1]])).exp().numpy(),
3.6216, rtol=1e-5)
np.testing.assert_allclose(
log_perplexity(Tensor([[[0.3659, 0.7025, 0.3104], [0.0097, 0.6577, 0.1947]]]), Tensor([[2, 1]]), ignore_index=1).exp().numpy(),
3.5372, rtol=1e-4)
if __name__ == '__main__':
unittest.main()