mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Add mlperf RNN-T model (#782)
* feat: initial rnn-t * feat: working with BS>1 * feat: add lstm test * feat: test passing hidden * clean: cleanup * feat: specify start * feat: way faster lstm & model * fix: default batch size * feat: optimization * fix: fix metrics * fix: fix feature splicing * feat: cleaner stacktime * clean: remove unused import * clean: remove extra prints * fix: fix tests and happy llvm * feat: have the librispeech dataset in its own dir * clean: unused variable * feat: no longer need numpy for the embedding + slightly more memory efficient lstm * fix: forgot to remove something that broke tests * feat: use relative paths * feat: even faster * feat: remove pointless transposes in StackTime * fix: correct forward * feat: switch to soundfile for loading and fix some leaks * feat: add comment about initial dataset setup * feat: jit more things * feat: default batch size back to 1 larger than 1 is broken again :( and even in the reference implementation it gives worse results
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,3 +21,4 @@ disassemblers/applegpu
|
||||
disassemblers/cuda_ioctl_sniffer
|
||||
*.prof
|
||||
datasets/cifar-10-python.tar.gz
|
||||
datasets/librispeech/
|
||||
|
||||
82
datasets/librispeech.py
Normal file
82
datasets/librispeech.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
import pathlib
|
||||
import numpy as np
|
||||
import librosa
|
||||
import soundfile
|
||||
|
||||
"""
|
||||
The dataset has to be downloaded manually from https://www.openslr.org/12/ and put in `datasets/librispeech`.
|
||||
For mlperf validation the dev-clean dataset is used.
|
||||
|
||||
Then all the flacs have to be converted to wav using something like:
|
||||
```fish
|
||||
for file in **/*.flac; ffmpeg -i $file -ar 16k "$(dirname $file)/$(basename $file .flac).wav"; end
|
||||
```
|
||||
|
||||
Then this [file](https://github.com/mlcommons/inference/blob/master/speech_recognition/rnnt/dev-clean-wav.json) has to also be put in `datasets/librispeech`.
|
||||
"""
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/librispeech"
|
||||
with open(BASEDIR / "dev-clean-wav.json") as f:
|
||||
ci = json.load(f)
|
||||
|
||||
FILTER_BANK = np.expand_dims(librosa.filters.mel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000), 0)
|
||||
WINDOW = librosa.filters.get_window("hann", 320)
|
||||
|
||||
def feature_extract(x, x_lens):
|
||||
x_lens = np.ceil((x_lens / 160) / 3).astype(np.int32)
|
||||
|
||||
# pre-emphasis
|
||||
x = np.concatenate((np.expand_dims(x[:, 0], 1), x[:, 1:] - 0.97 * x[:, :-1]), axis=1)
|
||||
|
||||
# stft
|
||||
x = librosa.stft(x, n_fft=512, window=WINDOW, hop_length=160, win_length=320, center=True, pad_mode="reflect")
|
||||
x = np.stack((x.real, x.imag), axis=-1)
|
||||
|
||||
# power spectrum
|
||||
x = (x**2).sum(-1)
|
||||
|
||||
# mel filter bank
|
||||
x = np.matmul(FILTER_BANK, x)
|
||||
|
||||
# log
|
||||
x = np.log(x + 1e-20)
|
||||
|
||||
# feature splice
|
||||
seq = [x]
|
||||
for i in range(1, 3):
|
||||
tmp = np.zeros_like(x)
|
||||
tmp[:, :, :-i] = x[:, :, i:]
|
||||
seq.append(tmp)
|
||||
features = np.concatenate(seq, axis=1)[:, :, ::3]
|
||||
|
||||
# normalize
|
||||
features_mean = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32)
|
||||
features_std = np.zeros((features.shape[0], features.shape[1]), dtype=np.float32)
|
||||
for i in range(features.shape[0]):
|
||||
features_mean[i, :] = features[i, :, :x_lens[i]].mean(axis=1)
|
||||
features_std[i, :] = features[i, :, :x_lens[i]].std(axis=1, ddof=1)
|
||||
features_std += 1e-5
|
||||
features = (features - np.expand_dims(features_mean, 2)) / np.expand_dims(features_std, 2)
|
||||
|
||||
return features.transpose(2, 0, 1), x_lens.astype(np.float32)
|
||||
|
||||
def load_wav(file):
|
||||
sample = soundfile.read(file)[0].astype(np.float32)
|
||||
return sample, sample.shape[0]
|
||||
|
||||
def iterate(bs=1, start=0):
|
||||
print(f"there are {len(ci)} samples in the dataset")
|
||||
for i in range(start, len(ci), bs):
|
||||
samples, sample_lens = zip(*[load_wav(BASEDIR / v["files"][0]["fname"]) for v in ci[i : i + bs]])
|
||||
samples = list(samples)
|
||||
# pad to same length
|
||||
max_len = max(sample_lens)
|
||||
for j in range(len(samples)):
|
||||
samples[j] = np.pad(samples[j], (0, max_len - sample_lens[j]), "constant")
|
||||
samples, sample_lens = np.array(samples), np.array(sample_lens)
|
||||
|
||||
yield feature_extract(samples, sample_lens), np.array([v["transcript"] for v in ci[i : i + bs]])
|
||||
|
||||
if __name__ == "__main__":
|
||||
X, Y = next(iterate())
|
||||
print(X[0].shape, Y.shape)
|
||||
25
examples/mlperf/metrics.py
Normal file
25
examples/mlperf/metrics.py
Normal file
@@ -0,0 +1,25 @@
|
||||
def levenshtein(a, b):
|
||||
n, m = len(a), len(b)
|
||||
if n > m:
|
||||
a, b, n, m = b, a, m, n
|
||||
|
||||
current = list(range(n + 1))
|
||||
for i in range(1, m + 1):
|
||||
previous, current = current, [i] + [0] * n
|
||||
for j in range(1, n + 1):
|
||||
add, delete = previous[j] + 1, current[j - 1] + 1
|
||||
change = previous[j - 1]
|
||||
if a[j - 1] != b[i - 1]:
|
||||
change = change + 1
|
||||
current[j] = min(add, delete, change)
|
||||
|
||||
return current[n]
|
||||
|
||||
def word_error_rate(x, y):
|
||||
scores = words = 0
|
||||
for h, r in zip(x, y):
|
||||
h_list = h.split()
|
||||
r_list = r.split()
|
||||
words += len(r_list)
|
||||
scores += levenshtein(h_list, r_list)
|
||||
return float(scores) / words, float(scores), words
|
||||
@@ -43,4 +43,30 @@ if __name__ == "__main__":
|
||||
print(f"****** {n}/{d} {n*100.0/d:.2f}%")
|
||||
st = time.perf_counter()
|
||||
|
||||
# RNN-T
|
||||
from models.rnnt import RNNT
|
||||
mdl = RNNT()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
from datasets.librispeech import iterate
|
||||
from examples.mlperf.metrics import word_error_rate
|
||||
|
||||
LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
|
||||
|
||||
c = 0
|
||||
scores = 0
|
||||
words = 0
|
||||
st = time.perf_counter()
|
||||
for X, Y in iterate():
|
||||
mt = time.perf_counter()
|
||||
tt = mdl.decode(Tensor(X[0]), Tensor([X[1]]))
|
||||
et = time.perf_counter()
|
||||
print(f"{(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model")
|
||||
for n, t in enumerate(tt):
|
||||
tnp = np.array(t)
|
||||
_, scores_, words_ = word_error_rate(["".join([LABELS[int(tnp[i])] for i in range(tnp.shape[0])])], [Y[n]])
|
||||
scores += scores_
|
||||
words += words_
|
||||
c += len(tt)
|
||||
print(f"WER: {scores/words}, {words} words, raw scores: {scores}, c: {c}")
|
||||
st = time.perf_counter()
|
||||
|
||||
@@ -29,5 +29,11 @@ if __name__ == "__main__":
|
||||
test_model(mdl, img)
|
||||
|
||||
# RNNT
|
||||
from models.rnnt import RNNT
|
||||
mdl = RNNT()
|
||||
mdl.load_from_pretrained()
|
||||
x = Tensor.randn(220, 1, 240)
|
||||
y = Tensor.randn(1, 220)
|
||||
test_model(mdl, x, y)
|
||||
|
||||
# BERT-large
|
||||
|
||||
218
models/rnnt.py
Normal file
218
models/rnnt.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.nn import Linear
|
||||
import numpy as np
|
||||
from extra.utils import download_file
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class RNNT:
|
||||
def __init__(self, input_features=240, vocab_size=29, enc_hidden_size=1024, pred_hidden_size=320, joint_hidden_size=512, pre_enc_layers=2, post_enc_layers=3, pred_layers=2, stack_time_factor=2, dropout=0.32):
|
||||
self.encoder = Encoder(input_features, enc_hidden_size, pre_enc_layers, post_enc_layers, stack_time_factor, dropout)
|
||||
self.prediction = Prediction(vocab_size, pred_hidden_size, pred_layers, dropout)
|
||||
self.joint = Joint(vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout)
|
||||
|
||||
@TinyJit
|
||||
def __call__(self, x, y, hc=None):
|
||||
f, _ = self.encoder(x, None)
|
||||
g, _ = self.prediction(y, hc, Tensor.ones(1, requires_grad=False))
|
||||
out = self.joint(f, g)
|
||||
return out.realize()
|
||||
|
||||
def decode(self, x, x_lens):
|
||||
logits, logit_lens = self.encoder(x, x_lens)
|
||||
outputs = []
|
||||
for b in range(logits.shape[0]):
|
||||
inseq = logits[b, :, :].unsqueeze(1)
|
||||
logit_len = logit_lens[b]
|
||||
seq = self._greedy_decode(inseq, int(np.ceil(logit_len.numpy()).item()))
|
||||
outputs.append(seq)
|
||||
return outputs
|
||||
|
||||
def _greedy_decode(self, logits, logit_len):
|
||||
hc = Tensor.zeros(self.prediction.rnn.layers, 2, self.prediction.hidden_size, requires_grad=False)
|
||||
labels = []
|
||||
label = Tensor.zeros(1, 1, requires_grad=False)
|
||||
mask = Tensor.zeros(1, requires_grad=False)
|
||||
for time_idx in range(logit_len):
|
||||
logit = logits[time_idx, :, :].unsqueeze(0)
|
||||
not_blank = True
|
||||
added = 0
|
||||
while not_blank and added < 30:
|
||||
if len(labels) > 0:
|
||||
mask = (mask + 1).clip(0, 1)
|
||||
label = Tensor([[labels[-1] if labels[-1] <= 28 else labels[-1] - 1]], requires_grad=False) + 1 - 1
|
||||
jhc = self._pred_joint(Tensor(logit.numpy()), label, hc, mask)
|
||||
k = np.argmax(jhc[0, 0, :29].numpy(), axis=0)
|
||||
not_blank = k != 28
|
||||
if not_blank:
|
||||
labels.append(k)
|
||||
hc = jhc[:, :, 29:] + 1 - 1
|
||||
added += 1
|
||||
return labels
|
||||
|
||||
@TinyJit
|
||||
def _pred_joint(self, logit, label, hc, mask):
|
||||
g, hc = self.prediction(label, hc, mask)
|
||||
j = self.joint(logit, g)[0]
|
||||
j = j.pad(((0, 1), (0, 1), (0, 0)))
|
||||
out = j.cat(hc, dim=2)
|
||||
return out.realize()
|
||||
|
||||
def load_from_pretrained(self):
|
||||
fn = Path(__file__).parent.parent / "weights/rnnt.pt"
|
||||
download_file("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn)
|
||||
|
||||
import torch
|
||||
with open(fn, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")["state_dict"]
|
||||
|
||||
# encoder
|
||||
for i in range(2):
|
||||
self.encoder.pre_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.pre_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.pre_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.pre_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.encoder.pre_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.pre_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
for i in range(3):
|
||||
self.encoder.post_rnn.cells[i].weights_ih.assign(state_dict[f"encoder.post_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].weights_hh.assign(state_dict[f"encoder.post_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].bias_ih.assign(state_dict[f"encoder.post_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.encoder.post_rnn.cells[i].bias_hh.assign(state_dict[f"encoder.post_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
|
||||
# prediction
|
||||
self.prediction.emb.weight.assign(state_dict["prediction.embed.weight"].numpy())
|
||||
for i in range(2):
|
||||
self.prediction.rnn.cells[i].weights_ih.assign(state_dict[f"prediction.dec_rnn.lstm.weight_ih_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].weights_hh.assign(state_dict[f"prediction.dec_rnn.lstm.weight_hh_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].bias_ih.assign(state_dict[f"prediction.dec_rnn.lstm.bias_ih_l{i}"].numpy())
|
||||
self.prediction.rnn.cells[i].bias_hh.assign(state_dict[f"prediction.dec_rnn.lstm.bias_hh_l{i}"].numpy())
|
||||
|
||||
# joint
|
||||
self.joint.l1.weight.assign(state_dict["joint_net.0.weight"].numpy())
|
||||
self.joint.l1.bias.assign(state_dict["joint_net.0.bias"].numpy())
|
||||
self.joint.l2.weight.assign(state_dict["joint_net.3.weight"].numpy())
|
||||
self.joint.l2.bias.assign(state_dict["joint_net.3.bias"].numpy())
|
||||
|
||||
|
||||
class LSTMCell:
|
||||
def __init__(self, input_size, hidden_size, dropout):
|
||||
self.dropout = dropout
|
||||
|
||||
self.weights_ih = Tensor.uniform(hidden_size * 4, input_size)
|
||||
self.bias_ih = Tensor.uniform(hidden_size * 4)
|
||||
self.weights_hh = Tensor.uniform(hidden_size * 4, hidden_size)
|
||||
self.bias_hh = Tensor.uniform(hidden_size * 4)
|
||||
|
||||
def __call__(self, x, hc):
|
||||
gates = x.linear(self.weights_ih.T, self.bias_ih) + hc[:x.shape[0]].linear(self.weights_hh.T, self.bias_hh)
|
||||
|
||||
i, f, g, o = gates.chunk(4, 1)
|
||||
i, f, g, o = i.sigmoid(), f.sigmoid(), g.tanh(), o.sigmoid()
|
||||
|
||||
c = (f * hc[x.shape[0]:]) + (i * g)
|
||||
h = (o * c.tanh()).dropout(self.dropout)
|
||||
|
||||
return Tensor.cat(h, c).realize()
|
||||
|
||||
|
||||
class LSTM:
|
||||
def __init__(self, input_size, hidden_size, layers, dropout):
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.layers = layers
|
||||
|
||||
self.cells = [LSTMCell(input_size, hidden_size, dropout) if i == 0 else LSTMCell(hidden_size, hidden_size, dropout if i != layers - 1 else 0) for i in range(layers)]
|
||||
|
||||
def __call__(self, x, hc):
|
||||
@TinyJit
|
||||
def _do_step(x_, hc_):
|
||||
return self.do_step(x_, hc_)
|
||||
|
||||
if hc is None:
|
||||
hc = Tensor.zeros(self.layers, 2 * x.shape[1], self.hidden_size, requires_grad=False)
|
||||
|
||||
output = None
|
||||
for t in range(x.shape[0]):
|
||||
hc = _do_step(x[t] + 1 - 1, hc) # TODO: why do we need to do this?
|
||||
if output is None:
|
||||
output = hc[-1:, :x.shape[1]]
|
||||
else:
|
||||
output = output.cat(hc[-1:, :x.shape[1]], dim=0).realize()
|
||||
|
||||
return output, hc
|
||||
|
||||
def do_step(self, x, hc):
|
||||
new_hc = [x]
|
||||
for i, cell in enumerate(self.cells):
|
||||
new_hc.append(cell(new_hc[i][:x.shape[0]], hc[i]))
|
||||
return Tensor.stack(new_hc[1:]).realize()
|
||||
|
||||
|
||||
class StackTime:
|
||||
def __init__(self, factor):
|
||||
self.factor = factor
|
||||
|
||||
def __call__(self, x, x_lens):
|
||||
x = x.pad(((0, (-x.shape[0]) % self.factor), (0, 0), (0, 0)))
|
||||
x = x.reshape(x.shape[0] // self.factor, x.shape[1], x.shape[2] * self.factor)
|
||||
return x, x_lens / self.factor if x_lens is not None else None
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, input_size, hidden_size, pre_layers, post_layers, stack_time_factor, dropout):
|
||||
self.pre_rnn = LSTM(input_size, hidden_size, pre_layers, dropout)
|
||||
self.stack_time = StackTime(stack_time_factor)
|
||||
self.post_rnn = LSTM(stack_time_factor * hidden_size, hidden_size, post_layers, dropout)
|
||||
|
||||
def __call__(self, x, x_lens):
|
||||
x, _ = self.pre_rnn(x, None)
|
||||
x, x_lens = self.stack_time(x, x_lens)
|
||||
x, _ = self.post_rnn(x, None)
|
||||
return x.transpose(0, 1), x_lens
|
||||
|
||||
|
||||
class Embedding:
|
||||
def __init__(self, vocab_size: int, embed_size: int):
|
||||
self.vocab_size = vocab_size
|
||||
self.vocab_counter = Tensor(np.arange(vocab_size, dtype=np.float32), requires_grad=False)
|
||||
self.weight = Tensor.scaled_uniform(vocab_size, embed_size)
|
||||
|
||||
def __call__(self, idx: Tensor) -> Tensor:
|
||||
oha = []
|
||||
for i in range(idx.shape[0]):
|
||||
ohba = []
|
||||
for j in range(idx.shape[1]):
|
||||
ohba.append((self.vocab_counter == idx[i, j]).realize())
|
||||
oha.append(Tensor.stack(ohba).realize())
|
||||
return Tensor.stack(oha) @ self.weight
|
||||
|
||||
|
||||
class Prediction:
|
||||
def __init__(self, vocab_size, hidden_size, layers, dropout):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.emb = Embedding(vocab_size - 1, hidden_size)
|
||||
self.rnn = LSTM(hidden_size, hidden_size, layers, dropout)
|
||||
|
||||
def __call__(self, x, hc, m):
|
||||
emb = self.emb(x) * m
|
||||
x_, hc = self.rnn(emb.transpose(0, 1), hc)
|
||||
return x_.transpose(0, 1), hc
|
||||
|
||||
|
||||
class Joint:
|
||||
def __init__(self, vocab_size, pred_hidden_size, enc_hidden_size, joint_hidden_size, dropout):
|
||||
self.dropout = dropout
|
||||
|
||||
self.l1 = Linear(pred_hidden_size + enc_hidden_size, joint_hidden_size)
|
||||
self.l2 = Linear(joint_hidden_size, vocab_size)
|
||||
|
||||
def __call__(self, f, g):
|
||||
(_, T, H), (B, U, H2) = f.shape, g.shape
|
||||
f = f.unsqueeze(2).expand(B, T, U, H)
|
||||
g = g.unsqueeze(1).expand(B, T, U, H2)
|
||||
|
||||
inp = f.cat(g, dim=3)
|
||||
t = self.l1(inp).relu()
|
||||
t = t.dropout(self.dropout)
|
||||
return self.l2(t)
|
||||
47
test/models/test_rnnt.py
Normal file
47
test/models/test_rnnt.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from models.rnnt import LSTM
|
||||
import torch
|
||||
|
||||
class TestRNNT(unittest.TestCase):
|
||||
def test_lstm(self):
|
||||
BS, SQ, IS, HS, L = 4, 220, 240, 1024, 2
|
||||
|
||||
# create in torch
|
||||
with torch.no_grad():
|
||||
torch_layer = torch.nn.LSTM(IS, HS, L)
|
||||
|
||||
# create in tinygrad
|
||||
layer = LSTM(IS, HS, L, 0.0)
|
||||
|
||||
# copy weights
|
||||
with torch.no_grad():
|
||||
layer.cells[0].weights_ih.assign(Tensor(torch_layer.weight_ih_l0.numpy()))
|
||||
layer.cells[0].weights_hh.assign(Tensor(torch_layer.weight_hh_l0.numpy()))
|
||||
layer.cells[0].bias_ih.assign(Tensor(torch_layer.bias_ih_l0.numpy()))
|
||||
layer.cells[0].bias_hh.assign(Tensor(torch_layer.bias_hh_l0.numpy()))
|
||||
layer.cells[1].weights_ih.assign(Tensor(torch_layer.weight_ih_l1.numpy()))
|
||||
layer.cells[1].weights_hh.assign(Tensor(torch_layer.weight_hh_l1.numpy()))
|
||||
layer.cells[1].bias_ih.assign(Tensor(torch_layer.bias_ih_l1.numpy()))
|
||||
layer.cells[1].bias_hh.assign(Tensor(torch_layer.bias_hh_l1.numpy()))
|
||||
|
||||
# test initial hidden
|
||||
for _ in range(3):
|
||||
x = Tensor.randn(SQ, BS, IS)
|
||||
z, hc = layer(x, None)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z, torch_hc = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
# test passing hidden
|
||||
for _ in range(3):
|
||||
x = Tensor.randn(SQ, BS, IS)
|
||||
z, hc = layer(x, hc)
|
||||
torch_x = torch.tensor(x.cpu().numpy())
|
||||
torch_z, torch_hc = torch_layer(torch_x, torch_hc)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user