Filter out non_speech_tokens in suppressed tokens (#898)

* Filter out non_speech_tokens in suppressed tokens
This commit is contained in:
Jordi Mas
2024-07-05 09:43:11 +02:00
committed by GitHub
parent c22db5125d
commit 1195359984
3 changed files with 159 additions and 11 deletions

View File

@@ -105,6 +105,42 @@ class Tokenizer:
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
)
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
- ♪♪♪
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.encode(" -")[0], self.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.encode(symbol),
self.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def split_to_word_tokens(
self, tokens: List[int]
) -> Tuple[List[str], List[List[int]]]:

View File

@@ -277,7 +277,7 @@ class WhisperModel:
prefix: Optional text to provide as a prefix for the first window.
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in the model config.json file.
of symbols as defined in `tokenizer.non_speech_tokens()`
without_timestamps: Only sample text tokens.
max_initial_timestamp: The initial timestamp cannot be later than this.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
@@ -462,7 +462,11 @@ class WhisperModel:
initial_prompt=initial_prompt,
prefix=prefix,
suppress_blank=suppress_blank,
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
suppress_tokens=(
get_suppressed_tokens(tokenizer, suppress_tokens)
if suppress_tokens
else suppress_tokens
),
without_timestamps=without_timestamps,
max_initial_timestamp=max_initial_timestamp,
word_timestamps=word_timestamps,
@@ -488,7 +492,6 @@ class WhisperModel:
vad_options=vad_parameters,
all_language_probs=all_language_probs,
)
return segments, info
def generate_segments(
@@ -1227,15 +1230,16 @@ def get_compression_ratio(text: str) -> float:
def get_suppressed_tokens(
tokenizer: Tokenizer,
suppress_tokens: Optional[List[int]],
suppress_tokens: Tuple[int],
) -> Optional[List[int]]:
if not suppress_tokens or -1 in suppress_tokens:
return suppress_tokens
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens = list(suppress_tokens)
# Ensure the following special tokens are suppressed when the user does
# not use the default set (-1).
suppress_tokens.extend(
[
tokenizer.transcribe,
@@ -1246,7 +1250,7 @@ def get_suppressed_tokens(
]
)
return sorted(set(suppress_tokens))
return tuple(sorted(set(suppress_tokens)))
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None: