Files
deepxi-flask-server/infer_file.py
2020-04-20 22:05:35 +07:00

178 lines
8.4 KiB
Python

import os, sys
sys.path.insert(0, 'lib')
from dev.infer import infer
from dev.sample_stats import get_stats
from dev.train import train
import dev.deepxi_net as deepxi_net
import numpy as np
import tensorflow as tf
import dev.utils as utils
import argparse
from dev.utils import read_wav
from tqdm import tqdm
import dev.gain as gain
import dev.utils as utils
import dev.xi as xi
import numpy as np
import os
import scipy.io as spio
import librosa
import pickle
from scipy.io.wavfile import write as wav_write
np.set_printoptions(threshold=1e6)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
def str2bool(s): return s.lower() in ("yes", "true", "t", "1")
def create_args():
with open('data/stats.p', 'rb') as f:
stats = pickle.load(f)
parser = argparse.ArgumentParser()
## OPTIONS (GENERAL)
parser.add_argument('--gpu', default='0', type=str, help='GPU selection')
parser.add_argument('--ver', type=str, help='Model version')
parser.add_argument('--epoch', type=int, help='Epoch to use/retrain from')
parser.add_argument('--train', default=False, type=str2bool, help='Training flag')
parser.add_argument('--infer', default=True, type=str2bool, help='Inference flag')
parser.add_argument('--verbose', default=False, type=str2bool, help='Verbose')
parser.add_argument('--model', default='ResNet', type=str, help='Model type')
## OPTIONS (TRAIN)
parser.add_argument('--cont', default=False, type=str2bool, help='Continue testing from last epoch')
parser.add_argument('--mbatch_size', default=10, type=int, help='Mini-batch size')
parser.add_argument('--sample_size', default=1000, type=int, help='Sample size')
parser.add_argument('--max_epochs', default=250, type=int, help='Maximum number of epochs')
parser.add_argument('--grad_clip', default=True, type=str2bool, help='Gradient clipping')
parser.add_argument('--out_type', default='y', type=str, help='Output type for testing')
## GAIN FUNCTION
parser.add_argument('--gain', default='mmse-lsa', type=str, help='Gain function for testing')
## PATHS
parser.add_argument('--model_path', default='model/3f/epoch-175', type=str, help='Model save path')
parser.add_argument('--set_path', default='set', type=str, help='Path to datasets')
parser.add_argument('--data_path', default='data', type=str, help='Save data path')
parser.add_argument('--test_x_path', default='set/test_noisy_speech', type=str, help='Path to the noisy speech test set')
parser.add_argument('--in_filepath', default='test.wav', type=str, help='Output path')
parser.add_argument('--out_filepath', default='out.wav', type=str, help='Output path')
## FEATURES
parser.add_argument('--min_snr', default=-10, type=int, help='Minimum trained SNR level')
parser.add_argument('--max_snr', default=20, type=int, help='Maximum trained SNR level')
parser.add_argument('--f_s', default=16000, type=int, help='Sampling frequency (Hz)')
parser.add_argument('--T_w', default=32, type=int, help='Window length (ms)')
parser.add_argument('--T_s', default=16, type=int, help='Window shift (ms)')
parser.add_argument('--nconst', default=32768.0, type=float, help='Normalisation constant (see feat.addnoisepad())')
parser.add_argument('--N_w', default=int(16000*32*0.001), type=int, help='window length (samples)')
parser.add_argument('--N_s', default=int(16000*16*0.001), type=int, help='window shift (samples)')
parser.add_argument('--NFFT', default=int(pow(2, np.ceil(np.log2(int(16000*32*0.001))))), type=float, help='number of DFT components')
parser.add_argument('--stats', default=stats)
## NETWORK PARAMETERS
parser.add_argument('--d_in', default=257, type=int, help='Input dimensionality')
parser.add_argument('--d_out', default=257, type=int, help='Ouput dimensionality')
parser.add_argument('--d_model', default=256, type=int, help='Model dimensions')
parser.add_argument('--n_blocks', default=40, type=int, help='Number of blocks')
parser.add_argument('--d_f', default=64, type=int, help='Number of filters')
parser.add_argument('--k_size', default=3, type=int, help='Kernel size')
parser.add_argument('--max_d_rate', default=16, type=int, help='Maximum dilation rate')
parser.add_argument('--norm_type', default='FrameLayerNorm', type=str, help='Normalisation type')
parser.add_argument('--net_height', default=[4], type=list, help='RDL block height')
args = parser.parse_args()
return args
def build_restore_model(model_path, args, config):
## MAKE DEEP XI NNET
print('Start: Build and Restore model!')
sess = tf.Session(config=config)
net = deepxi_net.deepxi_net(args)
net.saver.restore(sess, args.model_path)
print('Done: Build and Restore model!')
return net, sess
# def infer(filename ,net, sess):
# print('Start infer file: {}'.format(filename))
# #(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
# (wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
# wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
# print(max(wav), min(wav), np.mean(wav))
# print(wav.shape)
# input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
# xi_bar_hat = sess.run(
# net.infer_output, feed_dict={net.input_ph: input_feat[0],
# net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
# xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
# #file_name = filename.split('/')[-1].split('.')
# y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
# y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
# net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
# if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
# if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
# y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
# out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
# wav_write(out_filepath, args.f_s, y)
# print('Infer out file: {} done'.format(out_filepath))
# return out_filepath
def get_model():
args = create_args()
## GPU CONFIGURATION
config = tf.ConfigProto()
config.allow_soft_placement=True
config.gpu_options.allow_growth=True
config.log_device_placement=False
net, sess = build_restore_model(args.model_path, args, config)
def infer(filename):
print('Start infer file: {}'.format(filename))
#(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
(wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
print(max(wav), min(wav), np.mean(wav))
print(wav.shape)
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
xi_bar_hat = sess.run(
net.infer_output, feed_dict={net.input_ph: input_feat[0],
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
#file_name = filename.split('/')[-1].split('.')
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
wav_write(out_filepath, args.f_s, y)
print('Infer out file: {} done'.format(out_filepath))
return out_filepath
return infer
if __name__ == '__main__':
infer = get_model()
infer('Toàn cảnh phòng chống dịch COVID-19 ngày 18-4-2020 - VTV24.mp3')
#infer('set/test_noisy_speech/198853.wav')