mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
whisper: fix oob, explicit dtype (#13144)
* fix dtype depending on numpy version numpy v2 np.array returns int64 which Tensor passed through for the first decode call, swallowing the <|notimestamps|> token and corrupting the sequence * fix whisper OOB global limit on whisper's context length * enforce whisper max_tokens_to_sample (match openai) local limit on max tokens decoded
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import sys, base64, multiprocessing, itertools, collections
|
||||
from typing import Optional, Union, Literal, List
|
||||
|
||||
from tinygrad import Tensor, TinyJit, Variable, nn
|
||||
from tinygrad import Tensor, TinyJit, Variable, nn, dtypes
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from tinygrad.helpers import getenv, fetch
|
||||
|
||||
@@ -244,15 +244,16 @@ def transcribe_waveform(model: Whisper, enc, waveforms, truncate=False):
|
||||
|
||||
log_spec = prep_audio(waveforms, model.batch_size, truncate)
|
||||
nsample = model.decoder.max_tokens_to_sample
|
||||
nctx = model.decoder.max_self_attn_cache_len
|
||||
|
||||
def inferloop(ctx: Union[np.ndarray, List[np.ndarray]], encoded_audio):
|
||||
pos, next_tokens = 0, ctx
|
||||
for i in range((nsample-len(start_tokens))*2):
|
||||
next_tokens = model.decoder(Tensor(next_tokens), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
|
||||
for i in range(nsample):
|
||||
next_tokens = model.decoder(Tensor(next_tokens, dtype=dtypes.int32), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1)
|
||||
next_tokens[ctx[:, -1] == eot] = eot
|
||||
ctx = np.concatenate((ctx, next_tokens), axis=1)
|
||||
pos = ctx.shape[-1] - 1
|
||||
if (next_tokens == eot).all(): break
|
||||
if (next_tokens == eot).all() or pos == nctx: break
|
||||
return ctx
|
||||
|
||||
def gettexttoks(line): return [tok for tok in line if tok < eot or tok > enc._special_tokens["<|notimestamps|>"]][-nsample+len(start_tokens):]
|
||||
|
||||
Reference in New Issue
Block a user