From f5dd25d3767f67b70a3c33daa8255e64b380a590 Mon Sep 17 00:00:00 2001 From: kormann <49917710+DKormann@users.noreply.github.com> Date: Tue, 17 Sep 2024 06:42:10 +0200 Subject: [PATCH] enable whisper batch for long sequences (#6458) * long batch +test * long batch +test * cleanup * rollback syntactic changes --------- Co-authored-by: chenyu --- examples/whisper.py | 57 ++++++++++++++++--------------------- test/models/test_whisper.py | 7 +++-- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/examples/whisper.py b/examples/whisper.py index d087660986..b44f764cf9 100644 --- a/examples/whisper.py +++ b/examples/whisper.py @@ -162,7 +162,7 @@ def prep_audio(waveforms: List[np.ndarray], batch_size: int, truncate=False) -> mel_spec = librosa.filters.mel(sr=RATE, n_fft=N_FFT, n_mels=N_MELS) @ magnitudes log_spec = np.log10(np.clip(mel_spec, 1e-10, None)) - log_spec = np.maximum(log_spec, log_spec.max() - 8.0) + log_spec = np.maximum(log_spec, log_spec.max((1,2), keepdims=True) - 8.0) log_spec = (log_spec + 4.0) / 4.0 return log_spec @@ -241,14 +241,21 @@ def transcribe_waveform(model: Whisper, enc, waveforms, truncate=False): Expects an array of shape (N,S) where N is the number waveforms to transcribe in parallel and S is number of 16000Hz samples Returns the transcribed text if a single waveform is provided, or an array of transcriptions if multiple are provided """ - N_audio = len(waveforms) + log_spec = prep_audio(waveforms, model.batch_size, truncate) + nsample = model.decoder.max_tokens_to_sample - if log_spec.shape[-1] > FRAMES_PER_SEGMENT and N_audio > 1: - # we don't support multi-segment batching because the size of the prompt tokens would be different for each item in the batch - # if we really want this feature, we can consider padding or trimming prompt tokens of varying lengths to make them consistent - raise Exception("Multi-segment transcription not supported with batch audio input") + def inferloop(ctx: Union[np.ndarray, List[np.ndarray]], encoded_audio): + pos, next_tokens = 0, ctx + for i in range((nsample-len(start_tokens))*2): + next_tokens = model.decoder(Tensor(next_tokens), pos, encoded_audio)[:, -1].argmax(axis=-1).numpy().astype(np.int32).reshape(-1, 1) + next_tokens[ctx[:, -1] == eot] = eot + ctx = np.concatenate((ctx, next_tokens), axis=1) + pos = ctx.shape[-1] - 1 + if (next_tokens == eot).all(): break + return ctx + def gettexttoks(line): return [tok for tok in line if tok < eot or tok > enc._special_tokens["<|notimestamps|>"]][-nsample+len(start_tokens):] start_tokens = [enc._special_tokens["<|startoftranscript|>"]] if model.is_multilingual: # TODO detect language @@ -256,40 +263,24 @@ def transcribe_waveform(model: Whisper, enc, waveforms, truncate=False): start_tokens.append(language_token) start_tokens.append(enc._special_tokens["<|transcribe|>"]) start_tokens.append(enc._special_tokens["<|notimestamps|>"]) - transcription_start_index = len(start_tokens) + eot = enc._special_tokens["<|endoftext|>"] - transcription_tokens = [np.array([], dtype=np.int32)] * log_spec.shape[0] + + ctx = np.tile(start_tokens, (model.batch_size,1)) + transcriptions = [[] for _ in waveforms] for curr_frame in range(0, log_spec.shape[-1], FRAMES_PER_SEGMENT): encoded_audio = model.encoder.encode(Tensor(log_spec[:, :, curr_frame:curr_frame + FRAMES_PER_SEGMENT])) - pos = 0 - curr_segment_tokens = np.tile(start_tokens, (log_spec.shape[0], 1)) - if curr_frame > 0: - # pass the previously inferred tokens as 'prompt' - https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 - prompt = np.concatenate(( - [enc._special_tokens["<|startofprev|>"]], - transcription_tokens[0][-model.decoder.max_tokens_to_sample+1:], - start_tokens)) - curr_segment_tokens = np.tile(prompt, (log_spec.shape[0], 1)) - transcription_start_index = len(curr_segment_tokens[0]) - for i in range(model.decoder.max_tokens_to_sample): - out = model.decoder.forward(Tensor(curr_segment_tokens if i == 0 else curr_segment_tokens[:, -1:]), pos, None if i else encoded_audio) - next_tokens = out[:, -1].argmax(axis=-1).numpy().astype(np.int32) - next_tokens[curr_segment_tokens[:, -1] == eot] = eot - curr_segment_tokens = np.concatenate((curr_segment_tokens, next_tokens.reshape(-1, 1)), axis=1) - pos = curr_segment_tokens.shape[-1] - 1 - if DEBUG >= 1: print(i, list(map(lambda tokens: enc.decode(tokens), curr_segment_tokens))) - if (curr_segment_tokens[:, -1] == eot).all(): - break + if all(len(c) == len(ctx[0]) for c in ctx): ctx = inferloop(np.array(ctx), encoded_audio) + else: ctx = [inferloop((np.array([c]*model.batch_size)), encoded_audio)[i] for i,c in enumerate(ctx)] - for i, t in enumerate(curr_segment_tokens): - eot_index = np.where(t == eot)[0] - eot_index = None if len(eot_index) == 0 else eot_index[0] - transcription_tokens[i] = np.concatenate((transcription_tokens[i], t[transcription_start_index:eot_index])) + for i, (res, arr) in enumerate(zip(transcriptions, ctx)): + if curr_frame*HOP_LENGTH <= len(waveforms[i]):res.extend(arr[np.where(arr == start_tokens[-1])[0][0]+1:eoti[0] if len (eoti:=np.where(arr == eot)[0]) else None]) + ctx = [[enc._special_tokens['<|startofprev|>']]+gettexttoks(cs)+start_tokens for cs in ctx] - transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcription_tokens)) - return transcriptions[:N_audio] if N_audio > 1 else transcriptions[0] + transcriptions = list(map(lambda tokens: enc.decode(tokens).strip(), transcriptions)) + return transcriptions if len(transcriptions) > 1 else transcriptions[0] CHUNK = 1600 RECORD_SECONDS = 10 diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index 7dd87bca32..16bded0eee 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -61,8 +61,11 @@ class TestWhisper(unittest.TestCase): @unittest.skipIf(CI, "too long for CI") def test_transcribe_long_no_batch(self): waveforms = [load_file_waveform(fetch(TEST_FILE_3_URL)), load_file_waveform(TEST_FILE_1)] - with self.assertRaises(Exception): - transcribe_waveform(self.model, self.enc, waveforms) + + trancriptions = transcribe_waveform(self.model, self.enc, waveforms) + self.assertEqual(2, len(trancriptions)) + self.assertEqual(TRANSCRIPTION_3, trancriptions[0]) + self.assertEqual(TRANSCRIPTION_1, trancriptions[1]) if __name__ == '__main__': unittest.main()