diff --git a/examples/whisper.py b/examples/whisper.py index 2df3122628..5a189ad0a9 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -3,7 +3,7 @@ import sys, base64, multiprocessing, itertools, collections from typing import Optional, Union, Literal, List -from tinygrad import Tensor, TinyJit, Variable, nn +from tinygrad import Tensor, TinyJit, Variable, nn, dtypes from tinygrad.nn.state import torch_load, load_state_dict from tinygrad.helpers import getenv, fetch @@ -244,15 +244,16 @@ def transcribe_waveform(model: Whisper, enc, waveforms, truncate=False): log_spec = prep_audio(waveforms, model.batch_size, truncate) nsample = model.decoder.max_tokens_to_sample + nctx = model.decoder.max_self_attn_cache_len def inferloop(ctx: Union[np.ndarray, List[np.ndarray]], encoded_audio): pos, next_tokens = 0, ctx - for i in range((nsample-len(start_tokens))*2): - next_tokens = model.decoder(Tensor(next_tokens), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1) + for i in range(nsample): + next_tokens = model.decoder(Tensor(next_tokens, dtype=dtypes.int32), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1) next_tokens[ctx[:, -1] == eot] = eot ctx = np.concatenate((ctx, next_tokens), axis=1) pos = ctx.shape[-1] - 1 - if (next_tokens == eot).all(): break + if (next_tokens == eot).all() or pos == nctx: break return ctx def gettexttoks(line): return [tok for tok in line if tok < eot or tok > enc._special_tokens["<|notimestamps|>"]][-nsample+len(start_tokens):]