diff --git a/custom_components/llama_conversation/const.py b/custom_components/llama_conversation/const.py index 003c869..ebb8451 100644 --- a/custom_components/llama_conversation/const.py +++ b/custom_components/llama_conversation/const.py @@ -287,7 +287,12 @@ OPTIONS_OVERRIDES = { CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR, }, "phi-3": { - CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR3 + CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_EXTRAS, + CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_ZEPHYR3, + }, + "command-r": { + CONF_PROMPT: DEFAULT_PROMPT_BASE + ICL_EXTRAS, + CONF_PROMPT_TEMPLATE: PROMPT_TEMPLATE_COMMAND_R, } } diff --git a/custom_components/llama_conversation/translations/en.json b/custom_components/llama_conversation/translations/en.json index 7c5d7ea..2119e06 100644 --- a/custom_components/llama_conversation/translations/en.json +++ b/custom_components/llama_conversation/translations/en.json @@ -166,6 +166,7 @@ "zephyr2": "Zephyr ('')", "zephyr3": "Zephyr (<|end|>)", "llama3": "Llama 3", + "command-r": "Command R", "no_prompt_template": "None" } }, diff --git a/data/translate_data.py b/data/translate_data.py index f745a45..a3c3475 100644 --- a/data/translate_data.py +++ b/data/translate_data.py @@ -1,14 +1,15 @@ """Original script by @BramNH on GitHub""" +import argparse import csv import os import time import re -from deep_translator import GoogleTranslator +from deep_translator import GoogleTranslator, DeeplTranslator from deep_translator.base import BaseTranslator from deep_translator.exceptions import TooManyRequests from tqdm import tqdm -from transformers import pipeline +import langcodes SUPPORTED_DEVICES = [ "light", @@ -32,20 +33,22 @@ class Seq2SeqTranslator(BaseTranslator): def __init__(self, model_name: str, **kwargs): super().__init__(**kwargs) + from transformers import pipeline + self.translator = pipeline("translation", model=model_name, tokenizer=model_name, device=0) def translate(self, text: str, **kwargs): return self.translator(text)[0]["translation_text"] class DatasetTranslator: - translator: Seq2SeqTranslator + translator: BaseTranslator source_language: str target_language: str - def __init__(self, source_language, target_language, model_name): + def __init__(self, source_language, target_language, translator): self.source_language = source_language self.target_language = target_language - self.translator = Seq2SeqTranslator(model_name=model_name) + self.translator = translator def translate_all_piles(self): os.makedirs(f"./piles/{self.target_language}", exist_ok=True) @@ -334,7 +337,32 @@ class DatasetTranslator: f.writelines(pile_of_todo_items_target) -# TODO: cmd line args -DatasetTranslator("english", "german", "Helsinki-NLP/opus-mt-en-de").translate_all_piles() -# DatasetTranslator("english", "spanish", "Helsinki-NLP/opus-mt-en-es").translate_all_piles() -# DatasetTranslator("english", "french", "Helsinki-NLP/opus-mt-en-fr").translate_all_piles() \ No newline at end of file +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("destination_language", help="Destination language for translation") + parser.add_argument("--source_language", default="english", help="Source language for translation (default: english)") + parser.add_argument("--translator_type", choices=["google", "transformers", "deepl"], required=True, help="Translator type (choose from: google, transformers, deepl)") + parser.add_argument("--model_name", help="Model name (optional)") + parser.add_argument("--api_key", help="API key (optional)") + + args = parser.parse_args() + + source_code = langcodes.find(args.source_language).language + dest_code = langcodes.find(args.destination_language).language + + if args.translator_type == "google": + translator = GoogleTranslator(source=source_code, target=dest_code) + elif args.translator_type == "transformers": + if not args.model_name: + print("No model name was provided!") + parser.print_usage() + exit(-1) + translator = Seq2SeqTranslator(model_name=args.model_name) + elif args.translator_type == "deepl": + if not args.api_key and os.getenv(): + print("No api key was provided!") + parser.print_usage() + exit(-1) + translator = DeeplTranslator(source=source_code, target=dest_code, api_key=args.api_key) + + DatasetTranslator(args.source_language, args.destination_language, translator).translate_all_piles() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index a31965b..0b299a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pandas # flash-attn sentencepiece deep-translator +langcodes homeassistant hassil