mirror of
https://github.com/dangvansam/deepxi-flask-server.git
synced 2026-01-09 22:27:56 -05:00
178 lines
8.4 KiB
Python
178 lines
8.4 KiB
Python
import os, sys
|
|
sys.path.insert(0, 'lib')
|
|
from dev.infer import infer
|
|
from dev.sample_stats import get_stats
|
|
from dev.train import train
|
|
import dev.deepxi_net as deepxi_net
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import dev.utils as utils
|
|
import argparse
|
|
|
|
from dev.utils import read_wav
|
|
from tqdm import tqdm
|
|
import dev.gain as gain
|
|
import dev.utils as utils
|
|
import dev.xi as xi
|
|
import numpy as np
|
|
import os
|
|
import scipy.io as spio
|
|
import librosa
|
|
import pickle
|
|
from scipy.io.wavfile import write as wav_write
|
|
np.set_printoptions(threshold=1e6)
|
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
|
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
|
|
|
def str2bool(s): return s.lower() in ("yes", "true", "t", "1")
|
|
|
|
def create_args():
|
|
with open('data/stats.p', 'rb') as f:
|
|
stats = pickle.load(f)
|
|
|
|
parser = argparse.ArgumentParser()
|
|
## OPTIONS (GENERAL)
|
|
parser.add_argument('--gpu', default='0', type=str, help='GPU selection')
|
|
parser.add_argument('--ver', type=str, help='Model version')
|
|
parser.add_argument('--epoch', type=int, help='Epoch to use/retrain from')
|
|
parser.add_argument('--train', default=False, type=str2bool, help='Training flag')
|
|
parser.add_argument('--infer', default=True, type=str2bool, help='Inference flag')
|
|
parser.add_argument('--verbose', default=False, type=str2bool, help='Verbose')
|
|
parser.add_argument('--model', default='ResNet', type=str, help='Model type')
|
|
|
|
## OPTIONS (TRAIN)
|
|
parser.add_argument('--cont', default=False, type=str2bool, help='Continue testing from last epoch')
|
|
parser.add_argument('--mbatch_size', default=10, type=int, help='Mini-batch size')
|
|
parser.add_argument('--sample_size', default=1000, type=int, help='Sample size')
|
|
parser.add_argument('--max_epochs', default=250, type=int, help='Maximum number of epochs')
|
|
parser.add_argument('--grad_clip', default=True, type=str2bool, help='Gradient clipping')
|
|
|
|
parser.add_argument('--out_type', default='y', type=str, help='Output type for testing')
|
|
|
|
## GAIN FUNCTION
|
|
parser.add_argument('--gain', default='mmse-lsa', type=str, help='Gain function for testing')
|
|
|
|
## PATHS
|
|
parser.add_argument('--model_path', default='model/3f/epoch-175', type=str, help='Model save path')
|
|
parser.add_argument('--set_path', default='set', type=str, help='Path to datasets')
|
|
parser.add_argument('--data_path', default='data', type=str, help='Save data path')
|
|
parser.add_argument('--test_x_path', default='set/test_noisy_speech', type=str, help='Path to the noisy speech test set')
|
|
parser.add_argument('--in_filepath', default='test.wav', type=str, help='Output path')
|
|
parser.add_argument('--out_filepath', default='out.wav', type=str, help='Output path')
|
|
|
|
## FEATURES
|
|
parser.add_argument('--min_snr', default=-10, type=int, help='Minimum trained SNR level')
|
|
parser.add_argument('--max_snr', default=20, type=int, help='Maximum trained SNR level')
|
|
parser.add_argument('--f_s', default=16000, type=int, help='Sampling frequency (Hz)')
|
|
parser.add_argument('--T_w', default=32, type=int, help='Window length (ms)')
|
|
parser.add_argument('--T_s', default=16, type=int, help='Window shift (ms)')
|
|
parser.add_argument('--nconst', default=32768.0, type=float, help='Normalisation constant (see feat.addnoisepad())')
|
|
parser.add_argument('--N_w', default=int(16000*32*0.001), type=int, help='window length (samples)')
|
|
parser.add_argument('--N_s', default=int(16000*16*0.001), type=int, help='window shift (samples)')
|
|
parser.add_argument('--NFFT', default=int(pow(2, np.ceil(np.log2(int(16000*32*0.001))))), type=float, help='number of DFT components')
|
|
parser.add_argument('--stats', default=stats)
|
|
|
|
## NETWORK PARAMETERS
|
|
parser.add_argument('--d_in', default=257, type=int, help='Input dimensionality')
|
|
parser.add_argument('--d_out', default=257, type=int, help='Ouput dimensionality')
|
|
parser.add_argument('--d_model', default=256, type=int, help='Model dimensions')
|
|
parser.add_argument('--n_blocks', default=40, type=int, help='Number of blocks')
|
|
parser.add_argument('--d_f', default=64, type=int, help='Number of filters')
|
|
parser.add_argument('--k_size', default=3, type=int, help='Kernel size')
|
|
parser.add_argument('--max_d_rate', default=16, type=int, help='Maximum dilation rate')
|
|
parser.add_argument('--norm_type', default='FrameLayerNorm', type=str, help='Normalisation type')
|
|
parser.add_argument('--net_height', default=[4], type=list, help='RDL block height')
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def build_restore_model(model_path, args, config):
|
|
## MAKE DEEP XI NNET
|
|
print('Start: Build and Restore model!')
|
|
sess = tf.Session(config=config)
|
|
net = deepxi_net.deepxi_net(args)
|
|
net.saver.restore(sess, args.model_path)
|
|
print('Done: Build and Restore model!')
|
|
return net, sess
|
|
|
|
# def infer(filename ,net, sess):
|
|
# print('Start infer file: {}'.format(filename))
|
|
# #(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
|
|
# (wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
|
|
# wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
|
|
|
|
# print(max(wav), min(wav), np.mean(wav))
|
|
# print(wav.shape)
|
|
|
|
# input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
|
|
# xi_bar_hat = sess.run(
|
|
# net.infer_output, feed_dict={net.input_ph: input_feat[0],
|
|
# net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
|
|
# xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
|
|
|
|
# #file_name = filename.split('/')[-1].split('.')
|
|
|
|
# y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
|
|
# y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
|
|
# net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
|
|
# if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
|
|
# if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
|
|
|
|
# y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
|
|
# out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
|
|
# wav_write(out_filepath, args.f_s, y)
|
|
# print('Infer out file: {} done'.format(out_filepath))
|
|
# return out_filepath
|
|
|
|
def get_model():
|
|
args = create_args()
|
|
|
|
## GPU CONFIGURATION
|
|
config = tf.ConfigProto()
|
|
config.allow_soft_placement=True
|
|
config.gpu_options.allow_growth=True
|
|
config.log_device_placement=False
|
|
|
|
net, sess = build_restore_model(args.model_path, args, config)
|
|
|
|
def infer(filename):
|
|
print('Start infer file: {}'.format(filename))
|
|
#(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
|
|
(wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
|
|
wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
|
|
|
|
print(max(wav), min(wav), np.mean(wav))
|
|
print(wav.shape)
|
|
|
|
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
|
|
xi_bar_hat = sess.run(
|
|
net.infer_output, feed_dict={net.input_ph: input_feat[0],
|
|
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
|
|
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
|
|
|
|
#file_name = filename.split('/')[-1].split('.')
|
|
|
|
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
|
|
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
|
|
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
|
|
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
|
|
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
|
|
|
|
y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
|
|
out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
|
|
wav_write(out_filepath, args.f_s, y)
|
|
print('Infer out file: {} done'.format(out_filepath))
|
|
return out_filepath
|
|
return infer
|
|
|
|
if __name__ == '__main__':
|
|
|
|
infer = get_model()
|
|
infer('Toàn cảnh phòng chống dịch COVID-19 ngày 18-4-2020 - VTV24.mp3')
|
|
#infer('set/test_noisy_speech/198853.wav')
|
|
|
|
|
|
|
|
|
|
|
|
|