mirror of
https://github.com/dangvansam/deepxi-flask-server.git
synced 2026-01-10 06:37:55 -05:00
20042020
This commit is contained in:
BIN
__pycache__/infer_file.cpython-36.pyc
Normal file
BIN
__pycache__/infer_file.cpython-36.pyc
Normal file
Binary file not shown.
BIN
__pycache__/model.cpython-36.pyc
Normal file
BIN
__pycache__/model.cpython-36.pyc
Normal file
Binary file not shown.
BIN
__pycache__/predict_img2spec.cpython-36.pyc
Normal file
BIN
__pycache__/predict_img2spec.cpython-36.pyc
Normal file
Binary file not shown.
BIN
__pycache__/utils.cpython-36.pyc
Normal file
BIN
__pycache__/utils.cpython-36.pyc
Normal file
Binary file not shown.
16
convert.py
Normal file
16
convert.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from os import path
|
||||
from pydub import AudioSegment
|
||||
import soundfile as sf
|
||||
import librosa
|
||||
import numpy as np
|
||||
# files
|
||||
src = "Tap4_baLan_23m47_23m54_tucgian_nhieu.mp3"
|
||||
dst = "test.wav"
|
||||
|
||||
# convert wav to mp3
|
||||
#sound = AudioSegment.from_mp3(src)
|
||||
#sound.export(dst, format="wav")
|
||||
|
||||
s, r = librosa.load(src, None, dtype='int16')
|
||||
print(len(s),r)
|
||||
print(np.mean(s))
|
||||
1
data/info.txt
Normal file
1
data/info.txt
Normal file
@@ -0,0 +1 @@
|
||||
Statistics are stored here, as well as the training lists.
|
||||
BIN
data/stats.m
Normal file
BIN
data/stats.m
Normal file
Binary file not shown.
BIN
data/stats.p
Normal file
BIN
data/stats.p
Normal file
Binary file not shown.
BIN
data/test_x_list_trung-desktop.p
Normal file
BIN
data/test_x_list_trung-desktop.p
Normal file
Binary file not shown.
46
deepxi.py
Normal file
46
deepxi.py
Normal file
@@ -0,0 +1,46 @@
|
||||
## FILE: deepxi.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University
|
||||
## BRIEF: 'DeepXi' training and testing.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import os, sys
|
||||
sys.path.insert(0, 'lib')
|
||||
from dev.args import add_args, get_args
|
||||
from dev.infer import infer
|
||||
from dev.sample_stats import get_stats
|
||||
from dev.train import train
|
||||
import dev.deepxi_net as deepxi_net
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import dev.utils as utils
|
||||
np.set_printoptions(threshold=1e6)
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
## GET COMMAND LINE ARGUMENTS
|
||||
args = get_args()
|
||||
|
||||
## TRAINING AND TESTING SET ARGUMENTS
|
||||
args = add_args(args)
|
||||
|
||||
## GPU CONFIGURATION
|
||||
config = utils.gpu_config(args.gpu)
|
||||
|
||||
## GET STATISTICS
|
||||
args = get_stats(args.data_path, args, config)
|
||||
|
||||
print(args)
|
||||
exit()
|
||||
## MAKE DEEP XI NNET
|
||||
net = deepxi_net.deepxi_net(args)
|
||||
|
||||
with tf.Session(config=config) as sess:
|
||||
if args.train: train(sess, net, args)
|
||||
if args.infer: infer(sess, net, args)
|
||||
46
flask_server.py
Normal file
46
flask_server.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
import requests
|
||||
from flask import Flask, escape, request, render_template, jsonify, make_response, session
|
||||
#from utils import cvtToWavMono16, split
|
||||
|
||||
import random
|
||||
from infer_file import get_model
|
||||
from flask_cors import CORS
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
infer = get_model()
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return render_template('index.html')
|
||||
|
||||
@app.route('/upload', methods=['POST','GET'])
|
||||
def upload():
|
||||
if request.method == 'POST':
|
||||
file = request.files['file']
|
||||
filename = file.filename
|
||||
print(filename)
|
||||
if os.path.splitext(filename)[1][1:].strip() not in ['mp3','wav','flac']:
|
||||
return render_template('index.html', filename='{} file not support. select mp3, wav or flac!'.format(filename))
|
||||
file_path = 'static/upload/' + filename
|
||||
file.save(file_path)
|
||||
print('saved file: {}'.format(file_path))
|
||||
res = make_response(jsonify({"file_path": file_path, "message": "Saved: {} to server".format(filename)}))
|
||||
return res
|
||||
return render_template('index.html')
|
||||
|
||||
@app.route('/predict/<file_path>')
|
||||
def predict(file_path):
|
||||
print(file_path)
|
||||
file_path = file_path.replace('=','/')
|
||||
out_file_path = infer(file_path)
|
||||
print('predict done!!')
|
||||
res = make_response(jsonify({"out_file_path":out_file_path, "message": "Predict susscess!"}))
|
||||
return res
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.debug = True
|
||||
app.secret_key = 'dangvansam'
|
||||
#app.run(host='192.168.1.254', port='9002')
|
||||
app.run(port='8080')
|
||||
178
infer_file.py
Normal file
178
infer_file.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import os, sys
|
||||
sys.path.insert(0, 'lib')
|
||||
from dev.infer import infer
|
||||
from dev.sample_stats import get_stats
|
||||
from dev.train import train
|
||||
import dev.deepxi_net as deepxi_net
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import dev.utils as utils
|
||||
import argparse
|
||||
|
||||
from dev.utils import read_wav
|
||||
from tqdm import tqdm
|
||||
import dev.gain as gain
|
||||
import dev.utils as utils
|
||||
import dev.xi as xi
|
||||
import numpy as np
|
||||
import os
|
||||
import scipy.io as spio
|
||||
import librosa
|
||||
import pickle
|
||||
from scipy.io.wavfile import write as wav_write
|
||||
np.set_printoptions(threshold=1e6)
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
||||
|
||||
def str2bool(s): return s.lower() in ("yes", "true", "t", "1")
|
||||
|
||||
def create_args():
|
||||
with open('data/stats.p', 'rb') as f:
|
||||
stats = pickle.load(f)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
## OPTIONS (GENERAL)
|
||||
parser.add_argument('--gpu', default='0', type=str, help='GPU selection')
|
||||
parser.add_argument('--ver', type=str, help='Model version')
|
||||
parser.add_argument('--epoch', type=int, help='Epoch to use/retrain from')
|
||||
parser.add_argument('--train', default=False, type=str2bool, help='Training flag')
|
||||
parser.add_argument('--infer', default=True, type=str2bool, help='Inference flag')
|
||||
parser.add_argument('--verbose', default=False, type=str2bool, help='Verbose')
|
||||
parser.add_argument('--model', default='ResNet', type=str, help='Model type')
|
||||
|
||||
## OPTIONS (TRAIN)
|
||||
parser.add_argument('--cont', default=False, type=str2bool, help='Continue testing from last epoch')
|
||||
parser.add_argument('--mbatch_size', default=10, type=int, help='Mini-batch size')
|
||||
parser.add_argument('--sample_size', default=1000, type=int, help='Sample size')
|
||||
parser.add_argument('--max_epochs', default=250, type=int, help='Maximum number of epochs')
|
||||
parser.add_argument('--grad_clip', default=True, type=str2bool, help='Gradient clipping')
|
||||
|
||||
parser.add_argument('--out_type', default='y', type=str, help='Output type for testing')
|
||||
|
||||
## GAIN FUNCTION
|
||||
parser.add_argument('--gain', default='mmse-lsa', type=str, help='Gain function for testing')
|
||||
|
||||
## PATHS
|
||||
parser.add_argument('--model_path', default='model/3f/epoch-175', type=str, help='Model save path')
|
||||
parser.add_argument('--set_path', default='set', type=str, help='Path to datasets')
|
||||
parser.add_argument('--data_path', default='data', type=str, help='Save data path')
|
||||
parser.add_argument('--test_x_path', default='set/test_noisy_speech', type=str, help='Path to the noisy speech test set')
|
||||
parser.add_argument('--in_filepath', default='test.wav', type=str, help='Output path')
|
||||
parser.add_argument('--out_filepath', default='out.wav', type=str, help='Output path')
|
||||
|
||||
## FEATURES
|
||||
parser.add_argument('--min_snr', default=-10, type=int, help='Minimum trained SNR level')
|
||||
parser.add_argument('--max_snr', default=20, type=int, help='Maximum trained SNR level')
|
||||
parser.add_argument('--f_s', default=16000, type=int, help='Sampling frequency (Hz)')
|
||||
parser.add_argument('--T_w', default=32, type=int, help='Window length (ms)')
|
||||
parser.add_argument('--T_s', default=16, type=int, help='Window shift (ms)')
|
||||
parser.add_argument('--nconst', default=32768.0, type=float, help='Normalisation constant (see feat.addnoisepad())')
|
||||
parser.add_argument('--N_w', default=int(16000*32*0.001), type=int, help='window length (samples)')
|
||||
parser.add_argument('--N_s', default=int(16000*16*0.001), type=int, help='window shift (samples)')
|
||||
parser.add_argument('--NFFT', default=int(pow(2, np.ceil(np.log2(int(16000*32*0.001))))), type=float, help='number of DFT components')
|
||||
parser.add_argument('--stats', default=stats)
|
||||
|
||||
## NETWORK PARAMETERS
|
||||
parser.add_argument('--d_in', default=257, type=int, help='Input dimensionality')
|
||||
parser.add_argument('--d_out', default=257, type=int, help='Ouput dimensionality')
|
||||
parser.add_argument('--d_model', default=256, type=int, help='Model dimensions')
|
||||
parser.add_argument('--n_blocks', default=40, type=int, help='Number of blocks')
|
||||
parser.add_argument('--d_f', default=64, type=int, help='Number of filters')
|
||||
parser.add_argument('--k_size', default=3, type=int, help='Kernel size')
|
||||
parser.add_argument('--max_d_rate', default=16, type=int, help='Maximum dilation rate')
|
||||
parser.add_argument('--norm_type', default='FrameLayerNorm', type=str, help='Normalisation type')
|
||||
parser.add_argument('--net_height', default=[4], type=list, help='RDL block height')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def build_restore_model(model_path, args, config):
|
||||
## MAKE DEEP XI NNET
|
||||
print('Start: Build and Restore model!')
|
||||
sess = tf.Session(config=config)
|
||||
net = deepxi_net.deepxi_net(args)
|
||||
net.saver.restore(sess, args.model_path)
|
||||
print('Done: Build and Restore model!')
|
||||
return net, sess
|
||||
|
||||
# def infer(filename ,net, sess):
|
||||
# print('Start infer file: {}'.format(filename))
|
||||
# #(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
|
||||
# (wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
|
||||
# wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
|
||||
|
||||
# print(max(wav), min(wav), np.mean(wav))
|
||||
# print(wav.shape)
|
||||
|
||||
# input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
|
||||
# xi_bar_hat = sess.run(
|
||||
# net.infer_output, feed_dict={net.input_ph: input_feat[0],
|
||||
# net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
|
||||
# xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
|
||||
|
||||
# #file_name = filename.split('/')[-1].split('.')
|
||||
|
||||
# y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
|
||||
# y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
|
||||
# net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
|
||||
# if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
|
||||
# if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
|
||||
|
||||
# y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
|
||||
# out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
|
||||
# wav_write(out_filepath, args.f_s, y)
|
||||
# print('Infer out file: {} done'.format(out_filepath))
|
||||
# return out_filepath
|
||||
|
||||
def get_model():
|
||||
args = create_args()
|
||||
|
||||
## GPU CONFIGURATION
|
||||
config = tf.ConfigProto()
|
||||
config.allow_soft_placement=True
|
||||
config.gpu_options.allow_growth=True
|
||||
config.log_device_placement=False
|
||||
|
||||
net, sess = build_restore_model(args.model_path, args, config)
|
||||
|
||||
def infer(filename):
|
||||
print('Start infer file: {}'.format(filename))
|
||||
#(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
|
||||
(wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
|
||||
wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
|
||||
|
||||
print(max(wav), min(wav), np.mean(wav))
|
||||
print(wav.shape)
|
||||
|
||||
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
|
||||
xi_bar_hat = sess.run(
|
||||
net.infer_output, feed_dict={net.input_ph: input_feat[0],
|
||||
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
|
||||
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
|
||||
|
||||
#file_name = filename.split('/')[-1].split('.')
|
||||
|
||||
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
|
||||
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
|
||||
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
|
||||
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
|
||||
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
|
||||
|
||||
y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
|
||||
out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
|
||||
wav_write(out_filepath, args.f_s, y)
|
||||
print('Infer out file: {} done'.format(out_filepath))
|
||||
return out_filepath
|
||||
return infer
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
infer = get_model()
|
||||
infer('Toàn cảnh phòng chống dịch COVID-19 ngày 18-4-2020 - VTV24.mp3')
|
||||
#infer('set/test_noisy_speech/198853.wav')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
58
lib/dev/ResLSTM.py
Normal file
58
lib/dev/ResLSTM.py
Normal file
@@ -0,0 +1,58 @@
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.training import moving_averages
|
||||
from dev.normalisation import Normalisation
|
||||
import numpy as np
|
||||
import argparse, math, sys
|
||||
|
||||
def ResLSTMBlock(x, d_model, seq_len, block, parallel_iterations=1024):
|
||||
'''
|
||||
ResLSTM block.
|
||||
|
||||
Input/s:
|
||||
x - input to block.
|
||||
d_model - cell size.
|
||||
seq_len - sequence length.
|
||||
block - block number.
|
||||
parallel_iterations - number of parrallel iterations.
|
||||
|
||||
Output/s:
|
||||
output of residual block.
|
||||
'''
|
||||
with tf.variable_scope( 'block_' + str(block)):
|
||||
cell = tf.contrib.rnn.LSTMCell(d_model)
|
||||
# activation, _ = tf.nn.static_rnn(cell, [tf.expand_dims(x[0,0,:],0)]*seq_len[0], dtype=tf.float32)
|
||||
activation, _ = tf.nn.dynamic_rnn(cell, x, seq_len, swap_memory=True,
|
||||
parallel_iterations=parallel_iterations, dtype=tf.float32)
|
||||
return tf.add(x, activation)
|
||||
|
||||
def ResLSTM(x, seq_len, norm_type, training=None, d_out=257,
|
||||
n_blocks=5, d_model=512, out_layer=True, boolean_mask=False):
|
||||
'''
|
||||
ResLSTM network.
|
||||
|
||||
Input/s:
|
||||
x - input.
|
||||
seq_len - length of each sequence.
|
||||
norm_type - normalisation type.
|
||||
training - training flag.
|
||||
d_out - output dimensions.
|
||||
n_blocks - number of residual blocks.
|
||||
d_model - cell size.
|
||||
out_layer - add an output layer.
|
||||
boolean_mask - convert padded 3D output to unpadded and stacked 2D output.
|
||||
|
||||
Output/s:
|
||||
unactivated output of ResLSTM.
|
||||
'''
|
||||
blocks = [tf.nn.relu(Normalisation(tf.layers.dense(x,
|
||||
d_model, use_bias=False), norm_type, seq_len))] # (W -> Norm -> ReLU).
|
||||
for i in range(n_blocks): blocks.append(ResLSTMBlock(blocks[-1], d_model, seq_len, i))
|
||||
if boolean_mask: blocks[-1] = tf.boolean_mask(blocks[-1], tf.sequence_mask(seq_len))
|
||||
return tf.layers.dense(blocks[-1], d_out, use_bias=True)
|
||||
89
lib/dev/ResNet.py
Normal file
89
lib/dev/ResNet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
## FILE: .py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: ResNet with bottlekneck blocks and 1D causal dlated convolutional units. .
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.training import moving_averages
|
||||
from dev.normalisation import Normalisation
|
||||
import numpy as np
|
||||
import argparse, math, sys
|
||||
from dev.add_noise import add_noise_batch
|
||||
|
||||
def CausalDilatedConv1d(x, d_f, k_size, d_rate=1, use_bias=True):
|
||||
'''
|
||||
1D Causal dilated convolutional unit.
|
||||
|
||||
Input/s:
|
||||
x - input.
|
||||
d_f - filter dimensions.
|
||||
k_size - kernel dimensions.
|
||||
d_rate - dilation rate.
|
||||
use_bias - include use bias vector.
|
||||
|
||||
Output/s:
|
||||
output of convolutional unit.
|
||||
'''
|
||||
|
||||
if k_size > 1: # padding for causality.
|
||||
x_shape = tf.shape(x)
|
||||
x = tf.concat([tf.zeros([x_shape[0], (k_size - 1)*d_rate, x_shape[2]]), x], 1)
|
||||
return tf.layers.conv1d(x, d_f, k_size, dilation_rate=d_rate,
|
||||
activation=None, padding='valid', use_bias=use_bias)
|
||||
|
||||
def BottlekneckBlock(x, norm_type, seq_len, d_model, d_f, k_size, d_rate):
|
||||
'''
|
||||
Bottlekneck block with causal dilated convolutional
|
||||
units, and normalisation.
|
||||
|
||||
Input/s:
|
||||
x - input to block.
|
||||
norm_type - normalisation type.
|
||||
seq_len - length of each sequence.
|
||||
d_out - output dimensions.
|
||||
d_f - filter dimensions.
|
||||
k_size - kernel dimensions.
|
||||
d_rate - dilation rate.
|
||||
|
||||
Output/s:
|
||||
output of residual block.
|
||||
'''
|
||||
|
||||
layer_1 = CausalDilatedConv1d(tf.nn.relu(Normalisation(x, norm_type, seq_len=seq_len)), d_f, 1, 1, False)
|
||||
layer_2 = CausalDilatedConv1d(tf.nn.relu(Normalisation(layer_1, norm_type, seq_len=seq_len)), d_f, k_size, d_rate, False)
|
||||
layer_3 = CausalDilatedConv1d(tf.nn.relu(Normalisation(layer_2, norm_type, seq_len=seq_len)), d_model, 1, 1, True)
|
||||
return tf.add(x, layer_3)
|
||||
|
||||
def ResNet(x, seq_len, norm_type, training=None, d_out=257,
|
||||
n_blocks=40, d_model=256, d_f=64, k_size=3, max_d_rate=16, out_layer=True, boolean_mask=False):
|
||||
'''
|
||||
ResNet with bottlekneck blocks, causal dilated convolutional
|
||||
units, and normalisation. Dilation resets after
|
||||
exceeding 'max_d_rate'.
|
||||
|
||||
Input/s:
|
||||
x - input to ResNet.
|
||||
norm_type - normalisation type.
|
||||
seq_len - length of each sequence.
|
||||
training - training flag.
|
||||
d_out - output dimensions.
|
||||
n_blocks - number of residual blocks.
|
||||
d_model - model dimensions.
|
||||
d_f - filter dimensions.
|
||||
k_size - kernel dimensions.
|
||||
max_d_rate - maximum dilation rate.
|
||||
|
||||
Output/s:
|
||||
unactivated output of ResNet.
|
||||
'''
|
||||
# mask = tf.cast(tf.expand_dims(tf.sequence_mask(seq_len), 2), tf.float32) # convert mask to float.
|
||||
blocks = [tf.nn.relu(Normalisation(tf.layers.dense(x, d_model, use_bias=False), norm_type, seq_len=seq_len))] # (W -> Norm -> ReLU).
|
||||
for i in range(n_blocks): blocks.append(BottlekneckBlock(blocks[-1], norm_type, seq_len,
|
||||
d_model, d_f, k_size, int(2**(i%(np.log2(max_d_rate)+1)))))
|
||||
if boolean_mask: blocks[-1] = tf.boolean_mask(blocks[-1], tf.sequence_mask(seq_len))
|
||||
return tf.layers.dense(blocks[-1], d_out, use_bias=True)
|
||||
BIN
lib/dev/__pycache__/ResNet.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/ResNet.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/ResNet.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/ResNet.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/add_args.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/add_args.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/add_args.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/add_args.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/add_noise.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/add_noise.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/add_noise.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/add_noise.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/args.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/args.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/args.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/args.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/deepxi_net.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/deepxi_net.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/encoder_decoder.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/encoder_decoder.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/encoder_decoder.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/encoder_decoder.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/gain.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/gain.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/gain.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/gain.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/get_args.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/get_args.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/get_args.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/get_args.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/infer.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/infer.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/mha.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/mha.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/mmse_normal_gamma.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/mmse_normal_gamma.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/mod_as.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/mod_as.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/mod_feat.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/mod_feat.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/normalisation.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/normalisation.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/num_frames.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/num_frames.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/num_frames.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/num_frames.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/optimisation.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/optimisation.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/optimisation.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/optimisation.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/sample_stats.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/sample_stats.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/se_batch.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/se_batch.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/se_batch.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/se_batch.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/train.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/train.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/utils.cpython-34.pyc
Normal file
BIN
lib/dev/__pycache__/utils.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/utils.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/utils.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/__pycache__/xi.cpython-36.pyc
Normal file
BIN
lib/dev/__pycache__/xi.cpython-36.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
71
lib/dev/acoustic/analysis_synthesis/polar.py
Normal file
71
lib/dev/acoustic/analysis_synthesis/polar.py
Normal file
@@ -0,0 +1,71 @@
|
||||
## FILE: polar.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: Analysis and synthesis with the polar representation in the acoustic-domain.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import tensorflow as tf
|
||||
import functools
|
||||
from tensorflow.python.ops.signal import window_ops
|
||||
|
||||
def analysis(x, N_w, N_s, NFFT, legacy=False):
|
||||
'''
|
||||
Polar form acoustic-domain analysis.
|
||||
|
||||
Input/s:
|
||||
x - noisy speech.
|
||||
N_w - time-domain window length (samples).
|
||||
N_s - time-domain window shift (samples).
|
||||
NFFT - acoustic-domain DFT components.
|
||||
|
||||
Output/s:
|
||||
Magnitude and phase spectrums.
|
||||
'''
|
||||
|
||||
|
||||
if legacy:
|
||||
|
||||
## MAGNITUDE & PHASE SPECTRUMS (ACOUSTIC DOMAIN)
|
||||
x_DFT = tf.signal.stft(x, N_w, N_s, NFFT, pad_end=True)
|
||||
x_MAG = tf.abs(x_DFT); x_PHA = tf.angle(x_DFT)
|
||||
return x_MAG, x_PHA
|
||||
|
||||
else:
|
||||
|
||||
## MAGNITUDE & PHASE SPECTRUMS (ACOUSTIC DOMAIN)
|
||||
W = functools.partial(window_ops.hamming_window, periodic=False)
|
||||
x_DFT = tf.signal.stft(x, N_w, N_s, NFFT, window_fn=W, pad_end=True)
|
||||
x_MAG = tf.abs(x_DFT); x_PHA = tf.angle(x_DFT)
|
||||
return x_MAG, x_PHA
|
||||
|
||||
def synthesis(y_MAG, x_PHA, N_w, N_s, NFFT, legacy=False):
|
||||
'''
|
||||
Polar form acoustic-domain synthesis.
|
||||
|
||||
Input/s:
|
||||
y_MAG - modified nagnitude spectrum.
|
||||
x_PHA - unmodified phase spectrum.
|
||||
N_w - time-domain window length (samples).
|
||||
N_s - time-domain window shift (samples).
|
||||
NFFT - acoustic-domain DFT components.
|
||||
|
||||
Output/s:
|
||||
synthesised signal.
|
||||
'''
|
||||
|
||||
if legacy:
|
||||
|
||||
## SYNTHESISED SIGNAL
|
||||
y_DFT = tf.cast(y_MAG, tf.complex64)*tf.exp(1j*tf.cast(x_PHA, tf.complex64))
|
||||
return tf.signal.inverse_stft(y_DFT, N_w, N_s, NFFT, tf.signal.inverse_stft_window_fn(N_s))
|
||||
|
||||
else:
|
||||
|
||||
## SYNTHESISED SIGNAL
|
||||
W = functools.partial(window_ops.hamming_window, periodic=False)
|
||||
y_DFT = tf.cast(y_MAG, tf.complex64)*tf.exp(1j*tf.cast(x_PHA, tf.complex64))
|
||||
return tf.signal.inverse_stft(y_DFT, N_w, N_s, NFFT, tf.signal.inverse_stft_window_fn(N_s, W))
|
||||
BIN
lib/dev/acoustic/feat/__pycache__/filterbank.cpython-34.pyc
Normal file
BIN
lib/dev/acoustic/feat/__pycache__/filterbank.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/acoustic/feat/__pycache__/filterbank.cpython-36.pyc
Normal file
BIN
lib/dev/acoustic/feat/__pycache__/filterbank.cpython-36.pyc
Normal file
Binary file not shown.
BIN
lib/dev/acoustic/feat/__pycache__/polar.cpython-34.pyc
Normal file
BIN
lib/dev/acoustic/feat/__pycache__/polar.cpython-34.pyc
Normal file
Binary file not shown.
BIN
lib/dev/acoustic/feat/__pycache__/polar.cpython-36.pyc
Normal file
BIN
lib/dev/acoustic/feat/__pycache__/polar.cpython-36.pyc
Normal file
Binary file not shown.
125
lib/dev/acoustic/feat/polar.py
Normal file
125
lib/dev/acoustic/feat/polar.py
Normal file
@@ -0,0 +1,125 @@
|
||||
## FILE: polar.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University
|
||||
## BRIEF: Feature and target generation for polar representation in the acoustic-domain.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
from dev.acoustic.analysis_synthesis import polar
|
||||
from dev.add_noise import add_noise_batch
|
||||
from dev.num_frames import num_frames
|
||||
from dev.utils import log10
|
||||
import dev.xi as xi
|
||||
import tensorflow as tf
|
||||
|
||||
def input(z, z_len, N_w, N_s, NFFT, f_s):
|
||||
'''
|
||||
Input features for polar form acoustic-domain.
|
||||
|
||||
Input/s:
|
||||
z - speech (dtype=tf.int32).
|
||||
z_len - speech length without padding (samples).
|
||||
N_w - time-domain window length (samples).
|
||||
N_s - time-domain window shift (samples).
|
||||
NFFT - number of acoustic-domain DFT components.
|
||||
f_s - sampling frequency (Hz).
|
||||
|
||||
Output/s:
|
||||
z_MAG - speech magnitude spectrum.
|
||||
z_PHA - speech phase spectrum.
|
||||
L - number of time-domain frames for each sequence.
|
||||
'''
|
||||
z = tf.truediv(tf.cast(z, tf.float32), 32768.0)
|
||||
L = num_frames(z_len, N_s)
|
||||
z_MAG, z_PHA = polar.analysis(z, N_w, N_s, NFFT)
|
||||
return z_MAG, L, z_PHA
|
||||
|
||||
def input_target_spec(s, d, s_len, d_len, SNR, N_w, N_s, NFFT, f_s):
|
||||
'''
|
||||
Input features and target (spectrum) for polar form acoustic-domain.
|
||||
|
||||
Inputs:
|
||||
s - clean speech (dtype=tf.int32).
|
||||
d - noise (dtype=tf.int32).
|
||||
s_len - clean speech length without padding (samples).
|
||||
d_len - noise length without padding (samples).
|
||||
SNR - SNR level.
|
||||
N_w - time-domain window length (samples).
|
||||
N_s - time-domain window shift (samples).
|
||||
NFFT - number of acoustic-domain DFT components.
|
||||
f_s - sampling frequency (Hz).
|
||||
|
||||
Outputs:
|
||||
x_MAG - noisy speech magnitude spectrum.
|
||||
s_MAG - clean speech magnitude spectrum (target).
|
||||
L - number of time-domain frames for each sequence.
|
||||
'''
|
||||
(x, s, _) = add_noise_batch(s, d, s_len, d_len, SNR)
|
||||
L = num_frames(s_len, N_s) # number of time-domain frames for each sequence (uppercase eta).
|
||||
x_MAG, _ = polar.analysis(x, N_w, N_s, NFFT)
|
||||
s_MAG, _ = polar.analysis(s, N_w, N_s, NFFT)
|
||||
s_MAG = tf.boolean_mask(s_MAG, tf.sequence_mask(L))
|
||||
return x_MAG, s_MAG, L
|
||||
|
||||
def input_target_xi(s, d, s_len, d_len, SNR, N_w, N_s, NFFT, f_s, mu, sigma):
|
||||
'''
|
||||
Input features and target (mapped a priori SNR) for polar form acoustic-domain.
|
||||
|
||||
Inputs:
|
||||
s - clean speech (dtype=tf.int32).
|
||||
d - noise (dtype=tf.int32).
|
||||
s_len - clean speech length without padding (samples).
|
||||
d_len - noise length without padding (samples).
|
||||
SNR - SNR level.
|
||||
N_w - time-domain window length (samples).
|
||||
N_s - time-domain window shift (samples).
|
||||
NFFT - number of acoustic-domain DFT components.
|
||||
f_s - sampling frequency (Hz).
|
||||
mu - sample mean.
|
||||
sigma - sample standard deviation.
|
||||
|
||||
Outputs:
|
||||
x_MAG - noisy speech magnitude spectrum.
|
||||
xi_mapped - mapped a priori SNR (target).
|
||||
L - number of time-domain frames for each sequence.
|
||||
'''
|
||||
(x, s, d) = add_noise_batch(s, d, s_len, d_len, SNR)
|
||||
L = num_frames(s_len, N_s) # number of acoustic-domain frames for each sequence (uppercase eta).
|
||||
x_MAG, _ = polar.analysis(x, N_w, N_s, NFFT)
|
||||
s_MAG, _ = polar.analysis(s, N_w, N_s, NFFT)
|
||||
s_MAG = tf.boolean_mask(s_MAG, tf.sequence_mask(L))
|
||||
d_MAG, _ = polar.analysis(d, N_w, N_s, NFFT)
|
||||
d_MAG = tf.boolean_mask(d_MAG, tf.sequence_mask(L))
|
||||
xi_bar = xi.xi_bar(s_MAG, d_MAG, mu, sigma)
|
||||
return x_MAG, xi_bar, L
|
||||
|
||||
def target_xi(s, d, s_len, d_len, SNR, N_w, N_s, NFFT, f_s):
|
||||
'''
|
||||
Target (a priori SNR) for polar form acoustic-domain.
|
||||
|
||||
Inputs:
|
||||
s - clean speech (dtype=tf.int32).
|
||||
d - noise (dtype=tf.int32).
|
||||
s_len - clean speech length without padding (samples).
|
||||
d_len - noise length without padding (samples).
|
||||
SNR - SNR level.
|
||||
N_w - time-domain window length (samples).
|
||||
N_s - time-domain window shift (samples).
|
||||
NFFT - number of acoustic-domain DFT components.
|
||||
f_s - sampling frequency (Hz).
|
||||
|
||||
Outputs:
|
||||
xi_dB - a priori SNR in dB (target).
|
||||
L - number of time-domain frames for each sequence.
|
||||
'''
|
||||
(_, s, d) = add_noise_batch(s, d, s_len, d_len, SNR)
|
||||
L = num_frames(s_len, N_s) # number of acoustic-domain frames for each sequence (uppercase eta).
|
||||
s_MAG, _ = polar.analysis(s, N_w, N_s, NFFT)
|
||||
d_MAG, _ = polar.analysis(d, N_w, N_s, NFFT)
|
||||
s_MAG = tf.boolean_mask(s_MAG, tf.sequence_mask(L))
|
||||
d_MAG = tf.boolean_mask(d_MAG, tf.sequence_mask(L))
|
||||
xi_dB = xi.xi_dB(s_MAG, d_MAG)
|
||||
return xi_dB, L
|
||||
79
lib/dev/add_noise.py
Normal file
79
lib/dev/add_noise.py
Normal file
@@ -0,0 +1,79 @@
|
||||
## FILE: add_noise.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University
|
||||
## BRIEF: Add noise to clean speech at set SNR level.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
def add_noise_batch(s, d, s_len, d_len, SNR):
|
||||
'''
|
||||
Creates noisy speech batch from clean speech, noise, and SNR batches.
|
||||
|
||||
Input/s:
|
||||
s - clean waveforms (dtype=tf.int32).
|
||||
d - noisy waveforms (dtype=tf.int32).
|
||||
s_len - clean waveform lengths without padding (samples).
|
||||
d_len - noise waveform lengths without padding (samples).
|
||||
SNR - SNR levels.
|
||||
|
||||
Output/s:
|
||||
tuple consisting of clean speech, noisy speech, and noise (x, s, d).
|
||||
'''
|
||||
return tf.map_fn(lambda z: add_noise_pad(z[0], z[1], z[2], z[3], z[4],
|
||||
tf.reduce_max(s_len)), (s, d, s_len, d_len, SNR), dtype=(tf.float32, tf.float32,
|
||||
tf.float32))
|
||||
|
||||
def add_noise_pad(s, d, s_len, d_len, SNR, P):
|
||||
'''
|
||||
Calls addnoise() and pads the waveforms to the length given by P.
|
||||
Also normalises the waveforms.
|
||||
|
||||
Inputs:
|
||||
s - clean speech waveform.
|
||||
d - noise waveform.
|
||||
s_len - length of s.
|
||||
d_len - length of d.
|
||||
SNR - SNR level.
|
||||
P - padded length.
|
||||
|
||||
Outputs:
|
||||
s - padded clean speech waveform.
|
||||
x - padded noisy speech waveform.
|
||||
d - truncated, scaled, and padded noise waveform.
|
||||
'''
|
||||
s = tf.truediv(tf.cast(tf.slice(s, [0], [s_len]), tf.float32), 32768.0)
|
||||
d = tf.truediv(tf.cast(tf.slice(d, [0], [d_len]), tf.float32), 32768.0)
|
||||
(x, d) = add_noise(s, d, SNR)
|
||||
total_zeros = tf.subtract(P, tf.shape(s)[0])
|
||||
x = tf.pad(x, [[0, total_zeros]], "CONSTANT")
|
||||
s = tf.pad(s, [[0, total_zeros]], "CONSTANT")
|
||||
d = tf.pad(d, [[0, total_zeros]], "CONSTANT")
|
||||
return (x, s, d)
|
||||
|
||||
def add_noise(s, d, SNR):
|
||||
'''
|
||||
Adds noise to the clean waveform at a specific SNR value. A random section
|
||||
of the noise waveform is used.
|
||||
|
||||
Inputs:
|
||||
s - clean waveform.
|
||||
d - noise waveform.
|
||||
SNR - SNR level.
|
||||
|
||||
Outputs:
|
||||
x - noisy speech waveform.
|
||||
d - truncated and scaled noise waveform.
|
||||
'''
|
||||
s_len = tf.shape(s)[0]
|
||||
d_len = tf.shape(d)[0]
|
||||
i = tf.random_uniform([1], 0, tf.add(1, tf.subtract(d_len, s_len)), tf.int32)
|
||||
d = tf.slice(d, [i[0]], [s_len])
|
||||
d = tf.multiply(tf.truediv(d, tf.norm(d)), tf.truediv(tf.norm(s),
|
||||
tf.pow(10.0, tf.multiply(0.05, SNR))))
|
||||
x = tf.add(s, d)
|
||||
return (x, d)
|
||||
109
lib/dev/args.py
Normal file
109
lib/dev/args.py
Normal file
@@ -0,0 +1,109 @@
|
||||
## FILE: args.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: Get command line arguments.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import os
|
||||
from dev.se_batch import Batch_list, Batch
|
||||
from os.path import expanduser
|
||||
|
||||
## ADD ADDITIONAL ARGUMENTS
|
||||
def add_args(args, modulation=False):
|
||||
|
||||
## DEPENDANT OPTIONS
|
||||
args.model_path = args.model_path + '/' + args.ver # model save path.
|
||||
args.train_s_path = args.set_path + '/train_clean_speech' # path to the clean speech training set.
|
||||
args.train_d_path = args.set_path + '/train_noise' # path to the noise training set.
|
||||
args.val_s_path = args.set_path + '/val_clean_speech' # path to the clean speech validation set.
|
||||
args.val_d_path = args.set_path + '/val_noise' # path to the noise validation set.
|
||||
args.out_path = args.out_path + '/' + args.ver + '/' + 'e' + str(args.epoch) # output path.
|
||||
args.N_w = int(args.f_s*args.T_w*0.001) # window length (samples).
|
||||
args.N_s = int(args.f_s*args.T_s*0.001) # window shift (samples).
|
||||
args.NFFT = int(pow(2, np.ceil(np.log2(args.N_w)))) # number of DFT components.
|
||||
|
||||
## DATASETS
|
||||
if args.train: ## TRAINING AND VALIDATION CLEAN SPEECH AND NOISE SET
|
||||
args.train_s_list = Batch_list(args.train_s_path, 'clean_speech_' + args.set_path.rsplit('/', 1)[-1], args.data_path) # clean speech training list.
|
||||
args.train_d_list = Batch_list(args.train_d_path, 'noise_' + args.set_path.rsplit('/', 1)[-1], args.data_path) # noise training list.
|
||||
if not os.path.exists(args.model_path): os.makedirs(args.model_path) # make model path directory.
|
||||
args.val_s, args.val_s_len, args.val_snr, _ = Batch(args.val_s_path,
|
||||
list(range(args.min_snr, args.max_snr + 1))) # clean validation waveforms and lengths.
|
||||
args.val_d, args.val_d_len, _, _ = Batch(args.val_d_path,
|
||||
list(range(args.min_snr, args.max_snr + 1))) # noise validation waveforms and lengths.
|
||||
args.train_steps=int(np.ceil(len(args.train_s_list)/args.mbatch_size))
|
||||
args.val_steps=int(np.ceil(args.val_s.shape[0]/args.mbatch_size))
|
||||
|
||||
## INFERENCE
|
||||
# if args.infer: args.test_x, args.test_x_len, args.test_snr, args.test_fnames = Batch(args.test_x_path, '*', []) # noisy speech test waveforms and lengths.
|
||||
if args.infer: args.test_x_list = Batch_list(args.test_x_path, 'test_x', args.data_path, make_new=True)
|
||||
return args
|
||||
|
||||
## STRING TO BOOLEAN
|
||||
def str2bool(s): return s.lower() in ("yes", "true", "t", "1")
|
||||
|
||||
## GET COMMAND LINE ARGUMENTS
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
## OPTIONS (GENERAL)
|
||||
parser.add_argument('--gpu', default='0', type=str, help='GPU selection')
|
||||
parser.add_argument('--ver', type=str, help='Model version')
|
||||
parser.add_argument('--epoch', type=int, help='Epoch to use/retrain from')
|
||||
parser.add_argument('--train', default=False, type=str2bool, help='Training flag')
|
||||
parser.add_argument('--infer', default=False, type=str2bool, help='Inference flag')
|
||||
parser.add_argument('--verbose', default=False, type=str2bool, help='Verbose')
|
||||
parser.add_argument('--model', default='ResNet', type=str, help='Model type')
|
||||
|
||||
## OPTIONS (TRAIN)
|
||||
parser.add_argument('--cont', default=False, type=str2bool, help='Continue testing from last epoch')
|
||||
parser.add_argument('--mbatch_size', default=10, type=int, help='Mini-batch size')
|
||||
parser.add_argument('--sample_size', default=1000, type=int, help='Sample size')
|
||||
parser.add_argument('--max_epochs', default=250, type=int, help='Maximum number of epochs')
|
||||
parser.add_argument('--grad_clip', default=True, type=str2bool, help='Gradient clipping')
|
||||
|
||||
# TEST OUTPUT TYPE
|
||||
# 'xi_hat' - a priori SNR estimate (.mat),
|
||||
# 'y' - enhanced speech (.wav).
|
||||
parser.add_argument('--out_type', default='y', type=str, help='Output type for testing')
|
||||
|
||||
## GAIN FUNCTION
|
||||
# 'ibm' - Ideal Binary Mask (IBM), 'wf' - Wiener Filter (WF), 'srwf' - Square-Root Wiener Filter (SRWF),
|
||||
# 'cwf' - Constrained Wiener Filter (cWF), 'mmse-stsa' - Minimum-Mean Square Error - Short-Time Spectral Amplitude (MMSE-STSA) estimator,
|
||||
# 'mmse-lsa' - Minimum-Mean Square Error - Log-Spectral Amplitude (MMSE-LSA) estimator.
|
||||
parser.add_argument('--gain', default='srwf', type=str, help='Gain function for testing')
|
||||
|
||||
## PATHS
|
||||
parser.add_argument('--model_path', default='model', type=str, help='Model save path')
|
||||
parser.add_argument('--set_path', default='set', type=str, help='Path to datasets')
|
||||
parser.add_argument('--data_path', default='data', type=str, help='Save data path')
|
||||
parser.add_argument('--test_x_path', default='set/test_noisy_speech', type=str, help='Path to the noisy speech test set')
|
||||
parser.add_argument('--out_path', default='out', type=str, help='Output path')
|
||||
|
||||
## FEATURES
|
||||
parser.add_argument('--min_snr', default=-10, type=int, help='Minimum trained SNR level')
|
||||
parser.add_argument('--max_snr', default=20, type=int, help='Maximum trained SNR level')
|
||||
parser.add_argument('--f_s', default=16000, type=int, help='Sampling frequency (Hz)')
|
||||
parser.add_argument('--T_w', default=32, type=int, help='Window length (ms)')
|
||||
parser.add_argument('--T_s', default=16, type=int, help='Window shift (ms)')
|
||||
parser.add_argument('--nconst', default=32768.0, type=float, help='Normalisation constant (see feat.addnoisepad())')
|
||||
|
||||
## NETWORK PARAMETERS
|
||||
parser.add_argument('--d_in', default=257, type=int, help='Input dimensionality')
|
||||
parser.add_argument('--d_out', default=257, type=int, help='Ouput dimensionality')
|
||||
parser.add_argument('--d_model', default=256, type=int, help='Model dimensions')
|
||||
parser.add_argument('--n_blocks', default=40, type=int, help='Number of blocks')
|
||||
parser.add_argument('--d_f', default=64, type=int, help='Number of filters')
|
||||
parser.add_argument('--k_size', default=3, type=int, help='Kernel size')
|
||||
parser.add_argument('--max_d_rate', default=16, type=int, help='Maximum dilation rate')
|
||||
parser.add_argument('--norm_type', default='FrameLayerNorm', type=str, help='Normalisation type')
|
||||
parser.add_argument('--net_height', default=[4], type=list, help='RDL block height')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
77
lib/dev/deepxi_net.py
Normal file
77
lib/dev/deepxi_net.py
Normal file
@@ -0,0 +1,77 @@
|
||||
## FILE: deepxi_net.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University
|
||||
## BRIEF: Network employed withing the Deep Xi framework.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
## VARIABLE DESCRIPTIONS
|
||||
# s - clean speech.
|
||||
# d - noise.
|
||||
# x - noisy speech.
|
||||
|
||||
from dev.acoustic.analysis_synthesis.polar import synthesis
|
||||
from dev.acoustic.feat import polar
|
||||
from dev.ResNet import ResNet
|
||||
import dev.optimisation as optimisation
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
## ARTIFICIAL NEURAL NETWORK
|
||||
class deepxi_net:
|
||||
def __init__(self, args):
|
||||
print('Preparing graph...')
|
||||
|
||||
## RESNET
|
||||
self.input_ph = tf.placeholder(tf.float32, shape=[None, None, args.d_in], name='input_ph') # noisy speech MS placeholder.
|
||||
self.nframes_ph = tf.placeholder(tf.int32, shape=[None], name='nframes_ph') # noisy speech MS sequence length placeholder.
|
||||
if args.model == 'ResNet':
|
||||
self.output = ResNet(self.input_ph, self.nframes_ph, args.norm_type, n_blocks=args.n_blocks, boolean_mask=True, d_out=args.d_out,
|
||||
d_model=args.d_model, d_f=args.d_f, k_size=args.k_size, max_d_rate=args.max_d_rate)
|
||||
elif args.model == 'RDLNet':
|
||||
from dev.RDLNet import RDLNet
|
||||
self.output = RDLNet(self.input_ph, self.nframes_ph, args.norm_type, n_blocks=args.n_blocks, boolean_mask=True, d_out=args.d_out,
|
||||
d_f=args.d_f, net_height=args.net_height)
|
||||
elif args.model == 'ResLSTM':
|
||||
from dev.ResLSTM import ResLSTM
|
||||
self.output = ResLSTM(self.input_ph, self.nframes_ph, args.norm_type, n_blocks=args.n_blocks, boolean_mask=True, d_out=args.d_out, d_model=args.d_model)
|
||||
|
||||
## TRAINING FEATURE EXTRACTION GRAPH
|
||||
self.s_ph = tf.placeholder(tf.int16, shape=[None, None], name='s_ph') # clean speech placeholder.
|
||||
self.d_ph = tf.placeholder(tf.int16, shape=[None, None], name='d_ph') # noise placeholder.
|
||||
self.s_len_ph = tf.placeholder(tf.int32, shape=[None], name='s_len_ph') # clean speech sequence length placeholder.
|
||||
self.d_len_ph = tf.placeholder(tf.int32, shape=[None], name='d_len_ph') # noise sequence length placeholder.
|
||||
self.snr_ph = tf.placeholder(tf.float32, shape=[None], name='snr_ph') # SNR placeholder.
|
||||
self.train_feat = polar.input_target_xi(self.s_ph, self.d_ph, self.s_len_ph,
|
||||
self.d_len_ph, self.snr_ph, args.N_w, args.N_s, args.NFFT, args.f_s, args.stats['mu_hat'], args.stats['sigma_hat'])
|
||||
|
||||
## INFERENCE FEATURE EXTRACTION GRAPH
|
||||
self.infer_feat = polar.input(self.s_ph, self.s_len_ph, args.N_w, args.N_s, args.NFFT, args.f_s)
|
||||
|
||||
## PLACEHOLDERS
|
||||
self.x_ph = tf.placeholder(tf.int16, shape=[None, None], name='x_ph') # noisy speech placeholder.
|
||||
self.x_len_ph = tf.placeholder(tf.int32, shape=[None], name='x_len_ph') # noisy speech sequence length placeholder.
|
||||
self.target_ph = tf.placeholder(tf.float32, shape=[None, args.d_out], name='target_ph') # training target placeholder.
|
||||
self.keep_prob_ph = tf.placeholder(tf.float32, name='keep_prob_ph') # keep probability placeholder.
|
||||
self.training_ph = tf.placeholder(tf.bool, name='training_ph') # training placeholder.
|
||||
|
||||
## SYNTHESIS GRAPH
|
||||
if args.infer:
|
||||
self.infer_output = tf.nn.sigmoid(self.output)
|
||||
self.y_MAG_ph = tf.placeholder(tf.float32, shape=[None, None, args.d_in], name='y_MAG_ph')
|
||||
self.x_PHA_ph = tf.placeholder(tf.float32, [None, None, args.d_in], name='x_PHA_ph')
|
||||
self.y = synthesis(self.y_MAG_ph, self.x_PHA_ph, args.N_w, args.N_s, args.NFFT)
|
||||
|
||||
## LOSS & OPTIMIZER
|
||||
self.loss = optimisation.loss(self.target_ph, self.output, 'mean_sigmoid_cross_entropy', axis=[1])
|
||||
self.total_loss = tf.reduce_mean(self.loss, axis=0)
|
||||
self.trainer, _ = optimisation.optimiser(self.total_loss, optimizer='adam', grad_clip=True)
|
||||
|
||||
## SAVE VARIABLES
|
||||
self.saver = tf.train.Saver(max_to_keep=256)
|
||||
|
||||
## NUMBER OF PARAMETERS
|
||||
args.params = (np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
|
||||
129
lib/dev/gain.py
Normal file
129
lib/dev/gain.py
Normal file
@@ -0,0 +1,129 @@
|
||||
## FILE: gain.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University
|
||||
## BRIEF: Gain functions and masks for speech enhancement.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import numpy as np
|
||||
from scipy.special import exp1, i0, i1
|
||||
|
||||
def mmse_stsa(xi, gamma):
|
||||
'''
|
||||
Computes the MMSE-STSA gain function.
|
||||
|
||||
Input/s:
|
||||
xi - a priori SNR.
|
||||
gamma - a posteriori SNR.
|
||||
|
||||
Output/s:
|
||||
G - MMSE-STSA gain function.
|
||||
'''
|
||||
nu = np.multiply(xi, np.divide(gamma, np.add(1, xi)))
|
||||
G = np.multiply(np.multiply(np.multiply(np.divide(np.sqrt(np.pi), 2),
|
||||
np.divide(np.sqrt(nu), gamma)), np.exp(np.divide(-nu,2))),
|
||||
np.add(np.multiply(np.add(1, nu), i0(np.divide(nu,2))),
|
||||
np.multiply(nu, i1(np.divide(nu, 2))))) # MMSE-STSA gain function.
|
||||
idx = np.isnan(G) | np.isinf(G) # replace by Wiener gain.
|
||||
G[idx] = np.divide(xi[idx], np.add(1, xi[idx])) # Wiener gain.
|
||||
return G
|
||||
|
||||
def mmse_lsa(xi, gamma):
|
||||
'''
|
||||
Computes the MMSE-LSA gain function.
|
||||
|
||||
Input/s:
|
||||
xi - a priori SNR.
|
||||
gamma - a posteriori SNR.
|
||||
|
||||
Output/s:
|
||||
MMSE-LSA gain function.
|
||||
'''
|
||||
nu = np.multiply(np.divide(xi, np.add(1, xi)), gamma)
|
||||
return np.multiply(np.divide(xi, np.add(1, xi)), np.exp(np.multiply(0.5, exp1(nu)))) # MMSE-LSA gain function.
|
||||
|
||||
def wf(xi):
|
||||
'''
|
||||
Computes the Wiener filter (WF) gain function.
|
||||
|
||||
Input/s:
|
||||
xi - a priori SNR.
|
||||
|
||||
Output/s:
|
||||
WF gain function.
|
||||
'''
|
||||
return np.divide(xi, np.add(xi, 1.0)) # WF gain function.
|
||||
|
||||
def srwf(xi):
|
||||
'''
|
||||
Computes the square-root Wiener filter (WF) gain function.
|
||||
|
||||
Input/s:
|
||||
xi - a priori SNR.
|
||||
|
||||
Output/s:
|
||||
SRWF gain function.
|
||||
'''
|
||||
return np.sqrt(wf(xi)) # SRWF gain function.
|
||||
|
||||
def cwf(xi):
|
||||
'''
|
||||
Computes the constrained Wiener filter (WF) gain function.
|
||||
|
||||
Input/s:
|
||||
xi - a priori SNR.
|
||||
|
||||
Output/s:
|
||||
cWF gain function.
|
||||
'''
|
||||
return wf(np.sqrt(xi)) # cWF gain function.
|
||||
|
||||
def irm(xi):
|
||||
'''
|
||||
Computes the ideal ratio mask (IRM).
|
||||
|
||||
Input/s:
|
||||
xi - a priori SNR.
|
||||
|
||||
Output/s:
|
||||
IRM.
|
||||
'''
|
||||
return srwf(xi) # IRM.
|
||||
|
||||
|
||||
def ibm(xi):
|
||||
'''
|
||||
Computes the ideal binary mask (IBM) with a threshold of 0 dB.
|
||||
|
||||
Input/s:
|
||||
xi - a priori SNR.
|
||||
|
||||
Output/s:
|
||||
IBM.
|
||||
'''
|
||||
return np.greater(self.xi_hat_ph, 1, dtype=np.float32) # IBM (1 corresponds to 0 dB).
|
||||
|
||||
def gfunc(xi, gamma=None, gtype='mmse-lsa'):
|
||||
'''
|
||||
Computes the selected gain function.
|
||||
|
||||
Input/s:
|
||||
xi - a priori SNR.
|
||||
gamma - a posteriori SNR.
|
||||
gtype - gain function type.
|
||||
|
||||
Output/s:
|
||||
G - gain function.
|
||||
'''
|
||||
if gtype == 'mmse-lsa': G = mmse_lsa(xi, gamma)
|
||||
elif gtype == 'mmse-stsa': G = mmse_stsa(xi, gamma)
|
||||
elif gtype == 'wf': G = wf(xi)
|
||||
elif gtype == 'srwf': G = srwf(xi)
|
||||
elif gtype == 'cwf': G = cwf(xi)
|
||||
elif gtype == 'irm': G = irm(xi)
|
||||
elif gtype == 'ibm': G = ibm(xi)
|
||||
else: ValueError('Gain function not available.')
|
||||
return G
|
||||
97
lib/dev/infer.py
Normal file
97
lib/dev/infer.py
Normal file
@@ -0,0 +1,97 @@
|
||||
## FILE: infer.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: Inference module.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
from dev.utils import read_wav
|
||||
from tqdm import tqdm
|
||||
import dev.gain as gain
|
||||
import dev.utils as utils
|
||||
import dev.xi as xi
|
||||
import numpy as np
|
||||
import os
|
||||
import scipy.io as spio
|
||||
|
||||
## INFERENCE
|
||||
def infer(sess, net, args):
|
||||
print("Inference...", )
|
||||
print (args.test_x_list)
|
||||
net.saver.restore(sess, args.model_path + '/epoch-' + str(args.epoch)) # load model from epoch.
|
||||
|
||||
if args.out_type == 'xi_hat': args.out_path = args.out_path + '/xi_hat'
|
||||
elif args.out_type == 'y': args.out_path = args.out_path + '/' + args.gain + '/y'
|
||||
elif args.out_type == 'ibm_hat': args.out_path = args.out_path + '/ibm_hat'
|
||||
else: ValueError('Incorrect output type.')
|
||||
|
||||
if not os.path.exists(args.out_path): os.makedirs(args.out_path) # make output directory.
|
||||
|
||||
for j in tqdm(args.test_x_list):
|
||||
(wav, _) = read_wav(j['file_path']) # read wav from given file path.
|
||||
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [j['seq_len']]}) # sample of training set.
|
||||
xi_bar_hat = sess.run(net.infer_output, feed_dict={net.input_ph: input_feat[0],
|
||||
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
|
||||
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
|
||||
|
||||
file_name = j['file_path'].rsplit('/',1)[1].split('.')[0]
|
||||
|
||||
if args.out_type == 'xi_hat':
|
||||
spio.savemat(args.out_path + '/' + file_name + '.mat', {'xi_hat':xi_hat})
|
||||
|
||||
elif args.out_type == 'y':
|
||||
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
|
||||
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
|
||||
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
|
||||
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
|
||||
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
|
||||
print (args.out_path + '/' + file_name + '.wav')
|
||||
utils.save_wav(args.out_path + '/' + file_name + '.wav', args.f_s, y)
|
||||
|
||||
elif args.out_type == 'ibm_hat':
|
||||
ibm_hat = np.greater(xi_hat, 1.0)
|
||||
spio.savemat(args.out_path + '/' + file_name + '.mat', {'ibm_hat':ibm_hat})
|
||||
|
||||
print('Inference complete.')
|
||||
|
||||
def infer2(sess, net, args):
|
||||
print("Inference...", )
|
||||
print (args.test_x_list)
|
||||
net.saver.restore(sess, args.model_path + '/epoch-' + str(args.epoch)) # load model from epoch.
|
||||
|
||||
if args.out_type == 'xi_hat': args.out_path = args.out_path + '/xi_hat'
|
||||
elif args.out_type == 'y': args.out_path = args.out_path + '/' + args.gain + '/y'
|
||||
elif args.out_type == 'ibm_hat': args.out_path = args.out_path + '/ibm_hat'
|
||||
else: ValueError('Incorrect output type.')
|
||||
|
||||
if not os.path.exists(args.out_path): os.makedirs(args.out_path) # make output directory.
|
||||
|
||||
for j in tqdm(args.test_x_list):
|
||||
(wav, _) = read_wav(j['file_path']) # read wav from given file path.
|
||||
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [j['seq_len']]}) # sample of training set.
|
||||
xi_bar_hat = sess.run(net.infer_output, feed_dict={net.input_ph: input_feat[0],
|
||||
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
|
||||
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
|
||||
|
||||
file_name = j['file_path'].rsplit('/',1)[1].split('.')[0]
|
||||
|
||||
if args.out_type == 'xi_hat':
|
||||
spio.savemat(args.out_path + '/' + file_name + '.mat', {'xi_hat':xi_hat})
|
||||
|
||||
elif args.out_type == 'y':
|
||||
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
|
||||
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
|
||||
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
|
||||
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
|
||||
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
|
||||
print (args.out_path + '/' + file_name + '.wav')
|
||||
utils.save_wav(args.out_path + '/' + file_name + '.wav', args.f_s, y)
|
||||
|
||||
elif args.out_type == 'ibm_hat':
|
||||
ibm_hat = np.greater(xi_hat, 1.0)
|
||||
spio.savemat(args.out_path + '/' + file_name + '.mat', {'ibm_hat':ibm_hat})
|
||||
|
||||
print('Inference complete.')
|
||||
96
lib/dev/normalisation.py
Normal file
96
lib/dev/normalisation.py
Normal file
@@ -0,0 +1,96 @@
|
||||
## FILE: normalisation.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: Layer/instance/batch normalisation functions.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
from os.path import expanduser
|
||||
import argparse, os, string
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
def Normalisation(x, norm_type='FrameLayerNorm', seq_len=None, mask=None, training=False, centre=True, scale=True):
|
||||
'''
|
||||
Normalisation.
|
||||
|
||||
Input/s:
|
||||
x - unnormalised input.
|
||||
norm_type - normalisation type.
|
||||
seq_len - length of each sequence.
|
||||
mask - sequence mask.
|
||||
training - training flag.
|
||||
|
||||
Output/s:
|
||||
normalised input.
|
||||
'''
|
||||
|
||||
if norm_type == 'SeqCausalLayerNorm': return SeqCausalLayerNorm(x, seq_len, centre=centre, scale=scale)
|
||||
elif norm_type == 'FrameLayerNorm': return FrameLayerNorm(x, centre=centre, scale=scale)
|
||||
elif norm_type == 'unnormalised': return x
|
||||
else: ValueError('Normalisation type does not exist: %s.' % (norm_type))
|
||||
|
||||
count = 0
|
||||
def SeqCausalLayerNorm(x, seq_len, centre=True, scale=True):
|
||||
'''
|
||||
Sequence-wise causal layer normalisation with sequence masking (causal layer norm version of https://arxiv.org/pdf/1510.01378.pdf).
|
||||
|
||||
Input/s:
|
||||
x - input.
|
||||
seq_len - length of each sequence.
|
||||
centre - centre parameter.
|
||||
scale - scale parameter.
|
||||
|
||||
Output/s:
|
||||
normalised input.
|
||||
'''
|
||||
global count
|
||||
count += 1
|
||||
with tf.variable_scope('LayerNorm' + str(count)):
|
||||
input_size = x.get_shape().as_list()[-1]
|
||||
mask = tf.cast(tf.sequence_mask(seq_len), tf.float32) # convert mask to float.
|
||||
den = tf.multiply(tf.range(1.0, tf.add(tf.cast(tf.shape(mask)[-1], tf.float32), 1.0), dtype=tf.float32), input_size)
|
||||
mu = tf.expand_dims(tf.truediv(tf.cumsum(tf.reduce_sum(x, -1), -1), den), 2)
|
||||
sigma = tf.expand_dims(tf.truediv(tf.cumsum(tf.reduce_sum(tf.square(tf.subtract(x,
|
||||
mu)), -1), -1), den),2)
|
||||
if centre: beta = tf.get_variable("beta", input_size, dtype=tf.float32,
|
||||
initializer=tf.constant_initializer(0.0), trainable=True)
|
||||
else: beta = tf.constant(np.zeros(input_size), name="beta", dtype=tf.float32)
|
||||
if scale: gamma = tf.get_variable("Gamma", input_size, dtype=tf.float32,
|
||||
initializer=tf.constant_initializer(1.0), trainable=True)
|
||||
else: gamma = tf.constant(np.ones(input_size), name="Gamma", dtype=tf.float32)
|
||||
return tf.multiply(tf.nn.batch_normalization(x, mu, sigma, offset=beta, scale=gamma,
|
||||
variance_epsilon = 1e-12), tf.expand_dims(mask, 2))
|
||||
|
||||
count = 0
|
||||
def FrameLayerNorm(x, centre=True, scale=True):
|
||||
'''
|
||||
Frame-wise layer normalisation (layer norm version of https://arxiv.org/pdf/1510.01378.pdf).
|
||||
|
||||
Input/s:
|
||||
x - input.
|
||||
seq_len - length of each sequence.
|
||||
centre - centre parameter.
|
||||
scale - scale parameter.
|
||||
|
||||
Output/s:
|
||||
normalised input.
|
||||
'''
|
||||
global count
|
||||
count += 1
|
||||
|
||||
with tf.variable_scope('frm_wise_layer_norm' + str(count)):
|
||||
mu, sigma = tf.nn.moments(x, -1, keepdims=True)
|
||||
input_size = x.get_shape().as_list()[-1] # get number of input dimensions.
|
||||
if centre:
|
||||
beta = tf.get_variable("beta", input_size, dtype=tf.float32,
|
||||
initializer=tf.constant_initializer(0.0), trainable=True)
|
||||
else: beta = tf.constant(np.zeros(input_size), name="beta", dtype=tf.float32)
|
||||
if scale:
|
||||
gamma = tf.get_variable("Gamma", input_size, dtype=tf.float32,
|
||||
initializer=tf.constant_initializer(1.0), trainable=True)
|
||||
else: gamma = tf.constant(np.ones(input_size), name="Gamma", dtype=tf.float32)
|
||||
return tf.nn.batch_normalization(x, mu, sigma, offset=beta, scale=gamma,
|
||||
variance_epsilon = 1e-12) # normalise batch.
|
||||
43
lib/dev/num_frames.py
Normal file
43
lib/dev/num_frames.py
Normal file
@@ -0,0 +1,43 @@
|
||||
## FILE: nframes.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: Detirmines number of frames in a signal.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
def num_frames(N, N_s):
|
||||
'''
|
||||
Returns the number of frames for a given sequence length, and
|
||||
frame shift.
|
||||
|
||||
Inputs:
|
||||
N - sequence length (samples).
|
||||
N_s - frame shift (samples).
|
||||
|
||||
Output:
|
||||
number of frames
|
||||
'''
|
||||
return tf.cast(tf.ceil(tf.truediv(tf.cast(N, tf.float32),tf.cast(N_s, tf.float32))), tf.int32) # number of frames.
|
||||
|
||||
def acou_num_frames(N, N_s, K_s):
|
||||
'''
|
||||
Returns the number of acoustic-domain frames for a given sequence length, and
|
||||
frame shift.
|
||||
|
||||
Inputs:
|
||||
N - time-domain sequence length (samples).
|
||||
N_s - time-domain frame shift (samples).
|
||||
K_s - acoustic-domain frame shift (samples).
|
||||
|
||||
Output:
|
||||
number of modulation-domain frames
|
||||
'''
|
||||
N = tf.cast(N, tf.float32)
|
||||
N_s = tf.cast(N_s, tf.float32)
|
||||
K_s = tf.cast(K_s, tf.float32)
|
||||
return tf.cast(tf.ceil(tf.truediv(tf.truediv(N, N_s), K_s)), tf.int32) # number of frames.
|
||||
50
lib/dev/optimisation.py
Normal file
50
lib/dev/optimisation.py
Normal file
@@ -0,0 +1,50 @@
|
||||
## FILE: optimistion.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: Loss functions and algorithms for gradient descent optimisation.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
## LOSS FUNCTIONS
|
||||
def loss(target, estimate, loss_fnc, axis=1):
|
||||
'loss functions for gradient descent.'
|
||||
with tf.name_scope(loss_fnc + '_loss'):
|
||||
if loss_fnc == 'quadratic':
|
||||
loss = tf.reduce_sum(tf.square(tf.subtract(target, estimate)), axis=axis)
|
||||
if loss_fnc == 'sigmoid_cross_entropy':
|
||||
loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=estimate), axis=axis)
|
||||
if loss_fnc == 'mean_quadratic':
|
||||
loss = tf.reduce_mean(tf.square(tf.subtract(target, estimate)), axis=axis)
|
||||
if loss_fnc == 'mean_sigmoid_cross_entropy':
|
||||
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=estimate), axis=axis)
|
||||
return loss
|
||||
|
||||
## GRADIENT DESCENT OPTIMISERS
|
||||
def optimiser(loss, lr=None, epsilon=None, var_list=None, optimizer='adam', grad_clip=False):
|
||||
'optimizers for training.'
|
||||
with tf.name_scope(optimizer + '_opt'):
|
||||
if optimizer == 'adam':
|
||||
if lr == None: lr = 0.001 # default.
|
||||
if epsilon == None: epsilon = 1e-8 # default.
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=epsilon)
|
||||
if optimizer == 'lazyadam':
|
||||
if lr == None: lr = 0.001 # default.
|
||||
if epsilon == None: epsilon = 1e-8 # default.
|
||||
optimizer = tf.contrib.opt.LazyAdamOptimizer(learning_rate=lr, epsilon=epsilon)
|
||||
if optimizer == 'nadam':
|
||||
if lr == None: lr = 0.001 # default.
|
||||
if epsilon == None: epsilon = 1e-8 # default.
|
||||
optimizer = tf.contrib.opt.NadamOptimizer(learning_rate=lr, epsilon=epsilon)
|
||||
if optimizer == 'sgd':
|
||||
if lr == None: lr = 0.5 # default.
|
||||
optimizer = tf.train.GradientDescentOptimizer(lr)
|
||||
grads_and_vars = optimizer.compute_gradients(loss, var_list=var_list)
|
||||
if grad_clip: grads_and_vars = [(tf.clip_by_value(gv[0], -1., 1.), gv[1]) for gv in grads_and_vars]
|
||||
trainer = optimizer.apply_gradients(grads_and_vars)
|
||||
|
||||
return trainer, optimizer
|
||||
57
lib/dev/sample_stats.py
Normal file
57
lib/dev/sample_stats.py
Normal file
@@ -0,0 +1,57 @@
|
||||
## FILE: sample_stats.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University
|
||||
## BRIEF: Get statistics from sample of the training set.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import os, pickle, random
|
||||
import dev.se_batch as batch
|
||||
from dev.acoustic.feat import polar
|
||||
from tqdm import tqdm
|
||||
import scipy.io as spio
|
||||
|
||||
## GET STATISTICS OF SAMPLE
|
||||
def get_stats(stats_path, args, config):
|
||||
if os.path.exists(stats_path + '/stats.p'):
|
||||
print('Loading sample statistics from pickle file...')
|
||||
with open(stats_path + '/stats.p', 'rb') as f:
|
||||
args.stats = pickle.load(f)
|
||||
return args
|
||||
elif args.infer:
|
||||
raise ValueError('You have not completed training (no stats.p file exsists). In the Deep Xi github repository, data/stats.p is available.')
|
||||
else:
|
||||
print('Finding sample statistics...')
|
||||
random.shuffle(args.train_s_list) # shuffle list.
|
||||
s_sample, s_sample_seq_len = batch.Clean_mbatch(args.train_s_list,
|
||||
args.sample_size, 0, args.sample_size) # generate mini-batch of clean training waveforms.
|
||||
d_sample, d_sample_seq_len = batch.Noise_mbatch(args.train_d_list,
|
||||
args.sample_size, s_sample_seq_len) # generate mini-batch of noise training waveforms.
|
||||
snr_sample = np.random.randint(args.min_snr, args.max_snr + 1, args.sample_size) # generate mini-batch of SNR levels.
|
||||
s_ph = tf.placeholder(tf.int16, shape=[None, None], name='s_ph') # clean speech placeholder.
|
||||
d_ph = tf.placeholder(tf.int16, shape=[None, None], name='d_ph') # noise placeholder.
|
||||
s_len_ph = tf.placeholder(tf.int32, shape=[None], name='s_len_ph') # clean speech sequence length placeholder.
|
||||
d_len_ph = tf.placeholder(tf.int32, shape=[None], name='d_len_ph') # noise sequence length placeholder.
|
||||
snr_ph = tf.placeholder(tf.float32, shape=[None], name='snr_ph') # SNR placeholder.
|
||||
analysis = polar.target_xi(s_ph, d_ph, s_len_ph, d_len_ph, snr_ph, args.N_w, args.N_s, args.NFFT, args.f_s)
|
||||
sample_graph = analysis[0]
|
||||
samples = []
|
||||
with tf.Session(config=config) as sess:
|
||||
for i in tqdm(range(s_sample.shape[0])):
|
||||
sample = sess.run(sample_graph, feed_dict={s_ph: [s_sample[i]], d_ph: [d_sample[i]], s_len_ph: [s_sample_seq_len[i]],
|
||||
d_len_ph: [d_sample_seq_len[i]], snr_ph: [snr_sample[i]]}) # sample of training set.
|
||||
samples.append(sample)
|
||||
samples = np.vstack(samples)
|
||||
if len(samples.shape) != 2: ValueError('Incorrect shape for sample.')
|
||||
args.stats = {'mu_hat': np.mean(samples, axis=0), 'sigma_hat': np.std(samples, axis=0)}
|
||||
if not os.path.exists(stats_path): os.makedirs(stats_path) # make directory.
|
||||
with open(stats_path + '/stats.p', 'wb') as f:
|
||||
pickle.dump(args.stats, f)
|
||||
spio.savemat(stats_path + '/stats.m', mdict={'mu_hat': args.stats['mu_hat'], 'sigma_hat': args.stats['sigma_hat']})
|
||||
print('Sample statistics saved to pickle file.')
|
||||
return args
|
||||
155
lib/dev/se_batch.py
Normal file
155
lib/dev/se_batch.py
Normal file
@@ -0,0 +1,155 @@
|
||||
## FILE: se_batch.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: Generates mini-batches, creates training, and test lists for speech enhancement.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import contextlib, glob, os, pickle, platform, random, sys, wave
|
||||
import numpy as np
|
||||
from dev.utils import read_wav
|
||||
from scipy.io.wavfile import read
|
||||
|
||||
def Batch_list(file_dir, list_name, data_path=None, make_new=False):
|
||||
from soundfile import SoundFile, SEEK_END
|
||||
'''
|
||||
Places the file paths and wav lengths of an audio file into a dictionary, which
|
||||
is then appended to a list. SPHERE format cannot be used. 'glob' is used to
|
||||
support Unix style pathname pattern expansions. Checks if the training list
|
||||
has already been pickled, and loads it. If a different dataset is to be
|
||||
used, delete the pickle file.
|
||||
|
||||
Inputs:
|
||||
file_dir - directory containing the wavs.
|
||||
list_name - name for the list.
|
||||
data_path - path to store pickle files.
|
||||
make_new - re-create list.
|
||||
|
||||
Outputs:
|
||||
batch_list - list of file paths and wav length.
|
||||
'''
|
||||
file_name = ['*.wav', '*.flac', '*.mp3']
|
||||
if data_path == None: data_path = 'data'
|
||||
if not make_new:
|
||||
if os.path.exists(data_path + '/' + list_name + '_list_' + platform.node() + '.p'):
|
||||
print('Loading ' + list_name + ' list from pickle file...')
|
||||
with open(data_path + '/' + list_name + '_list_' + platform.node() + '.p', 'rb') as f:
|
||||
batch_list = pickle.load(f)
|
||||
if batch_list[0]['file_path'].find(file_dir) != -1:
|
||||
print('The ' + list_name + ' list has a total of %i entries.' % (len(batch_list)))
|
||||
return batch_list
|
||||
print('Creating ' + list_name + ' list, as no pickle file exists...')
|
||||
batch_list = [] # list for wav paths and lengths.
|
||||
for fn in file_name:
|
||||
for file_path in glob.glob(os.path.join(file_dir, fn)):
|
||||
f = SoundFile(file_path)
|
||||
seq_len = f.seek(0, SEEK_END)
|
||||
batch_list.append({'file_path': file_path, 'seq_len': seq_len}) # append dictionary.
|
||||
if not os.path.exists(data_path): os.makedirs(data_path) # make directory.
|
||||
with open(data_path + '/' + list_name + '_list_' + platform.node() + '.p', 'wb') as f:
|
||||
pickle.dump(batch_list, f)
|
||||
print('The ' + list_name + ' list has a total of %i entries.' % (len(batch_list)))
|
||||
return batch_list
|
||||
|
||||
def Clean_mbatch(clean_list, mbatch_size, start_idx, end_idx):
|
||||
'''
|
||||
Creates a padded mini-batch of clean speech wavs.
|
||||
|
||||
Inputs:
|
||||
clean_list - training list for the clean speech files.
|
||||
mbatch_size - size of the mini-batch.
|
||||
version - version name.
|
||||
|
||||
Outputs:
|
||||
mbatch - matrix of paded wavs stored as a numpy array.
|
||||
seq_len - length of each wavs strored as a numpy array.
|
||||
clean_list - training list for the clean files.
|
||||
'''
|
||||
mbatch_list = clean_list[start_idx:end_idx] # get mini-batch list from training list.
|
||||
maxlen = max([dic['seq_len'] for dic in mbatch_list]) # find maximum length wav in mini-batch.
|
||||
seq_len = [] # list of the wavs lengths.
|
||||
mbatch = np.zeros([len(mbatch_list), maxlen], np.int16) # numpy array for wav matrix.
|
||||
for i in range(len(mbatch_list)):
|
||||
(wav, _) = read_wav(mbatch_list[i]['file_path']) # read wav from given file path.
|
||||
mbatch[i,:mbatch_list[i]['seq_len']] = wav # add wav to numpy array.
|
||||
seq_len.append(mbatch_list[i]['seq_len']) # append length of wav to list.
|
||||
return mbatch, np.array(seq_len, np.int32)
|
||||
|
||||
def Noise_mbatch(noise_list, mbatch_size, clean_seq_len):
|
||||
'''
|
||||
Creates a padded mini-batch of noise speech wavs.
|
||||
|
||||
Inputs:
|
||||
noise_list - training list for the noise files.
|
||||
mbatch_size - size of the mini-batch.
|
||||
clean_seq_len - sequence length of each clean speech file in the mini-batch.
|
||||
|
||||
Outputs:
|
||||
mbatch - matrix of paded wavs stored as a numpy array.
|
||||
seq_len - length of each wavs strored as a numpy array.
|
||||
'''
|
||||
|
||||
mbatch_list = random.sample(noise_list, mbatch_size) # get mini-batch list from training list.
|
||||
for i in range(len(clean_seq_len)):
|
||||
flag = True
|
||||
while flag:
|
||||
if mbatch_list[i]['seq_len'] < clean_seq_len[i]:
|
||||
mbatch_list[i] = random.choice(noise_list)
|
||||
else:
|
||||
flag = False
|
||||
maxlen = max([dic['seq_len'] for dic in mbatch_list]) # find maximum length wav in mini-batch.
|
||||
seq_len = [] # list of the wav lengths.
|
||||
mbatch = np.zeros([len(mbatch_list), maxlen], np.int16) # numpy array for wav matrix.
|
||||
for i in range(len(mbatch_list)):
|
||||
(wav, _) = read_wav(mbatch_list[i]['file_path']) # read wav from given file path.
|
||||
mbatch[i,:mbatch_list[i]['seq_len']] = wav # add wav to numpy array.
|
||||
seq_len.append(mbatch_list[i]['seq_len']) # append length of wav to list.
|
||||
return mbatch, np.array(seq_len, np.int32)
|
||||
|
||||
def Batch(fdir, snr_l):
|
||||
'''
|
||||
REQUIRES REWRITING.
|
||||
|
||||
Places all of the test waveforms from the list into a numpy array.
|
||||
SPHERE format cannot be used. 'glob' is used to support Unix style pathname
|
||||
pattern expansions. Waveforms are padded to the maximum waveform length. The
|
||||
waveform lengths are recorded so that the correct lengths can be sliced
|
||||
for feature extraction. The SNR levels of each test file are placed into a
|
||||
numpy array. Also returns a list of the file names.
|
||||
|
||||
Inputs:
|
||||
fdir - directory containing the waveforms.
|
||||
fnames - filename/s of the waveforms.
|
||||
snr_l - list of the SNR levels used.
|
||||
|
||||
Outputs:
|
||||
wav_np - matrix of paded waveforms stored as a numpy array.
|
||||
len_np - length of each waveform strored as a numpy array.
|
||||
snr_test_np - numpy array of all the SNR levels for the test set.
|
||||
fname_l - list of filenames.
|
||||
'''
|
||||
fname_l = [] # list of file names.
|
||||
wav_l = [] # list for waveforms.
|
||||
snr_test_l = [] # list of SNR levels for the test set.
|
||||
# if isinstance(fnames, str): fnames = [fnames] # if string, put into list.
|
||||
fnames = ['*.wav', '*.flac', '*.mp3']
|
||||
for fname in fnames:
|
||||
for fpath in glob.glob(os.path.join(fdir, fname)):
|
||||
for snr in snr_l:
|
||||
if fpath.find('_' + str(snr) + 'dB') != -1:
|
||||
snr_test_l.append(snr) # append SNR level.
|
||||
(wav, _) = read_wav(fpath) # read waveform from given file path.
|
||||
if np.isnan(wav).any() or np.isinf(wav).any():
|
||||
raise ValueError('Error: NaN or Inf value. File path: %s.' % (file_path))
|
||||
wav_l.append(wav) # append.
|
||||
fname_l.append(os.path.basename(os.path.splitext(fpath)[0])) # append name.
|
||||
len_l = [] # list of the waveform lengths.
|
||||
maxlen = max(len(wav) for wav in wav_l) # maximum length of waveforms.
|
||||
wav_np = np.zeros([len(wav_l), maxlen], np.int16) # numpy array for waveform matrix.
|
||||
for (i, wav) in zip(range(len(wav_l)), wav_l):
|
||||
wav_np[i,:len(wav)] = wav # add waveform to numpy array.
|
||||
len_l.append(len(wav)) # append length of waveform to list.
|
||||
return wav_np, np.array(len_l, np.int32), np.array(snr_test_l, np.int32), fname_l
|
||||
97
lib/dev/train.py
Normal file
97
lib/dev/train.py
Normal file
@@ -0,0 +1,97 @@
|
||||
## FILE: train.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: Training module.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import dev.se_batch as batch
|
||||
import numpy as np
|
||||
import math, os, random
|
||||
import tensorflow as tf
|
||||
from datetime import datetime
|
||||
from tqdm import tqdm
|
||||
|
||||
## TRAINING
|
||||
def train(sess, net, args):
|
||||
print("Training...")
|
||||
random.shuffle(args.train_s_list) # shuffle training list.
|
||||
|
||||
## CONTINUE FROM LAST EPOCH
|
||||
if args.cont:
|
||||
epoch_size = len(args.train_s_list); epoch_comp = args.epoch; start_idx = 0;
|
||||
end_idx = args.mbatch_size; val_error = float("inf") # create epoch parameters.
|
||||
net.saver.restore(sess, args.model_path + '/epoch-' + str(args.epoch)) # load model from last epoch.
|
||||
|
||||
## TRAIN RAW NETWORK
|
||||
else:
|
||||
epoch_size = len(args.train_s_list); epoch_comp = 0; start_idx = 0;
|
||||
end_idx = args.mbatch_size; val_error = float("inf") # create epoch parameters.
|
||||
if args.mbatch_size > epoch_size: raise ValueError('Error: mini-batch size is greater than the epoch size.')
|
||||
sess.run(tf.global_variables_initializer()) # initialise model variables.
|
||||
net.saver.save(sess, args.model_path + '/epoch', global_step=epoch_comp) # save model.
|
||||
|
||||
## TRAINING LOG
|
||||
if not os.path.exists('log'): os.makedirs('log') # create log directory.
|
||||
with open("log/" + args.ver + ".csv", "a") as results: results.write("'"'Validation error'"', '"'Training error'"', '"'Epoch'"', '"'D/T'"'\n")
|
||||
|
||||
train_err = 0; mbatch_count = 0
|
||||
while args.train:
|
||||
|
||||
print('Training E%d (ver=%s, gpu=%s, params=%g)...' % (epoch_comp + 1, args.ver, args.gpu, args.params))
|
||||
for _ in tqdm(range(args.train_steps)):
|
||||
|
||||
## MINI-BATCH GENERATION
|
||||
mbatch_size_iter = end_idx - start_idx # number of examples in mini-batch for the training iteration.
|
||||
s_mbatch, s_mbatch_seq_len = batch.Clean_mbatch(args.train_s_list,
|
||||
mbatch_size_iter, start_idx, end_idx) # generate mini-batch of clean training waveforms.
|
||||
d_mbatch, d_mbatch_seq_len = batch.Noise_mbatch(args.train_d_list,
|
||||
mbatch_size_iter, s_mbatch_seq_len) # generate mini-batch of noise training waveforms.
|
||||
snr_mbatch = np.random.randint(args.min_snr, args.max_snr + 1, end_idx - start_idx) # generate mini-batch of SNR levels.
|
||||
|
||||
## TRAINING ITERATION
|
||||
mbatch = sess.run(net.train_feat, feed_dict={net.s_ph: s_mbatch, net.d_ph: d_mbatch,
|
||||
net.s_len_ph: s_mbatch_seq_len, net.d_len_ph: d_mbatch_seq_len, net.snr_ph: snr_mbatch}) # mini-batch.
|
||||
[_, mbatch_err] = sess.run([net.trainer, net.total_loss], feed_dict={net.input_ph: mbatch[0], net.target_ph: mbatch[1],
|
||||
net.nframes_ph: mbatch[2], net.training_ph: True}) # training iteration.
|
||||
if not math.isnan(mbatch_err):
|
||||
train_err += mbatch_err; mbatch_count += 1
|
||||
|
||||
## UPDATE EPOCH PARAMETERS
|
||||
start_idx += args.mbatch_size; end_idx += args.mbatch_size # start and end index of mini-batch.
|
||||
if end_idx > epoch_size: end_idx = epoch_size # if less than the mini-batch size of examples is left.
|
||||
|
||||
## VALIDATION SET ERROR
|
||||
start_idx = 0; end_idx = args.mbatch_size # reset start and end index of mini-batch.
|
||||
random.shuffle(args.train_s_list) # shuffle list.
|
||||
start_idx = 0; end_idx = args.mbatch_size; frames = 0; val_error = 0; # validation variables.
|
||||
print('Validation error for E%d...' % (epoch_comp + 1))
|
||||
for _ in tqdm(range(args.val_steps)):
|
||||
mbatch = sess.run(net.train_feat, feed_dict={net.s_ph: args.val_s[start_idx:end_idx],
|
||||
net.d_ph: args.val_d[start_idx:end_idx], net.s_len_ph: args.val_s_len[start_idx:end_idx],
|
||||
net.d_len_ph: args.val_d_len[start_idx:end_idx], net.snr_ph: args.val_snr[start_idx:end_idx]}) # mini-batch.
|
||||
val_error_mbatch = sess.run(net.loss, feed_dict={net.input_ph: mbatch[0],
|
||||
net.target_ph: mbatch[1], net.nframes_ph: mbatch[2], net.training_ph: False}) # validation error for each frame in mini-batch.
|
||||
val_error += np.sum(val_error_mbatch)
|
||||
frames += mbatch[1].shape[0] # total number of frames.
|
||||
print("Validation error for Epoch %d: %3.3f%% complete. " %
|
||||
(epoch_comp + 1, 100*(end_idx/args.val_s_len.shape[0])), end="\r")
|
||||
start_idx += args.mbatch_size; end_idx += args.mbatch_size
|
||||
if end_idx > args.val_s_len.shape[0]: end_idx = args.val_s_len.shape[0]
|
||||
val_error /= frames # validation error.
|
||||
epoch_comp += 1 # an epoch has been completed.
|
||||
net.saver.save(sess, args.model_path + '/epoch', global_step=epoch_comp) # save model.
|
||||
print("E%d: train err=%3.3f, val err=%3.3f. " %
|
||||
(epoch_comp, train_err/mbatch_count, val_error))
|
||||
with open("log/" + args.ver + ".csv", "a") as results:
|
||||
results.write("%g, %g, %d, %s\n" % (val_error, train_err/mbatch_count,
|
||||
epoch_comp, datetime.now().strftime('%Y-%m-%d/%H:%M:%S')))
|
||||
train_err = 0; mbatch_count = 0; start_idx = 0; end_idx = args.mbatch_size
|
||||
|
||||
if epoch_comp >= args.max_epochs:
|
||||
args.train = False
|
||||
print('\nTraining complete. Validation error for epoch %d: %g. ' %
|
||||
(epoch_comp, val_error))
|
||||
49
lib/dev/utils.py
Normal file
49
lib/dev/utils.py
Normal file
@@ -0,0 +1,49 @@
|
||||
## FILE: utils.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University.
|
||||
## BRIEF: General utility functions/modules.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
from os.path import expanduser
|
||||
import argparse, os, string
|
||||
import numpy as np
|
||||
from scipy.io.wavfile import write as wav_write
|
||||
import tensorflow as tf
|
||||
import soundfile as sf
|
||||
|
||||
def save_wav(save_path, f_s, wav):
|
||||
if isinstance(wav[0], np.float32): wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
|
||||
wav_write(save_path, f_s, wav)
|
||||
|
||||
def read_wav(path):
|
||||
wav, f_s = sf.read(path, dtype='int16')
|
||||
return wav, f_s
|
||||
|
||||
def log10(x):
|
||||
numerator = tf.log(x)
|
||||
denominator = tf.constant(np.log(10), dtype=numerator.dtype)
|
||||
return tf.truediv(numerator, denominator)
|
||||
|
||||
## CHARACTER DICTIONARIES
|
||||
def char_dict():
|
||||
chars = list(" " + string.ascii_lowercase + "'") # 26 alphabetic characters + space + EOS + blank = 29 classes.
|
||||
char2idx = dict(zip(chars, [i for i in range(len(chars))]))
|
||||
idx2char = dict((y,x) for x,y in char2idx.items())
|
||||
return char2idx, idx2char
|
||||
|
||||
## NUMPY SIGMOID FUNCTION
|
||||
def np_sigmoid(x): return np.divide(1, np.add(1, np.exp(np.negative(x))))
|
||||
|
||||
## GPU CONFIGURATION
|
||||
def gpu_config(gpu_selection, log_device_placement=False):
|
||||
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_selection)
|
||||
config = tf.ConfigProto()
|
||||
config.allow_soft_placement=True
|
||||
config.gpu_options.allow_growth=True
|
||||
config.log_device_placement=log_device_placement
|
||||
return config
|
||||
BIN
lib/dev/utils.pyc
Normal file
BIN
lib/dev/utils.pyc
Normal file
Binary file not shown.
33
lib/dev/xi.py
Normal file
33
lib/dev/xi.py
Normal file
@@ -0,0 +1,33 @@
|
||||
## FILE: xi.py
|
||||
## DATE: 2019
|
||||
## AUTHOR: Aaron Nicolson
|
||||
## AFFILIATION: Signal Processing Laboratory, Griffith University
|
||||
## BRIEF: Functions for computing a priori SNR.
|
||||
##
|
||||
## This Source Code Form is subject to the terms of the Mozilla Public
|
||||
## License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
## file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
import numpy as np
|
||||
import scipy.special as spsp
|
||||
import tensorflow as tf
|
||||
|
||||
def log10(x):
|
||||
numerator = tf.log(x)
|
||||
denominator = tf.constant(np.log(10), dtype=numerator.dtype)
|
||||
return tf.truediv(numerator, denominator)
|
||||
|
||||
def xi(s_MAG, d_MAG):
|
||||
return tf.truediv(tf.square(s_MAG), tf.maximum(tf.square(d_MAG), 1e-12)) # a priori SNR.
|
||||
|
||||
def xi_dB(s_MAG, d_MAG):
|
||||
return tf.multiply(10.0, log10(tf.maximum(xi(s_MAG, d_MAG), 1e-12)))
|
||||
|
||||
def xi_bar(s_MAG, d_MAG, mu, sigma):
|
||||
return tf.multiply(0.5, tf.add(1.0, tf.erf(tf.truediv(tf.subtract(xi_dB(s_MAG, d_MAG), mu),
|
||||
tf.multiply(sigma, tf.sqrt(2.0))))))
|
||||
|
||||
def xi_hat(xi_bar_hat, mu, sigma):
|
||||
xi_dB_hat = np.add(np.multiply(np.multiply(sigma, np.sqrt(2.0)),
|
||||
spsp.erfinv(np.subtract(np.multiply(2.0, xi_bar_hat), 1))), mu)
|
||||
return np.power(10.0, np.divide(xi_dB_hat, 10.0))
|
||||
202
model/3f/checkpoint
Normal file
202
model/3f/checkpoint
Normal file
@@ -0,0 +1,202 @@
|
||||
model_checkpoint_path: "epoch-200"
|
||||
all_model_checkpoint_paths: "epoch-0"
|
||||
all_model_checkpoint_paths: "epoch-1"
|
||||
all_model_checkpoint_paths: "epoch-2"
|
||||
all_model_checkpoint_paths: "epoch-3"
|
||||
all_model_checkpoint_paths: "epoch-4"
|
||||
all_model_checkpoint_paths: "epoch-5"
|
||||
all_model_checkpoint_paths: "epoch-6"
|
||||
all_model_checkpoint_paths: "epoch-7"
|
||||
all_model_checkpoint_paths: "epoch-8"
|
||||
all_model_checkpoint_paths: "epoch-9"
|
||||
all_model_checkpoint_paths: "epoch-10"
|
||||
all_model_checkpoint_paths: "epoch-11"
|
||||
all_model_checkpoint_paths: "epoch-12"
|
||||
all_model_checkpoint_paths: "epoch-13"
|
||||
all_model_checkpoint_paths: "epoch-14"
|
||||
all_model_checkpoint_paths: "epoch-15"
|
||||
all_model_checkpoint_paths: "epoch-16"
|
||||
all_model_checkpoint_paths: "epoch-17"
|
||||
all_model_checkpoint_paths: "epoch-18"
|
||||
all_model_checkpoint_paths: "epoch-19"
|
||||
all_model_checkpoint_paths: "epoch-20"
|
||||
all_model_checkpoint_paths: "epoch-21"
|
||||
all_model_checkpoint_paths: "epoch-22"
|
||||
all_model_checkpoint_paths: "epoch-23"
|
||||
all_model_checkpoint_paths: "epoch-24"
|
||||
all_model_checkpoint_paths: "epoch-25"
|
||||
all_model_checkpoint_paths: "epoch-26"
|
||||
all_model_checkpoint_paths: "epoch-27"
|
||||
all_model_checkpoint_paths: "epoch-28"
|
||||
all_model_checkpoint_paths: "epoch-29"
|
||||
all_model_checkpoint_paths: "epoch-30"
|
||||
all_model_checkpoint_paths: "epoch-31"
|
||||
all_model_checkpoint_paths: "epoch-32"
|
||||
all_model_checkpoint_paths: "epoch-33"
|
||||
all_model_checkpoint_paths: "epoch-34"
|
||||
all_model_checkpoint_paths: "epoch-35"
|
||||
all_model_checkpoint_paths: "epoch-36"
|
||||
all_model_checkpoint_paths: "epoch-37"
|
||||
all_model_checkpoint_paths: "epoch-38"
|
||||
all_model_checkpoint_paths: "epoch-39"
|
||||
all_model_checkpoint_paths: "epoch-40"
|
||||
all_model_checkpoint_paths: "epoch-41"
|
||||
all_model_checkpoint_paths: "epoch-42"
|
||||
all_model_checkpoint_paths: "epoch-43"
|
||||
all_model_checkpoint_paths: "epoch-44"
|
||||
all_model_checkpoint_paths: "epoch-45"
|
||||
all_model_checkpoint_paths: "epoch-46"
|
||||
all_model_checkpoint_paths: "epoch-47"
|
||||
all_model_checkpoint_paths: "epoch-48"
|
||||
all_model_checkpoint_paths: "epoch-49"
|
||||
all_model_checkpoint_paths: "epoch-50"
|
||||
all_model_checkpoint_paths: "epoch-51"
|
||||
all_model_checkpoint_paths: "epoch-52"
|
||||
all_model_checkpoint_paths: "epoch-53"
|
||||
all_model_checkpoint_paths: "epoch-54"
|
||||
all_model_checkpoint_paths: "epoch-55"
|
||||
all_model_checkpoint_paths: "epoch-56"
|
||||
all_model_checkpoint_paths: "epoch-57"
|
||||
all_model_checkpoint_paths: "epoch-58"
|
||||
all_model_checkpoint_paths: "epoch-59"
|
||||
all_model_checkpoint_paths: "epoch-60"
|
||||
all_model_checkpoint_paths: "epoch-61"
|
||||
all_model_checkpoint_paths: "epoch-62"
|
||||
all_model_checkpoint_paths: "epoch-63"
|
||||
all_model_checkpoint_paths: "epoch-64"
|
||||
all_model_checkpoint_paths: "epoch-65"
|
||||
all_model_checkpoint_paths: "epoch-66"
|
||||
all_model_checkpoint_paths: "epoch-67"
|
||||
all_model_checkpoint_paths: "epoch-68"
|
||||
all_model_checkpoint_paths: "epoch-69"
|
||||
all_model_checkpoint_paths: "epoch-70"
|
||||
all_model_checkpoint_paths: "epoch-71"
|
||||
all_model_checkpoint_paths: "epoch-72"
|
||||
all_model_checkpoint_paths: "epoch-73"
|
||||
all_model_checkpoint_paths: "epoch-74"
|
||||
all_model_checkpoint_paths: "epoch-75"
|
||||
all_model_checkpoint_paths: "epoch-76"
|
||||
all_model_checkpoint_paths: "epoch-77"
|
||||
all_model_checkpoint_paths: "epoch-78"
|
||||
all_model_checkpoint_paths: "epoch-79"
|
||||
all_model_checkpoint_paths: "epoch-80"
|
||||
all_model_checkpoint_paths: "epoch-81"
|
||||
all_model_checkpoint_paths: "epoch-82"
|
||||
all_model_checkpoint_paths: "epoch-83"
|
||||
all_model_checkpoint_paths: "epoch-84"
|
||||
all_model_checkpoint_paths: "epoch-85"
|
||||
all_model_checkpoint_paths: "epoch-86"
|
||||
all_model_checkpoint_paths: "epoch-87"
|
||||
all_model_checkpoint_paths: "epoch-88"
|
||||
all_model_checkpoint_paths: "epoch-89"
|
||||
all_model_checkpoint_paths: "epoch-90"
|
||||
all_model_checkpoint_paths: "epoch-91"
|
||||
all_model_checkpoint_paths: "epoch-92"
|
||||
all_model_checkpoint_paths: "epoch-93"
|
||||
all_model_checkpoint_paths: "epoch-94"
|
||||
all_model_checkpoint_paths: "epoch-95"
|
||||
all_model_checkpoint_paths: "epoch-96"
|
||||
all_model_checkpoint_paths: "epoch-97"
|
||||
all_model_checkpoint_paths: "epoch-98"
|
||||
all_model_checkpoint_paths: "epoch-99"
|
||||
all_model_checkpoint_paths: "epoch-100"
|
||||
all_model_checkpoint_paths: "epoch-101"
|
||||
all_model_checkpoint_paths: "epoch-102"
|
||||
all_model_checkpoint_paths: "epoch-103"
|
||||
all_model_checkpoint_paths: "epoch-104"
|
||||
all_model_checkpoint_paths: "epoch-105"
|
||||
all_model_checkpoint_paths: "epoch-106"
|
||||
all_model_checkpoint_paths: "epoch-107"
|
||||
all_model_checkpoint_paths: "epoch-108"
|
||||
all_model_checkpoint_paths: "epoch-109"
|
||||
all_model_checkpoint_paths: "epoch-110"
|
||||
all_model_checkpoint_paths: "epoch-111"
|
||||
all_model_checkpoint_paths: "epoch-112"
|
||||
all_model_checkpoint_paths: "epoch-113"
|
||||
all_model_checkpoint_paths: "epoch-114"
|
||||
all_model_checkpoint_paths: "epoch-115"
|
||||
all_model_checkpoint_paths: "epoch-116"
|
||||
all_model_checkpoint_paths: "epoch-117"
|
||||
all_model_checkpoint_paths: "epoch-118"
|
||||
all_model_checkpoint_paths: "epoch-119"
|
||||
all_model_checkpoint_paths: "epoch-120"
|
||||
all_model_checkpoint_paths: "epoch-121"
|
||||
all_model_checkpoint_paths: "epoch-122"
|
||||
all_model_checkpoint_paths: "epoch-123"
|
||||
all_model_checkpoint_paths: "epoch-124"
|
||||
all_model_checkpoint_paths: "epoch-125"
|
||||
all_model_checkpoint_paths: "epoch-126"
|
||||
all_model_checkpoint_paths: "epoch-127"
|
||||
all_model_checkpoint_paths: "epoch-128"
|
||||
all_model_checkpoint_paths: "epoch-129"
|
||||
all_model_checkpoint_paths: "epoch-130"
|
||||
all_model_checkpoint_paths: "epoch-131"
|
||||
all_model_checkpoint_paths: "epoch-132"
|
||||
all_model_checkpoint_paths: "epoch-133"
|
||||
all_model_checkpoint_paths: "epoch-134"
|
||||
all_model_checkpoint_paths: "epoch-135"
|
||||
all_model_checkpoint_paths: "epoch-136"
|
||||
all_model_checkpoint_paths: "epoch-137"
|
||||
all_model_checkpoint_paths: "epoch-138"
|
||||
all_model_checkpoint_paths: "epoch-139"
|
||||
all_model_checkpoint_paths: "epoch-140"
|
||||
all_model_checkpoint_paths: "epoch-141"
|
||||
all_model_checkpoint_paths: "epoch-142"
|
||||
all_model_checkpoint_paths: "epoch-143"
|
||||
all_model_checkpoint_paths: "epoch-144"
|
||||
all_model_checkpoint_paths: "epoch-145"
|
||||
all_model_checkpoint_paths: "epoch-146"
|
||||
all_model_checkpoint_paths: "epoch-147"
|
||||
all_model_checkpoint_paths: "epoch-148"
|
||||
all_model_checkpoint_paths: "epoch-149"
|
||||
all_model_checkpoint_paths: "epoch-150"
|
||||
all_model_checkpoint_paths: "epoch-151"
|
||||
all_model_checkpoint_paths: "epoch-152"
|
||||
all_model_checkpoint_paths: "epoch-153"
|
||||
all_model_checkpoint_paths: "epoch-154"
|
||||
all_model_checkpoint_paths: "epoch-155"
|
||||
all_model_checkpoint_paths: "epoch-156"
|
||||
all_model_checkpoint_paths: "epoch-157"
|
||||
all_model_checkpoint_paths: "epoch-158"
|
||||
all_model_checkpoint_paths: "epoch-159"
|
||||
all_model_checkpoint_paths: "epoch-160"
|
||||
all_model_checkpoint_paths: "epoch-161"
|
||||
all_model_checkpoint_paths: "epoch-162"
|
||||
all_model_checkpoint_paths: "epoch-163"
|
||||
all_model_checkpoint_paths: "epoch-164"
|
||||
all_model_checkpoint_paths: "epoch-165"
|
||||
all_model_checkpoint_paths: "epoch-166"
|
||||
all_model_checkpoint_paths: "epoch-167"
|
||||
all_model_checkpoint_paths: "epoch-168"
|
||||
all_model_checkpoint_paths: "epoch-169"
|
||||
all_model_checkpoint_paths: "epoch-170"
|
||||
all_model_checkpoint_paths: "epoch-171"
|
||||
all_model_checkpoint_paths: "epoch-172"
|
||||
all_model_checkpoint_paths: "epoch-173"
|
||||
all_model_checkpoint_paths: "epoch-174"
|
||||
all_model_checkpoint_paths: "epoch-175"
|
||||
all_model_checkpoint_paths: "epoch-176"
|
||||
all_model_checkpoint_paths: "epoch-177"
|
||||
all_model_checkpoint_paths: "epoch-178"
|
||||
all_model_checkpoint_paths: "epoch-179"
|
||||
all_model_checkpoint_paths: "epoch-180"
|
||||
all_model_checkpoint_paths: "epoch-181"
|
||||
all_model_checkpoint_paths: "epoch-182"
|
||||
all_model_checkpoint_paths: "epoch-183"
|
||||
all_model_checkpoint_paths: "epoch-184"
|
||||
all_model_checkpoint_paths: "epoch-185"
|
||||
all_model_checkpoint_paths: "epoch-186"
|
||||
all_model_checkpoint_paths: "epoch-187"
|
||||
all_model_checkpoint_paths: "epoch-188"
|
||||
all_model_checkpoint_paths: "epoch-189"
|
||||
all_model_checkpoint_paths: "epoch-190"
|
||||
all_model_checkpoint_paths: "epoch-191"
|
||||
all_model_checkpoint_paths: "epoch-192"
|
||||
all_model_checkpoint_paths: "epoch-193"
|
||||
all_model_checkpoint_paths: "epoch-194"
|
||||
all_model_checkpoint_paths: "epoch-195"
|
||||
all_model_checkpoint_paths: "epoch-196"
|
||||
all_model_checkpoint_paths: "epoch-197"
|
||||
all_model_checkpoint_paths: "epoch-198"
|
||||
all_model_checkpoint_paths: "epoch-199"
|
||||
all_model_checkpoint_paths: "epoch-200"
|
||||
BIN
model/3f/epoch-175.data-00000-of-00001
Normal file
BIN
model/3f/epoch-175.data-00000-of-00001
Normal file
Binary file not shown.
BIN
model/3f/epoch-175.index
Normal file
BIN
model/3f/epoch-175.index
Normal file
Binary file not shown.
BIN
model/3f/epoch-175.meta
Normal file
BIN
model/3f/epoch-175.meta
Normal file
Binary file not shown.
BIN
out/3f/e175/mmse-lsa/y/198853.wav
Normal file
BIN
out/3f/e175/mmse-lsa/y/198853.wav
Normal file
Binary file not shown.
BIN
out/3f/e175/mmse-lsa/y/FB_FB10_07_voice-babble_5dB.wav
Normal file
BIN
out/3f/e175/mmse-lsa/y/FB_FB10_07_voice-babble_5dB.wav
Normal file
Binary file not shown.
Binary file not shown.
BIN
out/3f/e175/mmse-lsa/y/g1201-20190221-115827-1550725107.wav
Normal file
BIN
out/3f/e175/mmse-lsa/y/g1201-20190221-115827-1550725107.wav
Normal file
Binary file not shown.
BIN
out/3f/e175/xi_hat/FB_FB10_07_voice-babble_5dB.mat
Normal file
BIN
out/3f/e175/xi_hat/FB_FB10_07_voice-babble_5dB.mat
Normal file
Binary file not shown.
25
server.py
Normal file
25
server.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from flask import Flask
|
||||
from flask import Flask, request, render_template, url_for, session, jsonify, make_response
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route('/')
|
||||
def upload():
|
||||
if request.method == 'POST' and 'photo' in request.files:
|
||||
filename = photos.save(request.files['photo'])
|
||||
rec = Photo(filename=filename, user=g.user.id)
|
||||
rec.store()
|
||||
flash("Photo saved.")
|
||||
return redirect(url_for('show', id=rec.id))
|
||||
return render_template('upload.html')
|
||||
|
||||
@app.route('/photo/<id>')
|
||||
def show(id):
|
||||
photo = Photo.load(id)
|
||||
if photo is None:
|
||||
abort(404)
|
||||
url = photos.url(photo.filename)
|
||||
return render_template('show.html', url=url, photo=photo)
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(host='192.168.1.254',port=9100,debug=True)
|
||||
6
set/info.txt
Normal file
6
set/info.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
Directories used to store clean speech and noise for validation and training, as well as noisy speech for testing.
|
||||
|
||||
For the validation set only:
|
||||
|
||||
Identical filenames for the clean speech and noise must be placed in 'val_clean_speech' and 'val_noise', with the SNR at the end of the filename. As an example:
|
||||
'./val_clean_speech/198_19-198-0003_Machinery17_15dB.wav' contains the clean speech, and './val_noise/198_19-198-0003_Machinery17_15dB.wav' contains the noise at the same length. They will be mixed together at the SNR level specified in the filename.
|
||||
BIN
set/test_noisy_speech/198853._pred.wav
Normal file
BIN
set/test_noisy_speech/198853._pred.wav
Normal file
Binary file not shown.
BIN
set/test_noisy_speech/198853.wav
Normal file
BIN
set/test_noisy_speech/198853.wav
Normal file
Binary file not shown.
BIN
set/test_noisy_speech/198853_pred.wav
Normal file
BIN
set/test_noisy_speech/198853_pred.wav
Normal file
Binary file not shown.
BIN
set/test_noisy_speech/FB_FB10_07_voice-babble_5dB.wav
Normal file
BIN
set/test_noisy_speech/FB_FB10_07_voice-babble_5dB.wav
Normal file
Binary file not shown.
Binary file not shown.
1
set/test_noisy_speech/info.txt
Normal file
1
set/test_noisy_speech/info.txt
Normal file
@@ -0,0 +1 @@
|
||||
Place all noisy speech .wav files for testing here.
|
||||
1
set/train_clean_speech/info.txt
Normal file
1
set/train_clean_speech/info.txt
Normal file
@@ -0,0 +1 @@
|
||||
Place all clean speech .wav files for training here.
|
||||
1
set/train_noise/info.txt
Normal file
1
set/train_noise/info.txt
Normal file
@@ -0,0 +1 @@
|
||||
Place all noise .wav files for training here.
|
||||
1
set/val_clean_speech/info.txt
Normal file
1
set/val_clean_speech/info.txt
Normal file
@@ -0,0 +1 @@
|
||||
Place all clean speech .wav files for validation here.
|
||||
1
set/val_noise/info.txt
Normal file
1
set/val_noise/info.txt
Normal file
@@ -0,0 +1 @@
|
||||
Place all noise .wav files for validation here.
|
||||
214
templates/index.html
Normal file
214
templates/index.html
Normal file
@@ -0,0 +1,214 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.2.1/css/bootstrap.min.css" integrity="sha384-GJzZqFGwb1QTTN6wy59ffF1BuGJpLSa9DkKMp0DgiMDm4iYMj70gZWKYbI706tWS" crossorigin="anonymous">
|
||||
<title>deep-xi</title>
|
||||
</head>
|
||||
<style>
|
||||
body{
|
||||
text-align: center;
|
||||
|
||||
}
|
||||
.custom_btn{
|
||||
background-color: lightgreen;
|
||||
border: none;
|
||||
padding: 8px;
|
||||
border-radius: 3px;
|
||||
margin-top: 10px;
|
||||
|
||||
}
|
||||
.container{
|
||||
padding-left: 20%;
|
||||
padding-right: 20%;
|
||||
}
|
||||
</style>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="row">
|
||||
<div class="col">
|
||||
<div class="mb-3 mt-3">
|
||||
<h2 class="mb-3" style="font-weight: 300">DeepXi</h2>
|
||||
|
||||
<div class="form-group mb-3">
|
||||
<div class="custom-file">
|
||||
<input type="file" class="custom-file-input" name="file_input" id="file_input" oninput="input_filename();">
|
||||
<label id="file_input_label" class="custom-file-label" >Select file (wav, mp3, flac)</label>
|
||||
</div>
|
||||
</div>
|
||||
<div id="progress_wrapper" class="d-none">
|
||||
<label id="progress_status"></label>
|
||||
<div class="progress mb-3">
|
||||
<div id="progress" class="progress-bar" role="progressbar" aria-valuenow="25" aria-valuemin="0" aria-valuemax="100"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div id="alert_wrapper" class="d-none"></div>
|
||||
<button onclick="upload('/upload');" id="upload_btn" class="btn btn-primary">Upload</button>
|
||||
<button class="btn btn-primary d-none" id="loading_btn" type="button" disabled>
|
||||
<span class="spinner-border spinner-border-sm" role="status" aria-hidden="true"></span> Uploading...
|
||||
</button>
|
||||
<button type="button" id="cancel_btn" class="btn btn-secondary d-none">Cancel</button>
|
||||
</div>
|
||||
|
||||
<div id="original_play" style="margin: 10px;" class="d-none">
|
||||
<h4>original</h4>
|
||||
<audio id="original_audio" controls src="" type="audio/*">
|
||||
</div>
|
||||
<div id="predict_play" style="margin: 10px;" class="d-none">
|
||||
<h4>predict</h4>
|
||||
<audio id="predict_audio" controls src="" type="audio/*">
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
var input = document.getElementById("file_input")
|
||||
var file_input_label = document.getElementById("file_input_label")
|
||||
var progress = document.getElementById("progress");
|
||||
var progress_wrapper = document.getElementById("progress_wrapper");
|
||||
var progress_status = document.getElementById("progress_status");
|
||||
var upload_btn = document.getElementById("upload_btn");
|
||||
var loading_btn = document.getElementById("loading_btn");
|
||||
var loading_btn_text = document.getElementById("loading_btn_text");
|
||||
var cancel_btn = document.getElementById("cancel_btn");
|
||||
var alert_wrapper = document.getElementById("alert_wrapper");
|
||||
var original_play = document.getElementById("original_play")
|
||||
var predict_play = document.getElementById("predict_play")
|
||||
var original_audio = document.getElementById("original_audio")
|
||||
var predict_audio = document.getElementById("predict_audio")
|
||||
function input_filename() {
|
||||
file_input_label.innerText = input.files[0].name;}
|
||||
|
||||
function show_alert(message, alert, autohide=false) {
|
||||
alert_wrapper.classList.remove("d-none")
|
||||
alert_wrapper.innerHTML = `
|
||||
<div id="alert" class="alert alert-${alert} alert-dismissible fade show" role="alert">
|
||||
<span>${message}</span>
|
||||
<button type="button" class="close" data-dismiss="alert" aria-label="Close">
|
||||
<span aria-hidden="true">×</span>
|
||||
</button>
|
||||
</div>
|
||||
`
|
||||
// if (autohide){
|
||||
// setTimeout(() => {alert_wrapper.classList.add("d-none")}, 5000)
|
||||
// }
|
||||
}
|
||||
|
||||
function upload(url) {
|
||||
console.log(url)
|
||||
if (!input.value) {
|
||||
show_alert("No file selected", "warning")
|
||||
return;
|
||||
}
|
||||
var data = new FormData();
|
||||
var request = new XMLHttpRequest();
|
||||
request.responseType = "json";
|
||||
alert_wrapper.innerHTML = "";
|
||||
input.disabled = true;
|
||||
upload_btn.classList.add("d-none");
|
||||
loading_btn.classList.remove("d-none");
|
||||
cancel_btn.classList.remove("d-none");
|
||||
progress_wrapper.classList.remove("d-none");
|
||||
original_play.classList.add("d-none")
|
||||
predict_play.classList.add("d-none")
|
||||
var file = input.files[0];
|
||||
var filename = file.name;
|
||||
var filesize = file.size;
|
||||
document.cookie = `filesize=${filesize}`;
|
||||
data.append("file", file);
|
||||
request.upload.addEventListener("progress", function (e) {
|
||||
var loaded = e.loaded;
|
||||
var total = e.total
|
||||
var percent_complete = (loaded / total) * 100;
|
||||
progress.setAttribute("style", `width: ${Math.floor(percent_complete)}%`);
|
||||
progress_status.innerHTML = `${Math.floor(percent_complete)}% uploaded`;
|
||||
if (percent_complete == 100){
|
||||
show_alert("Saving file to server...", "primary");
|
||||
loading_btn.innerHTML = "<span class='spinner-border spinner-border-sm' role='status' aria-hidden='true'></span>" + " Predict..."
|
||||
}
|
||||
})
|
||||
// request load handler (transfer complete)
|
||||
request.addEventListener("load", function (e) {
|
||||
if (request.status == 200) {
|
||||
show_alert(`${request.response.message}`, "success", true);
|
||||
original_play.classList.remove("d-none")
|
||||
original_audio.src = request.response.file_path
|
||||
original_audio.autoplay = true
|
||||
progress_wrapper.classList.add("d-none")
|
||||
//loading_btn.innerHTML = "<span class='spinner-border spinner-border-sm' role='status' aria-hidden='true'></span>" + " Predict..."
|
||||
predict_rq = predict('/predict/' + String(request.response.file_path).split('/').join('='))
|
||||
if (predict_rq.status == 200){
|
||||
json = JSON.parse(predict_rq.responseText)
|
||||
show_alert(json.message, "success")
|
||||
console.log(predict_rq.responseText)
|
||||
console.log(json.out_file_path)
|
||||
predict_play.classList.remove("d-none")
|
||||
predict_audio.src = json.out_file_path
|
||||
}
|
||||
else{
|
||||
show_alert(`Error predict file:` + filename, "danger");
|
||||
}
|
||||
loading_btn.classList.add("d-none")
|
||||
loading_btn.innerHTML = "<span class='spinner-border spinner-border-sm' role='status' aria-hidden='true'></span>" + " Uploading..."
|
||||
|
||||
}
|
||||
else {
|
||||
}
|
||||
reset();
|
||||
});
|
||||
// request error handler
|
||||
request.addEventListener("error", function (e) {
|
||||
reset();
|
||||
show_alert(`Error uploading file`, "danger");
|
||||
});
|
||||
// request abort handler
|
||||
request.addEventListener("abort", function (e) {
|
||||
reset();
|
||||
show_alert(`Upload cancelled`, "danger");
|
||||
});
|
||||
// Open and send the request
|
||||
request.open("post", url);
|
||||
request.send(data);
|
||||
cancel_btn.addEventListener("click", function () {
|
||||
request.abort();
|
||||
})
|
||||
}
|
||||
|
||||
function predict(url){
|
||||
var xmlHttp = new XMLHttpRequest();
|
||||
//xmlHttp.responseType = "json";
|
||||
xmlHttp.open( "GET", url, false ); // false for synchronous request
|
||||
xmlHttp.send( null );
|
||||
//console.log(xmlHttp.response.out_file_path)
|
||||
return xmlHttp;
|
||||
}
|
||||
|
||||
function reset() {
|
||||
// Clear the input
|
||||
input.value = null;
|
||||
// Hide the cancel button
|
||||
cancel_btn.classList.add("d-none");
|
||||
// Reset the input element
|
||||
input.disabled = false;
|
||||
// Show the upload button
|
||||
upload_btn.classList.remove("d-none");
|
||||
// Hide the loading button
|
||||
loading_btn.classList.add("d-none");
|
||||
//loading_btn_text.textContent = "Uploading..."
|
||||
// Hide the progress bar
|
||||
progress_wrapper.classList.add("d-none");
|
||||
// Reset the progress bar state
|
||||
progress.setAttribute("style", `width: 0%`);
|
||||
// Reset the input placeholder
|
||||
file_input_label.innerText = "Select file";
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
5
templates1/index.html
Normal file
5
templates1/index.html
Normal file
@@ -0,0 +1,5 @@
|
||||
<form method=POST enctype=multipart/form-data action="{{ url_for('upload') }}">
|
||||
...
|
||||
<input type=file name=photo>
|
||||
...
|
||||
</form>
|
||||
5
templates1/upload.html
Normal file
5
templates1/upload.html
Normal file
@@ -0,0 +1,5 @@
|
||||
<form method=POST enctype=multipart/form-data action="{{ url_for('upload') }}">
|
||||
...
|
||||
<input type=file name=photo>
|
||||
...
|
||||
</form>
|
||||
9
testload.py
Normal file
9
testload.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import librosa
|
||||
print(librosa.__version__)
|
||||
import scipy
|
||||
filename = 'music.mp3'
|
||||
#sr, data = scipy.io.wavfile.read(filename)
|
||||
data2, sr2 = librosa.load(filename,None)
|
||||
|
||||
#print(sr,sr2)
|
||||
print(sr2,data2)
|
||||
65
utils.py
Normal file
65
utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import wave
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
from time import sleep
|
||||
import csv
|
||||
import numpy as np
|
||||
#from pydub import AudioSegment
|
||||
import librosa
|
||||
import wave
|
||||
import audioop
|
||||
import scipy
|
||||
from datetime import datetime
|
||||
import tensorflow as tf
|
||||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
||||
from inaSpeechSegmenter import Segmenter, seg2csv
|
||||
|
||||
def split(file_name, out_dir):
|
||||
print('\nREMOVE MUSIC AND CUT')
|
||||
seg = Segmenter()
|
||||
segmentation = seg(file_name)
|
||||
sample_rate, raw_audio = scipy.io.wavfile.read(file_name)
|
||||
#raw_audio , sr = librosa.load(file_name, sr=16000)
|
||||
speech = []
|
||||
print(segmentation)
|
||||
count = 1
|
||||
if not os.path.exists(out_dir):
|
||||
os.mkdir(out_dir)
|
||||
list_file = []
|
||||
for s in segmentation:
|
||||
if s[0] != 'Music' and s[0] != 'NOACTIVITY':
|
||||
print(str(count),'dur of sen:',s[2]-s[1])
|
||||
speech_data = raw_audio[int(s[1]*sample_rate) - int(sample_rate/4):int(s[2]*sample_rate + int(sample_rate/4))]
|
||||
speech_data = np.array(speech_data)
|
||||
|
||||
print(len(speech_data), len(speech_data)/sample_rate)
|
||||
if len(speech_data)/sample_rate < 0.5 or len(speech_data)/sample_rate > 20:
|
||||
continue
|
||||
else:
|
||||
out_filename = out_dir + '/' + file_name.split('/')[-1].replace('.wav','') + '_' + str(count) + '.wav'
|
||||
list_file.append(out_filename)
|
||||
scipy.io.wavfile.write(out_filename, sample_rate, speech_data)
|
||||
count += 1
|
||||
return list_file
|
||||
|
||||
def cvtToWavMono16(filename):
|
||||
#covert to wav
|
||||
#try:
|
||||
#if filename.split('.')[-1] == 'mp3':
|
||||
# filename = mp3_to_wav(filename)
|
||||
#stereo_to_mono()
|
||||
#to16000()
|
||||
print('start convert to wav 16000 mono')
|
||||
#filename = 'myfile.wav'
|
||||
# Extract data and sampling rate from file
|
||||
#data, fs = sf.read(filename, dtype='float32')
|
||||
sig, rate = librosa.load(filename, sr=16000, mono=True)
|
||||
#print(sig, rate)
|
||||
new_filename = ''.join(filename.split('.')[:-1]) + '.wav'
|
||||
librosa.output.write_wav(new_filename, sig, sr=rate)
|
||||
print('converted to wav 16000 mono')
|
||||
return new_filename
|
||||
# except:
|
||||
# return False
|
||||
|
||||
Reference in New Issue
Block a user