add cmdline arguments to translate script + add defaults for command r

This commit is contained in:
Alex O'Connell
2024-05-08 20:50:25 -04:00
parent 7404b6b36c
commit 179e794283
4 changed files with 45 additions and 10 deletions

View File

@@ -287,7 +287,12 @@ OPTIONS_OVERRIDES = {
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR,
},
"phi-3": {
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR3
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_EXTRAS,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR3,
},
"command-r": {
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_EXTRAS,
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_COMMAND_R,
}
}

View File

@@ -166,6 +166,7 @@
"zephyr2": "Zephyr ('</s>')",
"zephyr3": "Zephyr (<|end|>)",
"llama3": "Llama 3",
"command-r": "Command R",
"no_prompt_template": "None"
}
},

View File

@@ -1,14 +1,15 @@
"""Original script by @BramNH on GitHub"""
import argparse
import csv
import os
import time
import re
from deep_translator import GoogleTranslator
from deep_translator import GoogleTranslator, DeeplTranslator
from deep_translator.base import BaseTranslator
from deep_translator.exceptions import TooManyRequests
from tqdm import tqdm
from transformers import pipeline
import langcodes
SUPPORTED_DEVICES = [
"light",
@@ -32,20 +33,22 @@ class Seq2SeqTranslator(BaseTranslator):
def __init__(self, model_name: str, **kwargs):
super().__init__(**kwargs)
from transformers import pipeline
self.translator = pipeline("translation", model=model_name, tokenizer=model_name, device=0)
def translate(self, text: str, **kwargs):
return self.translator(text)[0]["translation_text"]
class DatasetTranslator:
translator: Seq2SeqTranslator
translator: BaseTranslator
source_language: str
target_language: str
def __init__(self, source_language, target_language, model_name):
def __init__(self, source_language, target_language, translator):
self.source_language = source_language
self.target_language = target_language
self.translator = Seq2SeqTranslator(model_name=model_name)
self.translator = translator
def translate_all_piles(self):
os.makedirs(f"./piles/{self.target_language}", exist_ok=True)
@@ -334,7 +337,32 @@ class DatasetTranslator:
f.writelines(pile_of_todo_items_target)
# TODO: cmd line args
DatasetTranslator("english", "german", "Helsinki-NLP/opus-mt-en-de").translate_all_piles()
# DatasetTranslator("english", "spanish", "Helsinki-NLP/opus-mt-en-es").translate_all_piles()
# DatasetTranslator("english", "french", "Helsinki-NLP/opus-mt-en-fr").translate_all_piles()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("destination_language", help="Destination language for translation")
parser.add_argument("--source_language", default="english", help="Source language for translation (default: english)")
parser.add_argument("--translator_type", choices=["google", "transformers", "deepl"], required=True, help="Translator type (choose from: google, transformers, deepl)")
parser.add_argument("--model_name", help="Model name (optional)")
parser.add_argument("--api_key", help="API key (optional)")
args = parser.parse_args()
source_code = langcodes.find(args.source_language).language
dest_code = langcodes.find(args.destination_language).language
if args.translator_type == "google":
translator = GoogleTranslator(source=source_code, target=dest_code)
elif args.translator_type == "transformers":
if not args.model_name:
print("No model name was provided!")
parser.print_usage()
exit(-1)
translator = Seq2SeqTranslator(model_name=args.model_name)
elif args.translator_type == "deepl":
if not args.api_key and os.getenv():
print("No api key was provided!")
parser.print_usage()
exit(-1)
translator = DeeplTranslator(source=source_code, target=dest_code, api_key=args.api_key)
DatasetTranslator(args.source_language, args.destination_language, translator).translate_all_piles()

View File

@@ -9,6 +9,7 @@ pandas
# flash-attn
sentencepiece
deep-translator
langcodes
homeassistant
hassil