whisper: fix oob, explicit dtype (#13144)

* fix dtype depending on numpy version

numpy v2 np.array returns int64 which Tensor passed through for the
first decode call, swallowing the <|notimestamps|> token and corrupting
the sequence

* fix whisper OOB

global limit on whisper's context length

* enforce whisper max_tokens_to_sample (match openai)

local limit on max tokens decoded
This commit is contained in:
C T
2025-11-07 19:55:01 +02:00
committed by GitHub
parent 3ecff3a8da
commit 0f9d7f650d

View File

@@ -3,7 +3,7 @@
import sys, base64, multiprocessing, itertools, collections
from typing import Optional, Union, Literal, List
from tinygrad import Tensor, TinyJit, Variable, nn
from tinygrad import Tensor, TinyJit, Variable, nn, dtypes
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import getenv, fetch
@@ -244,15 +244,16 @@ def transcribe_waveform(model: Whisper, enc, waveforms, truncate=False):
log_spec = prep_audio(waveforms, model.batch_size, truncate)
nsample = model.decoder.max_tokens_to_sample
nctx = model.decoder.max_self_attn_cache_len
def inferloop(ctx: Union[np.ndarray, List[np.ndarray]], encoded_audio):
pos, next_tokens = 0, ctx
for i in range((nsample-len(start_tokens))*2):
next_tokens = model.decoder(Tensor(next_tokens), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
for i in range(nsample):
next_tokens = model.decoder(Tensor(next_tokens, dtype=dtypes.int32), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
next_tokens[ctx[:, -1] == eot] = eot
ctx = np.concatenate((ctx, next_tokens), axis=1)
pos = ctx.shape[-1] - 1
if (next_tokens == eot).all(): break
if (next_tokens == eot).all() or pos == nctx: break
return ctx
def gettexttoks(line): return [tok for tok in line if tok < eot or tok > enc._special_tokens["<|notimestamps|>"]][-nsample+len(start_tokens):]