mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-08 05:14:02 -05:00
add cmdline arguments to translate script + add defaults for command r
This commit is contained in:
@@ -287,7 +287,12 @@ OPTIONS_OVERRIDES = {
|
||||
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR,
|
||||
},
|
||||
"phi-3": {
|
||||
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR3
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_EXTRAS,
|
||||
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR3,
|
||||
},
|
||||
"command-r": {
|
||||
CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_EXTRAS,
|
||||
CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_COMMAND_R,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,6 +166,7 @@
|
||||
"zephyr2": "Zephyr ('</s>')",
|
||||
"zephyr3": "Zephyr (<|end|>)",
|
||||
"llama3": "Llama 3",
|
||||
"command-r": "Command R",
|
||||
"no_prompt_template": "None"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
"""Original script by @BramNH on GitHub"""
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
|
||||
from deep_translator import GoogleTranslator
|
||||
from deep_translator import GoogleTranslator, DeeplTranslator
|
||||
from deep_translator.base import BaseTranslator
|
||||
from deep_translator.exceptions import TooManyRequests
|
||||
from tqdm import tqdm
|
||||
from transformers import pipeline
|
||||
import langcodes
|
||||
|
||||
SUPPORTED_DEVICES = [
|
||||
"light",
|
||||
@@ -32,20 +33,22 @@ class Seq2SeqTranslator(BaseTranslator):
|
||||
def __init__(self, model_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
self.translator = pipeline("translation", model=model_name, tokenizer=model_name, device=0)
|
||||
|
||||
def translate(self, text: str, **kwargs):
|
||||
return self.translator(text)[0]["translation_text"]
|
||||
|
||||
class DatasetTranslator:
|
||||
translator: Seq2SeqTranslator
|
||||
translator: BaseTranslator
|
||||
source_language: str
|
||||
target_language: str
|
||||
|
||||
def __init__(self, source_language, target_language, model_name):
|
||||
def __init__(self, source_language, target_language, translator):
|
||||
self.source_language = source_language
|
||||
self.target_language = target_language
|
||||
self.translator = Seq2SeqTranslator(model_name=model_name)
|
||||
self.translator = translator
|
||||
|
||||
def translate_all_piles(self):
|
||||
os.makedirs(f"./piles/{self.target_language}", exist_ok=True)
|
||||
@@ -334,7 +337,32 @@ class DatasetTranslator:
|
||||
f.writelines(pile_of_todo_items_target)
|
||||
|
||||
|
||||
# TODO: cmd line args
|
||||
DatasetTranslator("english", "german", "Helsinki-NLP/opus-mt-en-de").translate_all_piles()
|
||||
# DatasetTranslator("english", "spanish", "Helsinki-NLP/opus-mt-en-es").translate_all_piles()
|
||||
# DatasetTranslator("english", "french", "Helsinki-NLP/opus-mt-en-fr").translate_all_piles()
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("destination_language", help="Destination language for translation")
|
||||
parser.add_argument("--source_language", default="english", help="Source language for translation (default: english)")
|
||||
parser.add_argument("--translator_type", choices=["google", "transformers", "deepl"], required=True, help="Translator type (choose from: google, transformers, deepl)")
|
||||
parser.add_argument("--model_name", help="Model name (optional)")
|
||||
parser.add_argument("--api_key", help="API key (optional)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
source_code = langcodes.find(args.source_language).language
|
||||
dest_code = langcodes.find(args.destination_language).language
|
||||
|
||||
if args.translator_type == "google":
|
||||
translator = GoogleTranslator(source=source_code, target=dest_code)
|
||||
elif args.translator_type == "transformers":
|
||||
if not args.model_name:
|
||||
print("No model name was provided!")
|
||||
parser.print_usage()
|
||||
exit(-1)
|
||||
translator = Seq2SeqTranslator(model_name=args.model_name)
|
||||
elif args.translator_type == "deepl":
|
||||
if not args.api_key and os.getenv():
|
||||
print("No api key was provided!")
|
||||
parser.print_usage()
|
||||
exit(-1)
|
||||
translator = DeeplTranslator(source=source_code, target=dest_code, api_key=args.api_key)
|
||||
|
||||
DatasetTranslator(args.source_language, args.destination_language, translator).translate_all_piles()
|
||||
@@ -9,6 +9,7 @@ pandas
|
||||
# flash-attn
|
||||
sentencepiece
|
||||
deep-translator
|
||||
langcodes
|
||||
|
||||
homeassistant
|
||||
hassil
|
||||
|
||||
Reference in New Issue
Block a user