mirror of
https://github.com/acon96/home-llm.git
synced 2026-01-09 21:58:00 -05:00
local translation + training update
This commit is contained in:
@@ -5,8 +5,10 @@ import time
|
||||
import re
|
||||
|
||||
from deep_translator import GoogleTranslator
|
||||
from deep_translator.base import BaseTranslator
|
||||
from deep_translator.exceptions import TooManyRequests
|
||||
from tqdm import tqdm
|
||||
from transformers import pipeline
|
||||
|
||||
SUPPORTED_DEVICES = [
|
||||
"light",
|
||||
@@ -25,15 +27,25 @@ SUPPORTED_DEVICES = [
|
||||
def format_device_name(input_str):
|
||||
return input_str.replace('-', '_').replace(' ', '_').lower()
|
||||
|
||||
class Seq2SeqTranslator(BaseTranslator):
|
||||
|
||||
def __init__(self, model_name: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.translator = pipeline("translation", model=model_name, tokenizer=model_name, device=0)
|
||||
|
||||
def translate(self, text: str, **kwargs):
|
||||
return self.translator(text)[0]["translation_text"]
|
||||
|
||||
class DatasetTranslator:
|
||||
translator: GoogleTranslator
|
||||
translator: Seq2SeqTranslator
|
||||
source_language: str
|
||||
target_language: str
|
||||
|
||||
def __init__(self, source_language, target_language):
|
||||
def __init__(self, source_language, target_language, model_name):
|
||||
self.source_language = source_language
|
||||
self.target_language = target_language
|
||||
self.translator = GoogleTranslator(source=source_language, target=target_language)
|
||||
self.translator = Seq2SeqTranslator(model_name=model_name)
|
||||
|
||||
def translate_all_piles(self):
|
||||
os.makedirs(f"./piles/{self.target_language}", exist_ok=True)
|
||||
@@ -323,6 +335,6 @@ class DatasetTranslator:
|
||||
|
||||
|
||||
# TODO: cmd line args
|
||||
DatasetTranslator("english", "german").translate_all_piles()
|
||||
DatasetTranslator("english", "spanish").translate_all_piles()
|
||||
DatasetTranslator("english", "french").translate_all_piles()
|
||||
DatasetTranslator("english", "german", "Helsinki-NLP/opus-mt-en-de").translate_all_piles()
|
||||
# DatasetTranslator("english", "spanish", "Helsinki-NLP/opus-mt-en-es").translate_all_piles()
|
||||
# DatasetTranslator("english", "french", "Helsinki-NLP/opus-mt-en-fr").translate_all_piles()
|
||||
Reference in New Issue
Block a user