mirror of
https://github.com/DrewThomasson/ebook2audiobook.git
synced 2026-01-10 06:18:02 -05:00
Merge pull request #1196 from ROBERT-MCDOWELL/v25
v25.11.25 pre-release 1
This commit is contained in:
@@ -1 +1 @@
|
||||
25.11.22
|
||||
25.11.25
|
||||
285
app.py
285
app.py
@@ -12,6 +12,7 @@ if not os.path.exists(dst_pyfile) or os.path.getmtime(dst_pyfile) < os.path.getm
|
||||
shutil.copy2(src_pyfile, dst_pyfile)
|
||||
##############
|
||||
|
||||
import platform
|
||||
import argparse
|
||||
import filecmp
|
||||
import importlib.util
|
||||
@@ -21,7 +22,9 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
import re
|
||||
|
||||
from typing import Tuple
|
||||
from importlib.metadata import version, PackageNotFoundError
|
||||
from pathlib import Path
|
||||
|
||||
@@ -58,16 +61,226 @@ and run "./ebook2audiobook.sh" for Linux and Mac or "ebook2audiobook.cmd" for Wi
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def torch_version_is_leq(target):
|
||||
import torch
|
||||
|
||||
def detect_platform_tag()->str:
|
||||
if sys.platform.startswith('win'):
|
||||
return 'win'
|
||||
if sys.platform == 'darwin':
|
||||
return 'macosx'
|
||||
if sys.platform.startswith('linux'):
|
||||
return 'manylinux'
|
||||
return 'unknown'
|
||||
|
||||
def detect_arch_tag()->str:
|
||||
m=platform.machine().lower()
|
||||
if m in ('x86_64','amd64'):
|
||||
return m
|
||||
if m in ('aarch64','arm64'):
|
||||
return m
|
||||
return 'unknown'
|
||||
|
||||
def detect_gpu()->str:
|
||||
|
||||
def has_cmd(cmd:str)->bool:
|
||||
return shutil.which(cmd) is not None
|
||||
|
||||
def try_cmd(cmd:str)->str:
|
||||
try:
|
||||
out = subprocess.check_output(
|
||||
cmd,
|
||||
shell = True,
|
||||
stderr = subprocess.DEVNULL
|
||||
)
|
||||
return out.decode().lower()
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
def toolkit_version_parse(text:str)->str|None:
|
||||
m = re.findall(r'\d+(?:\.\d+)+', text)
|
||||
return m[0] if m else None
|
||||
|
||||
def is_toolkit_version_exceeds(version_str:str|None, max_tuple:Tuple[int,int])->bool:
|
||||
if max_tuple == (0, 0) or version_str is None:
|
||||
return False
|
||||
parts = version_str.split('.')
|
||||
major = int(parts[0])
|
||||
minor = int(parts[1]) if len(parts) > 1 else 0
|
||||
return (major, minor) > max_tuple
|
||||
|
||||
def tegra_version()->str:
|
||||
if os.path.exists('/etc/nv_tegra_release'):
|
||||
return try_cmd('cat /etc/nv_tegra_release')
|
||||
return ''
|
||||
|
||||
def jetpack_version(text:str)->str:
|
||||
m1 = re.search(r'r(\d+)', text) # R35, R36...
|
||||
m2 = re.search(r'revision:\s*([\d\.]+)', text) # X.Y
|
||||
if not m1 or not m2:
|
||||
warn(msg)
|
||||
return 'unknown'
|
||||
l4t_major = int(m1.group(1))
|
||||
rev = m2.group(1)
|
||||
parts = rev.split('.')
|
||||
rev_major = int(parts[0])
|
||||
rev_minor = int(parts[1]) if len(parts) > 1 else 0
|
||||
# -------------------------------------------------------
|
||||
# JetPack < 5.0 → CPU
|
||||
# -------------------------------------------------------
|
||||
if l4t_major < 35:
|
||||
msg = f'JetPack too old (L4T {l4t_major}). Please upgrade to JetPack 5.1+. Falling back to CPU.'
|
||||
warn(msg)
|
||||
return 'unsupported'
|
||||
# -------------------------------------------------------
|
||||
# JetPack 5.x (L4T 35)
|
||||
# -------------------------------------------------------
|
||||
if l4t_major == 35:
|
||||
# JetPack 5.0 / 5.0.1
|
||||
if rev_major == 0 and rev_minor <= 1:
|
||||
msg = 'JetPack 5.0/5.0.1 detected. Please upgrade to JetPack 5.1+ to use the GPU. Failing back to CPU'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
# JetPack 5.0.2 / 5.0.x
|
||||
if rev_major == 0 and rev_minor >= 2:
|
||||
msg = 'JetPack 5.0.x detected. Please upgrade to JetPack 5.1+ to use the GPU. Failing back to CPU'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
# JetPack 5.1.0
|
||||
if rev_major == 1 and rev_minor == 0:
|
||||
msg = 'JetPack 5.1.0 detected. Please upgrade to JetPack 5.1.2 or newer.'
|
||||
warn(msg)
|
||||
return 'jetpack5' # 51
|
||||
# JetPack 5.1.1
|
||||
if rev_major == 1 and rev_minor == 1:
|
||||
msg = 'JetPack 5.1.1 detected. Please upgrade to JetPack 5.1.2 or newer.'
|
||||
warn(msg)
|
||||
return 'jetpack5' # 511
|
||||
# JetPack >= 5.1.2 AND < 6 → ALWAYS == 512
|
||||
if (rev_major > 1) or (rev_major == 1 and rev_minor >= 2):
|
||||
return 'jetpack5' # 512
|
||||
msg = 'Unrecognized JetPack 5.x version. Falling back to CPU.'
|
||||
warn(msg)
|
||||
return 'unknown'
|
||||
# -------------------------------------------------------
|
||||
# JetPack 6.x (L4T 36)
|
||||
# -------------------------------------------------------
|
||||
if l4t_major == 36:
|
||||
if rev_major == 2:
|
||||
return '60'
|
||||
if rev_major == 3:
|
||||
return '61'
|
||||
msg = 'Unrecognized JetPack 6.x version. Falling back to CPU.'
|
||||
warn(msg)
|
||||
return 'unknown'
|
||||
|
||||
def warn(msg:str)->None:
|
||||
print(f'[WARNING] {msg}')
|
||||
|
||||
arch:str = platform.machine().lower()
|
||||
|
||||
# ============================================================
|
||||
# CUDA
|
||||
# ============================================================
|
||||
if has_cmd('nvidia-smi'):
|
||||
out = try_cmd('nvidia-smi')
|
||||
version_str:str|None = toolkit_version_parse(out)
|
||||
if is_toolkit_version_exceeds(version_str, max_cuda_version):
|
||||
msg = f'CUDA {version_str} > max {max_cuda_version}. Falling back to CPU.'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
if version_str:
|
||||
devices['CUDA']['found'] = True
|
||||
major = version_str.split('.')[0]
|
||||
minor = version_str.split('.')[1]
|
||||
return f'cu{major}{minor}'
|
||||
msg = 'No CUDA version found. Falling back to CPU.'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
|
||||
# ============================================================
|
||||
# ROCm
|
||||
# ============================================================
|
||||
if has_cmd('rocminfo') or os.path.exists('/opt/rocm'):
|
||||
out = try_cmd('rocminfo')
|
||||
version_str = toolkit_version_parse(out)
|
||||
if is_toolkit_version_exceeds(version_str, max_rocm_version):
|
||||
msg = f'ROCm {version_str} > max {max_rocm_version} → CPU'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
if version_str:
|
||||
devices['ROCM']['found'] = True
|
||||
return f'rocm{version_str}'
|
||||
msg = 'No ROCm version found. Falling back to CPU.'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
|
||||
# ============================================================
|
||||
# APPLE MPS
|
||||
# ============================================================
|
||||
if sys.platform == 'darwin' and arch in ('arm64', 'aarch64'):
|
||||
devices['MPS']['found'] = True
|
||||
return 'mps'
|
||||
|
||||
# ============================================================
|
||||
# INTEL XPU
|
||||
# ============================================================
|
||||
if os.path.exists('/dev/dri/renderD128'):
|
||||
out = try_cmd('lspci')
|
||||
if 'intel' in out:
|
||||
oneapi_out:str = try_cmd('sycl-ls') if has_cmd('sycl-ls') else ''
|
||||
version_str = toolkit_version_parse(oneapi_out)
|
||||
if is_toolkit_version_exceeds(version_str, max_xpu_version):
|
||||
msg = f'XPU {version_str} > max {max_xpu_version} → CPU'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
if has_cmd('sycl-ls') or has_cmd('clinfo'):
|
||||
devices['XPU']['found']
|
||||
return 'xpu'
|
||||
msg = 'Intel GPU detected but oneAPI runtime missing → CPU'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
if has_cmd('clinfo'):
|
||||
out = try_cmd('clinfo')
|
||||
if 'intel' in out:
|
||||
return 'xpu'
|
||||
|
||||
# ============================================================
|
||||
# JETSON
|
||||
# ============================================================
|
||||
if arch in ('aarch64','arm64') and (os.path.exists('/etc/nv_tegra_release') or 'tegra' in try_cmd('cat /proc/device-tree/compatible')):
|
||||
# Always read Tegra release if device looks like Jetson
|
||||
raw = tegra_version()
|
||||
# Detect JetPack version code
|
||||
jp_code = jetpack_version(raw)
|
||||
# Unsupported JetPack (<5.1)
|
||||
if jp_code in ['unsupported', 'unknown']:
|
||||
return 'cpu'
|
||||
# Direct Jetson detection mechanisms
|
||||
if os.path.exists('/etc/nv_tegra_release'):
|
||||
devices['CUDA']['found'] = True
|
||||
return f'jetson-{jp_code}'
|
||||
if os.path.exists('/proc/device-tree/compatible'):
|
||||
out = try_cmd('cat /proc/device-tree/compatible')
|
||||
if 'tegra' in out:
|
||||
devices['CUDA']['found'] = True
|
||||
return f'jetson-{jp_code}'
|
||||
out = try_cmd('uname -a')
|
||||
if 'tegra' in out:
|
||||
msg = 'Unknown Jetson device. Failing back to cpu'
|
||||
warn(msg)
|
||||
return 'cpu'
|
||||
|
||||
# ============================================================
|
||||
# CPU
|
||||
# ============================================================
|
||||
return 'cpu'
|
||||
|
||||
def parse_torch_version(current:str)->str:
|
||||
from packaging.version import Version, InvalidVersion
|
||||
v = torch.__version__
|
||||
try:
|
||||
parsed = Version(v)
|
||||
parsed = Version(current)
|
||||
except InvalidVersion:
|
||||
parsed = Version(v.split('+')[0])
|
||||
return parsed <= Version(target)
|
||||
parsed = Version(current.split('+')[0])
|
||||
return parsed
|
||||
|
||||
def check_and_install_requirements(file_path:str)->bool:
|
||||
if not os.path.exists(file_path):
|
||||
@@ -75,18 +288,19 @@ def check_and_install_requirements(file_path:str)->bool:
|
||||
print(error)
|
||||
return False
|
||||
try:
|
||||
backend_specs = {"os": detect_platform_tag(), "arch": detect_arch_tag(), "pyvenv": sys.version_info[:2], "gpu": detect_gpu()}
|
||||
print(f'--------------- {backend_specs} -------------')
|
||||
try:
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import Version
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from tqdm import tqdm
|
||||
from packaging.markers import Marker
|
||||
except ImportError:
|
||||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--no-cache-dir', 'packaging', 'tqdm'])
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from packaging.version import Version
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from tqdm import tqdm
|
||||
from packaging.markers import Marker
|
||||
import re as regex
|
||||
flexible_packages = {"torch", "torchaudio", "numpy"}
|
||||
torch_version = False
|
||||
try:
|
||||
@@ -97,7 +311,7 @@ def check_and_install_requirements(file_path:str)->bool:
|
||||
cuda_only_packages = ('deepspeed')
|
||||
with open(file_path, 'r') as f:
|
||||
contents = f.read().replace('\r', '\n')
|
||||
packages = [pkg.strip() for pkg in contents.splitlines() if pkg.strip() and regex.search(r'[a-zA-Z0-9]', pkg)]
|
||||
packages = [pkg.strip() for pkg in contents.splitlines() if pkg.strip() and re.search(r'[a-zA-Z0-9]', pkg)]
|
||||
if sys.version_info >= (3, 11):
|
||||
packages.append("pymupdf-layout")
|
||||
missing_packages = []
|
||||
@@ -115,7 +329,7 @@ def check_and_install_requirements(file_path:str)->bool:
|
||||
print(error)
|
||||
package = pkg_part.strip()
|
||||
if 'git+' in package or '://' in package:
|
||||
pkg_name_match = regex.search(r'([\w\-]+)\s*@?\s*git\+', package)
|
||||
pkg_name_match = re.search(r'([\w\-]+)\s*@?\s*git\+', package)
|
||||
pkg_name = pkg_name_match.group(1) if pkg_name_match else None
|
||||
if pkg_name:
|
||||
spec = importlib.util.find_spec(pkg_name)
|
||||
@@ -128,8 +342,8 @@ def check_and_install_requirements(file_path:str)->bool:
|
||||
print(error)
|
||||
missing_packages.append(package)
|
||||
continue
|
||||
clean_pkg = regex.sub(r'\[.*?\]', '', package)
|
||||
pkg_name = regex.split(r'[<>=]', clean_pkg, maxsplit=1)[0].strip()
|
||||
clean_pkg = re.sub(r'\[.*?\]', '', package)
|
||||
pkg_name = re.split(r'[<>=]', clean_pkg, maxsplit=1)[0].strip()
|
||||
if pkg_name in cuda_only_packages:
|
||||
has_cuda_build = False
|
||||
if torch_version:
|
||||
@@ -151,13 +365,13 @@ def check_and_install_requirements(file_path:str)->bool:
|
||||
spec_str = clean_pkg[len(pkg_name):].strip()
|
||||
if spec_str:
|
||||
spec = SpecifierSet(spec_str)
|
||||
norm_match = regex.match(r'^(\d+\.\d+(?:\.\d+)?)', installed_version)
|
||||
norm_match = re.match(r'^(\d+\.\d+(?:\.\d+)?)', installed_version)
|
||||
short_version = norm_match.group(1) if norm_match else installed_version
|
||||
try:
|
||||
installed_v = Version(short_version)
|
||||
except Exception:
|
||||
installed_v = Version('0')
|
||||
req_match = regex.search(r'(\d+\.\d+(?:\.\d+)?)', spec_str)
|
||||
req_match = re.search(r'(\d+\.\d+(?:\.\d+)?)', spec_str)
|
||||
if req_match:
|
||||
req_v = Version(req_match.group(1))
|
||||
imajor, iminor = installed_v.major, installed_v.minor
|
||||
@@ -212,22 +426,47 @@ def check_and_install_requirements(file_path:str)->bool:
|
||||
return False
|
||||
msg = '\nAll required packages are installed.'
|
||||
print(msg)
|
||||
import torch
|
||||
import numpy as np
|
||||
torch_version = torch.__version__
|
||||
numpy_version = Version(np.__version__)
|
||||
if torch_version_is_leq('2.2.2') and numpy_version >= Version('2.0.0'):
|
||||
torch_version_parsed = parse_torch_version(torch_version)
|
||||
if backend_specs['gpu'] not in ['cpu', 'unknown', 'unsupported']:
|
||||
current_tag_pattern = re.search(r'\+(.+)$', torch_version)
|
||||
current_tag = current_tag_pattern.group(1)
|
||||
non_standard_tag = re.fullmatch(r'[0-9a-f]{7,40}', current_tag)
|
||||
if (
|
||||
non_standard_tag is None and current_tag != backend_specs['gpu'] or
|
||||
non_standard_tag is not None and backend_specs['gpu'] in ['jetson-jetpack5', 'jetson-60', 'jetson-61'] and non_standard_tag != torch_mapping[backend_specs['gpu']]['tag']
|
||||
):
|
||||
try:
|
||||
backend_tag = torch_mapping[backend_specs['gpu']]['tag']
|
||||
backend_os = backend_specs['os']
|
||||
backend_arch = backend_specs['arch']
|
||||
backend_url = torch_mapping[backend_specs['gpu']]['url']
|
||||
if backend-specs['gpu'] == 'jetson-jetpack5':
|
||||
torch_pkg = f''
|
||||
elif backend_specs['gpu'] in ['jetson-60', 'jetson-61']:
|
||||
jetson_torch_version = default_jetson60_torch if backend_specs['gpu'] == 'jetson-60' else default_jetson61_torch
|
||||
torch_pkg = f'{backend_url}/v{backend_tag}/pytorch/torch-{jetson_torch_version}-{default_py_tag}-linux_{backend_arch}.whl'
|
||||
else:
|
||||
torch_pkg = f'{gpu_url}/{backend_tag}/torch/torch-{torch_version_parsed}+{gpu_tag}-{default_py_tag}-{backend_os}_{backend_arch}.whl'
|
||||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--no-cache-dir', torch_pkg])
|
||||
import torch
|
||||
torch_version = torch.__version__
|
||||
except subprocess.CalledProcessError as e:
|
||||
error = f'Failed to install {packages}: {e}'
|
||||
print(error)
|
||||
return False
|
||||
if torch_version_parsed <= Version('2.2.2') and numpy_version >= Version('2.0.0'):
|
||||
try:
|
||||
msg = 'torch version needs nump < 2. downgrading numpy to 1.26.4...'
|
||||
msg = 'torch version needs numpy < 2. downgrading numpy to 1.26.4...'
|
||||
print(msg)
|
||||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--no-cache-dir', '--use-pep517', 'numpy<2'])
|
||||
except subprocess.CalledProcessError as e:
|
||||
error = f'Failed to downgrade to numpy < 2: {e}'
|
||||
print(error)
|
||||
return False
|
||||
import torch
|
||||
devices['CUDA']['found'] = getattr(torch, "cuda", None) is not None and torch.cuda.is_available() and not (hasattr(torch.version, "hip") and torch.version.hip is not None)
|
||||
devices['ROCM']['found'] = hasattr(torch.version, "hip") and torch.version.hip is not None and torch.cuda.is_available()
|
||||
devices['MPS']['found'] = getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()
|
||||
devices['XPU']['found'] = getattr(torch, "xpu", None) is not None and torch.xpu.is_available()
|
||||
return True
|
||||
except Exception as e:
|
||||
error = f'check_and_install_requirements() error: {e}'
|
||||
|
||||
@@ -33,5 +33,4 @@ pyannote-audio<=3.4.0
|
||||
argostranslate<=1.10.0
|
||||
gradio==5.49.1
|
||||
torch<=2.7.1
|
||||
torchaudio<=2.7.1
|
||||
coqui-tts[languages]==0.27.2
|
||||
@@ -394,6 +394,40 @@ else
|
||||
if [[ ! -d "$SCRIPT_DIR/$PYTHON_ENV" ]]; then
|
||||
if [[ "$OSTYPE" = "darwin"* && "$ARCH" = "x86_64" ]]; then
|
||||
PYTHON_VERSION="3.11"
|
||||
elif [[ -r /proc/device-tree/model ]]; then
|
||||
# Detect Jetson and select correct Python version
|
||||
MODEL=$(tr -d '\0' </proc/device-tree/model | tr 'A-Z' 'a-z')
|
||||
if [[ "$MODEL" == *jetson* ]]; then
|
||||
# Detect JetPack (L4T version)
|
||||
JP=$(dpkg-query --showformat='${Version}' --show nvidia-l4t-release 2>/dev/null | cut -d. -f1-2)
|
||||
L4T=$(awk -F' ' '/# R/ {print $2}' /etc/nv_tegra_release)
|
||||
case "$JP" in
|
||||
32.*)
|
||||
# JetPack 4.x (Nano, TX2)
|
||||
PYTHON_VERSION="3.10"
|
||||
echo "[WARNING] JetPack installed is too old to use its GPU. Upgrade to version 5.1.x. CPU will be used instead"
|
||||
;;
|
||||
35.*)
|
||||
# JetPack 5.x (Xavier NX, AGX Xavier, Orin Nano/NX/AGX)
|
||||
if [ "$L4T" != "35.1" ]; then
|
||||
echo "[WARNING] JetPack must be updated to the last version 5.1.x allowing to use the GPU. CPU will be used instead"
|
||||
fi
|
||||
PYTHON_VERSION="3.10"
|
||||
;;
|
||||
36.*)
|
||||
# JetPack 6.x (Orin)
|
||||
if command -v python3.11 >/dev/null; then
|
||||
PYTHON_VERSION="3.11"
|
||||
else
|
||||
PYTHON_VERSION="3.10"
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
# Unknown JetPack -> safe default
|
||||
PYTHON_VERSION="3.10"
|
||||
;;
|
||||
esac
|
||||
fi
|
||||
else
|
||||
compare_versions "$PYTHON_VERSION" "$MIN_PYTHON_VERSION"
|
||||
case $? in
|
||||
|
||||
75
lib/conf.py
75
lib/conf.py
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import platform
|
||||
import tempfile
|
||||
import sys
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Global configuration
|
||||
@@ -8,7 +9,10 @@ import tempfile
|
||||
min_python_version = (3,10)
|
||||
max_python_version = (3,12)
|
||||
|
||||
components_dir = os.path.abspath('components')
|
||||
max_cuda_version = (12,8)
|
||||
max_rocm_version = (0,0)
|
||||
max_xpu_version = (0,0)
|
||||
|
||||
tmp_dir = os.path.abspath('tmp')
|
||||
tempfile.tempdir = tmp_dir
|
||||
tmp_expire = 7 # days
|
||||
@@ -17,6 +21,7 @@ models_dir = os.path.abspath('models')
|
||||
ebooks_dir = os.path.abspath('ebooks')
|
||||
voices_dir = os.path.abspath('voices')
|
||||
tts_dir = os.path.join(models_dir, 'tts')
|
||||
components_dir = os.path.abspath('components')
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Environment setup
|
||||
@@ -75,10 +80,76 @@ devices = {
|
||||
"ROCM": {"proc": "rocm", "found": False},
|
||||
"XPU": {"proc": "xpu", "found": False}
|
||||
}
|
||||
|
||||
default_device = devices['CPU']['proc']
|
||||
default_gpu_wiki = '<a href="https://github.com/DrewThomasson/ebook2audiobook/wiki/GPU-ISSUES">GPU howto wiki</a>'
|
||||
default_chapters_preview = False
|
||||
|
||||
default_gpu_wiki = '<a href="https://github.com/DrewThomasson/ebook2audiobook/wiki/GPU-ISSUES">GPU howto wiki</a>'
|
||||
|
||||
default_py_major = sys.version_info.major
|
||||
default_py_minor = sys.version_info.minor
|
||||
default_py_tag = f'cp{default_py_major}{default_py_minor}-cp{default_py_major}{default_py_minor}'
|
||||
|
||||
default_pytorch_url = 'https://download.pytorch.org/whl/'
|
||||
default_jetson_url = 'https://developer.download.nvidia.com/compute/redist/jp/'
|
||||
default_compiled_url = 'https://xxxxxxxxxx/compiled/xxxxxx.whl'
|
||||
|
||||
default_jetson5_torch = ''
|
||||
default_jetson60_torch = '2.4.0a0+3bcc3cddb5.nv24.07.16234504'
|
||||
default_jetson61_torch = '2.5.0a0+872d972e41.nv24.08.17622132'
|
||||
|
||||
torch_mapping = {
|
||||
|
||||
# CUDA
|
||||
"cu113": {"tag": "cu113", "url": default_pytorch_url},
|
||||
"cu114": {"tag": "cu114", "url": default_pytorch_url},
|
||||
"cu115": {"tag": "cu115", "url": default_pytorch_url},
|
||||
"cu116": {"tag": "cu116", "url": default_pytorch_url},
|
||||
"cu117": {"tag": "cu117", "url": default_pytorch_url},
|
||||
"cu118": {"tag": "cu118", "url": default_pytorch_url},
|
||||
"cu121": {"tag": "cu121", "url": default_pytorch_url},
|
||||
"cu124": {"tag": "cu124", "url": default_pytorch_url},
|
||||
"cu126": {"tag": "cu126", "url": default_pytorch_url},
|
||||
"cu128": {"tag": "cu128", "url": default_pytorch_url},
|
||||
"cu129": {"tag": "cu129", "url": default_pytorch_url},
|
||||
"cu130": {"tag": "cu130", "url": default_pytorch_url},
|
||||
|
||||
# ROCm
|
||||
"rocm3.10": {"tag": "rocm3.10", "url": default_pytorch_url},
|
||||
"rocm3.7": {"tag": "rocm3.7", "url": default_pytorch_url},
|
||||
"rocm3.8": {"tag": "rocm3.8", "url": default_pytorch_url},
|
||||
"rocm4.0.1": {"tag": "rocm4.0.1", "url": default_pytorch_url},
|
||||
"rocm4.1": {"tag": "rocm4.1", "url": default_pytorch_url},
|
||||
"rocm4.2": {"tag": "rocm4.2", "url": default_pytorch_url},
|
||||
"rocm4.3.1": {"tag": "rocm4.3.1", "url": default_pytorch_url},
|
||||
"rocm4.5.2": {"tag": "rocm4.5.2", "url": default_pytorch_url},
|
||||
"rocm5.0": {"tag": "rocm5.0", "url": default_pytorch_url},
|
||||
"rocm5.1.1": {"tag": "rocm5.1.1", "url": default_pytorch_url},
|
||||
"rocm5.2": {"tag": "rocm5.2", "url": default_pytorch_url},
|
||||
"rocm5.3": {"tag": "rocm5.3", "url": default_pytorch_url},
|
||||
"rocm5.4.2": {"tag": "rocm5.4.2", "url": default_pytorch_url},
|
||||
"rocm5.5": {"tag": "rocm5.5", "url": default_pytorch_url},
|
||||
"rocm5.6": {"tag": "rocm5.6", "url": default_pytorch_url},
|
||||
"rocm5.7": {"tag": "rocm5.7", "url": default_pytorch_url},
|
||||
"rocm6.0": {"tag": "rocm6.0", "url": default_pytorch_url},
|
||||
"rocm6.1": {"tag": "rocm6.1", "url": default_pytorch_url},
|
||||
"rocm6.2": {"tag": "rocm6.2", "url": default_pytorch_url},
|
||||
"rocm6.2.4": {"tag": "rocm6.2.4", "url": default_pytorch_url},
|
||||
"rocm6.3": {"tag": "rocm6.3", "url": default_pytorch_url},
|
||||
"rocm6.4": {"tag": "rocm6.4", "url": default_pytorch_url},
|
||||
|
||||
# MPS
|
||||
"mps": {"tag": "mps", "url": default_pytorch_url},
|
||||
|
||||
# XPU
|
||||
"xpu": {"tag": "xpu", "url": default_pytorch_url},
|
||||
|
||||
# JETSON
|
||||
"jetson-jetpack5": {"tag": "xxxxxxxxxxxx", "url": default_compiled_url},
|
||||
"jetson-60": {"tag": "v60", "url": default_jetson_url},
|
||||
"jetson-61": {"tag": "v61", "url": default_jetson_url}
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Python environment references
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
158
lib/functions.py
158
lib/functions.py
@@ -1460,87 +1460,87 @@ def is_latin(s: str) -> bool:
|
||||
return all((u'a' <= ch.lower() <= 'z') or ch.isdigit() or not ch.isalpha() for ch in s)
|
||||
|
||||
def foreign2latin(text, base_lang):
|
||||
def script_of(word):
|
||||
for ch in word:
|
||||
if ch.isalpha():
|
||||
name = unicodedata.name(ch, "")
|
||||
if "CYRILLIC" in name:
|
||||
return "cyrillic"
|
||||
if "LATIN" in name:
|
||||
return "latin"
|
||||
if "ARABIC" in name:
|
||||
return "arabic"
|
||||
if "HANGUL" in name:
|
||||
return "hangul"
|
||||
if "HIRAGANA" in name or "KATAKANA" in name:
|
||||
return "japanese"
|
||||
if "CJK" in name or "IDEOGRAPH" in name:
|
||||
return "chinese"
|
||||
return "unknown"
|
||||
def script_of(word):
|
||||
for ch in word:
|
||||
if ch.isalpha():
|
||||
name = unicodedata.name(ch, "")
|
||||
if "CYRILLIC" in name:
|
||||
return "cyrillic"
|
||||
if "LATIN" in name:
|
||||
return "latin"
|
||||
if "ARABIC" in name:
|
||||
return "arabic"
|
||||
if "HANGUL" in name:
|
||||
return "hangul"
|
||||
if "HIRAGANA" in name or "KATAKANA" in name:
|
||||
return "japanese"
|
||||
if "CJK" in name or "IDEOGRAPH" in name:
|
||||
return "chinese"
|
||||
return "unknown"
|
||||
|
||||
def romanize(word):
|
||||
scr = script_of(word)
|
||||
if scr == "latin":
|
||||
return word
|
||||
if base_lang in ["ru", "rus"] and scr == "cyrillic":
|
||||
return word
|
||||
if base_lang in ["ar", "ara"] and scr == "arabic":
|
||||
return word
|
||||
if base_lang in ["ko", "kor"] and scr == "hangul":
|
||||
return word
|
||||
if base_lang in ["ja", "jpn"] and scr == "japanese":
|
||||
return word
|
||||
if base_lang in ["zh", "zho"] and scr == "chinese":
|
||||
return word
|
||||
try:
|
||||
if scr == "chinese":
|
||||
from pypinyin import pinyin, Style
|
||||
return "".join(x[0] for x in pinyin(word, style=Style.NORMAL))
|
||||
if scr == "japanese":
|
||||
import pykakasi
|
||||
k = pykakasi.kakasi()
|
||||
k.setMode("H", "a")
|
||||
k.setMode("K", "a")
|
||||
k.setMode("J", "a")
|
||||
k.setMode("r", "Hepburn")
|
||||
return k.getConverter().do(word)
|
||||
if scr == "hangul":
|
||||
return unidecode(word)
|
||||
if scr == "arabic":
|
||||
return unidecode(phonemize(word, language="ar", backend="espeak"))
|
||||
if scr == "cyrillic":
|
||||
return unidecode(phonemize(word, language="ru", backend="espeak"))
|
||||
return unidecode(word)
|
||||
except:
|
||||
return unidecode(word)
|
||||
def romanize(word):
|
||||
scr = script_of(word)
|
||||
if scr == "latin":
|
||||
return word
|
||||
if base_lang in ["ru", "rus"] and scr == "cyrillic":
|
||||
return word
|
||||
if base_lang in ["ar", "ara"] and scr == "arabic":
|
||||
return word
|
||||
if base_lang in ["ko", "kor"] and scr == "hangul":
|
||||
return word
|
||||
if base_lang in ["ja", "jpn"] and scr == "japanese":
|
||||
return word
|
||||
if base_lang in ["zh", "zho"] and scr == "chinese":
|
||||
return word
|
||||
try:
|
||||
if scr == "chinese":
|
||||
from pypinyin import pinyin, Style
|
||||
return "".join(x[0] for x in pinyin(word, style=Style.NORMAL))
|
||||
if scr == "japanese":
|
||||
import pykakasi
|
||||
k = pykakasi.kakasi()
|
||||
k.setMode("H", "a")
|
||||
k.setMode("K", "a")
|
||||
k.setMode("J", "a")
|
||||
k.setMode("r", "Hepburn")
|
||||
return k.getConverter().do(word)
|
||||
if scr == "hangul":
|
||||
return unidecode(word)
|
||||
if scr == "arabic":
|
||||
return unidecode(phonemize(word, language="ar", backend="espeak"))
|
||||
if scr == "cyrillic":
|
||||
return unidecode(phonemize(word, language="ru", backend="espeak"))
|
||||
return unidecode(word)
|
||||
except:
|
||||
return unidecode(word)
|
||||
|
||||
tts_markers = set(TTS_SML.values())
|
||||
protected = {}
|
||||
for i, m in enumerate(tts_markers):
|
||||
key = f"__TTS_MARKER_{i}__"
|
||||
protected[key] = m
|
||||
text = text.replace(m, key)
|
||||
tokens = re.findall(r"\w+|[^\w\s]", text, re.UNICODE)
|
||||
buf = []
|
||||
for t in tokens:
|
||||
if t in protected:
|
||||
buf.append(t)
|
||||
elif re.match(r"^\w+$", t):
|
||||
buf.append(romanize(t))
|
||||
else:
|
||||
buf.append(t)
|
||||
out = ""
|
||||
for i, t in enumerate(buf):
|
||||
if i == 0:
|
||||
out += t
|
||||
else:
|
||||
if re.match(r"^\w+$", buf[i-1]) and re.match(r"^\w+$", t):
|
||||
out += " " + t
|
||||
else:
|
||||
out += t
|
||||
for k, v in protected.items():
|
||||
out = out.replace(k, v)
|
||||
return out
|
||||
tts_markers = set(TTS_SML.values())
|
||||
protected = {}
|
||||
for i, m in enumerate(tts_markers):
|
||||
key = f"__TTS_MARKER_{i}__"
|
||||
protected[key] = m
|
||||
text = text.replace(m, key)
|
||||
tokens = re.findall(r"\w+|[^\w\s]", text, re.UNICODE)
|
||||
buf = []
|
||||
for t in tokens:
|
||||
if t in protected:
|
||||
buf.append(t)
|
||||
elif re.match(r"^\w+$", t):
|
||||
buf.append(romanize(t))
|
||||
else:
|
||||
buf.append(t)
|
||||
out = ""
|
||||
for i, t in enumerate(buf):
|
||||
if i == 0:
|
||||
out += t
|
||||
else:
|
||||
if re.match(r"^\w+$", buf[i-1]) and re.match(r"^\w+$", t):
|
||||
out += " " + t
|
||||
else:
|
||||
out += t
|
||||
for k, v in protected.items():
|
||||
out = out.replace(k, v)
|
||||
return out
|
||||
|
||||
def filter_sml(text:str)->str:
|
||||
for key, value in TTS_SML.items():
|
||||
|
||||
102
lib/lang.py
102
lib/lang.py
@@ -50,14 +50,14 @@ punctuation_switch = {
|
||||
|
||||
# Dashes, underscores & Hyphens that might cause weird pauses
|
||||
'–': '.', # En dash (Unicode U+2013)
|
||||
"_": " ", # U+005F LOW LINE
|
||||
"‗": " ", # U+2017 DOUBLE LOW LINE
|
||||
"¯": " ", # U+00AF MACRON (technically an overline)
|
||||
"ˍ": " ", # U+02CD MODIFIER LETTER LOW MACRON
|
||||
"﹍": " ", # U+FE4D DASHED LOW LINE
|
||||
"﹎": " ", # U+FE4E CENTRELINE LOW LINE
|
||||
"﹏": " ", # U+FE4F WAVY LOW LINE
|
||||
"_": " ", # U+FF3F FULLWIDTH LOW LINE
|
||||
"_": " ", # U+005F LOW LINE
|
||||
"‗": " ", # U+2017 DOUBLE LOW LINE
|
||||
"¯": " ", # U+00AF MACRON (technically an overline)
|
||||
"ˍ": " ", # U+02CD MODIFIER LETTER LOW MACRON
|
||||
"﹍": " ", # U+FE4D DASHED LOW LINE
|
||||
"﹎": " ", # U+FE4E CENTRELINE LOW LINE
|
||||
"﹏": " ", # U+FE4F WAVY LOW LINE
|
||||
"_": " ", # U+FF3F FULLWIDTH LOW LINE
|
||||
|
||||
# Ellipsis (causes extreme long pauses in TTS)
|
||||
'...': '…', # standard triple dots replaced with Unicode ellipsis (U+2026)
|
||||
@@ -123,41 +123,41 @@ punctuation_list = [
|
||||
punctuation_list_set = set(punctuation_list)
|
||||
|
||||
punctuation_split_hard = [
|
||||
# Western
|
||||
'.', '!', '?', '…', '‽', '—', # sentence terminators
|
||||
# Arabic–Persian
|
||||
'؟', # Arabic question mark (hard)
|
||||
# CJK (Chinese/Japanese/Korean)
|
||||
'。', # full stop
|
||||
'!', '?', # full-width exclamation/question (hard for zho/jpn/kor)
|
||||
# Indic
|
||||
'।', '॥', # danda, double danda
|
||||
# Ethiopic
|
||||
'።', '፧', # full stop, question mark
|
||||
# Tibetan
|
||||
'།', # shad (end of verse/sentence)
|
||||
# Khmer
|
||||
'។', '៕' # full stop, end sign
|
||||
# Western
|
||||
'.', '!', '?', '…', '‽', '—', # sentence terminators
|
||||
# Arabic–Persian
|
||||
'؟', # Arabic question mark (hard)
|
||||
# CJK (Chinese/Japanese/Korean)
|
||||
'。', # full stop
|
||||
'!', '?', # full-width exclamation/question (hard for zho/jpn/kor)
|
||||
# Indic
|
||||
'।', '॥', # danda, double danda
|
||||
# Ethiopic
|
||||
'።', '፧', # full stop, question mark
|
||||
# Tibetan
|
||||
'།', # shad (end of verse/sentence)
|
||||
# Khmer
|
||||
'។', '៕' # full stop, end sign
|
||||
]
|
||||
punctuation_split_hard_set = set(punctuation_split_hard)
|
||||
|
||||
punctuation_split_soft = [
|
||||
# Western
|
||||
',', ':', ';',
|
||||
# Arabic–Persian
|
||||
'،',
|
||||
# CJK
|
||||
',', '、', '·',
|
||||
# Thai
|
||||
'ฯ',
|
||||
# Ethiopic
|
||||
'፡', '፣', '፤', '፥', '፦',
|
||||
# Hebrew
|
||||
'״',
|
||||
# Tibetan
|
||||
'༎',
|
||||
# Lao
|
||||
'໌', 'ໍ'
|
||||
# Western
|
||||
',', ':', ';',
|
||||
# Arabic–Persian
|
||||
'،',
|
||||
# CJK
|
||||
',', '、', '·',
|
||||
# Thai
|
||||
'ฯ',
|
||||
# Ethiopic
|
||||
'፡', '፣', '፤', '፥', '፦',
|
||||
# Hebrew
|
||||
'״',
|
||||
# Tibetan
|
||||
'༎',
|
||||
# Lao
|
||||
'໌', 'ໍ'
|
||||
]
|
||||
punctuation_split_soft_set = set(punctuation_split_soft)
|
||||
|
||||
@@ -173,18 +173,18 @@ roman_numbers_tuples = [
|
||||
]
|
||||
|
||||
emojis_list = [
|
||||
r"\U0001F600-\U0001F64F", # Emoticons
|
||||
r"\U0001F300-\U0001F5FF", # Symbols & pictographs
|
||||
r"\U0001F680-\U0001F6FF", # Transport & map symbols
|
||||
r"\U0001F1E0-\U0001F1FF", # Flags
|
||||
r"\U00002700-\U000027BF", # Dingbats
|
||||
r"\U0001F900-\U0001F9FF", # Supplemental symbols
|
||||
r"\U00002600-\U000026FF", # Misc symbols
|
||||
r"\U0001FA70-\U0001FAFF", # Extended pictographs
|
||||
r"\U00002480-\U00002BEF", # Box drawing, etc.
|
||||
r"\U0001F018-\U0001F270",
|
||||
r"\U0001F650-\U0001F67F",
|
||||
r"\U0001F700-\U0001F77F"
|
||||
r"\U0001F600-\U0001F64F", # Emoticons
|
||||
r"\U0001F300-\U0001F5FF", # Symbols & pictographs
|
||||
r"\U0001F680-\U0001F6FF", # Transport & map symbols
|
||||
r"\U0001F1E0-\U0001F1FF", # Flags
|
||||
r"\U00002700-\U000027BF", # Dingbats
|
||||
r"\U0001F900-\U0001F9FF", # Supplemental symbols
|
||||
r"\U00002600-\U000026FF", # Misc symbols
|
||||
r"\U0001FA70-\U0001FAFF", # Extended pictographs
|
||||
r"\U00002480-\U00002BEF", # Box drawing, etc.
|
||||
r"\U0001F018-\U0001F270",
|
||||
r"\U0001F650-\U0001F67F",
|
||||
r"\U0001F700-\U0001F77F"
|
||||
]
|
||||
|
||||
language_math_phonemes = {
|
||||
|
||||
@@ -46,8 +46,8 @@ default_engine_settings = {
|
||||
"top_k": 50,
|
||||
"top_p": 0.85,
|
||||
"speed": 1.0,
|
||||
#"gpt_cond_len": 512,
|
||||
#"gpt_batch_size": 1,
|
||||
#"gpt_cond_len": 512,
|
||||
#"gpt_batch_size": 1,
|
||||
"enable_text_splitting": False,
|
||||
"use_deepspeed": False,
|
||||
"files": ['config.json', 'model.pth', 'vocab.json', 'ref.wav'],
|
||||
|
||||
@@ -52,7 +52,6 @@ dependencies = [
|
||||
"argostranslate<=1.10.0",
|
||||
"gradio==5.49.1",
|
||||
"torch<=2.7.1",
|
||||
"torchaudio<=2.7.1",
|
||||
"coqui-tts[languages]==0.27.2"
|
||||
]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -33,5 +33,4 @@ pyannote-audio==3.4.0
|
||||
argostranslate==1.10.0
|
||||
gradio==5.49.1
|
||||
torch<=2.7.1
|
||||
torchaudio<=2.7.1
|
||||
coqui-tts[languages]==0.27.2
|
||||
@@ -1,61 +0,0 @@
|
||||
import os
|
||||
import platform
|
||||
import argparse
|
||||
|
||||
tmp_dir = os.path.abspath(os.path.join('..', 'tmp'))
|
||||
models_dir = os.path.abspath(os.path.join('..', 'models'))
|
||||
tts_dir = os.path.join(models_dir, 'tts')
|
||||
|
||||
os.environ['PYTHONUTF8'] = '1'
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
os.environ['COQUI_TOS_AGREED'] = '1'
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
os.environ['CALIBRE_NO_NATIVE_FILEDIALOGS'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = 'true'
|
||||
os.environ['CALIBRE_TEMP_DIR'] = tmp_dir
|
||||
os.environ['CALIBRE_CACHE_DIRECTORY'] = tmp_dir
|
||||
os.environ['HUGGINGFACE_HUB_CACHE'] = tts_dir
|
||||
os.environ['HF_HOME'] = tts_dir
|
||||
os.environ['HF_DATASETS_CACHE'] = tts_dir
|
||||
os.environ['BARK_CACHE_DIR'] = tts_dir
|
||||
os.environ['TTS_CACHE'] = tts_dir
|
||||
os.environ['TORCH_HOME'] = tts_dir
|
||||
os.environ['TTS_HOME'] = models_dir
|
||||
os.environ['XDG_CACHE_HOME'] = models_dir
|
||||
os.environ['ARGOS_TRANSLATE_PACKAGE_PATH'] = os.path.join(models_dir, 'argostranslate')
|
||||
os.environ['HF_TOKEN_PATH'] = os.path.join(os.path.expanduser('~'), '.huggingface_token')
|
||||
os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'
|
||||
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
||||
os.environ['SUNO_OFFLOAD_CPU'] = 'False' # BARK option: False needs A GPU
|
||||
os.environ['SUNO_USE_SMALL_MODELS'] = 'False' # BARK option: False needs a GPU with VRAM > 4GB
|
||||
if platform.system() == 'Windows':
|
||||
os.environ['ESPEAK_DATA_PATH'] = os.path.expandvars(r"%USERPROFILE%\scoop\apps\espeak-ng\current\eSpeak NG\espeak-ng-data")
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from bark import SAMPLE_RATE, preload_models
|
||||
from bark.generation import codec_decode
|
||||
|
||||
def npz_to_wav(npz_path, output_path):
|
||||
preload_models()
|
||||
data = np.load(npz_path)
|
||||
fine_prompt = data["fine_prompt"]
|
||||
audio_array = codec_decode(fine_prompt)
|
||||
audio_tensor = torch.tensor(audio_array).unsqueeze(0)
|
||||
torchaudio.save(output_path, audio_tensor, SAMPLE_RATE)
|
||||
print(f"✅ Saved: {output_path}")
|
||||
|
||||
def process_all_npz_in_folder(folder_path):
|
||||
preload_models()
|
||||
for npz_file in Path(folder_path).rglob("*.npz"):
|
||||
output_path = npz_file.with_suffix(".wav")
|
||||
npz_to_wav(str(npz_file), str(output_path))
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Process all NPZ files in a folder.")
|
||||
parser.add_argument("--folder_path", type=str, required=True, help="Path to the folder containing NPZ files")
|
||||
args = parser.parse_args()
|
||||
folder_path = os.path.abspath(args.folder_path)
|
||||
process_all_npz_in_folder(folder_path)
|
||||
@@ -1,141 +0,0 @@
|
||||
# NOTE: to run this script you must move it to the root of ebook2audiobook
|
||||
|
||||
import os
|
||||
|
||||
os.environ['PYTHONUTF8'] = '1'
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
os.environ['COQUI_TOS_AGREED'] = '1'
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
os.environ['DO_NOT_TRACK'] = 'true'
|
||||
os.environ['HUGGINGFACE_HUB_CACHE'] = tts_dir
|
||||
os.environ['HF_HOME'] = tts_dir
|
||||
os.environ['TRANSFORMERS_CACHE'] = tts_dir
|
||||
os.environ['HF_DATASETS_CACHE'] = tts_dir
|
||||
os.environ['BARK_CACHE_DIR'] = tts_dir
|
||||
os.environ['TTS_CACHE'] = tts_dir
|
||||
os.environ['TORCH_HOME'] = tts_dir
|
||||
os.environ['TTS_HOME'] = models_dir
|
||||
os.environ['XDG_CACHE_HOME'] = models_dir
|
||||
os.environ['HF_TOKEN_PATH'] = os.path.join(os.path.expanduser('~'), '.huggingface_token')
|
||||
os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'
|
||||
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
|
||||
os.environ['SUNO_OFFLOAD_CPU'] = 'False'
|
||||
os.environ['SUNO_USE_SMALL_MODELS'] = 'False'
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import shutil
|
||||
import soundfile as sf
|
||||
import subprocess
|
||||
import tempfile
|
||||
import torch
|
||||
import torchaudio
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from iso639 import languages
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pathlib import Path
|
||||
from scipy.io import wavfile as wav
|
||||
from scipy.signal import find_peaks
|
||||
from TTS.tts.configs.bark_config import BarkConfig
|
||||
from TTS.tts.models.bark import Bark
|
||||
|
||||
from lib import *
|
||||
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
torch.hub.set_dir(models_dir)
|
||||
|
||||
loaded_tts = {}
|
||||
|
||||
def load_checkpoint(**kwargs):
|
||||
try:
|
||||
key = kwargs.get('key')
|
||||
tts_engine = kwargs.get('tts_engine')
|
||||
device = kwargs.get('device')
|
||||
checkpoint_dir = kwargs.get('checkpoint_dir')
|
||||
config = BarkConfig()
|
||||
config.CACHE_DIR = tts_dir
|
||||
config.USE_SMALLER_MODELS = os.environ.get('SUNO_USE_SMALL_MODELS', '').lower() == 'true'
|
||||
tts = Bark.init_from_config(config)
|
||||
tts.load_checkpoint(
|
||||
config,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
eval=True
|
||||
)
|
||||
if tts:
|
||||
if device == 'cuda':
|
||||
tts.cuda()
|
||||
else:
|
||||
tts.to(device)
|
||||
loaded_tts[key] = {"engine": tts, "config": config}
|
||||
msg = f'{tts_engine} Loaded!'
|
||||
print(msg)
|
||||
return tts
|
||||
else:
|
||||
error = 'TTS engine could not be created!'
|
||||
print(error)
|
||||
except Exception as e:
|
||||
error = f'_load_checkpoint() error: {e}'
|
||||
return False
|
||||
|
||||
def wav_to_npz(bark_dir, wav_dir):
|
||||
try:
|
||||
tts_internal_key = f"TTS_ENGINES['BARK']-internal"
|
||||
hf_repo = models[TTS_ENGINES['BARK']]['internal']['repo']
|
||||
hf_sub = models[TTS_ENGINES['BARK']]['internal']['sub']
|
||||
text_model_path = hf_hub_download(repo_id=hf_repo, filename=f"{hf_sub}{models[TTS_ENGINES['BARK']]['internal']['files'][0]}", cache_dir=tts_dir)
|
||||
coarse_model_path = hf_hub_download(repo_id=hf_repo, filename=f"{hf_sub}{models[TTS_ENGINES['BARK']]['internal']['files'][1]}", cache_dir=tts_dir)
|
||||
fine_model_path = hf_hub_download(repo_id=hf_repo, filename=f"{hf_sub}{models[TTS_ENGINES['BARK']]['internal']['files'][2]}", cache_dir=tts_dir)
|
||||
checkpoint_dir = os.path.dirname(text_model_path)
|
||||
tts = load_checkpoint(tts_engine=TTS_ENGINES['BARK'], key=tts_internal_key, checkpoint_dir=checkpoint_dir, device='cpu')
|
||||
if tts:
|
||||
fine_tuned_params = {
|
||||
"text_temp": default_engine_settings[TTS_ENGINES['BARK']]['text_temp'],
|
||||
"waveform_temp": default_engine_settings[TTS_ENGINES['BARK']]['waveform_temp']
|
||||
}
|
||||
for root, dirs, files in os.walk(wav_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith('.wav'):
|
||||
match = re.match(r"^([a-z]{2})_", file)
|
||||
if match:
|
||||
speaker = os.path.splitext(file)[0]
|
||||
npz_file = f'{speaker}.npz'
|
||||
iso1_lang = match.group(1)
|
||||
lang_array = languages.get(part1=iso1_lang)
|
||||
if lang_array:
|
||||
iso3_lang = lang_array.part3
|
||||
default_text_file = os.path.join(voices_dir, iso3_lang, 'default.txt')
|
||||
default_text = Path(default_text_file).read_text(encoding="utf-8")
|
||||
with torch.no_grad():
|
||||
torch.manual_seed(67878789)
|
||||
audio_data = tts.synthesize(
|
||||
default_text,
|
||||
loaded_tts[tts_internal_key]['config'],
|
||||
speaker_id=speaker,
|
||||
voice_dirs=bark_dir,
|
||||
silent=True,
|
||||
**fine_tuned_params
|
||||
)
|
||||
del audio_data
|
||||
msg = f"Saved NPZ file: {npz_file}"
|
||||
print(msg)
|
||||
else:
|
||||
print('tts bark not loaded')
|
||||
except Exception as e:
|
||||
print(f'wav_to_npz() error: {e}')
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert WAV files to Bark NPZ format.")
|
||||
parser.add_argument("--bark_dir", type=str, required=True, help="Path to the Bark asset directory")
|
||||
parser.add_argument("--wav_dir", type=str, required=True, help="Path to the output WAV directory")
|
||||
args = parser.parse_args()
|
||||
bark_dir = os.path.abspath(args.bark_dir)
|
||||
wav_dir = os.path.abspath(args.wav_dir)
|
||||
wav_to_npz(bark_dir, wav_dir)
|
||||
|
||||
Reference in New Issue
Block a user