Add mlperf RNN-T model (#782)

* feat: initial rnn-t

* feat: working with BS>1

* feat: add lstm test

* feat: test passing hidden

* clean: cleanup

* feat: specify start

* feat: way faster lstm & model

* fix: default batch size

* feat: optimization

* fix: fix metrics

* fix: fix feature splicing

* feat: cleaner stacktime

* clean: remove unused import

* clean: remove extra prints

* fix: fix tests and happy llvm

* feat: have the librispeech dataset in its own dir

* clean: unused variable

* feat: no longer need numpy for the embedding + slightly more memory efficient lstm

* fix: forgot to remove something that broke tests

* feat: use relative paths

* feat: even faster

* feat: remove pointless transposes in StackTime

* fix: correct forward

* feat: switch to soundfile for loading and fix some leaks

* feat: add comment about initial dataset setup

* feat: jit more things

* feat: default batch size back to 1

larger than 1 is broken again :(
and even in the reference implementation it gives worse results
This commit is contained in:
wozeparrot
2023-05-25 03:41:21 -04:00
committed by GitHub
parent b258af117a
commit 01ae45a43c
7 changed files with 405 additions and 0 deletions

1
.gitignore vendored
View File

@@ -21,3 +21,4 @@ disassemblers/applegpu
disassemblers/cuda_ioctl_sniffer
*.prof
datasets/cifar-10-python.tar.gz
datasets/librispeech/

82
datasets/librispeech.py Normal file
View File

@@ -0,0 +1,82 @@
import json
import pathlib
import numpy as np
import librosa
import soundfile
"""
The dataset has to be downloaded manually from https://www.openslr.org/12/ and put in `datasets/librispeech`.
For mlperf validation the dev-clean dataset is used.
Then all the flacs have to be converted to wav using something like:
```fish
for file in **/*.flac; ffmpeg -i $file -ar 16k "$(dirname $file)/$(basename $file .flac).wav"; end
```
Then this [file](https://github.com/mlcommons/inference/blob/master/speech_recognition/rnnt/dev-clean-wav.json) has to also be put in `datasets/librispeech`.
"""
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/librispeech"
with open(BASEDIR / "dev-clean-wav.json") as f:
ci = json.load(f)
FILTER_BANK = np.expand_dims(librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0)
WINDOW = librosa.filters.get_window("hann", 320)
def feature_extract(x, x_lens):
x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32)
# pre-emphasis
x = np.concatenate((np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1)
# stft
x = librosa.stft(x, n_fft=512, window=WINDOW, hop_length=160, win_length=320, center=True, pad_mode="reflect")
x = np.stack((x.real, x.imag), axis=-1)
# power spectrum
x = (x**2).sum(-1)
# mel filter bank
x = np.matmul(FILTER_BANK, x)
# log
x = np.log(x + 1e-20)
# feature splice
seq = [x]
for i in range(1, 3):
tmp = np.zeros_like(x)
tmp[:, :, :-i] = x[:, :, i:]
seq.append(tmp)
features = np.concatenate(seq, axis=1)[:, :, ::3]
# normalize
features_mean = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32)
features_std = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32)
for i in range(features.shape[0]):
features_mean[i, :] = features[i, :, :x_lens[i]].mean(axis=1)
features_std[i, :] = features[i, :, :x_lens[i]].std(axis=1, ddof=1)
features_std += 1e-5
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(features_std, 2)
return features.transpose(2, 0, 1), x_lens.astype(np.float32)
def load_wav(file):
sample = soundfile.read(file)[0].astype(np.float32)
return sample, sample.shape[0]
def iterate(bs=1, start=0):
print(f"there are {len(ci)} samples in the dataset")
for i in range(start, len(ci), bs):
samples, sample_lens = zip(*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]])
samples = list(samples)
# pad to same length
max_len = max(sample_lens)
for j in range(len(samples)):
samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant")
samples, sample_lens = np.array(samples), np.array(sample_lens)
yield feature_extract(samples, sample_lens), np.array([v["transcript"] for v in ci[i : i + bs]])
if __name__ == "__main__":
X, Y = next(iterate())
print(X[0].shape, Y.shape)

View File

@@ -0,0 +1,25 @@
def levenshtein(a, b):
n, m = len(a), len(b)
if n > m:
a, b, n, m = b, a, m, n
current = list(range(n + 1))
for i in range(1, m + 1):
previous, current = current, [i] + [0] * n
for j in range(1, n + 1):
add, delete = previous[j] + 1, current[j - 1] + 1
change = previous[j - 1]
if a[j - 1] != b[i - 1]:
change = change + 1
current[j] = min(add, delete, change)
return current[n]
def word_error_rate(x, y):
scores = words = 0
for h, r in zip(x, y):
h_list = h.split()
r_list = r.split()
words += len(r_list)
scores += levenshtein(h_list, r_list)
return float(scores) / words, float(scores), words

View File

@@ -43,4 +43,30 @@ if __name__ == "__main__":
print(f"****** {n}/{d} {n*100.0/d:.2f}%")
st = time.perf_counter()
# RNN-T
from models.rnnt import RNNT
mdl = RNNT()
mdl.load_from_pretrained()
from datasets.librispeech import iterate
from examples.mlperf.metrics import word_error_rate
LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
c = 0
scores = 0
words = 0
st = time.perf_counter()
for X, Y in iterate():
mt = time.perf_counter()
tt = mdl.decode(Tensor(X[0]), Tensor([X[1]]))
et = time.perf_counter()
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
for n, t in enumerate(tt):
tnp = np.array(t)
_, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]])
scores += scores_
words += words_
c += len(tt)
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
st = time.perf_counter()

View File

@@ -29,5 +29,11 @@ if __name__ == "__main__":
test_model(mdl, img)
# RNNT
from models.rnnt import RNNT
mdl = RNNT()
mdl.load_from_pretrained()
x = Tensor.randn(220, 1, 240)
y = Tensor.randn(1, 220)
test_model(mdl, x, y)
# BERT-large

218
models/rnnt.py Normal file
View File

@@ -0,0 +1,218 @@
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.nn import Linear
import numpy as np
from extra.utils import download_file
from pathlib import Path
class RNNT:
def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32):
self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout)
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout)
@TinyJit
def __call__(self, x, y, hc=None):
f, _ = self.encoder(x, None)
g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False))
out = self.joint(f, g)
return out.realize()
def decode(self, x, x_lens):
logits, logit_lens = self.encoder(x, x_lens)
outputs = []
for b in range(logits.shape[0]):
inseq = logits[b, :, :].unsqueeze(1)
logit_len = logit_lens[b]
seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item()))
outputs.append(seq)
return outputs
def _greedy_decode(self, logits, logit_len):
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
labels = []
label = Tensor.zeros(1, 1, requires_grad=False)
mask = Tensor.zeros(1, requires_grad=False)
for time_idx in range(logit_len):
logit = logits[time_idx, :, :].unsqueeze(0)
not_blank = True
added = 0
while not_blank and added < 30:
if len(labels) > 0:
mask = (mask + 1).clip(0, 1)
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
k = np.argmax(jhc[0, 0, :29].numpy(), axis=0)
not_blank = k != 28
if not_blank:
labels.append(k)
hc = jhc[:, :, 29:] + 1 - 1
added += 1
return labels
@TinyJit
def _pred_joint(self, logit, label, hc, mask):
g, hc = self.prediction(label, hc, mask)
j = self.joint(logit, g)[0]
j = j.pad(((0, 1), (0, 1), (0, 0)))
out = j.cat(hc, dim=2)
return out.realize()
def load_from_pretrained(self):
fn = Path(__file__).parent.parent / "weights/rnnt.pt"
download_file("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
import torch
with open(fn, "rb") as f:
state_dict = torch.load(f, map_location="cpu")["state_dict"]
# encoder
for i in range(2):
self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy())
self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy())
self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy())
self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy())
for i in range(3):
self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy())
self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy())
self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy())
self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy())
# prediction
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
for i in range(2):
self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy())
self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy())
self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy())
self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy())
# joint
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy())
self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy())
self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy())
class LSTMCell:
def __init__(self, input_size, hidden_size, dropout):
self.dropout = dropout
self.weights_ih = Tensor.uniform(hidden_size * 4, input_size)
self.bias_ih = Tensor.uniform(hidden_size * 4)
self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size)
self.bias_hh = Tensor.uniform(hidden_size * 4)
def __call__(self, x, hc):
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh)
i, f, g, o = gates.chunk(4, 1)
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
c = (f * hc[x.shape[0]:]) + (i * g)
h = (o * c.tanh()).dropout(self.dropout)
return Tensor.cat(h, c).realize()
class LSTM:
def __init__(self, input_size, hidden_size, layers, dropout):
self.input_size = input_size
self.hidden_size = hidden_size
self.layers = layers
self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)]
def __call__(self, x, hc):
@TinyJit
def _do_step(x_, hc_):
return self.do_step(x_, hc_)
if hc is None:
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
output = None
for t in range(x.shape[0]):
hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this?
if output is None:
output = hc[-1:, :x.shape[1]]
else:
output = output.cat(hc[-1:, :x.shape[1]], dim=0).realize()
return output, hc
def do_step(self, x, hc):
new_hc = [x]
for i, cell in enumerate(self.cells):
new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i]))
return Tensor.stack(new_hc[1:]).realize()
class StackTime:
def __init__(self, factor):
self.factor = factor
def __call__(self, x, x_lens):
x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0)))
x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor)
return x, x_lens / self.factor if x_lens is not None else None
class Encoder:
def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout):
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
self.stack_time = StackTime(stack_time_factor)
self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout)
def __call__(self, x, x_lens):
x, _ = self.pre_rnn(x, None)
x, x_lens = self.stack_time(x, x_lens)
x, _ = self.post_rnn(x, None)
return x.transpose(0, 1), x_lens
class Embedding:
def __init__(self, vocab_size: int, embed_size: int):
self.vocab_size = vocab_size
self.vocab_counter = Tensor(np.arange(vocab_size, dtype=np.float32), requires_grad=False)
self.weight = Tensor.scaled_uniform(vocab_size, embed_size)
def __call__(self, idx: Tensor) -> Tensor:
oha = []
for i in range(idx.shape[0]):
ohba = []
for j in range(idx.shape[1]):
ohba.append((self.vocab_counter == idx[i, j]).realize())
oha.append(Tensor.stack(ohba).realize())
return Tensor.stack(oha) @ self.weight
class Prediction:
def __init__(self, vocab_size, hidden_size, layers, dropout):
self.hidden_size = hidden_size
self.emb = Embedding(vocab_size - 1, hidden_size)
self.rnn = LSTM(hidden_size, hidden_size, layers, dropout)
def __call__(self, x, hc, m):
emb = self.emb(x) * m
x_, hc = self.rnn(emb.transpose(0, 1), hc)
return x_.transpose(0, 1), hc
class Joint:
def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout):
self.dropout = dropout
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
self.l2 = Linear(joint_hidden_size, vocab_size)
def __call__(self, f, g):
(_, T, H), (B, U, H2) = f.shape, g.shape
f = f.unsqueeze(2).expand(B, T, U, H)
g = g.unsqueeze(1).expand(B, T, U, H2)
inp = f.cat(g, dim=3)
t = self.l1(inp).relu()
t = t.dropout(self.dropout)
return self.l2(t)

47
test/models/test_rnnt.py Normal file
View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from models.rnnt import LSTM
import torch
class TestRNNT(unittest.TestCase):
def test_lstm(self):
BS, SQ, IS, HS, L = 4, 220, 240, 1024, 2
# create in torch
with torch.no_grad():
torch_layer = torch.nn.LSTM(IS, HS, L)
# create in tinygrad
layer = LSTM(IS, HS, L, 0.0)
# copy weights
with torch.no_grad():
layer.cells[0].weights_ih.assign(Tensor(torch_layer.weight_ih_l0.numpy()))
layer.cells[0].weights_hh.assign(Tensor(torch_layer.weight_hh_l0.numpy()))
layer.cells[0].bias_ih.assign(Tensor(torch_layer.bias_ih_l0.numpy()))
layer.cells[0].bias_hh.assign(Tensor(torch_layer.bias_hh_l0.numpy()))
layer.cells[1].weights_ih.assign(Tensor(torch_layer.weight_ih_l1.numpy()))
layer.cells[1].weights_hh.assign(Tensor(torch_layer.weight_hh_l1.numpy()))
layer.cells[1].bias_ih.assign(Tensor(torch_layer.bias_ih_l1.numpy()))
layer.cells[1].bias_hh.assign(Tensor(torch_layer.bias_hh_l1.numpy()))
# test initial hidden
for _ in range(3):
x = Tensor.randn(SQ, BS, IS)
z, hc = layer(x, None)
torch_x = torch.tensor(x.cpu().numpy())
torch_z, torch_hc = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
# test passing hidden
for _ in range(3):
x = Tensor.randn(SQ, BS, IS)
z, hc = layer(x, hc)
torch_x = torch.tensor(x.cpu().numpy())
torch_z, torch_hc = torch_layer(torch_x, torch_hc)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
if __name__ == '__main__':
unittest.main()